Ruff format docstring Python code (#15792)
Signed-off-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
c1882a4327
commit
d27664216b
63 changed files with 370 additions and 374 deletions
|
|
@ -16,8 +16,8 @@ class FastSAM(Model):
|
|||
```python
|
||||
from ultralytics import FastSAM
|
||||
|
||||
model = FastSAM('last.pt')
|
||||
results = model.predict('ultralytics/assets/bus.jpg')
|
||||
model = FastSAM("last.pt")
|
||||
results = model.predict("ultralytics/assets/bus.jpg")
|
||||
```
|
||||
"""
|
||||
|
||||
|
|
|
|||
|
|
@ -92,8 +92,8 @@ class FastSAMPredictor(SegmentationPredictor):
|
|||
if labels.sum() == 0 # all negative points
|
||||
else torch.zeros(len(result), dtype=torch.bool, device=self.device)
|
||||
)
|
||||
for p, l in zip(points, labels):
|
||||
point_idx[torch.nonzero(masks[:, p[1], p[0]], as_tuple=True)[0]] = True if l else False
|
||||
for point, label in zip(points, labels):
|
||||
point_idx[torch.nonzero(masks[:, point[1], point[0]], as_tuple=True)[0]] = True if label else False
|
||||
idx |= point_idx
|
||||
if texts is not None:
|
||||
if isinstance(texts, str):
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ Example:
|
|||
```python
|
||||
from ultralytics import NAS
|
||||
|
||||
model = NAS('yolo_nas_s')
|
||||
results = model.predict('ultralytics/assets/bus.jpg')
|
||||
model = NAS("yolo_nas_s")
|
||||
results = model.predict("ultralytics/assets/bus.jpg")
|
||||
```
|
||||
"""
|
||||
|
||||
|
|
@ -34,8 +34,8 @@ class NAS(Model):
|
|||
```python
|
||||
from ultralytics import NAS
|
||||
|
||||
model = NAS('yolo_nas_s')
|
||||
results = model.predict('ultralytics/assets/bus.jpg')
|
||||
model = NAS("yolo_nas_s")
|
||||
results = model.predict("ultralytics/assets/bus.jpg")
|
||||
```
|
||||
|
||||
Attributes:
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ class NASPredictor(BasePredictor):
|
|||
```python
|
||||
from ultralytics import NAS
|
||||
|
||||
model = NAS('yolo_nas_s')
|
||||
model = NAS("yolo_nas_s")
|
||||
predictor = model.predictor
|
||||
# Assumes that raw_preds, img, orig_imgs are available
|
||||
results = predictor.postprocess(raw_preds, img, orig_imgs)
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ class NASValidator(DetectionValidator):
|
|||
```python
|
||||
from ultralytics import NAS
|
||||
|
||||
model = NAS('yolo_nas_s')
|
||||
model = NAS("yolo_nas_s")
|
||||
validator = model.validator
|
||||
# Assumes that raw_preds are available
|
||||
final_preds = validator.postprocess(raw_preds)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ class RTDETRPredictor(BasePredictor):
|
|||
from ultralytics.utils import ASSETS
|
||||
from ultralytics.models.rtdetr import RTDETRPredictor
|
||||
|
||||
args = dict(model='rtdetr-l.pt', source=ASSETS)
|
||||
args = dict(model="rtdetr-l.pt", source=ASSETS)
|
||||
predictor = RTDETRPredictor(overrides=args)
|
||||
predictor.predict_cli()
|
||||
```
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class RTDETRTrainer(DetectionTrainer):
|
|||
```python
|
||||
from ultralytics.models.rtdetr.train import RTDETRTrainer
|
||||
|
||||
args = dict(model='rtdetr-l.yaml', data='coco8.yaml', imgsz=640, epochs=3)
|
||||
args = dict(model="rtdetr-l.yaml", data="coco8.yaml", imgsz=640, epochs=3)
|
||||
trainer = RTDETRTrainer(overrides=args)
|
||||
trainer.train()
|
||||
```
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class RTDETRValidator(DetectionValidator):
|
|||
```python
|
||||
from ultralytics.models.rtdetr import RTDETRValidator
|
||||
|
||||
args = dict(model='rtdetr-l.pt', data='coco8.yaml')
|
||||
args = dict(model="rtdetr-l.pt", data="coco8.yaml")
|
||||
validator = RTDETRValidator(args=args)
|
||||
validator()
|
||||
```
|
||||
|
|
|
|||
|
|
@ -41,8 +41,8 @@ class SAM(Model):
|
|||
info: Logs information about the SAM model.
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM('sam_b.pt')
|
||||
>>> results = sam.predict('image.jpg', points=[[500, 375]])
|
||||
>>> sam = SAM("sam_b.pt")
|
||||
>>> results = sam.predict("image.jpg", points=[[500, 375]])
|
||||
>>> for r in results:
|
||||
>>> print(f"Detected {len(r.masks)} masks")
|
||||
"""
|
||||
|
|
@ -58,7 +58,7 @@ class SAM(Model):
|
|||
NotImplementedError: If the model file extension is not .pt or .pth.
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM('sam_b.pt')
|
||||
>>> sam = SAM("sam_b.pt")
|
||||
>>> print(sam.is_sam2)
|
||||
"""
|
||||
if model and Path(model).suffix not in {".pt", ".pth"}:
|
||||
|
|
@ -78,8 +78,8 @@ class SAM(Model):
|
|||
task (str | None): Task name. If provided, it specifies the particular task the model is being loaded for.
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM('sam_b.pt')
|
||||
>>> sam._load('path/to/custom_weights.pt')
|
||||
>>> sam = SAM("sam_b.pt")
|
||||
>>> sam._load("path/to/custom_weights.pt")
|
||||
"""
|
||||
self.model = build_sam(weights)
|
||||
|
||||
|
|
@ -100,8 +100,8 @@ class SAM(Model):
|
|||
(List): The model predictions.
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM('sam_b.pt')
|
||||
>>> results = sam.predict('image.jpg', points=[[500, 375]])
|
||||
>>> sam = SAM("sam_b.pt")
|
||||
>>> results = sam.predict("image.jpg", points=[[500, 375]])
|
||||
>>> for r in results:
|
||||
... print(f"Detected {len(r.masks)} masks")
|
||||
"""
|
||||
|
|
@ -130,8 +130,8 @@ class SAM(Model):
|
|||
(List): The model predictions, typically containing segmentation masks and other relevant information.
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM('sam_b.pt')
|
||||
>>> results = sam('image.jpg', points=[[500, 375]])
|
||||
>>> sam = SAM("sam_b.pt")
|
||||
>>> results = sam("image.jpg", points=[[500, 375]])
|
||||
>>> print(f"Detected {len(results[0].masks)} masks")
|
||||
"""
|
||||
return self.predict(source, stream, bboxes, points, labels, **kwargs)
|
||||
|
|
@ -151,7 +151,7 @@ class SAM(Model):
|
|||
(Tuple): A tuple containing the model's information (string representations of the model).
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM('sam_b.pt')
|
||||
>>> sam = SAM("sam_b.pt")
|
||||
>>> info = sam.info()
|
||||
>>> print(info[0]) # Print summary information
|
||||
"""
|
||||
|
|
@ -167,7 +167,7 @@ class SAM(Model):
|
|||
class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor.
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM('sam_b.pt')
|
||||
>>> sam = SAM("sam_b.pt")
|
||||
>>> task_map = sam.task_map
|
||||
>>> print(task_map)
|
||||
{'segment': <class 'ultralytics.models.sam.predict.Predictor'>}
|
||||
|
|
|
|||
|
|
@ -32,8 +32,9 @@ class MaskDecoder(nn.Module):
|
|||
|
||||
Examples:
|
||||
>>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)
|
||||
>>> masks, iou_pred = decoder(image_embeddings, image_pe, sparse_prompt_embeddings,
|
||||
... dense_prompt_embeddings, multimask_output=True)
|
||||
>>> masks, iou_pred = decoder(
|
||||
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, multimask_output=True
|
||||
... )
|
||||
>>> print(f"Predicted masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}")
|
||||
"""
|
||||
|
||||
|
|
@ -213,7 +214,8 @@ class SAM2MaskDecoder(nn.Module):
|
|||
>>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
|
||||
>>> decoder = SAM2MaskDecoder(256, transformer)
|
||||
>>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(
|
||||
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False)
|
||||
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False
|
||||
... )
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -345,7 +347,8 @@ class SAM2MaskDecoder(nn.Module):
|
|||
>>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
|
||||
>>> decoder = SAM2MaskDecoder(256, transformer)
|
||||
>>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(
|
||||
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False)
|
||||
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False
|
||||
... )
|
||||
"""
|
||||
masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
|
||||
image_embeddings=image_embeddings,
|
||||
|
|
|
|||
|
|
@ -417,7 +417,15 @@ class SAM2Model(torch.nn.Module):
|
|||
>>> point_inputs = {"point_coords": torch.rand(1, 2, 2), "point_labels": torch.tensor([[1, 0]])}
|
||||
>>> mask_inputs = torch.rand(1, 1, 512, 512)
|
||||
>>> results = model._forward_sam_heads(backbone_features, point_inputs, mask_inputs)
|
||||
>>> low_res_multimasks, high_res_multimasks, ious, low_res_masks, high_res_masks, obj_ptr, object_score_logits = results
|
||||
>>> (
|
||||
... low_res_multimasks,
|
||||
... high_res_multimasks,
|
||||
... ious,
|
||||
... low_res_masks,
|
||||
... high_res_masks,
|
||||
... obj_ptr,
|
||||
... object_score_logits,
|
||||
... ) = results
|
||||
"""
|
||||
B = backbone_features.size(0)
|
||||
device = backbone_features.device
|
||||
|
|
|
|||
|
|
@ -716,7 +716,7 @@ class BasicLayer(nn.Module):
|
|||
|
||||
Examples:
|
||||
>>> layer = BasicLayer(dim=96, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)
|
||||
>>> x = torch.randn(1, 56*56, 96)
|
||||
>>> x = torch.randn(1, 56 * 56, 96)
|
||||
>>> output = layer(x)
|
||||
>>> print(output.shape)
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num
|
|||
|
||||
Examples:
|
||||
>>> frame_idx = 5
|
||||
>>> cond_frame_outputs = {1: 'a', 3: 'b', 7: 'c', 9: 'd'}
|
||||
>>> cond_frame_outputs = {1: "a", 3: "b", 7: "c", 9: "d"}
|
||||
>>> max_cond_frame_num = 2
|
||||
>>> selected, unselected = select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num)
|
||||
>>> print(selected)
|
||||
|
|
|
|||
|
|
@ -69,8 +69,8 @@ class Predictor(BasePredictor):
|
|||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
>>> predictor.setup_model(model_path='sam_model.pt')
|
||||
>>> predictor.set_image('image.jpg')
|
||||
>>> predictor.setup_model(model_path="sam_model.pt")
|
||||
>>> predictor.set_image("image.jpg")
|
||||
>>> masks, scores, boxes = predictor.generate()
|
||||
>>> results = predictor.postprocess((masks, scores, boxes), im, orig_img)
|
||||
"""
|
||||
|
|
@ -90,8 +90,8 @@ class Predictor(BasePredictor):
|
|||
|
||||
Examples:
|
||||
>>> predictor = Predictor(cfg=DEFAULT_CFG)
|
||||
>>> predictor = Predictor(overrides={'imgsz': 640})
|
||||
>>> predictor = Predictor(_callbacks={'on_predict_start': custom_callback})
|
||||
>>> predictor = Predictor(overrides={"imgsz": 640})
|
||||
>>> predictor = Predictor(_callbacks={"on_predict_start": custom_callback})
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
|
|
@ -188,8 +188,8 @@ class Predictor(BasePredictor):
|
|||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
>>> predictor.setup_model(model_path='sam_model.pt')
|
||||
>>> predictor.set_image('image.jpg')
|
||||
>>> predictor.setup_model(model_path="sam_model.pt")
|
||||
>>> predictor.set_image("image.jpg")
|
||||
>>> masks, scores, logits = predictor.inference(im, bboxes=[[0, 0, 100, 100]])
|
||||
"""
|
||||
# Override prompts if any stored in self.prompts
|
||||
|
|
@ -475,8 +475,8 @@ class Predictor(BasePredictor):
|
|||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
>>> predictor.setup_source('path/to/images')
|
||||
>>> predictor.setup_source('video.mp4')
|
||||
>>> predictor.setup_source("path/to/images")
|
||||
>>> predictor.setup_source("video.mp4")
|
||||
>>> predictor.setup_source(None) # Uses default source if available
|
||||
|
||||
Notes:
|
||||
|
|
@ -504,8 +504,8 @@ class Predictor(BasePredictor):
|
|||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
>>> predictor.set_image('path/to/image.jpg')
|
||||
>>> predictor.set_image(cv2.imread('path/to/image.jpg'))
|
||||
>>> predictor.set_image("path/to/image.jpg")
|
||||
>>> predictor.set_image(cv2.imread("path/to/image.jpg"))
|
||||
|
||||
Notes:
|
||||
- This method should be called before performing inference on a new image.
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ class ClassificationPredictor(BasePredictor):
|
|||
from ultralytics.utils import ASSETS
|
||||
from ultralytics.models.yolo.classify import ClassificationPredictor
|
||||
|
||||
args = dict(model='yolov8n-cls.pt', source=ASSETS)
|
||||
args = dict(model="yolov8n-cls.pt", source=ASSETS)
|
||||
predictor = ClassificationPredictor(overrides=args)
|
||||
predictor.predict_cli()
|
||||
```
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ class ClassificationTrainer(BaseTrainer):
|
|||
```python
|
||||
from ultralytics.models.yolo.classify import ClassificationTrainer
|
||||
|
||||
args = dict(model='yolov8n-cls.pt', data='imagenet10', epochs=3)
|
||||
args = dict(model="yolov8n-cls.pt", data="imagenet10", epochs=3)
|
||||
trainer = ClassificationTrainer(overrides=args)
|
||||
trainer.train()
|
||||
```
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class ClassificationValidator(BaseValidator):
|
|||
```python
|
||||
from ultralytics.models.yolo.classify import ClassificationValidator
|
||||
|
||||
args = dict(model='yolov8n-cls.pt', data='imagenet10')
|
||||
args = dict(model="yolov8n-cls.pt", data="imagenet10")
|
||||
validator = ClassificationValidator(args=args)
|
||||
validator()
|
||||
```
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ class DetectionPredictor(BasePredictor):
|
|||
from ultralytics.utils import ASSETS
|
||||
from ultralytics.models.yolo.detect import DetectionPredictor
|
||||
|
||||
args = dict(model='yolov8n.pt', source=ASSETS)
|
||||
args = dict(model="yolov8n.pt", source=ASSETS)
|
||||
predictor = DetectionPredictor(overrides=args)
|
||||
predictor.predict_cli()
|
||||
```
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ class DetectionTrainer(BaseTrainer):
|
|||
```python
|
||||
from ultralytics.models.yolo.detect import DetectionTrainer
|
||||
|
||||
args = dict(model='yolov8n.pt', data='coco8.yaml', epochs=3)
|
||||
args = dict(model="yolov8n.pt", data="coco8.yaml", epochs=3)
|
||||
trainer = DetectionTrainer(overrides=args)
|
||||
trainer.train()
|
||||
```
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ class DetectionValidator(BaseValidator):
|
|||
```python
|
||||
from ultralytics.models.yolo.detect import DetectionValidator
|
||||
|
||||
args = dict(model='yolov8n.pt', data='coco8.yaml')
|
||||
args = dict(model="yolov8n.pt", data="coco8.yaml")
|
||||
validator = DetectionValidator(args=args)
|
||||
validator()
|
||||
```
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ class OBBPredictor(DetectionPredictor):
|
|||
from ultralytics.utils import ASSETS
|
||||
from ultralytics.models.yolo.obb import OBBPredictor
|
||||
|
||||
args = dict(model='yolov8n-obb.pt', source=ASSETS)
|
||||
args = dict(model="yolov8n-obb.pt", source=ASSETS)
|
||||
predictor = OBBPredictor(overrides=args)
|
||||
predictor.predict_cli()
|
||||
```
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|||
```python
|
||||
from ultralytics.models.yolo.obb import OBBTrainer
|
||||
|
||||
args = dict(model='yolov8n-obb.pt', data='dota8.yaml', epochs=3)
|
||||
args = dict(model="yolov8n-obb.pt", data="dota8.yaml", epochs=3)
|
||||
trainer = OBBTrainer(overrides=args)
|
||||
trainer.train()
|
||||
```
|
||||
|
|
|
|||
|
|
@ -18,9 +18,9 @@ class OBBValidator(DetectionValidator):
|
|||
```python
|
||||
from ultralytics.models.yolo.obb import OBBValidator
|
||||
|
||||
args = dict(model='yolov8n-obb.pt', data='dota8.yaml')
|
||||
args = dict(model="yolov8n-obb.pt", data="dota8.yaml")
|
||||
validator = OBBValidator(args=args)
|
||||
validator(model=args['model'])
|
||||
validator(model=args["model"])
|
||||
```
|
||||
"""
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ class PosePredictor(DetectionPredictor):
|
|||
from ultralytics.utils import ASSETS
|
||||
from ultralytics.models.yolo.pose import PosePredictor
|
||||
|
||||
args = dict(model='yolov8n-pose.pt', source=ASSETS)
|
||||
args = dict(model="yolov8n-pose.pt", source=ASSETS)
|
||||
predictor = PosePredictor(overrides=args)
|
||||
predictor.predict_cli()
|
||||
```
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|||
```python
|
||||
from ultralytics.models.yolo.pose import PoseTrainer
|
||||
|
||||
args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml', epochs=3)
|
||||
args = dict(model="yolov8n-pose.pt", data="coco8-pose.yaml", epochs=3)
|
||||
trainer = PoseTrainer(overrides=args)
|
||||
trainer.train()
|
||||
```
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class PoseValidator(DetectionValidator):
|
|||
```python
|
||||
from ultralytics.models.yolo.pose import PoseValidator
|
||||
|
||||
args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml')
|
||||
args = dict(model="yolov8n-pose.pt", data="coco8-pose.yaml")
|
||||
validator = PoseValidator(args=args)
|
||||
validator()
|
||||
```
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|||
from ultralytics.utils import ASSETS
|
||||
from ultralytics.models.yolo.segment import SegmentationPredictor
|
||||
|
||||
args = dict(model='yolov8n-seg.pt', source=ASSETS)
|
||||
args = dict(model="yolov8n-seg.pt", source=ASSETS)
|
||||
predictor = SegmentationPredictor(overrides=args)
|
||||
predictor.predict_cli()
|
||||
```
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|||
```python
|
||||
from ultralytics.models.yolo.segment import SegmentationTrainer
|
||||
|
||||
args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml', epochs=3)
|
||||
args = dict(model="yolov8n-seg.pt", data="coco8-seg.yaml", epochs=3)
|
||||
trainer = SegmentationTrainer(overrides=args)
|
||||
trainer.train()
|
||||
```
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ class SegmentationValidator(DetectionValidator):
|
|||
```python
|
||||
from ultralytics.models.yolo.segment import SegmentationValidator
|
||||
|
||||
args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml')
|
||||
args = dict(model="yolov8n-seg.pt", data="coco8-seg.yaml")
|
||||
validator = SegmentationValidator(args=args)
|
||||
validator()
|
||||
```
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ class WorldTrainer(yolo.detect.DetectionTrainer):
|
|||
```python
|
||||
from ultralytics.models.yolo.world import WorldModel
|
||||
|
||||
args = dict(model='yolov8s-world.pt', data='coco8.yaml', epochs=3)
|
||||
args = dict(model="yolov8s-world.pt", data="coco8.yaml", epochs=3)
|
||||
trainer = WorldTrainer(overrides=args)
|
||||
trainer.train()
|
||||
```
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue