Ruff format docstring Python code (#15792)
Signed-off-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
c1882a4327
commit
d27664216b
63 changed files with 370 additions and 374 deletions
|
|
@ -41,8 +41,8 @@ class SAM(Model):
|
|||
info: Logs information about the SAM model.
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM('sam_b.pt')
|
||||
>>> results = sam.predict('image.jpg', points=[[500, 375]])
|
||||
>>> sam = SAM("sam_b.pt")
|
||||
>>> results = sam.predict("image.jpg", points=[[500, 375]])
|
||||
>>> for r in results:
|
||||
>>> print(f"Detected {len(r.masks)} masks")
|
||||
"""
|
||||
|
|
@ -58,7 +58,7 @@ class SAM(Model):
|
|||
NotImplementedError: If the model file extension is not .pt or .pth.
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM('sam_b.pt')
|
||||
>>> sam = SAM("sam_b.pt")
|
||||
>>> print(sam.is_sam2)
|
||||
"""
|
||||
if model and Path(model).suffix not in {".pt", ".pth"}:
|
||||
|
|
@ -78,8 +78,8 @@ class SAM(Model):
|
|||
task (str | None): Task name. If provided, it specifies the particular task the model is being loaded for.
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM('sam_b.pt')
|
||||
>>> sam._load('path/to/custom_weights.pt')
|
||||
>>> sam = SAM("sam_b.pt")
|
||||
>>> sam._load("path/to/custom_weights.pt")
|
||||
"""
|
||||
self.model = build_sam(weights)
|
||||
|
||||
|
|
@ -100,8 +100,8 @@ class SAM(Model):
|
|||
(List): The model predictions.
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM('sam_b.pt')
|
||||
>>> results = sam.predict('image.jpg', points=[[500, 375]])
|
||||
>>> sam = SAM("sam_b.pt")
|
||||
>>> results = sam.predict("image.jpg", points=[[500, 375]])
|
||||
>>> for r in results:
|
||||
... print(f"Detected {len(r.masks)} masks")
|
||||
"""
|
||||
|
|
@ -130,8 +130,8 @@ class SAM(Model):
|
|||
(List): The model predictions, typically containing segmentation masks and other relevant information.
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM('sam_b.pt')
|
||||
>>> results = sam('image.jpg', points=[[500, 375]])
|
||||
>>> sam = SAM("sam_b.pt")
|
||||
>>> results = sam("image.jpg", points=[[500, 375]])
|
||||
>>> print(f"Detected {len(results[0].masks)} masks")
|
||||
"""
|
||||
return self.predict(source, stream, bboxes, points, labels, **kwargs)
|
||||
|
|
@ -151,7 +151,7 @@ class SAM(Model):
|
|||
(Tuple): A tuple containing the model's information (string representations of the model).
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM('sam_b.pt')
|
||||
>>> sam = SAM("sam_b.pt")
|
||||
>>> info = sam.info()
|
||||
>>> print(info[0]) # Print summary information
|
||||
"""
|
||||
|
|
@ -167,7 +167,7 @@ class SAM(Model):
|
|||
class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor.
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM('sam_b.pt')
|
||||
>>> sam = SAM("sam_b.pt")
|
||||
>>> task_map = sam.task_map
|
||||
>>> print(task_map)
|
||||
{'segment': <class 'ultralytics.models.sam.predict.Predictor'>}
|
||||
|
|
|
|||
|
|
@ -32,8 +32,9 @@ class MaskDecoder(nn.Module):
|
|||
|
||||
Examples:
|
||||
>>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)
|
||||
>>> masks, iou_pred = decoder(image_embeddings, image_pe, sparse_prompt_embeddings,
|
||||
... dense_prompt_embeddings, multimask_output=True)
|
||||
>>> masks, iou_pred = decoder(
|
||||
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, multimask_output=True
|
||||
... )
|
||||
>>> print(f"Predicted masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}")
|
||||
"""
|
||||
|
||||
|
|
@ -213,7 +214,8 @@ class SAM2MaskDecoder(nn.Module):
|
|||
>>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
|
||||
>>> decoder = SAM2MaskDecoder(256, transformer)
|
||||
>>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(
|
||||
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False)
|
||||
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False
|
||||
... )
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -345,7 +347,8 @@ class SAM2MaskDecoder(nn.Module):
|
|||
>>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
|
||||
>>> decoder = SAM2MaskDecoder(256, transformer)
|
||||
>>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(
|
||||
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False)
|
||||
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False
|
||||
... )
|
||||
"""
|
||||
masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
|
||||
image_embeddings=image_embeddings,
|
||||
|
|
|
|||
|
|
@ -417,7 +417,15 @@ class SAM2Model(torch.nn.Module):
|
|||
>>> point_inputs = {"point_coords": torch.rand(1, 2, 2), "point_labels": torch.tensor([[1, 0]])}
|
||||
>>> mask_inputs = torch.rand(1, 1, 512, 512)
|
||||
>>> results = model._forward_sam_heads(backbone_features, point_inputs, mask_inputs)
|
||||
>>> low_res_multimasks, high_res_multimasks, ious, low_res_masks, high_res_masks, obj_ptr, object_score_logits = results
|
||||
>>> (
|
||||
... low_res_multimasks,
|
||||
... high_res_multimasks,
|
||||
... ious,
|
||||
... low_res_masks,
|
||||
... high_res_masks,
|
||||
... obj_ptr,
|
||||
... object_score_logits,
|
||||
... ) = results
|
||||
"""
|
||||
B = backbone_features.size(0)
|
||||
device = backbone_features.device
|
||||
|
|
|
|||
|
|
@ -716,7 +716,7 @@ class BasicLayer(nn.Module):
|
|||
|
||||
Examples:
|
||||
>>> layer = BasicLayer(dim=96, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)
|
||||
>>> x = torch.randn(1, 56*56, 96)
|
||||
>>> x = torch.randn(1, 56 * 56, 96)
|
||||
>>> output = layer(x)
|
||||
>>> print(output.shape)
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num
|
|||
|
||||
Examples:
|
||||
>>> frame_idx = 5
|
||||
>>> cond_frame_outputs = {1: 'a', 3: 'b', 7: 'c', 9: 'd'}
|
||||
>>> cond_frame_outputs = {1: "a", 3: "b", 7: "c", 9: "d"}
|
||||
>>> max_cond_frame_num = 2
|
||||
>>> selected, unselected = select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num)
|
||||
>>> print(selected)
|
||||
|
|
|
|||
|
|
@ -69,8 +69,8 @@ class Predictor(BasePredictor):
|
|||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
>>> predictor.setup_model(model_path='sam_model.pt')
|
||||
>>> predictor.set_image('image.jpg')
|
||||
>>> predictor.setup_model(model_path="sam_model.pt")
|
||||
>>> predictor.set_image("image.jpg")
|
||||
>>> masks, scores, boxes = predictor.generate()
|
||||
>>> results = predictor.postprocess((masks, scores, boxes), im, orig_img)
|
||||
"""
|
||||
|
|
@ -90,8 +90,8 @@ class Predictor(BasePredictor):
|
|||
|
||||
Examples:
|
||||
>>> predictor = Predictor(cfg=DEFAULT_CFG)
|
||||
>>> predictor = Predictor(overrides={'imgsz': 640})
|
||||
>>> predictor = Predictor(_callbacks={'on_predict_start': custom_callback})
|
||||
>>> predictor = Predictor(overrides={"imgsz": 640})
|
||||
>>> predictor = Predictor(_callbacks={"on_predict_start": custom_callback})
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
|
|
@ -188,8 +188,8 @@ class Predictor(BasePredictor):
|
|||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
>>> predictor.setup_model(model_path='sam_model.pt')
|
||||
>>> predictor.set_image('image.jpg')
|
||||
>>> predictor.setup_model(model_path="sam_model.pt")
|
||||
>>> predictor.set_image("image.jpg")
|
||||
>>> masks, scores, logits = predictor.inference(im, bboxes=[[0, 0, 100, 100]])
|
||||
"""
|
||||
# Override prompts if any stored in self.prompts
|
||||
|
|
@ -475,8 +475,8 @@ class Predictor(BasePredictor):
|
|||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
>>> predictor.setup_source('path/to/images')
|
||||
>>> predictor.setup_source('video.mp4')
|
||||
>>> predictor.setup_source("path/to/images")
|
||||
>>> predictor.setup_source("video.mp4")
|
||||
>>> predictor.setup_source(None) # Uses default source if available
|
||||
|
||||
Notes:
|
||||
|
|
@ -504,8 +504,8 @@ class Predictor(BasePredictor):
|
|||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
>>> predictor.set_image('path/to/image.jpg')
|
||||
>>> predictor.set_image(cv2.imread('path/to/image.jpg'))
|
||||
>>> predictor.set_image("path/to/image.jpg")
|
||||
>>> predictor.set_image(cv2.imread("path/to/image.jpg"))
|
||||
|
||||
Notes:
|
||||
- This method should be called before performing inference on a new image.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue