ultralytics 8.0.134 add MobileSAM support (#3474)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Chaoning Zhang 2023-07-13 20:25:56 +08:00 committed by GitHub
parent c55a98ab8e
commit 201e69e4e4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
32 changed files with 1472 additions and 841 deletions

View file

@ -1,8 +1,6 @@
from pathlib import Path
from ultralytics import YOLO
from ultralytics.vit.sam import PromptPredictor, build_sam
from ultralytics.yolo.utils.torch_utils import select_device
from ultralytics import SAM, YOLO
def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', output_dir=None):
@ -16,33 +14,21 @@ def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='',
output_dir (str | None | optional): Directory to save the annotated results.
Defaults to a 'labels' folder in the same directory as 'data'.
"""
device = select_device(device)
det_model = YOLO(det_model)
sam_model = build_sam(sam_model)
det_model.to(device)
sam_model.to(device)
sam_model = SAM(sam_model)
if not output_dir:
output_dir = Path(str(data)).parent / 'labels'
Path(output_dir).mkdir(exist_ok=True, parents=True)
prompt_predictor = PromptPredictor(sam_model)
det_results = det_model(data, stream=True)
det_results = det_model(data, stream=True, device=device)
for result in det_results:
boxes = result.boxes.xyxy # Boxes object for bbox outputs
class_ids = result.boxes.cls.int().tolist() # noqa
if len(class_ids):
prompt_predictor.set_image(result.orig_img)
masks, _, _ = prompt_predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=prompt_predictor.transform.apply_boxes_torch(boxes, result.orig_shape[:2]),
multimask_output=False,
)
result.update(masks=masks.squeeze(1))
segments = result.masks.xyn # noqa
sam_results = sam_model(result.orig_img, bboxes=boxes, verbose=False, save=False, device=device)
segments = sam_results[0].masks.xyn # noqa
with open(str(Path(output_dir) / Path(result.path).stem) + '.txt', 'w') as f:
for i in range(len(segments)):