Model coverage cleanup (#4585)
This commit is contained in:
parent
c635418a27
commit
deac7575b1
12 changed files with 132 additions and 175 deletions
|
|
@ -51,32 +51,16 @@ class FastSAMPrompt:
|
|||
n = len(result.masks.data)
|
||||
for i in range(n):
|
||||
mask = result.masks.data[i] == 1.0
|
||||
|
||||
if torch.sum(mask) < filter:
|
||||
continue
|
||||
annotation = {
|
||||
'id': i,
|
||||
'segmentation': mask.cpu().numpy(),
|
||||
'bbox': result.boxes.data[i],
|
||||
'score': result.boxes.conf[i]}
|
||||
annotation['area'] = annotation['segmentation'].sum()
|
||||
annotations.append(annotation)
|
||||
if torch.sum(mask) >= filter:
|
||||
annotation = {
|
||||
'id': i,
|
||||
'segmentation': mask.cpu().numpy(),
|
||||
'bbox': result.boxes.data[i],
|
||||
'score': result.boxes.conf[i]}
|
||||
annotation['area'] = annotation['segmentation'].sum()
|
||||
annotations.append(annotation)
|
||||
return annotations
|
||||
|
||||
@staticmethod
|
||||
def filter_masks(annotations): # filter the overlap mask
|
||||
annotations.sort(key=lambda x: x['area'], reverse=True)
|
||||
to_remove = set()
|
||||
for i in range(len(annotations)):
|
||||
a = annotations[i]
|
||||
for j in range(i + 1, len(annotations)):
|
||||
b = annotations[j]
|
||||
if i != j and j not in to_remove and b['area'] < a['area'] and \
|
||||
(a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8:
|
||||
to_remove.add(j)
|
||||
|
||||
return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
|
||||
|
||||
@staticmethod
|
||||
def _get_bbox_from_mask(mask):
|
||||
mask = mask.astype(np.uint8)
|
||||
|
|
@ -242,15 +226,12 @@ class FastSAMPrompt:
|
|||
cropped_images = []
|
||||
not_crop = []
|
||||
filter_id = []
|
||||
# annotations, _ = filter_masks(annotations)
|
||||
# filter_id = list(_)
|
||||
for _, mask in enumerate(annotations):
|
||||
if np.sum(mask['segmentation']) <= 100:
|
||||
filter_id.append(_)
|
||||
continue
|
||||
bbox = self._get_bbox_from_mask(mask['segmentation']) # mask 的 bbox
|
||||
cropped_boxes.append(self._segment_image(image, bbox)) # 保存裁剪的图片
|
||||
# cropped_boxes.append(segment_image(image,mask["segmentation"]))
|
||||
cropped_images.append(bbox) # 保存裁剪的图片的bbox
|
||||
|
||||
return cropped_boxes, cropped_images, not_crop, filter_id, annotations
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue