Ruff format docstring Python code (#15792)

Signed-off-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-08-25 01:08:07 +08:00 committed by GitHub
parent c1882a4327
commit d27664216b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
63 changed files with 370 additions and 374 deletions

View file

@ -16,8 +16,8 @@ class FastSAM(Model):
```python
from ultralytics import FastSAM
model = FastSAM('last.pt')
results = model.predict('ultralytics/assets/bus.jpg')
model = FastSAM("last.pt")
results = model.predict("ultralytics/assets/bus.jpg")
```
"""

View file

@ -92,8 +92,8 @@ class FastSAMPredictor(SegmentationPredictor):
if labels.sum() == 0 # all negative points
else torch.zeros(len(result), dtype=torch.bool, device=self.device)
)
for p, l in zip(points, labels):
point_idx[torch.nonzero(masks[:, p[1], p[0]], as_tuple=True)[0]] = True if l else False
for point, label in zip(points, labels):
point_idx[torch.nonzero(masks[:, point[1], point[0]], as_tuple=True)[0]] = True if label else False
idx |= point_idx
if texts is not None:
if isinstance(texts, str):

View file

@ -6,8 +6,8 @@ Example:
```python
from ultralytics import NAS
model = NAS('yolo_nas_s')
results = model.predict('ultralytics/assets/bus.jpg')
model = NAS("yolo_nas_s")
results = model.predict("ultralytics/assets/bus.jpg")
```
"""
@ -34,8 +34,8 @@ class NAS(Model):
```python
from ultralytics import NAS
model = NAS('yolo_nas_s')
results = model.predict('ultralytics/assets/bus.jpg')
model = NAS("yolo_nas_s")
results = model.predict("ultralytics/assets/bus.jpg")
```
Attributes:

View file

@ -22,7 +22,7 @@ class NASPredictor(BasePredictor):
```python
from ultralytics import NAS
model = NAS('yolo_nas_s')
model = NAS("yolo_nas_s")
predictor = model.predictor
# Assumes that raw_preds, img, orig_imgs are available
results = predictor.postprocess(raw_preds, img, orig_imgs)

View file

@ -24,7 +24,7 @@ class NASValidator(DetectionValidator):
```python
from ultralytics import NAS
model = NAS('yolo_nas_s')
model = NAS("yolo_nas_s")
validator = model.validator
# Assumes that raw_preds are available
final_preds = validator.postprocess(raw_preds)

View file

@ -21,7 +21,7 @@ class RTDETRPredictor(BasePredictor):
from ultralytics.utils import ASSETS
from ultralytics.models.rtdetr import RTDETRPredictor
args = dict(model='rtdetr-l.pt', source=ASSETS)
args = dict(model="rtdetr-l.pt", source=ASSETS)
predictor = RTDETRPredictor(overrides=args)
predictor.predict_cli()
```

View file

@ -25,7 +25,7 @@ class RTDETRTrainer(DetectionTrainer):
```python
from ultralytics.models.rtdetr.train import RTDETRTrainer
args = dict(model='rtdetr-l.yaml', data='coco8.yaml', imgsz=640, epochs=3)
args = dict(model="rtdetr-l.yaml", data="coco8.yaml", imgsz=640, epochs=3)
trainer = RTDETRTrainer(overrides=args)
trainer.train()
```

View file

@ -62,7 +62,7 @@ class RTDETRValidator(DetectionValidator):
```python
from ultralytics.models.rtdetr import RTDETRValidator
args = dict(model='rtdetr-l.pt', data='coco8.yaml')
args = dict(model="rtdetr-l.pt", data="coco8.yaml")
validator = RTDETRValidator(args=args)
validator()
```

View file

@ -41,8 +41,8 @@ class SAM(Model):
info: Logs information about the SAM model.
Examples:
>>> sam = SAM('sam_b.pt')
>>> results = sam.predict('image.jpg', points=[[500, 375]])
>>> sam = SAM("sam_b.pt")
>>> results = sam.predict("image.jpg", points=[[500, 375]])
>>> for r in results:
>>> print(f"Detected {len(r.masks)} masks")
"""
@ -58,7 +58,7 @@ class SAM(Model):
NotImplementedError: If the model file extension is not .pt or .pth.
Examples:
>>> sam = SAM('sam_b.pt')
>>> sam = SAM("sam_b.pt")
>>> print(sam.is_sam2)
"""
if model and Path(model).suffix not in {".pt", ".pth"}:
@ -78,8 +78,8 @@ class SAM(Model):
task (str | None): Task name. If provided, it specifies the particular task the model is being loaded for.
Examples:
>>> sam = SAM('sam_b.pt')
>>> sam._load('path/to/custom_weights.pt')
>>> sam = SAM("sam_b.pt")
>>> sam._load("path/to/custom_weights.pt")
"""
self.model = build_sam(weights)
@ -100,8 +100,8 @@ class SAM(Model):
(List): The model predictions.
Examples:
>>> sam = SAM('sam_b.pt')
>>> results = sam.predict('image.jpg', points=[[500, 375]])
>>> sam = SAM("sam_b.pt")
>>> results = sam.predict("image.jpg", points=[[500, 375]])
>>> for r in results:
... print(f"Detected {len(r.masks)} masks")
"""
@ -130,8 +130,8 @@ class SAM(Model):
(List): The model predictions, typically containing segmentation masks and other relevant information.
Examples:
>>> sam = SAM('sam_b.pt')
>>> results = sam('image.jpg', points=[[500, 375]])
>>> sam = SAM("sam_b.pt")
>>> results = sam("image.jpg", points=[[500, 375]])
>>> print(f"Detected {len(results[0].masks)} masks")
"""
return self.predict(source, stream, bboxes, points, labels, **kwargs)
@ -151,7 +151,7 @@ class SAM(Model):
(Tuple): A tuple containing the model's information (string representations of the model).
Examples:
>>> sam = SAM('sam_b.pt')
>>> sam = SAM("sam_b.pt")
>>> info = sam.info()
>>> print(info[0]) # Print summary information
"""
@ -167,7 +167,7 @@ class SAM(Model):
class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor.
Examples:
>>> sam = SAM('sam_b.pt')
>>> sam = SAM("sam_b.pt")
>>> task_map = sam.task_map
>>> print(task_map)
{'segment': <class 'ultralytics.models.sam.predict.Predictor'>}

View file

@ -32,8 +32,9 @@ class MaskDecoder(nn.Module):
Examples:
>>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)
>>> masks, iou_pred = decoder(image_embeddings, image_pe, sparse_prompt_embeddings,
... dense_prompt_embeddings, multimask_output=True)
>>> masks, iou_pred = decoder(
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, multimask_output=True
... )
>>> print(f"Predicted masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}")
"""
@ -213,7 +214,8 @@ class SAM2MaskDecoder(nn.Module):
>>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
>>> decoder = SAM2MaskDecoder(256, transformer)
>>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False)
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False
... )
"""
def __init__(
@ -345,7 +347,8 @@ class SAM2MaskDecoder(nn.Module):
>>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
>>> decoder = SAM2MaskDecoder(256, transformer)
>>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False)
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False
... )
"""
masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
image_embeddings=image_embeddings,

View file

@ -417,7 +417,15 @@ class SAM2Model(torch.nn.Module):
>>> point_inputs = {"point_coords": torch.rand(1, 2, 2), "point_labels": torch.tensor([[1, 0]])}
>>> mask_inputs = torch.rand(1, 1, 512, 512)
>>> results = model._forward_sam_heads(backbone_features, point_inputs, mask_inputs)
>>> low_res_multimasks, high_res_multimasks, ious, low_res_masks, high_res_masks, obj_ptr, object_score_logits = results
>>> (
... low_res_multimasks,
... high_res_multimasks,
... ious,
... low_res_masks,
... high_res_masks,
... obj_ptr,
... object_score_logits,
... ) = results
"""
B = backbone_features.size(0)
device = backbone_features.device

View file

@ -716,7 +716,7 @@ class BasicLayer(nn.Module):
Examples:
>>> layer = BasicLayer(dim=96, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)
>>> x = torch.randn(1, 56*56, 96)
>>> x = torch.randn(1, 56 * 56, 96)
>>> output = layer(x)
>>> print(output.shape)
"""

View file

@ -22,7 +22,7 @@ def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num
Examples:
>>> frame_idx = 5
>>> cond_frame_outputs = {1: 'a', 3: 'b', 7: 'c', 9: 'd'}
>>> cond_frame_outputs = {1: "a", 3: "b", 7: "c", 9: "d"}
>>> max_cond_frame_num = 2
>>> selected, unselected = select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num)
>>> print(selected)

View file

@ -69,8 +69,8 @@ class Predictor(BasePredictor):
Examples:
>>> predictor = Predictor()
>>> predictor.setup_model(model_path='sam_model.pt')
>>> predictor.set_image('image.jpg')
>>> predictor.setup_model(model_path="sam_model.pt")
>>> predictor.set_image("image.jpg")
>>> masks, scores, boxes = predictor.generate()
>>> results = predictor.postprocess((masks, scores, boxes), im, orig_img)
"""
@ -90,8 +90,8 @@ class Predictor(BasePredictor):
Examples:
>>> predictor = Predictor(cfg=DEFAULT_CFG)
>>> predictor = Predictor(overrides={'imgsz': 640})
>>> predictor = Predictor(_callbacks={'on_predict_start': custom_callback})
>>> predictor = Predictor(overrides={"imgsz": 640})
>>> predictor = Predictor(_callbacks={"on_predict_start": custom_callback})
"""
if overrides is None:
overrides = {}
@ -188,8 +188,8 @@ class Predictor(BasePredictor):
Examples:
>>> predictor = Predictor()
>>> predictor.setup_model(model_path='sam_model.pt')
>>> predictor.set_image('image.jpg')
>>> predictor.setup_model(model_path="sam_model.pt")
>>> predictor.set_image("image.jpg")
>>> masks, scores, logits = predictor.inference(im, bboxes=[[0, 0, 100, 100]])
"""
# Override prompts if any stored in self.prompts
@ -475,8 +475,8 @@ class Predictor(BasePredictor):
Examples:
>>> predictor = Predictor()
>>> predictor.setup_source('path/to/images')
>>> predictor.setup_source('video.mp4')
>>> predictor.setup_source("path/to/images")
>>> predictor.setup_source("video.mp4")
>>> predictor.setup_source(None) # Uses default source if available
Notes:
@ -504,8 +504,8 @@ class Predictor(BasePredictor):
Examples:
>>> predictor = Predictor()
>>> predictor.set_image('path/to/image.jpg')
>>> predictor.set_image(cv2.imread('path/to/image.jpg'))
>>> predictor.set_image("path/to/image.jpg")
>>> predictor.set_image(cv2.imread("path/to/image.jpg"))
Notes:
- This method should be called before performing inference on a new image.

View file

@ -21,7 +21,7 @@ class ClassificationPredictor(BasePredictor):
from ultralytics.utils import ASSETS
from ultralytics.models.yolo.classify import ClassificationPredictor
args = dict(model='yolov8n-cls.pt', source=ASSETS)
args = dict(model="yolov8n-cls.pt", source=ASSETS)
predictor = ClassificationPredictor(overrides=args)
predictor.predict_cli()
```

View file

@ -22,7 +22,7 @@ class ClassificationTrainer(BaseTrainer):
```python
from ultralytics.models.yolo.classify import ClassificationTrainer
args = dict(model='yolov8n-cls.pt', data='imagenet10', epochs=3)
args = dict(model="yolov8n-cls.pt", data="imagenet10", epochs=3)
trainer = ClassificationTrainer(overrides=args)
trainer.train()
```

View file

@ -20,7 +20,7 @@ class ClassificationValidator(BaseValidator):
```python
from ultralytics.models.yolo.classify import ClassificationValidator
args = dict(model='yolov8n-cls.pt', data='imagenet10')
args = dict(model="yolov8n-cls.pt", data="imagenet10")
validator = ClassificationValidator(args=args)
validator()
```

View file

@ -14,7 +14,7 @@ class DetectionPredictor(BasePredictor):
from ultralytics.utils import ASSETS
from ultralytics.models.yolo.detect import DetectionPredictor
args = dict(model='yolov8n.pt', source=ASSETS)
args = dict(model="yolov8n.pt", source=ASSETS)
predictor = DetectionPredictor(overrides=args)
predictor.predict_cli()
```

View file

@ -24,7 +24,7 @@ class DetectionTrainer(BaseTrainer):
```python
from ultralytics.models.yolo.detect import DetectionTrainer
args = dict(model='yolov8n.pt', data='coco8.yaml', epochs=3)
args = dict(model="yolov8n.pt", data="coco8.yaml", epochs=3)
trainer = DetectionTrainer(overrides=args)
trainer.train()
```

View file

@ -22,7 +22,7 @@ class DetectionValidator(BaseValidator):
```python
from ultralytics.models.yolo.detect import DetectionValidator
args = dict(model='yolov8n.pt', data='coco8.yaml')
args = dict(model="yolov8n.pt", data="coco8.yaml")
validator = DetectionValidator(args=args)
validator()
```

View file

@ -16,7 +16,7 @@ class OBBPredictor(DetectionPredictor):
from ultralytics.utils import ASSETS
from ultralytics.models.yolo.obb import OBBPredictor
args = dict(model='yolov8n-obb.pt', source=ASSETS)
args = dict(model="yolov8n-obb.pt", source=ASSETS)
predictor = OBBPredictor(overrides=args)
predictor.predict_cli()
```

View file

@ -15,7 +15,7 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
```python
from ultralytics.models.yolo.obb import OBBTrainer
args = dict(model='yolov8n-obb.pt', data='dota8.yaml', epochs=3)
args = dict(model="yolov8n-obb.pt", data="dota8.yaml", epochs=3)
trainer = OBBTrainer(overrides=args)
trainer.train()
```

View file

@ -18,9 +18,9 @@ class OBBValidator(DetectionValidator):
```python
from ultralytics.models.yolo.obb import OBBValidator
args = dict(model='yolov8n-obb.pt', data='dota8.yaml')
args = dict(model="yolov8n-obb.pt", data="dota8.yaml")
validator = OBBValidator(args=args)
validator(model=args['model'])
validator(model=args["model"])
```
"""

View file

@ -14,7 +14,7 @@ class PosePredictor(DetectionPredictor):
from ultralytics.utils import ASSETS
from ultralytics.models.yolo.pose import PosePredictor
args = dict(model='yolov8n-pose.pt', source=ASSETS)
args = dict(model="yolov8n-pose.pt", source=ASSETS)
predictor = PosePredictor(overrides=args)
predictor.predict_cli()
```

View file

@ -16,7 +16,7 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
```python
from ultralytics.models.yolo.pose import PoseTrainer
args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml', epochs=3)
args = dict(model="yolov8n-pose.pt", data="coco8-pose.yaml", epochs=3)
trainer = PoseTrainer(overrides=args)
trainer.train()
```

View file

@ -20,7 +20,7 @@ class PoseValidator(DetectionValidator):
```python
from ultralytics.models.yolo.pose import PoseValidator
args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml')
args = dict(model="yolov8n-pose.pt", data="coco8-pose.yaml")
validator = PoseValidator(args=args)
validator()
```

View file

@ -14,7 +14,7 @@ class SegmentationPredictor(DetectionPredictor):
from ultralytics.utils import ASSETS
from ultralytics.models.yolo.segment import SegmentationPredictor
args = dict(model='yolov8n-seg.pt', source=ASSETS)
args = dict(model="yolov8n-seg.pt", source=ASSETS)
predictor = SegmentationPredictor(overrides=args)
predictor.predict_cli()
```

View file

@ -16,7 +16,7 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
```python
from ultralytics.models.yolo.segment import SegmentationTrainer
args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml', epochs=3)
args = dict(model="yolov8n-seg.pt", data="coco8-seg.yaml", epochs=3)
trainer = SegmentationTrainer(overrides=args)
trainer.train()
```

View file

@ -22,7 +22,7 @@ class SegmentationValidator(DetectionValidator):
```python
from ultralytics.models.yolo.segment import SegmentationValidator
args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml')
args = dict(model="yolov8n-seg.pt", data="coco8-seg.yaml")
validator = SegmentationValidator(args=args)
validator()
```

View file

@ -29,7 +29,7 @@ class WorldTrainer(yolo.detect.DetectionTrainer):
```python
from ultralytics.models.yolo.world import WorldModel
args = dict(model='yolov8s-world.pt', data='coco8.yaml', epochs=3)
args = dict(model="yolov8s-world.pt", data="coco8.yaml", epochs=3)
trainer = WorldTrainer(overrides=args)
trainer.train()
```