Ruff Docstring formatting (#15793)
Signed-off-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
d27664216b
commit
776ca86369
60 changed files with 241 additions and 309 deletions
18
.github/workflows/docs.yml
vendored
18
.github/workflows/docs.yml
vendored
|
|
@ -1,5 +1,16 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
# Test and publish docs to https://docs.ultralytics.com
|
||||
# Ignores the following Docs rules to match Google-style docstrings:
|
||||
# D100: Missing docstring in public module
|
||||
# D104: Missing docstring in public package
|
||||
# D203: 1 blank line required before class docstring
|
||||
# D205: 1 blank line required between summary line and description
|
||||
# D212: Multi-line docstring summary should start at the first line
|
||||
# D213: Multi-line docstring summary should start at the second line
|
||||
# D401: First line of docstring should be in imperative mood
|
||||
# D406: Section name should end with a newline
|
||||
# D407: Missing dashed underline after section
|
||||
# D413: Missing blank line after last section
|
||||
|
||||
name: Publish Docs
|
||||
|
||||
|
|
@ -32,20 +43,23 @@ jobs:
|
|||
python-version: "3.x"
|
||||
cache: "pip" # caching pip dependencies
|
||||
- name: Install Dependencies
|
||||
run: pip install black tqdm mkdocs-material "mkdocstrings[python]" mkdocs-jupyter mkdocs-redirects mkdocs-ultralytics-plugin mkdocs-macros-plugin
|
||||
run: pip install ruff black tqdm mkdocs-material "mkdocstrings[python]" mkdocs-jupyter mkdocs-redirects mkdocs-ultralytics-plugin mkdocs-macros-plugin
|
||||
- name: Update Docs Reference Section and Push Changes
|
||||
if: github.event_name == 'pull_request_target'
|
||||
run: |
|
||||
python docs/build_reference.py
|
||||
ruff check --fix --fix-unsafe --select D --ignore=D100,D104,D203,D205,D212,D213,D401,D406,D407,D413 . || true
|
||||
git pull origin ${{ github.head_ref || github.ref }}
|
||||
git add .
|
||||
git reset HEAD -- .github/workflows/ # workflow changes are not permitted with default token
|
||||
if ! git diff --staged --quiet; then
|
||||
git commit -m "Auto-update Ultralytics Docs Reference Section by https://ultralytics.com/actions"
|
||||
git commit -m "Auto-update Ultralytics Docs Reference by https://ultralytics.com/actions"
|
||||
git push
|
||||
else
|
||||
echo "No changes to commit"
|
||||
fi
|
||||
- name: Ruff checks
|
||||
run: ruff check --select D --ignore=D100,D104,D203,D205,D212,D213,D401,D406,D407,D413 .
|
||||
- name: Build Docs and Check for Warnings
|
||||
run: |
|
||||
export JUPYTER_PLATFORM_DIRS=1
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
This Python script is designed to automate the building and post-processing of MkDocs documentation, particularly for
|
||||
projects with multilingual content. It streamlines the workflow for generating localized versions of the documentation
|
||||
and updating HTML links to ensure they are correctly formatted.
|
||||
Automates the building and post-processing of MkDocs documentation, particularly for projects with multilingual content.
|
||||
It streamlines the workflow for generating localized versions of the documentation and updating HTML links to ensure
|
||||
they are correctly formatted.
|
||||
|
||||
Key Features:
|
||||
- Automated building of MkDocs documentation: The script compiles both the main documentation and
|
||||
|
|
@ -64,7 +64,6 @@ def prepare_docs_markdown(clone_repos=True):
|
|||
|
||||
def update_page_title(file_path: Path, new_title: str):
|
||||
"""Update the title of an HTML file."""
|
||||
|
||||
# Read the content of the file
|
||||
with open(file_path, encoding="utf-8") as file:
|
||||
content = file.read()
|
||||
|
|
@ -153,7 +152,6 @@ def update_markdown_files(md_filepath: Path):
|
|||
|
||||
def update_docs_html():
|
||||
"""Updates titles, edit links, head sections, and converts plaintext links in HTML documentation."""
|
||||
|
||||
# Update 404 titles
|
||||
update_page_title(SITE / "404.html", new_title="Ultralytics Docs - Not Found")
|
||||
|
||||
|
|
|
|||
|
|
@ -203,7 +203,6 @@ class HuggingFaceVideoClassifier:
|
|||
Returns:
|
||||
torch.Tensor: The model's output.
|
||||
"""
|
||||
|
||||
input_ids = self.processor(text=self.labels, return_tensors="pt", padding=True)["input_ids"].to(self.device)
|
||||
|
||||
inputs = {"pixel_values": sequences, "input_ids": input_ids}
|
||||
|
|
|
|||
|
|
@ -48,7 +48,6 @@ class YOLOv8:
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
# Extract the coordinates of the bounding box
|
||||
x1, y1, w, h = box
|
||||
|
||||
|
|
@ -118,7 +117,6 @@ class YOLOv8:
|
|||
Returns:
|
||||
numpy.ndarray: The input image with detections drawn on it.
|
||||
"""
|
||||
|
||||
# Transpose and squeeze the output to match the expected shape
|
||||
outputs = np.transpose(np.squeeze(output[0]))
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,6 @@ class LetterBox:
|
|||
|
||||
def __call__(self, labels=None, image=None):
|
||||
"""Return updated labels and image with added border."""
|
||||
|
||||
if labels is None:
|
||||
labels = {}
|
||||
img = labels.get("img") if image is None else image
|
||||
|
|
@ -79,7 +78,6 @@ class LetterBox:
|
|||
|
||||
def _update_labels(self, labels, ratio, padw, padh):
|
||||
"""Update labels."""
|
||||
|
||||
labels["instances"].convert_bbox(format="xyxy")
|
||||
labels["instances"].denormalize(*labels["img"].shape[:2][::-1])
|
||||
labels["instances"].scale(*ratio)
|
||||
|
|
@ -100,7 +98,6 @@ class Yolov8TFLite:
|
|||
confidence_thres: Confidence threshold for filtering detections.
|
||||
iou_thres: IoU (Intersection over Union) threshold for non-maximum suppression.
|
||||
"""
|
||||
|
||||
self.tflite_model = tflite_model
|
||||
self.input_image = input_image
|
||||
self.confidence_thres = confidence_thres
|
||||
|
|
@ -125,7 +122,6 @@ class Yolov8TFLite:
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
# Extract the coordinates of the bounding box
|
||||
x1, y1, w, h = box
|
||||
|
||||
|
|
@ -164,7 +160,6 @@ class Yolov8TFLite:
|
|||
Returns:
|
||||
image_data: Preprocessed image data ready for inference.
|
||||
"""
|
||||
|
||||
# Read the input image using OpenCV
|
||||
self.img = cv2.imread(self.input_image)
|
||||
|
||||
|
|
@ -193,7 +188,6 @@ class Yolov8TFLite:
|
|||
Returns:
|
||||
numpy.ndarray: The input image with detections drawn on it.
|
||||
"""
|
||||
|
||||
boxes = []
|
||||
scores = []
|
||||
class_ids = []
|
||||
|
|
@ -238,7 +232,6 @@ class Yolov8TFLite:
|
|||
Returns:
|
||||
output_img: The output image with drawn detections.
|
||||
"""
|
||||
|
||||
# Create an interpreter for the TFLite model
|
||||
interpreter = tflite.Interpreter(model_path=self.tflite_model)
|
||||
self.model = interpreter
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ def mouse_callback(event, x, y, flags, param):
|
|||
"""
|
||||
Handles mouse events for region manipulation.
|
||||
|
||||
Parameters:
|
||||
Args:
|
||||
event (int): The mouse event type (e.g., cv2.EVENT_LBUTTONDOWN).
|
||||
x (int): The x-coordinate of the mouse pointer.
|
||||
y (int): The y-coordinate of the mouse pointer.
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ class YOLOv8Seg:
|
|||
Args:
|
||||
onnx_model (str): Path to the ONNX model.
|
||||
"""
|
||||
|
||||
# Build Ort session
|
||||
self.session = ort.InferenceSession(
|
||||
onnx_model,
|
||||
|
|
@ -57,7 +56,6 @@ class YOLOv8Seg:
|
|||
segments (List): list of segments.
|
||||
masks (np.ndarray): [N, H, W], output masks.
|
||||
"""
|
||||
|
||||
# Pre-process
|
||||
im, ratio, (pad_w, pad_h) = self.preprocess(im0)
|
||||
|
||||
|
|
@ -90,7 +88,6 @@ class YOLOv8Seg:
|
|||
pad_w (float): width padding in letterbox.
|
||||
pad_h (float): height padding in letterbox.
|
||||
"""
|
||||
|
||||
# Resize and pad input image using letterbox() (Borrowed from Ultralytics)
|
||||
shape = img.shape[:2] # original image shape
|
||||
new_shape = (self.model_height, self.model_width)
|
||||
|
|
@ -130,7 +127,7 @@ class YOLOv8Seg:
|
|||
"""
|
||||
x, protos = preds[0], preds[1] # Two outputs: predictions and protos
|
||||
|
||||
# Transpose the first output: (Batch_size, xywh_conf_cls_nm, Num_anchors) -> (Batch_size, Num_anchors, xywh_conf_cls_nm)
|
||||
# Transpose dim 1: (Batch_size, xywh_conf_cls_nm, Num_anchors) -> (Batch_size, Num_anchors, xywh_conf_cls_nm)
|
||||
x = np.einsum("bcn->bnc", x)
|
||||
|
||||
# Predictions filtering by conf-threshold
|
||||
|
|
@ -169,8 +166,8 @@ class YOLOv8Seg:
|
|||
@staticmethod
|
||||
def masks2segments(masks):
|
||||
"""
|
||||
It takes a list of masks(n,h,w) and returns a list of segments(n,xy) (Borrowed from
|
||||
https://github.com/ultralytics/ultralytics/blob/465df3024f44fa97d4fad9986530d5a13cdabdca/ultralytics/utils/ops.py#L750)
|
||||
Takes a list of masks(n,h,w) and returns a list of segments(n,xy), from
|
||||
https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/ops.py.
|
||||
|
||||
Args:
|
||||
masks (numpy.ndarray): the output of the model, which is a tensor of shape (batch_size, 160, 160).
|
||||
|
|
@ -191,8 +188,8 @@ class YOLOv8Seg:
|
|||
@staticmethod
|
||||
def crop_mask(masks, boxes):
|
||||
"""
|
||||
It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box. (Borrowed from
|
||||
https://github.com/ultralytics/ultralytics/blob/465df3024f44fa97d4fad9986530d5a13cdabdca/ultralytics/utils/ops.py#L599)
|
||||
Takes a mask and a bounding box, and returns a mask that is cropped to the bounding box, from
|
||||
https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/ops.py.
|
||||
|
||||
Args:
|
||||
masks (Numpy.ndarray): [n, h, w] tensor of masks.
|
||||
|
|
@ -209,8 +206,8 @@ class YOLOv8Seg:
|
|||
|
||||
def process_mask(self, protos, masks_in, bboxes, im0_shape):
|
||||
"""
|
||||
Takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher quality
|
||||
but is slower. (Borrowed from https://github.com/ultralytics/ultralytics/blob/465df3024f44fa97d4fad9986530d5a13cdabdca/ultralytics/utils/ops.py#L618)
|
||||
Takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher
|
||||
quality but is slower, from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/ops.py.
|
||||
|
||||
Args:
|
||||
protos (numpy.ndarray): [mask_dim, mask_h, mask_w].
|
||||
|
|
@ -232,8 +229,8 @@ class YOLOv8Seg:
|
|||
@staticmethod
|
||||
def scale_mask(masks, im0_shape, ratio_pad=None):
|
||||
"""
|
||||
Takes a mask, and resizes it to the original image size. (Borrowed from
|
||||
https://github.com/ultralytics/ultralytics/blob/465df3024f44fa97d4fad9986530d5a13cdabdca/ultralytics/utils/ops.py#L305)
|
||||
Takes a mask, and resizes it to the original image size, from
|
||||
https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/ops.py.
|
||||
|
||||
Args:
|
||||
masks (np.ndarray): resized and padded masks/images, [h, w, num]/[h, w, 3].
|
||||
|
|
@ -277,7 +274,6 @@ class YOLOv8Seg:
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
# Draw rectangles and polygons
|
||||
im_canvas = im.copy()
|
||||
for (*box, conf, cls_), segment in zip(bboxes, segments):
|
||||
|
|
|
|||
|
|
@ -174,15 +174,12 @@ line-length = 120
|
|||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
convention = "google"
|
||||
|
||||
[tool.docformatter]
|
||||
wrap-summaries = 120
|
||||
wrap-descriptions = 120
|
||||
in-place = true
|
||||
pre-summary-newline = true
|
||||
close-quotes-on-newline = true
|
||||
in-place = true
|
||||
|
||||
[tool.codespell]
|
||||
ignore-words-list = "crate,nd,ned,strack,dota,ane,segway,fo,gool,winn,commend,bloc,nam,afterall"
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ WORKOUTS_SOLUTION_DEMO = "https://github.com/ultralytics/assets/releases/downloa
|
|||
@pytest.mark.slow
|
||||
def test_major_solutions():
|
||||
"""Test the object counting, heatmap, speed estimation and queue management solution."""
|
||||
|
||||
safe_download(url=MAJOR_SOLUTIONS_DEMO)
|
||||
model = YOLO("yolov8n.pt")
|
||||
names = model.names
|
||||
|
|
@ -41,7 +40,6 @@ def test_major_solutions():
|
|||
@pytest.mark.slow
|
||||
def test_aigym():
|
||||
"""Test the workouts monitoring solution."""
|
||||
|
||||
safe_download(url=WORKOUTS_SOLUTION_DEMO)
|
||||
model = YOLO("yolov8n-pose.pt")
|
||||
cap = cv2.VideoCapture("solution_ci_pose_demo.mp4")
|
||||
|
|
@ -60,7 +58,6 @@ def test_aigym():
|
|||
@pytest.mark.slow
|
||||
def test_instance_segmentation():
|
||||
"""Test the instance segmentation solution."""
|
||||
|
||||
from ultralytics.utils.plotting import Annotator, colors
|
||||
|
||||
model = YOLO("yolov8n-seg.pt")
|
||||
|
|
@ -86,5 +83,4 @@ def test_instance_segmentation():
|
|||
@pytest.mark.slow
|
||||
def test_streamlit_predict():
|
||||
"""Test streamlit predict live inference solution."""
|
||||
|
||||
solutions.inference()
|
||||
|
|
|
|||
|
|
@ -350,7 +350,6 @@ def get_save_dir(args, name=None):
|
|||
>>> print(save_dir)
|
||||
my_project/detect/train
|
||||
"""
|
||||
|
||||
if getattr(args, "save_dir", None):
|
||||
save_dir = args.save_dir
|
||||
else:
|
||||
|
|
@ -381,7 +380,6 @@ def _handle_deprecation(custom):
|
|||
equivalents. It also handles value conversions where necessary, such as inverting boolean values for
|
||||
'hide_labels' and 'hide_conf'.
|
||||
"""
|
||||
|
||||
for key in custom.copy().keys():
|
||||
if key == "boxes":
|
||||
deprecation_warn(key, "show_boxes")
|
||||
|
|
@ -548,9 +546,9 @@ def handle_yolo_settings(args: List[str]) -> None:
|
|||
|
||||
def handle_explorer(args: List[str]):
|
||||
"""
|
||||
This function launches a graphical user interface that provides tools for interacting with and analyzing datasets
|
||||
using the Ultralytics Explorer API. It checks for the required 'streamlit' package and informs the user that the
|
||||
Explorer dashboard is loading.
|
||||
Launches a graphical user interface that provides tools for interacting with and analyzing datasets using the
|
||||
Ultralytics Explorer API. It checks for the required 'streamlit' package and informs the user that the Explorer
|
||||
dashboard is loading.
|
||||
|
||||
Args:
|
||||
args (List[str]): A list of optional command line arguments.
|
||||
|
|
|
|||
|
|
@ -1005,7 +1005,6 @@ class RandomPerspective:
|
|||
>>> transform = RandomPerspective(degrees=10.0, translate=0.1, scale=0.5, shear=5.0)
|
||||
>>> result = transform(labels) # Apply random perspective to labels
|
||||
"""
|
||||
|
||||
self.degrees = degrees
|
||||
self.translate = translate
|
||||
self.scale = scale
|
||||
|
|
@ -1038,7 +1037,6 @@ class RandomPerspective:
|
|||
>>> border = (10, 10)
|
||||
>>> transformed_img, matrix, scale = affine_transform(img, border)
|
||||
"""
|
||||
|
||||
# Center
|
||||
C = np.eye(3, dtype=np.float32)
|
||||
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ def coco91_to_coco80_class():
|
|||
|
||||
|
||||
def coco80_to_coco91_class():
|
||||
"""
|
||||
r"""
|
||||
Converts 80-index (val2014) to 91-index (paper).
|
||||
For details see https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/.
|
||||
|
||||
|
|
@ -243,7 +243,6 @@ def convert_coco(
|
|||
Output:
|
||||
Generates output files in the specified output directory.
|
||||
"""
|
||||
|
||||
# Create dataset directory
|
||||
save_dir = increment_path(save_dir) # increment if save directory already exists
|
||||
for p in save_dir / "labels", save_dir / "images":
|
||||
|
|
|
|||
|
|
@ -226,6 +226,7 @@ class Explorer:
|
|||
def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image:
|
||||
"""
|
||||
Plot the results of a SQL-Like query on the table.
|
||||
|
||||
Args:
|
||||
query (str): SQL query to run.
|
||||
labels (bool): Whether to plot the labels or not.
|
||||
|
|
@ -457,20 +458,3 @@ class Explorer:
|
|||
LOGGER.error("AI generated query is not valid. Please try again with a different prompt")
|
||||
LOGGER.error(e)
|
||||
return None
|
||||
|
||||
def visualize(self, result):
|
||||
"""
|
||||
Visualize the results of a query. TODO.
|
||||
|
||||
Args:
|
||||
result (pyarrow.Table): Table containing the results of a query.
|
||||
"""
|
||||
pass
|
||||
|
||||
def generate_report(self, result):
|
||||
"""
|
||||
Generate a report of the dataset.
|
||||
|
||||
TODO
|
||||
"""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -240,7 +240,7 @@ class LoadScreenshots:
|
|||
return self
|
||||
|
||||
def __next__(self):
|
||||
"""mss screen capture: get raw pixels from the screen as np array."""
|
||||
"""Screen capture with 'mss' to get raw pixels from the screen as np array."""
|
||||
im0 = np.asarray(self.sct.grab(self.monitor))[:, :, :3] # BGRA to BGR
|
||||
s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
|
||||
|
||||
|
|
|
|||
|
|
@ -19,11 +19,19 @@ from shapely.geometry import Polygon
|
|||
|
||||
def bbox_iof(polygon1, bbox2, eps=1e-6):
|
||||
"""
|
||||
Calculate iofs between bbox1 and bbox2.
|
||||
Calculate Intersection over Foreground (IoF) between polygons and bounding boxes.
|
||||
|
||||
Args:
|
||||
polygon1 (np.ndarray): Polygon coordinates, (n, 8).
|
||||
bbox2 (np.ndarray): Bounding boxes, (n ,4).
|
||||
polygon1 (np.ndarray): Polygon coordinates, shape (n, 8).
|
||||
bbox2 (np.ndarray): Bounding boxes, shape (n, 4).
|
||||
eps (float, optional): Small value to prevent division by zero. Defaults to 1e-6.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): IoF scores, shape (n, 1) or (n, m) if bbox2 is (m, 4).
|
||||
|
||||
Note:
|
||||
Polygon format: [x1, y1, x2, y2, x3, y3, x4, y4].
|
||||
Bounding box format: [x_min, y_min, x_max, y_max].
|
||||
"""
|
||||
polygon1 = polygon1.reshape(-1, 4, 2)
|
||||
lt_point = np.min(polygon1, axis=-2) # left-top
|
||||
|
|
|
|||
|
|
@ -265,7 +265,6 @@ def check_det_dataset(dataset, autodownload=True):
|
|||
Returns:
|
||||
(dict): Parsed dataset information and paths.
|
||||
"""
|
||||
|
||||
file = check_file(dataset)
|
||||
|
||||
# Download (optional)
|
||||
|
|
@ -363,7 +362,6 @@ def check_cls_dataset(dataset, split=""):
|
|||
- 'nc' (int): The number of classes in the dataset.
|
||||
- 'names' (dict): A dictionary of class names in the dataset.
|
||||
"""
|
||||
|
||||
# Download (optional if dataset=https://file.zip is passed directly)
|
||||
if str(dataset).startswith(("http:/", "https:/")):
|
||||
dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False)
|
||||
|
|
@ -602,7 +600,6 @@ def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
|
|||
compress_one_image(f)
|
||||
```
|
||||
"""
|
||||
|
||||
try: # use PIL
|
||||
im = Image.open(f)
|
||||
r = max_dim / max(im.height, im.width) # ratio
|
||||
|
|
@ -635,7 +632,6 @@ def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annot
|
|||
autosplit()
|
||||
```
|
||||
"""
|
||||
|
||||
path = Path(path) # images dir
|
||||
files = sorted(x for x in path.rglob("*.*") if x.suffix[1:].lower() in IMG_FORMATS) # image files only
|
||||
n = len(files) # number of files
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Export a YOLOv8 PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit
|
||||
Export a YOLOv8 PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit.
|
||||
|
||||
Format | `format=argument` | Model
|
||||
--- | --- | ---
|
||||
|
|
@ -533,9 +533,7 @@ class Exporter:
|
|||
|
||||
@try_export
|
||||
def export_ncnn(self, prefix=colorstr("NCNN:")):
|
||||
"""
|
||||
YOLOv8 NCNN export using PNNX https://github.com/pnnx/pnnx.
|
||||
"""
|
||||
"""YOLOv8 NCNN export using PNNX https://github.com/pnnx/pnnx."""
|
||||
check_requirements("ncnn")
|
||||
import ncnn # noqa
|
||||
|
||||
|
|
|
|||
|
|
@ -384,7 +384,7 @@ class BasePredictor:
|
|||
cv2.imwrite(save_path, im)
|
||||
|
||||
def show(self, p=""):
|
||||
"""Display an image in a window using OpenCV imshow()."""
|
||||
"""Display an image in a window using the OpenCV imshow function."""
|
||||
im = self.plotted_img
|
||||
if platform.system() == "Linux" and p not in self.windows:
|
||||
self.windows.append(p)
|
||||
|
|
|
|||
|
|
@ -228,7 +228,6 @@ class BaseTrainer:
|
|||
|
||||
def _setup_train(self, world_size):
|
||||
"""Builds dataloaders and optimizer on correct rank process."""
|
||||
|
||||
# Model
|
||||
self.run_callbacks("on_pretrain_routine_start")
|
||||
ckpt = self.setup_model()
|
||||
|
|
@ -638,7 +637,7 @@ class BaseTrainer:
|
|||
pass
|
||||
|
||||
def on_plot(self, name, data=None):
|
||||
"""Registers plots (e.g. to be consumed in callbacks)"""
|
||||
"""Registers plots (e.g. to be consumed in callbacks)."""
|
||||
path = Path(name)
|
||||
self.plots[path] = {"data": data, "timestamp": time.time()}
|
||||
|
||||
|
|
@ -737,7 +736,6 @@ class BaseTrainer:
|
|||
Returns:
|
||||
(torch.optim.Optimizer): The constructed optimizer.
|
||||
"""
|
||||
|
||||
g = [], [], [] # optimizer parameter groups
|
||||
bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
|
||||
if name == "auto":
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
This module provides functionalities for hyperparameter tuning of the Ultralytics YOLO models for object detection,
|
||||
instance segmentation, image classification, pose estimation, and multi-object tracking.
|
||||
Module provides functionalities for hyperparameter tuning of the Ultralytics YOLO models for object detection, instance
|
||||
segmentation, image classification, pose estimation, and multi-object tracking.
|
||||
|
||||
Hyperparameter tuning is the process of systematically searching for the optimal set of hyperparameters
|
||||
that yield the best model performance. This is particularly crucial in deep learning models like YOLO,
|
||||
|
|
@ -176,7 +176,6 @@ class Tuner:
|
|||
The method utilizes the `self.tune_csv` Path object to read and log hyperparameters and fitness scores.
|
||||
Ensure this path is set correctly in the Tuner instance.
|
||||
"""
|
||||
|
||||
t0 = time.time()
|
||||
best_save_dir, best_metrics = None, None
|
||||
(self.tune_dir / "weights").mkdir(parents=True, exist_ok=True)
|
||||
|
|
|
|||
|
|
@ -104,9 +104,7 @@ class BaseValidator:
|
|||
|
||||
@smart_inference_mode()
|
||||
def __call__(self, trainer=None, model=None):
|
||||
"""Supports validation of a pre-trained model if passed or a model being trained if trainer is passed (trainer
|
||||
gets priority).
|
||||
"""
|
||||
"""Executes validation process, running inference on dataloader and computing performance metrics."""
|
||||
self.training = trainer is not None
|
||||
augment = self.args.augment and (not self.training)
|
||||
if self.training:
|
||||
|
|
@ -280,7 +278,7 @@ class BaseValidator:
|
|||
return batch
|
||||
|
||||
def postprocess(self, preds):
|
||||
"""Describes and summarizes the purpose of 'postprocess()' but no details mentioned."""
|
||||
"""Preprocesses the predictions."""
|
||||
return preds
|
||||
|
||||
def init_metrics(self, model):
|
||||
|
|
@ -317,7 +315,7 @@ class BaseValidator:
|
|||
return []
|
||||
|
||||
def on_plot(self, name, data=None):
|
||||
"""Registers plots (e.g. to be consumed in callbacks)"""
|
||||
"""Registers plots (e.g. to be consumed in callbacks)."""
|
||||
self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
|
||||
|
||||
# TODO: may need to put these following functions into callback
|
||||
|
|
|
|||
|
|
@ -27,10 +27,14 @@ class Auth:
|
|||
|
||||
def __init__(self, api_key="", verbose=False):
|
||||
"""
|
||||
Initialize the Auth class with an optional API key.
|
||||
Initialize Auth class and authenticate user.
|
||||
|
||||
Handles API key validation, Google Colab authentication, and new key requests. Updates SETTINGS upon successful
|
||||
authentication.
|
||||
|
||||
Args:
|
||||
api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
|
||||
api_key (str): API key or combined key_id format.
|
||||
verbose (bool): Enable verbose logging.
|
||||
"""
|
||||
# Split the input API key in case it contains a combined key_model and keep only the API key part
|
||||
api_key = api_key.split("_")[0]
|
||||
|
|
|
|||
|
|
@ -159,7 +159,6 @@ class HUBTrainingSession:
|
|||
Raises:
|
||||
HUBModelError: If the identifier format is not recognized.
|
||||
"""
|
||||
|
||||
# Initialize variables
|
||||
api_key, model_id, filename = None, None, None
|
||||
|
||||
|
|
@ -200,7 +199,6 @@ class HUBTrainingSession:
|
|||
ValueError: If the model is already trained, if required dataset information is missing, or if there are
|
||||
issues with the provided training arguments.
|
||||
"""
|
||||
|
||||
if self.model.is_resumable():
|
||||
# Model has saved weights
|
||||
self.train_args = {"data": self.model.get_dataset_url(), "resume": True}
|
||||
|
|
|
|||
|
|
@ -30,18 +30,21 @@ class FastSAM(Model):
|
|||
|
||||
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, texts=None, **kwargs):
|
||||
"""
|
||||
Performs segmentation prediction on the given image or video source.
|
||||
Perform segmentation prediction on image or video source.
|
||||
|
||||
Supports prompted segmentation with bounding boxes, points, labels, and texts.
|
||||
|
||||
Args:
|
||||
source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
|
||||
stream (bool, optional): If True, enables real-time streaming. Defaults to False.
|
||||
bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
|
||||
points (list, optional): List of points for prompted segmentation. Defaults to None.
|
||||
labels (list, optional): List of labels for prompted segmentation. Defaults to None.
|
||||
texts (list, optional): List of texts for prompted segmentation. Defaults to None.
|
||||
source (str | PIL.Image | numpy.ndarray): Input source.
|
||||
stream (bool): Enable real-time streaming.
|
||||
bboxes (list): Bounding box coordinates for prompted segmentation.
|
||||
points (list): Points for prompted segmentation.
|
||||
labels (list): Labels for prompted segmentation.
|
||||
texts (list): Texts for prompted segmentation.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
(list): The model predictions.
|
||||
(list): Model predictions.
|
||||
"""
|
||||
prompts = dict(bboxes=bboxes, points=points, labels=labels, texts=texts)
|
||||
return super().predict(source, stream, prompts=prompts, **kwargs)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
|
|||
Returns:
|
||||
adjusted_boxes (torch.Tensor): adjusted bounding boxes
|
||||
"""
|
||||
|
||||
# Image dimensions
|
||||
h, w = image_shape
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,6 @@ class NASPredictor(BasePredictor):
|
|||
|
||||
def postprocess(self, preds_in, img, orig_imgs):
|
||||
"""Postprocess predictions and returns a list of Results objects."""
|
||||
|
||||
# Cat boxes and class scores
|
||||
boxes = ops.xyxy2xywh(preds_in[0][0])
|
||||
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
|
||||
|
|
|
|||
|
|
@ -232,7 +232,6 @@ class TwoWayAttentionBlock(nn.Module):
|
|||
|
||||
def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
"""Applies two-way attention to process query and key embeddings in a transformer block."""
|
||||
|
||||
# Self attention block
|
||||
if self.skip_first_layer_pe:
|
||||
queries = self.self_attn(q=queries, k=queries, v=queries)
|
||||
|
|
@ -353,7 +352,6 @@ class Attention(nn.Module):
|
|||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
||||
"""Applies multi-head attention to query, key, and value tensors with optional downsampling."""
|
||||
|
||||
# Input projections
|
||||
q = self.q_proj(q)
|
||||
k = self.k_proj(k)
|
||||
|
|
|
|||
|
|
@ -34,15 +34,19 @@ class DETRLoss(nn.Module):
|
|||
self, nc=80, loss_gain=None, aux_loss=True, use_fl=True, use_vfl=False, use_uni_match=False, uni_match_ind=0
|
||||
):
|
||||
"""
|
||||
DETR loss function.
|
||||
Initialize DETR loss function with customizable components and gains.
|
||||
|
||||
Uses default loss_gain if not provided. Initializes HungarianMatcher with
|
||||
preset cost gains. Supports auxiliary losses and various loss types.
|
||||
|
||||
Args:
|
||||
nc (int): The number of classes.
|
||||
loss_gain (dict): The coefficient of loss.
|
||||
aux_loss (bool): If 'aux_loss = True', loss at each decoder layer are to be used.
|
||||
use_vfl (bool): Use VarifocalLoss or not.
|
||||
use_uni_match (bool): Whether to use a fixed layer to assign labels for auxiliary branch.
|
||||
uni_match_ind (int): The fixed indices of a layer.
|
||||
nc (int): Number of classes.
|
||||
loss_gain (dict): Coefficients for different loss components.
|
||||
aux_loss (bool): Use auxiliary losses from each decoder layer.
|
||||
use_fl (bool): Use FocalLoss.
|
||||
use_vfl (bool): Use VarifocalLoss.
|
||||
use_uni_match (bool): Use fixed layer for auxiliary branch label assignment.
|
||||
uni_match_ind (int): Index of fixed layer for uni_match.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
|
@ -82,9 +86,7 @@ class DETRLoss(nn.Module):
|
|||
return {name_class: loss_cls.squeeze() * self.loss_gain["class"]}
|
||||
|
||||
def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=""):
|
||||
"""Calculates and returns the bounding box loss and GIoU loss for the predicted and ground truth bounding
|
||||
boxes.
|
||||
"""
|
||||
"""Computes bounding box and GIoU losses for predicted and ground truth bounding boxes."""
|
||||
# Boxes: [b, query, 4], gt_bbox: list[[n, 4]]
|
||||
name_bbox = f"loss_bbox{postfix}"
|
||||
name_giou = f"loss_giou{postfix}"
|
||||
|
|
@ -250,14 +252,24 @@ class DETRLoss(nn.Module):
|
|||
|
||||
def forward(self, pred_bboxes, pred_scores, batch, postfix="", **kwargs):
|
||||
"""
|
||||
Calculate loss for predicted bounding boxes and scores.
|
||||
|
||||
Args:
|
||||
pred_bboxes (torch.Tensor): [l, b, query, 4]
|
||||
pred_scores (torch.Tensor): [l, b, query, num_classes]
|
||||
batch (dict): A dict includes:
|
||||
gt_cls (torch.Tensor) with shape [num_gts, ],
|
||||
gt_bboxes (torch.Tensor): [num_gts, 4],
|
||||
gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
|
||||
postfix (str): postfix of loss name.
|
||||
pred_bboxes (torch.Tensor): Predicted bounding boxes, shape [l, b, query, 4].
|
||||
pred_scores (torch.Tensor): Predicted class scores, shape [l, b, query, num_classes].
|
||||
batch (dict): Batch information containing:
|
||||
cls (torch.Tensor): Ground truth classes, shape [num_gts].
|
||||
bboxes (torch.Tensor): Ground truth bounding boxes, shape [num_gts, 4].
|
||||
gt_groups (List[int]): Number of ground truths for each image in the batch.
|
||||
postfix (str): Postfix for loss names.
|
||||
**kwargs (Any): Additional arguments, may include 'match_indices'.
|
||||
|
||||
Returns:
|
||||
(dict): Computed losses, including main and auxiliary (if enabled).
|
||||
|
||||
Note:
|
||||
Uses last elements of pred_bboxes and pred_scores for main loss, and the rest for auxiliary losses if
|
||||
self.aux_loss is True.
|
||||
"""
|
||||
self.device = pred_bboxes.device
|
||||
match_indices = kwargs.get("match_indices", None)
|
||||
|
|
|
|||
|
|
@ -32,9 +32,7 @@ class HungarianMatcher(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
|
||||
"""Initializes HungarianMatcher with cost coefficients, Focal Loss, mask prediction, sample points, and alpha
|
||||
gamma factors.
|
||||
"""
|
||||
"""Initializes a HungarianMatcher module for optimal assignment of predicted and ground truth bounding boxes."""
|
||||
super().__init__()
|
||||
if cost_gain is None:
|
||||
cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1}
|
||||
|
|
@ -70,7 +68,6 @@ class HungarianMatcher(nn.Module):
|
|||
For each batch element, it holds:
|
||||
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
||||
"""
|
||||
|
||||
bs, nq, nc = pred_scores.shape
|
||||
|
||||
if sum(gt_groups) == 0:
|
||||
|
|
@ -175,7 +172,6 @@ def get_cdn_group(
|
|||
bounding boxes, attention mask and meta information for denoising. If not in training mode or 'num_dn'
|
||||
is less than or equal to 0, the function returns None for all elements in the tuple.
|
||||
"""
|
||||
|
||||
if (not training) or num_dn <= 0:
|
||||
return None, None, None, None
|
||||
gt_groups = batch["gt_groups"]
|
||||
|
|
|
|||
|
|
@ -64,10 +64,14 @@ class YOLOWorld(Model):
|
|||
|
||||
def __init__(self, model="yolov8s-world.pt", verbose=False) -> None:
|
||||
"""
|
||||
Initializes the YOLOv8-World model with the given pre-trained model file. Supports *.pt and *.yaml formats.
|
||||
Initialize YOLOv8-World model with a pre-trained model file.
|
||||
|
||||
Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default
|
||||
COCO class names.
|
||||
|
||||
Args:
|
||||
model (str | Path): Path to the pre-trained model. Defaults to 'yolov8s-world.pt'.
|
||||
model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
|
||||
verbose (bool): If True, prints additional information during initialization.
|
||||
"""
|
||||
super().__init__(model=model, task="detect", verbose=verbose)
|
||||
|
||||
|
|
|
|||
|
|
@ -641,8 +641,8 @@ class AutoBackend(nn.Module):
|
|||
@staticmethod
|
||||
def _model_type(p="path/to/model.pt"):
|
||||
"""
|
||||
This function takes a path to a model file and returns the model type. Possibles types are pt, jit, onnx, xml,
|
||||
engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, ncnn or paddle.
|
||||
Takes a path to a model file and returns the model type. Possibles types are pt, jit, onnx, xml, engine, coreml,
|
||||
saved_model, pb, tflite, edgetpu, tfjs, ncnn or paddle.
|
||||
|
||||
Args:
|
||||
p: path to the model file. Defaults to path/to/model.pt
|
||||
|
|
|
|||
|
|
@ -204,9 +204,7 @@ class C2(nn.Module):
|
|||
"""CSP Bottleneck with 2 convolutions."""
|
||||
|
||||
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
|
||||
"""Initializes the CSP Bottleneck with 2 convolutions module with arguments ch_in, ch_out, number, shortcut,
|
||||
groups, expansion.
|
||||
"""
|
||||
"""Initializes a CSP Bottleneck with 2 convolutions and optional shortcut connection."""
|
||||
super().__init__()
|
||||
self.c = int(c2 * e) # hidden channels
|
||||
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
|
||||
|
|
@ -224,9 +222,7 @@ class C2f(nn.Module):
|
|||
"""Faster Implementation of CSP Bottleneck with 2 convolutions."""
|
||||
|
||||
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
|
||||
"""Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,
|
||||
expansion.
|
||||
"""
|
||||
"""Initializes a CSP bottleneck with 2 convolutions and n Bottleneck blocks for faster processing."""
|
||||
super().__init__()
|
||||
self.c = int(c2 * e) # hidden channels
|
||||
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
|
||||
|
|
@ -335,9 +331,7 @@ class Bottleneck(nn.Module):
|
|||
"""Standard bottleneck."""
|
||||
|
||||
def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):
|
||||
"""Initializes a bottleneck module with given input/output channels, shortcut option, group, kernels, and
|
||||
expansion.
|
||||
"""
|
||||
"""Initializes a standard bottleneck module with optional shortcut connection and configurable parameters."""
|
||||
super().__init__()
|
||||
c_ = int(c2 * e) # hidden channels
|
||||
self.cv1 = Conv(c1, c_, k[0], 1)
|
||||
|
|
@ -345,7 +339,7 @@ class Bottleneck(nn.Module):
|
|||
self.add = shortcut and c1 == c2
|
||||
|
||||
def forward(self, x):
|
||||
"""'forward()' applies the YOLO FPN to input data."""
|
||||
"""Applies the YOLO FPN to input data."""
|
||||
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
|
||||
|
||||
|
||||
|
|
@ -449,9 +443,7 @@ class C2fAttn(nn.Module):
|
|||
"""C2f module with an additional attn module."""
|
||||
|
||||
def __init__(self, c1, c2, n=1, ec=128, nh=1, gc=512, shortcut=False, g=1, e=0.5):
|
||||
"""Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,
|
||||
expansion.
|
||||
"""
|
||||
"""Initializes C2f module with attention mechanism for enhanced feature extraction and processing."""
|
||||
super().__init__()
|
||||
self.c = int(c2 * e) # hidden channels
|
||||
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
|
||||
|
|
@ -521,9 +513,7 @@ class ImagePoolingAttn(nn.Module):
|
|||
|
||||
|
||||
class ContrastiveHead(nn.Module):
|
||||
"""Contrastive Head for YOLO-World compute the region-text scores according to the similarity between image and text
|
||||
features.
|
||||
"""
|
||||
"""Implements contrastive learning head for region-text similarity in vision-language models."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes ContrastiveHead with specified region-text similarity parameters."""
|
||||
|
|
@ -569,16 +559,14 @@ class RepBottleneck(Bottleneck):
|
|||
"""Rep bottleneck."""
|
||||
|
||||
def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):
|
||||
"""Initializes a RepBottleneck module with customizable in/out channels, shortcut option, groups and expansion
|
||||
ratio.
|
||||
"""
|
||||
"""Initializes a RepBottleneck module with customizable in/out channels, shortcuts, groups and expansion."""
|
||||
super().__init__(c1, c2, shortcut, g, k, e)
|
||||
c_ = int(c2 * e) # hidden channels
|
||||
self.cv1 = RepConv(c1, c_, k[0], 1)
|
||||
|
||||
|
||||
class RepCSP(C3):
|
||||
"""Rep CSP Bottleneck with 3 convolutions."""
|
||||
"""Repeatable Cross Stage Partial Network (RepCSP) module for efficient feature extraction."""
|
||||
|
||||
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
|
||||
"""Initializes RepCSP layer with given channels, repetitions, shortcut, groups and expansion ratio."""
|
||||
|
|
|
|||
|
|
@ -158,9 +158,7 @@ class GhostConv(nn.Module):
|
|||
"""Ghost Convolution https://github.com/huawei-noah/ghostnet."""
|
||||
|
||||
def __init__(self, c1, c2, k=1, s=1, g=1, act=True):
|
||||
"""Initializes the GhostConv object with input channels, output channels, kernel size, stride, groups and
|
||||
activation.
|
||||
"""
|
||||
"""Initializes Ghost Convolution module with primary and cheap operations for efficient feature learning."""
|
||||
super().__init__()
|
||||
c_ = c2 // 2 # hidden channels
|
||||
self.cv1 = Conv(c1, c_, k, s, None, g, act=act)
|
||||
|
|
|
|||
|
|
@ -266,9 +266,7 @@ class Classify(nn.Module):
|
|||
"""YOLOv8 classification head, i.e. x(b,c1,20,20) to x(b,c2)."""
|
||||
|
||||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1):
|
||||
"""Initializes YOLOv8 classification head with specified input and output channels, kernel size, stride,
|
||||
padding, and groups.
|
||||
"""
|
||||
"""Initializes YOLOv8 classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape."""
|
||||
super().__init__()
|
||||
c_ = 1280 # efficientnet_b0 size
|
||||
self.conv = Conv(c1, c_, k, s, p, g)
|
||||
|
|
@ -571,7 +569,7 @@ class RTDETRDecoder(nn.Module):
|
|||
|
||||
class v10Detect(Detect):
|
||||
"""
|
||||
v10 Detection head from https://arxiv.org/pdf/2405.14458
|
||||
v10 Detection head from https://arxiv.org/pdf/2405.14458.
|
||||
|
||||
Args:
|
||||
nc (int): Number of classes.
|
||||
|
|
|
|||
|
|
@ -352,7 +352,6 @@ class DeformableTransformerDecoderLayer(nn.Module):
|
|||
|
||||
def forward(self, embed, refer_bbox, feats, shapes, padding_mask=None, attn_mask=None, query_pos=None):
|
||||
"""Perform the forward pass through the entire decoder layer."""
|
||||
|
||||
# Self attention
|
||||
q = k = self.with_pos_embed(embed, query_pos)
|
||||
tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1), attn_mask=attn_mask)[
|
||||
|
|
|
|||
|
|
@ -50,7 +50,6 @@ def multi_scale_deformable_attn_pytorch(
|
|||
|
||||
https://github.com/IDEA-Research/detrex/blob/main/detrex/layers/multi_scale_deform_attn.py
|
||||
"""
|
||||
|
||||
bs, _, num_heads, embed_dims = value.shape
|
||||
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
||||
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
|
||||
|
|
|
|||
|
|
@ -89,13 +89,17 @@ class BaseModel(nn.Module):
|
|||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
"""
|
||||
Forward pass of the model on a single scale. Wrapper for `_forward_once` method.
|
||||
Perform forward pass of the model for either training or inference.
|
||||
|
||||
If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor | dict): The input image tensor or a dict including image tensor and gt labels.
|
||||
x (torch.Tensor | dict): Input tensor for inference, or dict with image tensor and labels for training.
|
||||
*args (Any): Variable length argument list.
|
||||
**kwargs (Any): Arbitrary keyword arguments.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The output of the network.
|
||||
(torch.Tensor): Loss if x is a dict (training), or network predictions (inference).
|
||||
"""
|
||||
if isinstance(x, dict): # for cases of training and validating while training.
|
||||
return self.loss(x, *args, **kwargs)
|
||||
|
|
@ -723,7 +727,6 @@ def temporary_modules(modules=None, attributes=None):
|
|||
Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger
|
||||
applications or libraries. Use this function with caution.
|
||||
"""
|
||||
|
||||
if modules is None:
|
||||
modules = {}
|
||||
if attributes is None:
|
||||
|
|
@ -752,9 +755,9 @@ def temporary_modules(modules=None, attributes=None):
|
|||
|
||||
def torch_safe_load(weight):
|
||||
"""
|
||||
This function attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised,
|
||||
it catches the error, logs a warning message, and attempts to install the missing module via the
|
||||
check_requirements() function. After installation, the function again attempts to load the model using torch.load().
|
||||
Attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the
|
||||
error, logs a warning message, and attempts to install the missing module via the check_requirements() function.
|
||||
After installation, the function again attempts to load the model using torch.load().
|
||||
|
||||
Args:
|
||||
weight (str): The file path of the PyTorch model.
|
||||
|
|
@ -813,7 +816,6 @@ def torch_safe_load(weight):
|
|||
|
||||
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
||||
"""Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a."""
|
||||
|
||||
ensemble = Ensemble()
|
||||
for w in weights if isinstance(weights, list) else [weights]:
|
||||
ckpt, w = torch_safe_load(w) # load ckpt
|
||||
|
|
|
|||
|
|
@ -29,7 +29,6 @@ class AIGym:
|
|||
pose_down_angle (float, optional): Angle threshold for the 'down' pose. Defaults to 90.0.
|
||||
pose_type (str, optional): Type of pose to detect ('pullup', 'pushup', 'abworkout'). Defaults to "pullup".
|
||||
"""
|
||||
|
||||
# Image and line thickness
|
||||
self.im0 = None
|
||||
self.tf = line_thickness
|
||||
|
|
@ -65,7 +64,6 @@ class AIGym:
|
|||
im0 (ndarray): Current frame from the video stream.
|
||||
results (list): Pose estimation data.
|
||||
"""
|
||||
|
||||
self.im0 = im0
|
||||
|
||||
if not len(results[0]):
|
||||
|
|
|
|||
|
|
@ -51,7 +51,6 @@ class Analytics:
|
|||
save_img (bool): Whether to save the image.
|
||||
max_points (int): Specifies when to remove the oldest points in a graph for multiple lines.
|
||||
"""
|
||||
|
||||
self.bg_color = bg_color
|
||||
self.fg_color = fg_color
|
||||
self.view_img = view_img
|
||||
|
|
@ -115,7 +114,6 @@ class Analytics:
|
|||
frame_number (int): The current frame number.
|
||||
counts_dict (dict): Dictionary with class names as keys and counts as values.
|
||||
"""
|
||||
|
||||
x_data = np.array([])
|
||||
y_data_dict = {key: np.array([]) for key in counts_dict.keys()}
|
||||
|
||||
|
|
@ -177,7 +175,6 @@ class Analytics:
|
|||
frame_number (int): The current frame number.
|
||||
total_counts (int): The total counts to plot.
|
||||
"""
|
||||
|
||||
# Update line graph data
|
||||
x_data = self.line.get_xdata()
|
||||
y_data = self.line.get_ydata()
|
||||
|
|
@ -230,7 +227,7 @@ class Analytics:
|
|||
"""
|
||||
Write and display the line graph
|
||||
Args:
|
||||
im0 (ndarray): Image for processing
|
||||
im0 (ndarray): Image for processing.
|
||||
"""
|
||||
im0 = cv2.cvtColor(im0[:, :, :3], cv2.COLOR_RGBA2BGR)
|
||||
cv2.imshow(self.title, im0) if self.view_img else None
|
||||
|
|
@ -243,7 +240,6 @@ class Analytics:
|
|||
Args:
|
||||
count_dict (dict): Dictionary containing the count data to plot.
|
||||
"""
|
||||
|
||||
# Update bar graph data
|
||||
self.ax.clear()
|
||||
self.ax.set_facecolor(self.bg_color)
|
||||
|
|
@ -282,7 +278,6 @@ class Analytics:
|
|||
Args:
|
||||
classes_dict (dict): Dictionary containing the class data to plot.
|
||||
"""
|
||||
|
||||
# Update pie chart data
|
||||
labels = list(classes_dict.keys())
|
||||
sizes = list(classes_dict.values())
|
||||
|
|
|
|||
|
|
@ -37,7 +37,6 @@ class Heatmap:
|
|||
shape="circle",
|
||||
):
|
||||
"""Initializes the heatmap class with default values for Visual, Image, track, count and heatmap parameters."""
|
||||
|
||||
# Visual information
|
||||
self.annotator = None
|
||||
self.view_img = view_img
|
||||
|
|
|
|||
|
|
@ -53,7 +53,6 @@ class ObjectCounter:
|
|||
line_dist_thresh (int): Euclidean distance threshold for line counter.
|
||||
cls_txtdisplay_gap (int): Display gap between each class count.
|
||||
"""
|
||||
|
||||
# Mouse events
|
||||
self.is_drawing = False
|
||||
self.selected_point = None
|
||||
|
|
@ -141,7 +140,6 @@ class ObjectCounter:
|
|||
|
||||
def extract_and_process_tracks(self, tracks):
|
||||
"""Extracts and processes tracks for object counting in a video stream."""
|
||||
|
||||
# Annotator Init and region drawing
|
||||
self.annotator = Annotator(self.im0, self.tf, self.names)
|
||||
|
||||
|
|
|
|||
|
|
@ -49,7 +49,6 @@ class QueueManager:
|
|||
region_thickness (int, optional): Thickness of the counting region lines. Defaults to 5.
|
||||
fontsize (float, optional): Font size for the text annotations. Defaults to 0.7.
|
||||
"""
|
||||
|
||||
# Mouse events state
|
||||
self.is_drawing = False
|
||||
self.selected_point = None
|
||||
|
|
@ -88,7 +87,6 @@ class QueueManager:
|
|||
|
||||
def extract_and_process_tracks(self, tracks):
|
||||
"""Extracts and processes tracks for queue management in a video stream."""
|
||||
|
||||
# Initialize annotator and draw the queue region
|
||||
self.annotator = Annotator(self.im0, self.tf, self.names)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""This module defines the base classes and structures for object tracking in YOLO."""
|
||||
"""Module defines the base classes and structures for object tracking in YOLO."""
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
|
|
|
|||
|
|
@ -37,7 +37,6 @@ def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = Tr
|
|||
>>> thresh = 5.0
|
||||
>>> matched_indices, unmatched_a, unmatched_b = linear_assignment(cost_matrix, thresh, use_lap=True)
|
||||
"""
|
||||
|
||||
if cost_matrix.size == 0:
|
||||
return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
|
||||
|
||||
|
|
@ -80,7 +79,6 @@ def iou_distance(atracks: list, btracks: list) -> np.ndarray:
|
|||
>>> btracks = [np.array([5, 5, 15, 15]), np.array([25, 25, 35, 35])]
|
||||
>>> cost_matrix = iou_distance(atracks, btracks)
|
||||
"""
|
||||
|
||||
if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray):
|
||||
atlbrs = atracks
|
||||
btlbrs = btracks
|
||||
|
|
@ -123,7 +121,6 @@ def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -
|
|||
>>> detections = [BaseTrack(...), BaseTrack(...)] # List of detection objects with embedding features
|
||||
>>> cost_matrix = embedding_distance(tracks, detections, metric="cosine")
|
||||
"""
|
||||
|
||||
cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
|
|
@ -152,7 +149,6 @@ def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray:
|
|||
>>> detections = [BaseTrack(score=np.random.rand()) for _ in range(10)]
|
||||
>>> fused_matrix = fuse_score(cost_matrix, detections)
|
||||
"""
|
||||
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
iou_sim = 1 - cost_matrix
|
||||
|
|
|
|||
|
|
@ -135,9 +135,7 @@ class TQDM(tqdm_original):
|
|||
|
||||
|
||||
class SimpleClass:
|
||||
"""Ultralytics SimpleClass is a base class providing helpful string representation, error reporting, and attribute
|
||||
access methods for easier debugging and usage.
|
||||
"""
|
||||
"""A base class providing string representation and attribute access functionality for Ultralytics objects."""
|
||||
|
||||
def __str__(self):
|
||||
"""Return a human-readable string representation of the object."""
|
||||
|
|
@ -164,9 +162,7 @@ class SimpleClass:
|
|||
|
||||
|
||||
class IterableSimpleNamespace(SimpleNamespace):
|
||||
"""Ultralytics IterableSimpleNamespace is an extension class of SimpleNamespace that adds iterable functionality and
|
||||
enables usage with dict() and for loops.
|
||||
"""
|
||||
"""Iterable SimpleNamespace subclass for key-value attribute iteration and custom error handling."""
|
||||
|
||||
def __iter__(self):
|
||||
"""Return an iterator of key-value pairs from the namespace's attributes."""
|
||||
|
|
@ -209,7 +205,6 @@ def plt_settings(rcparams=None, backend="Agg"):
|
|||
(Callable): Decorated function with temporarily set rc parameters and backend. This decorator can be
|
||||
applied to any function that needs to have specific matplotlib rc parameters and backend for its execution.
|
||||
"""
|
||||
|
||||
if rcparams is None:
|
||||
rcparams = {"font.size": 11}
|
||||
|
||||
|
|
@ -240,9 +235,7 @@ def plt_settings(rcparams=None, backend="Agg"):
|
|||
|
||||
|
||||
def set_logging(name="LOGGING_NAME", verbose=True):
|
||||
"""Sets up logging for the given name with UTF-8 encoding support, ensuring compatibility across different
|
||||
environments.
|
||||
"""
|
||||
"""Sets up logging with UTF-8 encoding and configurable verbosity for Ultralytics YOLO."""
|
||||
level = logging.INFO if verbose and RANK in {-1, 0} else logging.ERROR # rank in world for Multi-GPU trainings
|
||||
|
||||
# Configure the console (stdout) encoding to UTF-8, with checks for compatibility
|
||||
|
|
@ -702,7 +695,7 @@ SETTINGS_YAML = USER_CONFIG_DIR / "settings.yaml"
|
|||
|
||||
|
||||
def colorstr(*input):
|
||||
"""
|
||||
r"""
|
||||
Colors a string based on the provided color and style arguments. Utilizes ANSI escape codes.
|
||||
See https://en.wikipedia.org/wiki/ANSI_escape_code for more details.
|
||||
|
||||
|
|
@ -946,9 +939,7 @@ class SettingsManager(dict):
|
|||
"""
|
||||
|
||||
def __init__(self, file=SETTINGS_YAML, version="0.0.4"):
|
||||
"""Initialize the SettingsManager with default settings, load and validate current settings from the YAML
|
||||
file.
|
||||
"""
|
||||
"""Initializes the SettingsManager with default settings and loads user settings."""
|
||||
import copy
|
||||
import hashlib
|
||||
|
||||
|
|
|
|||
|
|
@ -16,13 +16,17 @@ def check_train_batch_size(model, imgsz=640, amp=True, batch=-1):
|
|||
|
||||
Args:
|
||||
model (torch.nn.Module): YOLO model to check batch size for.
|
||||
imgsz (int): Image size used for training.
|
||||
amp (bool): If True, use automatic mixed precision (AMP) for training.
|
||||
imgsz (int, optional): Image size used for training.
|
||||
amp (bool, optional): Use automatic mixed precision if True.
|
||||
batch (float, optional): Fraction of GPU memory to use. If -1, use default.
|
||||
|
||||
Returns:
|
||||
(int): Optimal batch size computed using the autobatch() function.
|
||||
"""
|
||||
|
||||
Note:
|
||||
If 0.0 < batch < 1.0, it's used as the fraction of GPU memory to use.
|
||||
Otherwise, a default fraction of 0.6 is used.
|
||||
"""
|
||||
with autocast(enabled=amp):
|
||||
return autobatch(deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6)
|
||||
|
||||
|
|
@ -40,7 +44,6 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
|
|||
Returns:
|
||||
(int): The optimal batch size.
|
||||
"""
|
||||
|
||||
# Check device
|
||||
prefix = colorstr("AutoBatch: ")
|
||||
LOGGER.info(f"{prefix}Computing optimal batch size for imgsz={imgsz} at {fraction * 100}% CUDA memory utilization.")
|
||||
|
|
|
|||
|
|
@ -182,7 +182,6 @@ class RF100Benchmark:
|
|||
Args:
|
||||
api_key (str): The API key.
|
||||
"""
|
||||
|
||||
check_requirements("roboflow")
|
||||
from roboflow import Roboflow
|
||||
|
||||
|
|
@ -195,7 +194,6 @@ class RF100Benchmark:
|
|||
Args:
|
||||
ds_link_txt (str): Path to dataset_links file.
|
||||
"""
|
||||
|
||||
(shutil.rmtree("rf-100"), os.mkdir("rf-100")) if os.path.exists("rf-100") else os.mkdir("rf-100")
|
||||
os.chdir("rf-100")
|
||||
os.mkdir("ultralytics-benchmarks")
|
||||
|
|
@ -225,7 +223,6 @@ class RF100Benchmark:
|
|||
Args:
|
||||
path (str): YAML file path.
|
||||
"""
|
||||
|
||||
with open(path, "r") as file:
|
||||
yaml_data = yaml.safe_load(file)
|
||||
yaml_data["train"] = "train/images"
|
||||
|
|
@ -393,9 +390,7 @@ class ProfileModels:
|
|||
return [Path(file) for file in sorted(files)]
|
||||
|
||||
def get_onnx_model_info(self, onnx_file: str):
|
||||
"""Retrieves the information including number of layers, parameters, gradients and FLOPs for an ONNX model
|
||||
file.
|
||||
"""
|
||||
"""Extracts metadata from an ONNX model file including parameters, GFLOPs, and input shape."""
|
||||
return 0.0, 0.0, 0.0, 0.0 # return (num_layers, num_params, num_gradients, num_flops)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -440,9 +435,7 @@ class ProfileModels:
|
|||
return np.mean(run_times), np.std(run_times)
|
||||
|
||||
def profile_onnx_model(self, onnx_file: str, eps: float = 1e-3):
|
||||
"""Profiles an ONNX model by executing it multiple times and returns the mean and standard deviation of run
|
||||
times.
|
||||
"""
|
||||
"""Profiles an ONNX model, measuring average inference time and standard deviation across multiple runs."""
|
||||
check_requirements("onnxruntime")
|
||||
import onnxruntime as ort
|
||||
|
||||
|
|
|
|||
|
|
@ -192,7 +192,6 @@ def add_integration_callbacks(instance):
|
|||
instance (Trainer, Predictor, Validator, Exporter): An object with a 'callbacks' attribute that is a dictionary
|
||||
of callback lists.
|
||||
"""
|
||||
|
||||
# Load HUB callbacks
|
||||
from .hub import callbacks as hub_cb
|
||||
|
||||
|
|
|
|||
|
|
@ -114,7 +114,6 @@ def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, origin
|
|||
|
||||
This function rescales the bounding box labels to the original image shape.
|
||||
"""
|
||||
|
||||
resized_image_height, resized_image_width = resized_image_shape
|
||||
|
||||
# Convert normalized xywh format predictions to xyxy in resized scale format
|
||||
|
|
|
|||
|
|
@ -34,7 +34,6 @@ def _log_scalars(scalars, step=0):
|
|||
|
||||
def _log_tensorboard_graph(trainer):
|
||||
"""Log model graph to TensorBoard."""
|
||||
|
||||
# Input image
|
||||
imgsz = trainer.args.imgsz
|
||||
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
|
||||
|
|
|
|||
|
|
@ -65,7 +65,6 @@ def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
|
|||
parse_requirements(package="ultralytics")
|
||||
```
|
||||
"""
|
||||
|
||||
if package:
|
||||
requires = [x for x in metadata.distribution(package).requires if "extra == " not in x]
|
||||
else:
|
||||
|
|
@ -257,7 +256,7 @@ def check_latest_pypi_version(package_name="ultralytics"):
|
|||
"""
|
||||
Returns the latest version of a PyPI package without downloading or installing it.
|
||||
|
||||
Parameters:
|
||||
Args:
|
||||
package_name (str): The name of the package to find the latest version for.
|
||||
|
||||
Returns:
|
||||
|
|
@ -362,7 +361,6 @@ def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=()
|
|||
check_requirements(["numpy", "ultralytics>=8.0.0"])
|
||||
```
|
||||
"""
|
||||
|
||||
prefix = colorstr("red", "bold", "requirements:")
|
||||
check_python() # check python version
|
||||
check_torchvision() # check torch-torchvision compatibility
|
||||
|
|
@ -422,7 +420,6 @@ def check_torchvision():
|
|||
The compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
|
||||
Torchvision versions.
|
||||
"""
|
||||
|
||||
# Compatibility table
|
||||
compatibility_table = {
|
||||
"2.3": ["0.18"],
|
||||
|
|
@ -622,9 +619,9 @@ def collect_system_info():
|
|||
|
||||
def check_amp(model):
|
||||
"""
|
||||
This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model. If the checks
|
||||
fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will
|
||||
be disabled during training.
|
||||
Checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model. If the checks fail, it means
|
||||
there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will be disabled
|
||||
during training.
|
||||
|
||||
Args:
|
||||
model (nn.Module): A YOLOv8 model instance.
|
||||
|
|
|
|||
|
|
@ -395,7 +395,6 @@ def get_github_assets(repo="ultralytics/assets", version="latest", retry=False):
|
|||
tag, assets = get_github_assets(repo="ultralytics/assets", version="latest")
|
||||
```
|
||||
"""
|
||||
|
||||
if version != "latest":
|
||||
version = f"tags/{version}" # i.e. tags/v6.2
|
||||
url = f"https://api.github.com/repos/{repo}/releases/{version}"
|
||||
|
|
|
|||
|
|
@ -71,7 +71,6 @@ def spaces_in_path(path):
|
|||
>>> with spaces_in_path('/path/with spaces') as new_path:
|
||||
>>> # Your code here
|
||||
"""
|
||||
|
||||
# If path has spaces, replace them with underscores
|
||||
if " " in str(path):
|
||||
string = isinstance(path, str) # input type
|
||||
|
|
|
|||
|
|
@ -96,8 +96,11 @@ class Bboxes:
|
|||
|
||||
def mul(self, scale):
|
||||
"""
|
||||
Multiply bounding box coordinates by scale factor(s).
|
||||
|
||||
Args:
|
||||
scale (tuple | list | int): the scale for four coords.
|
||||
scale (int | tuple | list): Scale factor(s) for four coordinates.
|
||||
If int, the same scale is applied to all coordinates.
|
||||
"""
|
||||
if isinstance(scale, Number):
|
||||
scale = to_4tuple(scale)
|
||||
|
|
@ -110,8 +113,11 @@ class Bboxes:
|
|||
|
||||
def add(self, offset):
|
||||
"""
|
||||
Add offset to bounding box coordinates.
|
||||
|
||||
Args:
|
||||
offset (tuple | list | int): the offset for four coords.
|
||||
offset (int | tuple | list): Offset(s) for four coordinates.
|
||||
If int, the same offset is applied to all coordinates.
|
||||
"""
|
||||
if isinstance(offset, Number):
|
||||
offset = to_4tuple(offset)
|
||||
|
|
@ -210,10 +216,14 @@ class Instances:
|
|||
|
||||
def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:
|
||||
"""
|
||||
Initialize the object with bounding boxes, segments, and keypoints.
|
||||
|
||||
Args:
|
||||
bboxes (ndarray): bboxes with shape [N, 4].
|
||||
segments (list | ndarray): segments.
|
||||
keypoints (ndarray): keypoints(x, y, visible) with shape [N, 17, 3].
|
||||
bboxes (np.ndarray): Bounding boxes, shape [N, 4].
|
||||
segments (list | np.ndarray, optional): Segmentation masks. Defaults to None.
|
||||
keypoints (np.ndarray, optional): Keypoints, shape [N, 17, 3] and format (x, y, visible). Defaults to None.
|
||||
bbox_format (str, optional): Format of bboxes. Defaults to "xywh".
|
||||
normalized (bool, optional): Whether the coordinates are normalized. Defaults to True.
|
||||
"""
|
||||
self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format)
|
||||
self.keypoints = keypoints
|
||||
|
|
@ -230,7 +240,7 @@ class Instances:
|
|||
return self._bboxes.areas()
|
||||
|
||||
def scale(self, scale_w, scale_h, bbox_only=False):
|
||||
"""This might be similar with denormalize func but without normalized sign."""
|
||||
"""Similar to denormalize func but without normalized sign."""
|
||||
self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h))
|
||||
if bbox_only:
|
||||
return
|
||||
|
|
|
|||
|
|
@ -30,7 +30,6 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7):
|
|||
Returns:
|
||||
(np.ndarray): A numpy array of shape (n, m) representing the intersection over box2 area.
|
||||
"""
|
||||
|
||||
# Get the coordinates of bounding boxes
|
||||
b1_x1, b1_y1, b1_x2, b1_y2 = box1.T
|
||||
b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
|
||||
|
|
@ -53,7 +52,7 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7):
|
|||
def box_iou(box1, box2, eps=1e-7):
|
||||
"""
|
||||
Calculate intersection-over-union (IoU) of boxes. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
||||
Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
|
||||
Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py.
|
||||
|
||||
Args:
|
||||
box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes.
|
||||
|
|
@ -63,7 +62,6 @@ def box_iou(box1, box2, eps=1e-7):
|
|||
Returns:
|
||||
(torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.
|
||||
"""
|
||||
|
||||
# NOTE: Need .float() to get accurate iou values
|
||||
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
|
||||
(a1, a2), (b1, b2) = box1.float().unsqueeze(1).chunk(2, 2), box2.float().unsqueeze(0).chunk(2, 2)
|
||||
|
|
@ -90,7 +88,6 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
|
|||
Returns:
|
||||
(torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags.
|
||||
"""
|
||||
|
||||
# Get the coordinates of bounding boxes
|
||||
if xywh: # transform from xywh to xyxy
|
||||
(x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
|
||||
|
|
@ -195,15 +192,22 @@ def _get_covariance_matrix(boxes):
|
|||
|
||||
def probiou(obb1, obb2, CIoU=False, eps=1e-7):
|
||||
"""
|
||||
Calculate the prob IoU between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf.
|
||||
Calculate probabilistic IoU between oriented bounding boxes.
|
||||
|
||||
Implements the algorithm from https://arxiv.org/pdf/2106.06072v1.pdf.
|
||||
|
||||
Args:
|
||||
obb1 (torch.Tensor): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
|
||||
obb2 (torch.Tensor): A tensor of shape (N, 5) representing predicted obbs, with xywhr format.
|
||||
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
|
||||
obb1 (torch.Tensor): Ground truth OBBs, shape (N, 5), format xywhr.
|
||||
obb2 (torch.Tensor): Predicted OBBs, shape (N, 5), format xywhr.
|
||||
CIoU (bool, optional): If True, calculate CIoU. Defaults to False.
|
||||
eps (float, optional): Small value to avoid division by zero. Defaults to 1e-7.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): A tensor of shape (N, ) representing obb similarities.
|
||||
(torch.Tensor): OBB similarities, shape (N,).
|
||||
|
||||
Note:
|
||||
OBB format: [center_x, center_y, width, height, rotation_angle].
|
||||
If CIoU is True, returns CIoU instead of IoU.
|
||||
"""
|
||||
x1, y1 = obb1[..., :2].split(1, dim=-1)
|
||||
x2, y2 = obb2[..., :2].split(1, dim=-1)
|
||||
|
|
@ -507,7 +511,6 @@ def compute_ap(recall, precision):
|
|||
(np.ndarray): Precision envelope curve.
|
||||
(np.ndarray): Modified recall curve with sentinel values added at the beginning and end.
|
||||
"""
|
||||
|
||||
# Append sentinel values to beginning and end
|
||||
mrec = np.concatenate(([0.0], recall, [1.0]))
|
||||
mpre = np.concatenate(([1.0], precision, [0.0]))
|
||||
|
|
@ -560,7 +563,6 @@ def ap_per_class(
|
|||
x (np.ndarray): X-axis values for the curves. Shape: (1000,).
|
||||
prec_values: Precision values at mAP@0.5 for each class. Shape: (nc, 1000).
|
||||
"""
|
||||
|
||||
# Sort by objectness
|
||||
i = np.argsort(-conf)
|
||||
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
|
||||
|
|
@ -792,8 +794,8 @@ class Metric(SimpleClass):
|
|||
|
||||
class DetMetrics(SimpleClass):
|
||||
"""
|
||||
This class is a utility class for computing detection metrics such as precision, recall, and mean average precision
|
||||
(mAP) of an object detection model.
|
||||
Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP) of an
|
||||
object detection model.
|
||||
|
||||
Args:
|
||||
save_dir (Path): A path to the directory where the output plots will be saved. Defaults to current directory.
|
||||
|
|
@ -942,7 +944,6 @@ class SegmentMetrics(SimpleClass):
|
|||
pred_cls (list): List of predicted classes.
|
||||
target_cls (list): List of target classes.
|
||||
"""
|
||||
|
||||
results_mask = ap_per_class(
|
||||
tp_m,
|
||||
conf,
|
||||
|
|
@ -1084,7 +1085,6 @@ class PoseMetrics(SegmentMetrics):
|
|||
pred_cls (list): List of predicted classes.
|
||||
target_cls (list): List of target classes.
|
||||
"""
|
||||
|
||||
results_pose = ap_per_class(
|
||||
tp_p,
|
||||
conf,
|
||||
|
|
|
|||
|
|
@ -141,14 +141,15 @@ def make_divisible(x, divisor):
|
|||
|
||||
def nms_rotated(boxes, scores, threshold=0.45):
|
||||
"""
|
||||
NMS for obbs, powered by probiou and fast-nms.
|
||||
NMS for oriented bounding boxes using probiou and fast-nms.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor): (N, 5), xywhr.
|
||||
scores (torch.Tensor): (N, ).
|
||||
threshold (float): IoU threshold.
|
||||
boxes (torch.Tensor): Rotated bounding boxes, shape (N, 5), format xywhr.
|
||||
scores (torch.Tensor): Confidence scores, shape (N,).
|
||||
threshold (float, optional): IoU threshold. Defaults to 0.45.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Indices of boxes to keep after NMS.
|
||||
"""
|
||||
if len(boxes) == 0:
|
||||
return np.empty((0,), dtype=np.int8)
|
||||
|
|
@ -597,7 +598,7 @@ def ltwh2xyxy(x):
|
|||
|
||||
def segments2boxes(segments):
|
||||
"""
|
||||
It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
|
||||
It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh).
|
||||
|
||||
Args:
|
||||
segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates
|
||||
|
|
@ -667,7 +668,6 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False):
|
|||
(torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w
|
||||
are the height and width of the input image. The mask is applied to the bounding boxes.
|
||||
"""
|
||||
|
||||
c, mh, mw = protos.shape # CHW
|
||||
ih, iw = shape
|
||||
masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) # CHW
|
||||
|
|
@ -785,7 +785,7 @@ def regularize_rboxes(rboxes):
|
|||
|
||||
def masks2segments(masks, strategy="largest"):
|
||||
"""
|
||||
It takes a list of masks(n,h,w) and returns a list of segments(n,xy)
|
||||
It takes a list of masks(n,h,w) and returns a list of segments(n,xy).
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160)
|
||||
|
|
@ -823,7 +823,7 @@ def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray:
|
|||
|
||||
def clean_str(s):
|
||||
"""
|
||||
Cleans a string by replacing special characters with underscore _
|
||||
Cleans a string by replacing special characters with '_' character.
|
||||
|
||||
Args:
|
||||
s (str): a string needing special characters replaced
|
||||
|
|
|
|||
|
|
@ -204,7 +204,6 @@ class Annotator:
|
|||
txt_color (tuple, optional): The color of the text (R, G, B).
|
||||
margin (int, optional): The margin between the text and the rectangle border.
|
||||
"""
|
||||
|
||||
# If label have more than 3 characters, skip other characters, due to circle size
|
||||
if len(label) > 3:
|
||||
print(
|
||||
|
|
@ -246,7 +245,6 @@ class Annotator:
|
|||
txt_color (tuple, optional): The color of the text (R, G, B).
|
||||
margin (int, optional): The margin between the text and the rectangle border.
|
||||
"""
|
||||
|
||||
# Calculate the center of the bounding box
|
||||
x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
|
||||
# Get the size of the text
|
||||
|
|
@ -284,7 +282,6 @@ class Annotator:
|
|||
txt_color (tuple, optional): The color of the text (R, G, B).
|
||||
rotated (bool, optional): Variable used to check if task is OBB
|
||||
"""
|
||||
|
||||
txt_color = self.get_txt_color(color, txt_color)
|
||||
if isinstance(box, torch.Tensor):
|
||||
box = box.tolist()
|
||||
|
|
@ -343,7 +340,6 @@ class Annotator:
|
|||
alpha (float): Mask transparency: 0.0 fully transparent, 1.0 opaque
|
||||
retina_masks (bool): Whether to use high resolution masks or not. Defaults to False.
|
||||
"""
|
||||
|
||||
if self.pil:
|
||||
# Convert to numpy first
|
||||
self.im = np.asarray(self.im).copy()
|
||||
|
|
@ -374,17 +370,18 @@ class Annotator:
|
|||
Plot keypoints on the image.
|
||||
|
||||
Args:
|
||||
kpts (tensor): Predicted keypoints with shape [17, 3]. Each keypoint has (x, y, confidence).
|
||||
shape (tuple): Image shape as a tuple (h, w), where h is the height and w is the width.
|
||||
radius (int, optional): Radius of the drawn keypoints. Default is 5.
|
||||
kpt_line (bool, optional): If True, the function will draw lines connecting keypoints
|
||||
for human pose. Default is True.
|
||||
kpt_color (tuple, optional): The color of the keypoints (B, G, R).
|
||||
kpts (torch.Tensor): Keypoints, shape [17, 3] (x, y, confidence).
|
||||
shape (tuple, optional): Image shape (h, w). Defaults to (640, 640).
|
||||
radius (int, optional): Keypoint radius. Defaults to 5.
|
||||
kpt_line (bool, optional): Draw lines between keypoints. Defaults to True.
|
||||
conf_thres (float, optional): Confidence threshold. Defaults to 0.25.
|
||||
kpt_color (tuple, optional): Keypoint color (B, G, R). Defaults to None.
|
||||
|
||||
Note:
|
||||
`kpt_line=True` currently only supports human pose plotting.
|
||||
- `kpt_line=True` currently only supports human pose plotting.
|
||||
- Modifies self.im in-place.
|
||||
- If self.pil is True, converts image to numpy array and back to PIL.
|
||||
"""
|
||||
|
||||
if self.pil:
|
||||
# Convert to numpy first
|
||||
self.im = np.asarray(self.im).copy()
|
||||
|
|
@ -488,7 +485,6 @@ class Annotator:
|
|||
Returns:
|
||||
angle (degree): Degree value of angle between three points
|
||||
"""
|
||||
|
||||
x_min, y_min, x_max, y_max = bbox
|
||||
width = x_max - x_min
|
||||
height = y_max - y_min
|
||||
|
|
@ -503,7 +499,6 @@ class Annotator:
|
|||
color (tuple): Region Color value
|
||||
thickness (int): Region area thickness value
|
||||
"""
|
||||
|
||||
cv2.polylines(self.im, [np.array(reg_pts, dtype=np.int32)], isClosed=True, color=color, thickness=thickness)
|
||||
|
||||
def draw_centroid_and_tracks(self, track, color=(255, 0, 255), track_thickness=2):
|
||||
|
|
@ -515,7 +510,6 @@ class Annotator:
|
|||
color (tuple): tracks line color
|
||||
track_thickness (int): track line thickness value
|
||||
"""
|
||||
|
||||
points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
|
||||
cv2.polylines(self.im, [points], isClosed=False, color=color, thickness=track_thickness)
|
||||
cv2.circle(self.im, (int(track[-1][0]), int(track[-1][1])), track_thickness * 2, color, -1)
|
||||
|
|
@ -530,7 +524,6 @@ class Annotator:
|
|||
region_color (RGB): queue region color
|
||||
txt_color (RGB): text display color
|
||||
"""
|
||||
|
||||
x_values = [point[0] for point in points]
|
||||
y_values = [point[1] for point in points]
|
||||
center_x = sum(x_values) // len(points)
|
||||
|
|
@ -574,7 +567,6 @@ class Annotator:
|
|||
y_center (float): y position center point for bounding box
|
||||
margin (int): gap between text and rectangle for better display
|
||||
"""
|
||||
|
||||
text_size = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0]
|
||||
text_x = x_center - text_size[0] // 2
|
||||
text_y = y_center + text_size[1] // 2
|
||||
|
|
@ -597,7 +589,6 @@ class Annotator:
|
|||
bg_color (bgr color): display color for text background
|
||||
margin (int): gap between text and rectangle for better display
|
||||
"""
|
||||
|
||||
horizontal_gap = int(im0.shape[1] * 0.02)
|
||||
vertical_gap = int(im0.shape[0] * 0.01)
|
||||
text_y_offset = 0
|
||||
|
|
@ -629,7 +620,6 @@ class Annotator:
|
|||
Returns:
|
||||
angle (degree): Degree value of angle between three points
|
||||
"""
|
||||
|
||||
a, b, c = np.array(a), np.array(b), np.array(c)
|
||||
radians = np.arctan2(c[1] - b[1], c[0] - b[0]) - np.arctan2(a[1] - b[1], a[0] - b[0])
|
||||
angle = np.abs(radians * 180.0 / np.pi)
|
||||
|
|
@ -642,12 +632,19 @@ class Annotator:
|
|||
Draw specific keypoints for gym steps counting.
|
||||
|
||||
Args:
|
||||
keypoints (list): list of keypoints data to be plotted
|
||||
indices (list): keypoints ids list to be plotted
|
||||
shape (tuple): imgsz for model inference
|
||||
radius (int): Keypoint radius value
|
||||
"""
|
||||
keypoints (list): Keypoints data to be plotted.
|
||||
indices (list, optional): Keypoint indices to be plotted. Defaults to [2, 5, 7].
|
||||
shape (tuple, optional): Image size for model inference. Defaults to (640, 640).
|
||||
radius (int, optional): Keypoint radius. Defaults to 2.
|
||||
conf_thres (float, optional): Confidence threshold for keypoints. Defaults to 0.25.
|
||||
|
||||
Returns:
|
||||
(numpy.ndarray): Image with drawn keypoints.
|
||||
|
||||
Note:
|
||||
Keypoint format: [x, y] or [x, y, confidence].
|
||||
Modifies self.im in-place.
|
||||
"""
|
||||
if indices is None:
|
||||
indices = [2, 5, 7]
|
||||
for i, k in enumerate(keypoints):
|
||||
|
|
@ -675,7 +672,6 @@ class Annotator:
|
|||
color (tuple): text background color for workout monitoring
|
||||
txt_color (tuple): text foreground color for workout monitoring
|
||||
"""
|
||||
|
||||
angle_text, count_text, stage_text = (f" {angle_text:.2f}", f"Steps : {count_text}", f" {stage_text}")
|
||||
|
||||
# Draw angle
|
||||
|
|
@ -744,7 +740,6 @@ class Annotator:
|
|||
label (str): Detection label text
|
||||
txt_color (RGB): text color
|
||||
"""
|
||||
|
||||
cv2.polylines(self.im, [np.int32([mask])], isClosed=True, color=mask_color, thickness=2)
|
||||
text_size, _ = cv2.getTextSize(label, 0, self.sf, self.tf)
|
||||
|
||||
|
|
@ -772,7 +767,6 @@ class Annotator:
|
|||
line_color (RGB): Distance line color.
|
||||
centroid_color (RGB): Bounding box centroid color.
|
||||
"""
|
||||
|
||||
(text_width_m, text_height_m), _ = cv2.getTextSize(f"Distance M: {distance_m:.2f}m", 0, self.sf, self.tf)
|
||||
cv2.rectangle(self.im, (15, 25), (15 + text_width_m + 10, 25 + text_height_m + 20), line_color, -1)
|
||||
cv2.putText(
|
||||
|
|
@ -813,7 +807,6 @@ class Annotator:
|
|||
color (tuple): object centroid and line color value
|
||||
pin_color (tuple): visioneye point color value
|
||||
"""
|
||||
|
||||
center_bbox = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
|
||||
cv2.circle(self.im, center_point, self.tf * 2, pin_color, -1)
|
||||
cv2.circle(self.im, center_bbox, self.tf * 2, color, -1)
|
||||
|
|
@ -906,7 +899,6 @@ def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False,
|
|||
cropped_im = save_one_box(xyxy, im, file="cropped.jpg", square=True)
|
||||
```
|
||||
"""
|
||||
|
||||
if not isinstance(xyxy, torch.Tensor): # may be list
|
||||
xyxy = torch.stack(xyxy)
|
||||
b = ops.xyxy2xywh(xyxy.view(-1, 4)) # boxes
|
||||
|
|
@ -1171,7 +1163,6 @@ def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none
|
|||
>>> f = np.random.rand(100)
|
||||
>>> plt_color_scatter(v, f)
|
||||
"""
|
||||
|
||||
# Calculate 2D histogram and corresponding colors
|
||||
hist, xedges, yedges = np.histogram2d(v, f, bins=bins)
|
||||
colors = [
|
||||
|
|
@ -1197,7 +1188,6 @@ def plot_tune_results(csv_file="tune_results.csv"):
|
|||
Examples:
|
||||
>>> plot_tune_results("path/to/tune_results.csv")
|
||||
"""
|
||||
|
||||
import pandas as pd # scope for faster 'import ultralytics'
|
||||
from scipy.ndimage import gaussian_filter1d
|
||||
|
||||
|
|
|
|||
|
|
@ -140,7 +140,6 @@ class TaskAlignedAssigner(nn.Module):
|
|||
Returns:
|
||||
(Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
|
||||
"""
|
||||
|
||||
# (b, max_num_obj, topk)
|
||||
topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
|
||||
if topk_mask is None:
|
||||
|
|
@ -184,7 +183,6 @@ class TaskAlignedAssigner(nn.Module):
|
|||
for positive anchor points, where num_classes is the number
|
||||
of object classes.
|
||||
"""
|
||||
|
||||
# Assigned target labels, (b, 1)
|
||||
batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
|
||||
target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w)
|
||||
|
|
@ -212,14 +210,19 @@ class TaskAlignedAssigner(nn.Module):
|
|||
@staticmethod
|
||||
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
|
||||
"""
|
||||
Select the positive anchor center in gt.
|
||||
Select positive anchor centers within ground truth bounding boxes.
|
||||
|
||||
Args:
|
||||
xy_centers (Tensor): shape(h*w, 2)
|
||||
gt_bboxes (Tensor): shape(b, n_boxes, 4)
|
||||
xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).
|
||||
gt_bboxes (torch.Tensor): Ground truth bounding boxes, shape (b, n_boxes, 4).
|
||||
eps (float, optional): Small value for numerical stability. Defaults to 1e-9.
|
||||
|
||||
Returns:
|
||||
(Tensor): shape(b, n_boxes, h*w)
|
||||
(torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).
|
||||
|
||||
Note:
|
||||
b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
|
||||
Bounding box format: [x_min, y_min, x_max, y_max].
|
||||
"""
|
||||
n_anchors = xy_centers.shape[0]
|
||||
bs, n_boxes, _ = gt_bboxes.shape
|
||||
|
|
@ -231,18 +234,22 @@ class TaskAlignedAssigner(nn.Module):
|
|||
@staticmethod
|
||||
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
|
||||
"""
|
||||
If an anchor box is assigned to multiple gts, the one with the highest IoU will be selected.
|
||||
Select anchor boxes with highest IoU when assigned to multiple ground truths.
|
||||
|
||||
Args:
|
||||
mask_pos (Tensor): shape(b, n_max_boxes, h*w)
|
||||
overlaps (Tensor): shape(b, n_max_boxes, h*w)
|
||||
mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w).
|
||||
overlaps (torch.Tensor): IoU overlaps, shape (b, n_max_boxes, h*w).
|
||||
n_max_boxes (int): Maximum number of ground truth boxes.
|
||||
|
||||
Returns:
|
||||
target_gt_idx (Tensor): shape(b, h*w)
|
||||
fg_mask (Tensor): shape(b, h*w)
|
||||
mask_pos (Tensor): shape(b, n_max_boxes, h*w)
|
||||
target_gt_idx (torch.Tensor): Indices of assigned ground truths, shape (b, h*w).
|
||||
fg_mask (torch.Tensor): Foreground mask, shape (b, h*w).
|
||||
mask_pos (torch.Tensor): Updated positive mask, shape (b, n_max_boxes, h*w).
|
||||
|
||||
Note:
|
||||
b: batch size, h: height, w: width.
|
||||
"""
|
||||
# (b, n_max_boxes, h*w) -> (b, h*w)
|
||||
# Convert (b, n_max_boxes, h*w) -> (b, h*w)
|
||||
fg_mask = mask_pos.sum(-2)
|
||||
if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
|
||||
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)
|
||||
|
|
@ -328,14 +335,16 @@ def bbox2dist(anchor_points, bbox, reg_max):
|
|||
|
||||
def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
|
||||
"""
|
||||
Decode predicted object bounding box coordinates from anchor points and distribution.
|
||||
Decode predicted rotated bounding box coordinates from anchor points and distribution.
|
||||
|
||||
Args:
|
||||
pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).
|
||||
pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).
|
||||
anchor_points (torch.Tensor): Anchor points, (h*w, 2).
|
||||
pred_dist (torch.Tensor): Predicted rotated distance, shape (bs, h*w, 4).
|
||||
pred_angle (torch.Tensor): Predicted angle, shape (bs, h*w, 1).
|
||||
anchor_points (torch.Tensor): Anchor points, shape (h*w, 2).
|
||||
dim (int, optional): Dimension along which to split. Defaults to -1.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Predicted rotated bounding boxes, (bs, h*w, 4).
|
||||
(torch.Tensor): Predicted rotated bounding boxes, shape (bs, h*w, 4).
|
||||
"""
|
||||
lt, rb = pred_dist.split(2, dim=dim)
|
||||
cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)
|
||||
|
|
|
|||
|
|
@ -146,7 +146,6 @@ def select_device(device="", batch=0, newline=False, verbose=True):
|
|||
Note:
|
||||
Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use.
|
||||
"""
|
||||
|
||||
if isinstance(device, torch.device):
|
||||
return device
|
||||
|
||||
|
|
@ -417,9 +416,7 @@ def initialize_weights(model):
|
|||
|
||||
|
||||
def scale_img(img, ratio=1.0, same_shape=False, gs=32):
|
||||
"""Scales and pads an image tensor of shape img(bs,3,y,x) based on given ratio and grid size gs, optionally
|
||||
retaining the original shape.
|
||||
"""
|
||||
"""Scales and pads an image tensor, optionally maintaining aspect ratio and padding to gs multiple."""
|
||||
if ratio == 1.0:
|
||||
return img
|
||||
h, w = img.shape[2:]
|
||||
|
|
@ -493,7 +490,7 @@ def init_seeds(seed=0, deterministic=False):
|
|||
class ModelEMA:
|
||||
"""
|
||||
Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models. Keeps a moving
|
||||
average of everything in the model state_dict (parameters and buffers)
|
||||
average of everything in the model state_dict (parameters and buffers).
|
||||
|
||||
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,6 @@ def run_ray_tune(
|
|||
result_grid = model.tune(data="coco8.yaml", use_ray=True)
|
||||
```
|
||||
"""
|
||||
|
||||
LOGGER.info("💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune")
|
||||
if train_args is None:
|
||||
train_args = {}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue