build.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. # Copyright (c) Meta Platforms, Inc. and affiliates.
  3. # All rights reserved.
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from functools import partial
  7. import torch
  8. from ultralytics.utils.downloads import attempt_download_asset
  9. from .modules.decoders import MaskDecoder
  10. from .modules.encoders import ImageEncoderViT, PromptEncoder
  11. from .modules.sam import Sam
  12. from .modules.tiny_encoder import TinyViT
  13. from .modules.transformer import TwoWayTransformer
  14. def build_sam_vit_h(checkpoint=None):
  15. """Build and return a Segment Anything Model (SAM) h-size model."""
  16. return _build_sam(
  17. encoder_embed_dim=1280,
  18. encoder_depth=32,
  19. encoder_num_heads=16,
  20. encoder_global_attn_indexes=[7, 15, 23, 31],
  21. checkpoint=checkpoint,
  22. )
  23. def build_sam_vit_l(checkpoint=None):
  24. """Build and return a Segment Anything Model (SAM) l-size model."""
  25. return _build_sam(
  26. encoder_embed_dim=1024,
  27. encoder_depth=24,
  28. encoder_num_heads=16,
  29. encoder_global_attn_indexes=[5, 11, 17, 23],
  30. checkpoint=checkpoint,
  31. )
  32. def build_sam_vit_b(checkpoint=None):
  33. """Build and return a Segment Anything Model (SAM) b-size model."""
  34. return _build_sam(
  35. encoder_embed_dim=768,
  36. encoder_depth=12,
  37. encoder_num_heads=12,
  38. encoder_global_attn_indexes=[2, 5, 8, 11],
  39. checkpoint=checkpoint,
  40. )
  41. def build_mobile_sam(checkpoint=None):
  42. """Build and return Mobile Segment Anything Model (Mobile-SAM)."""
  43. return _build_sam(
  44. encoder_embed_dim=[64, 128, 160, 320],
  45. encoder_depth=[2, 2, 6, 2],
  46. encoder_num_heads=[2, 4, 5, 10],
  47. encoder_global_attn_indexes=None,
  48. mobile_sam=True,
  49. checkpoint=checkpoint,
  50. )
  51. def _build_sam(
  52. encoder_embed_dim, encoder_depth, encoder_num_heads, encoder_global_attn_indexes, checkpoint=None, mobile_sam=False
  53. ):
  54. """Builds the selected SAM model architecture."""
  55. prompt_embed_dim = 256
  56. image_size = 1024
  57. vit_patch_size = 16
  58. image_embedding_size = image_size // vit_patch_size
  59. image_encoder = (
  60. TinyViT(
  61. img_size=1024,
  62. in_chans=3,
  63. num_classes=1000,
  64. embed_dims=encoder_embed_dim,
  65. depths=encoder_depth,
  66. num_heads=encoder_num_heads,
  67. window_sizes=[7, 7, 14, 7],
  68. mlp_ratio=4.0,
  69. drop_rate=0.0,
  70. drop_path_rate=0.0,
  71. use_checkpoint=False,
  72. mbconv_expand_ratio=4.0,
  73. local_conv_size=3,
  74. layer_lr_decay=0.8,
  75. )
  76. if mobile_sam
  77. else ImageEncoderViT(
  78. depth=encoder_depth,
  79. embed_dim=encoder_embed_dim,
  80. img_size=image_size,
  81. mlp_ratio=4,
  82. norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
  83. num_heads=encoder_num_heads,
  84. patch_size=vit_patch_size,
  85. qkv_bias=True,
  86. use_rel_pos=True,
  87. global_attn_indexes=encoder_global_attn_indexes,
  88. window_size=14,
  89. out_chans=prompt_embed_dim,
  90. )
  91. )
  92. sam = Sam(
  93. image_encoder=image_encoder,
  94. prompt_encoder=PromptEncoder(
  95. embed_dim=prompt_embed_dim,
  96. image_embedding_size=(image_embedding_size, image_embedding_size),
  97. input_image_size=(image_size, image_size),
  98. mask_in_chans=16,
  99. ),
  100. mask_decoder=MaskDecoder(
  101. num_multimask_outputs=3,
  102. transformer=TwoWayTransformer(
  103. depth=2,
  104. embedding_dim=prompt_embed_dim,
  105. mlp_dim=2048,
  106. num_heads=8,
  107. ),
  108. transformer_dim=prompt_embed_dim,
  109. iou_head_depth=3,
  110. iou_head_hidden_dim=256,
  111. ),
  112. pixel_mean=[123.675, 116.28, 103.53],
  113. pixel_std=[58.395, 57.12, 57.375],
  114. )
  115. if checkpoint is not None:
  116. checkpoint = attempt_download_asset(checkpoint)
  117. with open(checkpoint, "rb") as f:
  118. state_dict = torch.load(f)
  119. sam.load_state_dict(state_dict)
  120. sam.eval()
  121. # sam.load_state_dict(torch.load(checkpoint), strict=True)
  122. # sam.eval()
  123. return sam
  124. sam_model_map = {
  125. "sam_h.pt": build_sam_vit_h,
  126. "sam_l.pt": build_sam_vit_l,
  127. "sam_b.pt": build_sam_vit_b,
  128. "mobile_sam.pt": build_mobile_sam,
  129. }
  130. def build_sam(ckpt="sam_b.pt"):
  131. """Build a SAM model specified by ckpt."""
  132. model_builder = None
  133. ckpt = str(ckpt) # to allow Path ckpt types
  134. for k in sam_model_map.keys():
  135. if ckpt.endswith(k):
  136. model_builder = sam_model_map.get(k)
  137. if not model_builder:
  138. raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}")
  139. return model_builder(ckpt)