Add docformatter to pre-commit (#5279)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com>
This commit is contained in:
parent
c7aa83da31
commit
7517667a33
90 changed files with 1396 additions and 497 deletions
|
|
@ -22,7 +22,7 @@ class FastSAM(Model):
|
|||
"""
|
||||
|
||||
def __init__(self, model='FastSAM-x.pt'):
|
||||
"""Call the __init__ method of the parent class (YOLO) with the updated default model"""
|
||||
"""Call the __init__ method of the parent class (YOLO) with the updated default model."""
|
||||
if str(model) == 'FastSAM.pt':
|
||||
model = 'FastSAM-x.pt'
|
||||
assert Path(model).suffix not in ('.yaml', '.yml'), 'FastSAM models only support pre-trained models.'
|
||||
|
|
@ -30,4 +30,5 @@ class FastSAM(Model):
|
|||
|
||||
@property
|
||||
def task_map(self):
|
||||
"""Returns a dictionary mapping segment task to corresponding predictor and validator classes."""
|
||||
return {'segment': {'predictor': FastSAMPredictor, 'validator': FastSAMValidator}}
|
||||
|
|
|
|||
|
|
@ -11,10 +11,12 @@ from ultralytics.utils import DEFAULT_CFG, ops
|
|||
class FastSAMPredictor(DetectionPredictor):
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""Initializes FastSAMPredictor class by inheriting from DetectionPredictor and setting task to 'segment'."""
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.task = 'segment'
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Postprocesses the predictions, applies non-max suppression, scales the boxes, and returns the results."""
|
||||
p = ops.non_max_suppression(
|
||||
preds[0],
|
||||
self.args.conf,
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from ultralytics.utils import TQDM
|
|||
class FastSAMPrompt:
|
||||
|
||||
def __init__(self, source, results, device='cuda') -> None:
|
||||
"""Initializes FastSAMPrompt with given source, results and device, and assigns clip for linear assignment."""
|
||||
self.device = device
|
||||
self.results = results
|
||||
self.source = source
|
||||
|
|
@ -30,6 +31,7 @@ class FastSAMPrompt:
|
|||
|
||||
@staticmethod
|
||||
def _segment_image(image, bbox):
|
||||
"""Segments the given image according to the provided bounding box coordinates."""
|
||||
image_array = np.array(image)
|
||||
segmented_image_array = np.zeros_like(image_array)
|
||||
x1, y1, x2, y2 = bbox
|
||||
|
|
@ -45,6 +47,9 @@ class FastSAMPrompt:
|
|||
|
||||
@staticmethod
|
||||
def _format_results(result, filter=0):
|
||||
"""Formats detection results into list of annotations each containing ID, segmentation, bounding box, score and
|
||||
area.
|
||||
"""
|
||||
annotations = []
|
||||
n = len(result.masks.data) if result.masks is not None else 0
|
||||
for i in range(n):
|
||||
|
|
@ -61,6 +66,9 @@ class FastSAMPrompt:
|
|||
|
||||
@staticmethod
|
||||
def _get_bbox_from_mask(mask):
|
||||
"""Applies morphological transformations to the mask, displays it, and if with_contours is True, draws
|
||||
contours.
|
||||
"""
|
||||
mask = mask.astype(np.uint8)
|
||||
contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
x1, y1, w, h = cv2.boundingRect(contours[0])
|
||||
|
|
@ -195,6 +203,7 @@ class FastSAMPrompt:
|
|||
|
||||
@torch.no_grad()
|
||||
def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
|
||||
"""Processes images and text with a model, calculates similarity, and returns softmax score."""
|
||||
preprocessed_images = [preprocess(image).to(device) for image in elements]
|
||||
tokenized_text = self.clip.tokenize([search_text]).to(device)
|
||||
stacked_images = torch.stack(preprocessed_images)
|
||||
|
|
@ -206,6 +215,7 @@ class FastSAMPrompt:
|
|||
return probs[:, 0].softmax(dim=0)
|
||||
|
||||
def _crop_image(self, format_results):
|
||||
"""Crops an image based on provided annotation format and returns cropped images and related data."""
|
||||
if os.path.isdir(self.source):
|
||||
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
|
||||
image = Image.fromarray(cv2.cvtColor(self.results[0].orig_img, cv2.COLOR_BGR2RGB))
|
||||
|
|
@ -229,6 +239,7 @@ class FastSAMPrompt:
|
|||
return cropped_boxes, cropped_images, not_crop, filter_id, annotations
|
||||
|
||||
def box_prompt(self, bbox):
|
||||
"""Modifies the bounding box properties and calculates IoU between masks and bounding box."""
|
||||
if self.results[0].masks is not None:
|
||||
assert (bbox[2] != 0 and bbox[3] != 0)
|
||||
if os.path.isdir(self.source):
|
||||
|
|
@ -261,7 +272,8 @@ class FastSAMPrompt:
|
|||
self.results[0].masks.data = torch.tensor(np.array([masks[max_iou_index].cpu().numpy()]))
|
||||
return self.results
|
||||
|
||||
def point_prompt(self, points, pointlabel): # numpy 处理
|
||||
def point_prompt(self, points, pointlabel): # numpy
|
||||
"""Adjusts points on detected masks based on user input and returns the modified results."""
|
||||
if self.results[0].masks is not None:
|
||||
if os.path.isdir(self.source):
|
||||
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
|
||||
|
|
@ -284,6 +296,7 @@ class FastSAMPrompt:
|
|||
return self.results
|
||||
|
||||
def text_prompt(self, text):
|
||||
"""Processes a text prompt, applies it to existing results and returns the updated results."""
|
||||
if self.results[0].masks is not None:
|
||||
format_results = self._format_results(self.results[0], 0)
|
||||
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
|
||||
|
|
@ -296,4 +309,5 @@ class FastSAMPrompt:
|
|||
return self.results
|
||||
|
||||
def everything_prompt(self):
|
||||
"""Returns the processed results from the previous methods in the class."""
|
||||
return self.results
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue