Ruff Docstring formatting (#15793)

Signed-off-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-08-25 04:27:55 +08:00 committed by GitHub
parent d27664216b
commit 776ca86369
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
60 changed files with 241 additions and 309 deletions

View file

@ -135,9 +135,7 @@ class TQDM(tqdm_original):
class SimpleClass:
"""Ultralytics SimpleClass is a base class providing helpful string representation, error reporting, and attribute
access methods for easier debugging and usage.
"""
"""A base class providing string representation and attribute access functionality for Ultralytics objects."""
def __str__(self):
"""Return a human-readable string representation of the object."""
@ -164,9 +162,7 @@ class SimpleClass:
class IterableSimpleNamespace(SimpleNamespace):
"""Ultralytics IterableSimpleNamespace is an extension class of SimpleNamespace that adds iterable functionality and
enables usage with dict() and for loops.
"""
"""Iterable SimpleNamespace subclass for key-value attribute iteration and custom error handling."""
def __iter__(self):
"""Return an iterator of key-value pairs from the namespace's attributes."""
@ -209,7 +205,6 @@ def plt_settings(rcparams=None, backend="Agg"):
(Callable): Decorated function with temporarily set rc parameters and backend. This decorator can be
applied to any function that needs to have specific matplotlib rc parameters and backend for its execution.
"""
if rcparams is None:
rcparams = {"font.size": 11}
@ -240,9 +235,7 @@ def plt_settings(rcparams=None, backend="Agg"):
def set_logging(name="LOGGING_NAME", verbose=True):
"""Sets up logging for the given name with UTF-8 encoding support, ensuring compatibility across different
environments.
"""
"""Sets up logging with UTF-8 encoding and configurable verbosity for Ultralytics YOLO."""
level = logging.INFO if verbose and RANK in {-1, 0} else logging.ERROR # rank in world for Multi-GPU trainings
# Configure the console (stdout) encoding to UTF-8, with checks for compatibility
@ -702,7 +695,7 @@ SETTINGS_YAML = USER_CONFIG_DIR / "settings.yaml"
def colorstr(*input):
"""
r"""
Colors a string based on the provided color and style arguments. Utilizes ANSI escape codes.
See https://en.wikipedia.org/wiki/ANSI_escape_code for more details.
@ -946,9 +939,7 @@ class SettingsManager(dict):
"""
def __init__(self, file=SETTINGS_YAML, version="0.0.4"):
"""Initialize the SettingsManager with default settings, load and validate current settings from the YAML
file.
"""
"""Initializes the SettingsManager with default settings and loads user settings."""
import copy
import hashlib

View file

@ -16,13 +16,17 @@ def check_train_batch_size(model, imgsz=640, amp=True, batch=-1):
Args:
model (torch.nn.Module): YOLO model to check batch size for.
imgsz (int): Image size used for training.
amp (bool): If True, use automatic mixed precision (AMP) for training.
imgsz (int, optional): Image size used for training.
amp (bool, optional): Use automatic mixed precision if True.
batch (float, optional): Fraction of GPU memory to use. If -1, use default.
Returns:
(int): Optimal batch size computed using the autobatch() function.
"""
Note:
If 0.0 < batch < 1.0, it's used as the fraction of GPU memory to use.
Otherwise, a default fraction of 0.6 is used.
"""
with autocast(enabled=amp):
return autobatch(deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6)
@ -40,7 +44,6 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
Returns:
(int): The optimal batch size.
"""
# Check device
prefix = colorstr("AutoBatch: ")
LOGGER.info(f"{prefix}Computing optimal batch size for imgsz={imgsz} at {fraction * 100}% CUDA memory utilization.")

View file

@ -182,7 +182,6 @@ class RF100Benchmark:
Args:
api_key (str): The API key.
"""
check_requirements("roboflow")
from roboflow import Roboflow
@ -195,7 +194,6 @@ class RF100Benchmark:
Args:
ds_link_txt (str): Path to dataset_links file.
"""
(shutil.rmtree("rf-100"), os.mkdir("rf-100")) if os.path.exists("rf-100") else os.mkdir("rf-100")
os.chdir("rf-100")
os.mkdir("ultralytics-benchmarks")
@ -225,7 +223,6 @@ class RF100Benchmark:
Args:
path (str): YAML file path.
"""
with open(path, "r") as file:
yaml_data = yaml.safe_load(file)
yaml_data["train"] = "train/images"
@ -393,9 +390,7 @@ class ProfileModels:
return [Path(file) for file in sorted(files)]
def get_onnx_model_info(self, onnx_file: str):
"""Retrieves the information including number of layers, parameters, gradients and FLOPs for an ONNX model
file.
"""
"""Extracts metadata from an ONNX model file including parameters, GFLOPs, and input shape."""
return 0.0, 0.0, 0.0, 0.0 # return (num_layers, num_params, num_gradients, num_flops)
@staticmethod
@ -440,9 +435,7 @@ class ProfileModels:
return np.mean(run_times), np.std(run_times)
def profile_onnx_model(self, onnx_file: str, eps: float = 1e-3):
"""Profiles an ONNX model by executing it multiple times and returns the mean and standard deviation of run
times.
"""
"""Profiles an ONNX model, measuring average inference time and standard deviation across multiple runs."""
check_requirements("onnxruntime")
import onnxruntime as ort

View file

@ -192,7 +192,6 @@ def add_integration_callbacks(instance):
instance (Trainer, Predictor, Validator, Exporter): An object with a 'callbacks' attribute that is a dictionary
of callback lists.
"""
# Load HUB callbacks
from .hub import callbacks as hub_cb

View file

@ -114,7 +114,6 @@ def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, origin
This function rescales the bounding box labels to the original image shape.
"""
resized_image_height, resized_image_width = resized_image_shape
# Convert normalized xywh format predictions to xyxy in resized scale format

View file

@ -34,7 +34,6 @@ def _log_scalars(scalars, step=0):
def _log_tensorboard_graph(trainer):
"""Log model graph to TensorBoard."""
# Input image
imgsz = trainer.args.imgsz
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz

View file

@ -65,7 +65,6 @@ def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
parse_requirements(package="ultralytics")
```
"""
if package:
requires = [x for x in metadata.distribution(package).requires if "extra == " not in x]
else:
@ -257,7 +256,7 @@ def check_latest_pypi_version(package_name="ultralytics"):
"""
Returns the latest version of a PyPI package without downloading or installing it.
Parameters:
Args:
package_name (str): The name of the package to find the latest version for.
Returns:
@ -362,7 +361,6 @@ def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=()
check_requirements(["numpy", "ultralytics>=8.0.0"])
```
"""
prefix = colorstr("red", "bold", "requirements:")
check_python() # check python version
check_torchvision() # check torch-torchvision compatibility
@ -422,7 +420,6 @@ def check_torchvision():
The compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
Torchvision versions.
"""
# Compatibility table
compatibility_table = {
"2.3": ["0.18"],
@ -622,9 +619,9 @@ def collect_system_info():
def check_amp(model):
"""
This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model. If the checks
fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will
be disabled during training.
Checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model. If the checks fail, it means
there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will be disabled
during training.
Args:
model (nn.Module): A YOLOv8 model instance.

View file

@ -395,7 +395,6 @@ def get_github_assets(repo="ultralytics/assets", version="latest", retry=False):
tag, assets = get_github_assets(repo="ultralytics/assets", version="latest")
```
"""
if version != "latest":
version = f"tags/{version}" # i.e. tags/v6.2
url = f"https://api.github.com/repos/{repo}/releases/{version}"

View file

@ -71,7 +71,6 @@ def spaces_in_path(path):
>>> with spaces_in_path('/path/with spaces') as new_path:
>>> # Your code here
"""
# If path has spaces, replace them with underscores
if " " in str(path):
string = isinstance(path, str) # input type

View file

@ -96,8 +96,11 @@ class Bboxes:
def mul(self, scale):
"""
Multiply bounding box coordinates by scale factor(s).
Args:
scale (tuple | list | int): the scale for four coords.
scale (int | tuple | list): Scale factor(s) for four coordinates.
If int, the same scale is applied to all coordinates.
"""
if isinstance(scale, Number):
scale = to_4tuple(scale)
@ -110,8 +113,11 @@ class Bboxes:
def add(self, offset):
"""
Add offset to bounding box coordinates.
Args:
offset (tuple | list | int): the offset for four coords.
offset (int | tuple | list): Offset(s) for four coordinates.
If int, the same offset is applied to all coordinates.
"""
if isinstance(offset, Number):
offset = to_4tuple(offset)
@ -210,10 +216,14 @@ class Instances:
def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:
"""
Initialize the object with bounding boxes, segments, and keypoints.
Args:
bboxes (ndarray): bboxes with shape [N, 4].
segments (list | ndarray): segments.
keypoints (ndarray): keypoints(x, y, visible) with shape [N, 17, 3].
bboxes (np.ndarray): Bounding boxes, shape [N, 4].
segments (list | np.ndarray, optional): Segmentation masks. Defaults to None.
keypoints (np.ndarray, optional): Keypoints, shape [N, 17, 3] and format (x, y, visible). Defaults to None.
bbox_format (str, optional): Format of bboxes. Defaults to "xywh".
normalized (bool, optional): Whether the coordinates are normalized. Defaults to True.
"""
self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format)
self.keypoints = keypoints
@ -230,7 +240,7 @@ class Instances:
return self._bboxes.areas()
def scale(self, scale_w, scale_h, bbox_only=False):
"""This might be similar with denormalize func but without normalized sign."""
"""Similar to denormalize func but without normalized sign."""
self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h))
if bbox_only:
return

View file

@ -30,7 +30,6 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7):
Returns:
(np.ndarray): A numpy array of shape (n, m) representing the intersection over box2 area.
"""
# Get the coordinates of bounding boxes
b1_x1, b1_y1, b1_x2, b1_y2 = box1.T
b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
@ -53,7 +52,7 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7):
def box_iou(box1, box2, eps=1e-7):
"""
Calculate intersection-over-union (IoU) of boxes. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py.
Args:
box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes.
@ -63,7 +62,6 @@ def box_iou(box1, box2, eps=1e-7):
Returns:
(torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.
"""
# NOTE: Need .float() to get accurate iou values
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
(a1, a2), (b1, b2) = box1.float().unsqueeze(1).chunk(2, 2), box2.float().unsqueeze(0).chunk(2, 2)
@ -90,7 +88,6 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
Returns:
(torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags.
"""
# Get the coordinates of bounding boxes
if xywh: # transform from xywh to xyxy
(x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
@ -195,15 +192,22 @@ def _get_covariance_matrix(boxes):
def probiou(obb1, obb2, CIoU=False, eps=1e-7):
"""
Calculate the prob IoU between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf.
Calculate probabilistic IoU between oriented bounding boxes.
Implements the algorithm from https://arxiv.org/pdf/2106.06072v1.pdf.
Args:
obb1 (torch.Tensor): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
obb2 (torch.Tensor): A tensor of shape (N, 5) representing predicted obbs, with xywhr format.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
obb1 (torch.Tensor): Ground truth OBBs, shape (N, 5), format xywhr.
obb2 (torch.Tensor): Predicted OBBs, shape (N, 5), format xywhr.
CIoU (bool, optional): If True, calculate CIoU. Defaults to False.
eps (float, optional): Small value to avoid division by zero. Defaults to 1e-7.
Returns:
(torch.Tensor): A tensor of shape (N, ) representing obb similarities.
(torch.Tensor): OBB similarities, shape (N,).
Note:
OBB format: [center_x, center_y, width, height, rotation_angle].
If CIoU is True, returns CIoU instead of IoU.
"""
x1, y1 = obb1[..., :2].split(1, dim=-1)
x2, y2 = obb2[..., :2].split(1, dim=-1)
@ -507,7 +511,6 @@ def compute_ap(recall, precision):
(np.ndarray): Precision envelope curve.
(np.ndarray): Modified recall curve with sentinel values added at the beginning and end.
"""
# Append sentinel values to beginning and end
mrec = np.concatenate(([0.0], recall, [1.0]))
mpre = np.concatenate(([1.0], precision, [0.0]))
@ -560,7 +563,6 @@ def ap_per_class(
x (np.ndarray): X-axis values for the curves. Shape: (1000,).
prec_values: Precision values at mAP@0.5 for each class. Shape: (nc, 1000).
"""
# Sort by objectness
i = np.argsort(-conf)
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
@ -792,8 +794,8 @@ class Metric(SimpleClass):
class DetMetrics(SimpleClass):
"""
This class is a utility class for computing detection metrics such as precision, recall, and mean average precision
(mAP) of an object detection model.
Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP) of an
object detection model.
Args:
save_dir (Path): A path to the directory where the output plots will be saved. Defaults to current directory.
@ -942,7 +944,6 @@ class SegmentMetrics(SimpleClass):
pred_cls (list): List of predicted classes.
target_cls (list): List of target classes.
"""
results_mask = ap_per_class(
tp_m,
conf,
@ -1084,7 +1085,6 @@ class PoseMetrics(SegmentMetrics):
pred_cls (list): List of predicted classes.
target_cls (list): List of target classes.
"""
results_pose = ap_per_class(
tp_p,
conf,

View file

@ -141,14 +141,15 @@ def make_divisible(x, divisor):
def nms_rotated(boxes, scores, threshold=0.45):
"""
NMS for obbs, powered by probiou and fast-nms.
NMS for oriented bounding boxes using probiou and fast-nms.
Args:
boxes (torch.Tensor): (N, 5), xywhr.
scores (torch.Tensor): (N, ).
threshold (float): IoU threshold.
boxes (torch.Tensor): Rotated bounding boxes, shape (N, 5), format xywhr.
scores (torch.Tensor): Confidence scores, shape (N,).
threshold (float, optional): IoU threshold. Defaults to 0.45.
Returns:
(torch.Tensor): Indices of boxes to keep after NMS.
"""
if len(boxes) == 0:
return np.empty((0,), dtype=np.int8)
@ -597,7 +598,7 @@ def ltwh2xyxy(x):
def segments2boxes(segments):
"""
It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh).
Args:
segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates
@ -667,7 +668,6 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False):
(torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w
are the height and width of the input image. The mask is applied to the bounding boxes.
"""
c, mh, mw = protos.shape # CHW
ih, iw = shape
masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) # CHW
@ -785,7 +785,7 @@ def regularize_rboxes(rboxes):
def masks2segments(masks, strategy="largest"):
"""
It takes a list of masks(n,h,w) and returns a list of segments(n,xy)
It takes a list of masks(n,h,w) and returns a list of segments(n,xy).
Args:
masks (torch.Tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160)
@ -823,7 +823,7 @@ def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray:
def clean_str(s):
"""
Cleans a string by replacing special characters with underscore _
Cleans a string by replacing special characters with '_' character.
Args:
s (str): a string needing special characters replaced

View file

@ -204,7 +204,6 @@ class Annotator:
txt_color (tuple, optional): The color of the text (R, G, B).
margin (int, optional): The margin between the text and the rectangle border.
"""
# If label have more than 3 characters, skip other characters, due to circle size
if len(label) > 3:
print(
@ -246,7 +245,6 @@ class Annotator:
txt_color (tuple, optional): The color of the text (R, G, B).
margin (int, optional): The margin between the text and the rectangle border.
"""
# Calculate the center of the bounding box
x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
# Get the size of the text
@ -284,7 +282,6 @@ class Annotator:
txt_color (tuple, optional): The color of the text (R, G, B).
rotated (bool, optional): Variable used to check if task is OBB
"""
txt_color = self.get_txt_color(color, txt_color)
if isinstance(box, torch.Tensor):
box = box.tolist()
@ -343,7 +340,6 @@ class Annotator:
alpha (float): Mask transparency: 0.0 fully transparent, 1.0 opaque
retina_masks (bool): Whether to use high resolution masks or not. Defaults to False.
"""
if self.pil:
# Convert to numpy first
self.im = np.asarray(self.im).copy()
@ -374,17 +370,18 @@ class Annotator:
Plot keypoints on the image.
Args:
kpts (tensor): Predicted keypoints with shape [17, 3]. Each keypoint has (x, y, confidence).
shape (tuple): Image shape as a tuple (h, w), where h is the height and w is the width.
radius (int, optional): Radius of the drawn keypoints. Default is 5.
kpt_line (bool, optional): If True, the function will draw lines connecting keypoints
for human pose. Default is True.
kpt_color (tuple, optional): The color of the keypoints (B, G, R).
kpts (torch.Tensor): Keypoints, shape [17, 3] (x, y, confidence).
shape (tuple, optional): Image shape (h, w). Defaults to (640, 640).
radius (int, optional): Keypoint radius. Defaults to 5.
kpt_line (bool, optional): Draw lines between keypoints. Defaults to True.
conf_thres (float, optional): Confidence threshold. Defaults to 0.25.
kpt_color (tuple, optional): Keypoint color (B, G, R). Defaults to None.
Note:
`kpt_line=True` currently only supports human pose plotting.
- `kpt_line=True` currently only supports human pose plotting.
- Modifies self.im in-place.
- If self.pil is True, converts image to numpy array and back to PIL.
"""
if self.pil:
# Convert to numpy first
self.im = np.asarray(self.im).copy()
@ -488,7 +485,6 @@ class Annotator:
Returns:
angle (degree): Degree value of angle between three points
"""
x_min, y_min, x_max, y_max = bbox
width = x_max - x_min
height = y_max - y_min
@ -503,7 +499,6 @@ class Annotator:
color (tuple): Region Color value
thickness (int): Region area thickness value
"""
cv2.polylines(self.im, [np.array(reg_pts, dtype=np.int32)], isClosed=True, color=color, thickness=thickness)
def draw_centroid_and_tracks(self, track, color=(255, 0, 255), track_thickness=2):
@ -515,7 +510,6 @@ class Annotator:
color (tuple): tracks line color
track_thickness (int): track line thickness value
"""
points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(self.im, [points], isClosed=False, color=color, thickness=track_thickness)
cv2.circle(self.im, (int(track[-1][0]), int(track[-1][1])), track_thickness * 2, color, -1)
@ -530,7 +524,6 @@ class Annotator:
region_color (RGB): queue region color
txt_color (RGB): text display color
"""
x_values = [point[0] for point in points]
y_values = [point[1] for point in points]
center_x = sum(x_values) // len(points)
@ -574,7 +567,6 @@ class Annotator:
y_center (float): y position center point for bounding box
margin (int): gap between text and rectangle for better display
"""
text_size = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0]
text_x = x_center - text_size[0] // 2
text_y = y_center + text_size[1] // 2
@ -597,7 +589,6 @@ class Annotator:
bg_color (bgr color): display color for text background
margin (int): gap between text and rectangle for better display
"""
horizontal_gap = int(im0.shape[1] * 0.02)
vertical_gap = int(im0.shape[0] * 0.01)
text_y_offset = 0
@ -629,7 +620,6 @@ class Annotator:
Returns:
angle (degree): Degree value of angle between three points
"""
a, b, c = np.array(a), np.array(b), np.array(c)
radians = np.arctan2(c[1] - b[1], c[0] - b[0]) - np.arctan2(a[1] - b[1], a[0] - b[0])
angle = np.abs(radians * 180.0 / np.pi)
@ -642,12 +632,19 @@ class Annotator:
Draw specific keypoints for gym steps counting.
Args:
keypoints (list): list of keypoints data to be plotted
indices (list): keypoints ids list to be plotted
shape (tuple): imgsz for model inference
radius (int): Keypoint radius value
"""
keypoints (list): Keypoints data to be plotted.
indices (list, optional): Keypoint indices to be plotted. Defaults to [2, 5, 7].
shape (tuple, optional): Image size for model inference. Defaults to (640, 640).
radius (int, optional): Keypoint radius. Defaults to 2.
conf_thres (float, optional): Confidence threshold for keypoints. Defaults to 0.25.
Returns:
(numpy.ndarray): Image with drawn keypoints.
Note:
Keypoint format: [x, y] or [x, y, confidence].
Modifies self.im in-place.
"""
if indices is None:
indices = [2, 5, 7]
for i, k in enumerate(keypoints):
@ -675,7 +672,6 @@ class Annotator:
color (tuple): text background color for workout monitoring
txt_color (tuple): text foreground color for workout monitoring
"""
angle_text, count_text, stage_text = (f" {angle_text:.2f}", f"Steps : {count_text}", f" {stage_text}")
# Draw angle
@ -744,7 +740,6 @@ class Annotator:
label (str): Detection label text
txt_color (RGB): text color
"""
cv2.polylines(self.im, [np.int32([mask])], isClosed=True, color=mask_color, thickness=2)
text_size, _ = cv2.getTextSize(label, 0, self.sf, self.tf)
@ -772,7 +767,6 @@ class Annotator:
line_color (RGB): Distance line color.
centroid_color (RGB): Bounding box centroid color.
"""
(text_width_m, text_height_m), _ = cv2.getTextSize(f"Distance M: {distance_m:.2f}m", 0, self.sf, self.tf)
cv2.rectangle(self.im, (15, 25), (15 + text_width_m + 10, 25 + text_height_m + 20), line_color, -1)
cv2.putText(
@ -813,7 +807,6 @@ class Annotator:
color (tuple): object centroid and line color value
pin_color (tuple): visioneye point color value
"""
center_bbox = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
cv2.circle(self.im, center_point, self.tf * 2, pin_color, -1)
cv2.circle(self.im, center_bbox, self.tf * 2, color, -1)
@ -906,7 +899,6 @@ def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False,
cropped_im = save_one_box(xyxy, im, file="cropped.jpg", square=True)
```
"""
if not isinstance(xyxy, torch.Tensor): # may be list
xyxy = torch.stack(xyxy)
b = ops.xyxy2xywh(xyxy.view(-1, 4)) # boxes
@ -1171,7 +1163,6 @@ def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none
>>> f = np.random.rand(100)
>>> plt_color_scatter(v, f)
"""
# Calculate 2D histogram and corresponding colors
hist, xedges, yedges = np.histogram2d(v, f, bins=bins)
colors = [
@ -1197,7 +1188,6 @@ def plot_tune_results(csv_file="tune_results.csv"):
Examples:
>>> plot_tune_results("path/to/tune_results.csv")
"""
import pandas as pd # scope for faster 'import ultralytics'
from scipy.ndimage import gaussian_filter1d

View file

@ -140,7 +140,6 @@ class TaskAlignedAssigner(nn.Module):
Returns:
(Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
"""
# (b, max_num_obj, topk)
topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
if topk_mask is None:
@ -184,7 +183,6 @@ class TaskAlignedAssigner(nn.Module):
for positive anchor points, where num_classes is the number
of object classes.
"""
# Assigned target labels, (b, 1)
batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w)
@ -212,14 +210,19 @@ class TaskAlignedAssigner(nn.Module):
@staticmethod
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
"""
Select the positive anchor center in gt.
Select positive anchor centers within ground truth bounding boxes.
Args:
xy_centers (Tensor): shape(h*w, 2)
gt_bboxes (Tensor): shape(b, n_boxes, 4)
xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).
gt_bboxes (torch.Tensor): Ground truth bounding boxes, shape (b, n_boxes, 4).
eps (float, optional): Small value for numerical stability. Defaults to 1e-9.
Returns:
(Tensor): shape(b, n_boxes, h*w)
(torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).
Note:
b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
Bounding box format: [x_min, y_min, x_max, y_max].
"""
n_anchors = xy_centers.shape[0]
bs, n_boxes, _ = gt_bboxes.shape
@ -231,18 +234,22 @@ class TaskAlignedAssigner(nn.Module):
@staticmethod
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
"""
If an anchor box is assigned to multiple gts, the one with the highest IoU will be selected.
Select anchor boxes with highest IoU when assigned to multiple ground truths.
Args:
mask_pos (Tensor): shape(b, n_max_boxes, h*w)
overlaps (Tensor): shape(b, n_max_boxes, h*w)
mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w).
overlaps (torch.Tensor): IoU overlaps, shape (b, n_max_boxes, h*w).
n_max_boxes (int): Maximum number of ground truth boxes.
Returns:
target_gt_idx (Tensor): shape(b, h*w)
fg_mask (Tensor): shape(b, h*w)
mask_pos (Tensor): shape(b, n_max_boxes, h*w)
target_gt_idx (torch.Tensor): Indices of assigned ground truths, shape (b, h*w).
fg_mask (torch.Tensor): Foreground mask, shape (b, h*w).
mask_pos (torch.Tensor): Updated positive mask, shape (b, n_max_boxes, h*w).
Note:
b: batch size, h: height, w: width.
"""
# (b, n_max_boxes, h*w) -> (b, h*w)
# Convert (b, n_max_boxes, h*w) -> (b, h*w)
fg_mask = mask_pos.sum(-2)
if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)
@ -328,14 +335,16 @@ def bbox2dist(anchor_points, bbox, reg_max):
def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
"""
Decode predicted object bounding box coordinates from anchor points and distribution.
Decode predicted rotated bounding box coordinates from anchor points and distribution.
Args:
pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).
pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).
anchor_points (torch.Tensor): Anchor points, (h*w, 2).
pred_dist (torch.Tensor): Predicted rotated distance, shape (bs, h*w, 4).
pred_angle (torch.Tensor): Predicted angle, shape (bs, h*w, 1).
anchor_points (torch.Tensor): Anchor points, shape (h*w, 2).
dim (int, optional): Dimension along which to split. Defaults to -1.
Returns:
(torch.Tensor): Predicted rotated bounding boxes, (bs, h*w, 4).
(torch.Tensor): Predicted rotated bounding boxes, shape (bs, h*w, 4).
"""
lt, rb = pred_dist.split(2, dim=dim)
cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)

View file

@ -146,7 +146,6 @@ def select_device(device="", batch=0, newline=False, verbose=True):
Note:
Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use.
"""
if isinstance(device, torch.device):
return device
@ -417,9 +416,7 @@ def initialize_weights(model):
def scale_img(img, ratio=1.0, same_shape=False, gs=32):
"""Scales and pads an image tensor of shape img(bs,3,y,x) based on given ratio and grid size gs, optionally
retaining the original shape.
"""
"""Scales and pads an image tensor, optionally maintaining aspect ratio and padding to gs multiple."""
if ratio == 1.0:
return img
h, w = img.shape[2:]
@ -493,7 +490,7 @@ def init_seeds(seed=0, deterministic=False):
class ModelEMA:
"""
Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models. Keeps a moving
average of everything in the model state_dict (parameters and buffers)
average of everything in the model state_dict (parameters and buffers).
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage

View file

@ -34,7 +34,6 @@ def run_ray_tune(
result_grid = model.tune(data="coco8.yaml", use_ray=True)
```
"""
LOGGER.info("💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune")
if train_args is None:
train_args = {}