ultralytics 8.0.89 SAM predict and auto-annotate (#2298)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com> Co-authored-by: Paula Derrenger <107626595+pderrenger@users.noreply.github.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: Snyk bot <snyk-bot@snyk.io> Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
parent
3e118f6170
commit
243fc4b1fe
44 changed files with 2915 additions and 440 deletions
42
ultralytics/yolo/data/annotator.py
Normal file
42
ultralytics/yolo/data/annotator.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
from pathlib import Path
|
||||
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.vit.sam import PromptPredictor, build_sam
|
||||
from ultralytics.yolo.utils.torch_utils import select_device
|
||||
|
||||
|
||||
def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', output_dir=None):
|
||||
device = select_device(device)
|
||||
det_model = YOLO(det_model)
|
||||
sam_model = build_sam(sam_model)
|
||||
det_model.to(device)
|
||||
sam_model.to(device)
|
||||
|
||||
if not output_dir:
|
||||
output_dir = Path(str(data)).parent / 'labels'
|
||||
Path(output_dir).mkdir(exist_ok=True, parents=True)
|
||||
|
||||
prompt_predictor = PromptPredictor(sam_model)
|
||||
det_results = det_model(data, stream=True)
|
||||
|
||||
for result in det_results:
|
||||
boxes = result.boxes.xyxy # Boxes object for bbox outputs
|
||||
class_ids = result.boxes.cls.int().tolist() # noqa
|
||||
prompt_predictor.set_image(result.orig_img)
|
||||
masks, _, _ = prompt_predictor.predict_torch(
|
||||
point_coords=None,
|
||||
point_labels=None,
|
||||
boxes=prompt_predictor.transform.apply_boxes_torch(boxes, result.orig_shape[:2]),
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
result.update(masks=masks.squeeze(1))
|
||||
segments = result.masks.xyn # noqa
|
||||
|
||||
with open(str(Path(output_dir) / Path(result.path).stem) + '.txt', 'w') as f:
|
||||
for i in range(len(segments)):
|
||||
s = segments[i]
|
||||
if len(s) == 0:
|
||||
continue
|
||||
segment = map(str, segments[i].reshape(-1).tolist())
|
||||
f.write(f'{class_ids[i]} ' + ' '.join(segment) + '\n')
|
||||
Loading…
Add table
Add a link
Reference in a new issue