ultralytics 8.1.43 40% faster ultralytics imports (#9547)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
99c61d6f7b
commit
a2628657a1
21 changed files with 240 additions and 225 deletions
|
|
@ -10,7 +10,6 @@ import pytest
|
|||
import torch
|
||||
import yaml
|
||||
from PIL import Image
|
||||
from torchvision.transforms import ToTensor
|
||||
|
||||
from ultralytics import RTDETR, YOLO
|
||||
from ultralytics.cfg import TASK2DATA
|
||||
|
|
@ -108,20 +107,17 @@ def test_predict_img():
|
|||
assert len(model(batch, imgsz=32)) == len(batch) # multiple sources in a batch
|
||||
|
||||
# Test tensor inference
|
||||
im = cv2.imread(str(SOURCE)) # OpenCV
|
||||
t = cv2.resize(im, (32, 32))
|
||||
t = ToTensor()(t)
|
||||
t = torch.stack([t, t, t, t])
|
||||
results = model(t, imgsz=32)
|
||||
assert len(results) == t.shape[0]
|
||||
results = seg_model(t, imgsz=32)
|
||||
assert len(results) == t.shape[0]
|
||||
results = cls_model(t, imgsz=32)
|
||||
assert len(results) == t.shape[0]
|
||||
results = pose_model(t, imgsz=32)
|
||||
assert len(results) == t.shape[0]
|
||||
results = obb_model(t, imgsz=32)
|
||||
assert len(results) == t.shape[0]
|
||||
im = torch.rand((4, 3, 32, 32)) # batch-size 4, FP32 0.0-1.0 RGB order
|
||||
results = model(im, imgsz=32)
|
||||
assert len(results) == im.shape[0]
|
||||
results = seg_model(im, imgsz=32)
|
||||
assert len(results) == im.shape[0]
|
||||
results = cls_model(im, imgsz=32)
|
||||
assert len(results) == im.shape[0]
|
||||
results = pose_model(im, imgsz=32)
|
||||
assert len(results) == im.shape[0]
|
||||
results = obb_model(im, imgsz=32)
|
||||
assert len(results) == im.shape[0]
|
||||
|
||||
|
||||
def test_predict_grey_and_4ch():
|
||||
|
|
@ -592,8 +588,6 @@ def image():
|
|||
)
|
||||
def test_classify_transforms_train(image, auto_augment, erasing, force_color_jitter):
|
||||
"""Tests classification transforms during training with various augmentation settings."""
|
||||
import torchvision.transforms as T
|
||||
|
||||
from ultralytics.data.augment import classify_augmentations
|
||||
|
||||
transform = classify_augmentations(
|
||||
|
|
@ -610,7 +604,6 @@ def test_classify_transforms_train(image, auto_augment, erasing, force_color_jit
|
|||
hsv_v=0.4,
|
||||
force_color_jitter=force_color_jitter,
|
||||
erasing=erasing,
|
||||
interpolation=T.InterpolationMode.BILINEAR,
|
||||
)
|
||||
|
||||
transformed_image = transform(Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)))
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = "8.1.42"
|
||||
__version__ = "8.1.43"
|
||||
|
||||
from ultralytics.data.explorer.explorer import Explorer
|
||||
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from typing import Tuple, Union
|
|||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.utils import LOGGER, colorstr
|
||||
from ultralytics.utils.checks import check_version
|
||||
|
|
@ -167,8 +167,8 @@ class BaseMixTransform:
|
|||
text2id = {text: i for i, text in enumerate(mix_texts)}
|
||||
|
||||
for label in [labels] + labels["mix_labels"]:
|
||||
for i, l in enumerate(label["cls"].squeeze(-1).tolist()):
|
||||
text = label["texts"][int(l)]
|
||||
for i, cls in enumerate(label["cls"].squeeze(-1).tolist()):
|
||||
text = label["texts"][int(cls)]
|
||||
label["cls"][i] = text2id[tuple(text)]
|
||||
label["texts"] = mix_texts
|
||||
return labels
|
||||
|
|
@ -1133,7 +1133,7 @@ def classify_transforms(
|
|||
size=224,
|
||||
mean=DEFAULT_MEAN,
|
||||
std=DEFAULT_STD,
|
||||
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
|
||||
interpolation=Image.BILINEAR,
|
||||
crop_fraction: float = DEFAULT_CROP_FTACTION,
|
||||
):
|
||||
"""
|
||||
|
|
@ -1149,6 +1149,7 @@ def classify_transforms(
|
|||
Returns:
|
||||
(T.Compose): torchvision transforms
|
||||
"""
|
||||
import torchvision.transforms as T # scope for faster 'import ultralytics'
|
||||
|
||||
if isinstance(size, (tuple, list)):
|
||||
assert len(size) == 2
|
||||
|
|
@ -1157,12 +1158,12 @@ def classify_transforms(
|
|||
scale_size = math.floor(size / crop_fraction)
|
||||
scale_size = (scale_size, scale_size)
|
||||
|
||||
# aspect ratio is preserved, crops center within image, no borders are added, image is lost
|
||||
# Aspect ratio is preserved, crops center within image, no borders are added, image is lost
|
||||
if scale_size[0] == scale_size[1]:
|
||||
# simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
|
||||
# Simple case, use torchvision built-in Resize with the shortest edge mode (scalar size arg)
|
||||
tfl = [T.Resize(scale_size[0], interpolation=interpolation)]
|
||||
else:
|
||||
# resize shortest edge to matching target dim for non-square target
|
||||
# Resize the shortest edge to matching target dim for non-square target
|
||||
tfl = [T.Resize(scale_size)]
|
||||
tfl += [T.CenterCrop(size)]
|
||||
|
||||
|
|
@ -1192,7 +1193,7 @@ def classify_augmentations(
|
|||
hsv_v=0.4, # image HSV-Value augmentation (fraction)
|
||||
force_color_jitter=False,
|
||||
erasing=0.0,
|
||||
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
|
||||
interpolation=Image.BILINEAR,
|
||||
):
|
||||
"""
|
||||
Classification transforms with augmentation for training. Inspired by timm/data/transforms_factory.py.
|
||||
|
|
@ -1216,7 +1217,9 @@ def classify_augmentations(
|
|||
Returns:
|
||||
(T.Compose): torchvision transforms
|
||||
"""
|
||||
# Transforms to apply if albumentations not installed
|
||||
# Transforms to apply if Albumentations not installed
|
||||
import torchvision.transforms as T # scope for faster 'import ultralytics'
|
||||
|
||||
if not isinstance(size, int):
|
||||
raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)")
|
||||
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
|
||||
|
|
|
|||
|
|
@ -1,18 +1,17 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
import contextlib
|
||||
from itertools import repeat
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from itertools import repeat
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
|
||||
from ultralytics.utils.ops import resample_segments
|
||||
from .augment import (
|
||||
|
|
@ -103,16 +102,16 @@ class YOLODataset(BaseDataset):
|
|||
nc += nc_f
|
||||
if im_file:
|
||||
x["labels"].append(
|
||||
dict(
|
||||
im_file=im_file,
|
||||
shape=shape,
|
||||
cls=lb[:, 0:1], # n, 1
|
||||
bboxes=lb[:, 1:], # n, 4
|
||||
segments=segments,
|
||||
keypoints=keypoint,
|
||||
normalized=True,
|
||||
bbox_format="xywh",
|
||||
)
|
||||
{
|
||||
"im_file": im_file,
|
||||
"shape": shape,
|
||||
"cls": lb[:, 0:1], # n, 1
|
||||
"bboxes": lb[:, 1:], # n, 4
|
||||
"segments": segments,
|
||||
"keypoints": keypoint,
|
||||
"normalized": True,
|
||||
"bbox_format": "xywh",
|
||||
}
|
||||
)
|
||||
if msg:
|
||||
msgs.append(msg)
|
||||
|
|
@ -245,125 +244,6 @@ class YOLODataset(BaseDataset):
|
|||
return new_batch
|
||||
|
||||
|
||||
# Classification dataloaders -------------------------------------------------------------------------------------------
|
||||
class ClassificationDataset(torchvision.datasets.ImageFolder):
|
||||
"""
|
||||
Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like image
|
||||
augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep
|
||||
learning models, with optional image transformations and caching mechanisms to speed up training.
|
||||
|
||||
This class allows for augmentations using both torchvision and Albumentations libraries, and supports caching images
|
||||
in RAM or on disk to reduce IO overhead during training. Additionally, it implements a robust verification process
|
||||
to ensure data integrity and consistency.
|
||||
|
||||
Attributes:
|
||||
cache_ram (bool): Indicates if caching in RAM is enabled.
|
||||
cache_disk (bool): Indicates if caching on disk is enabled.
|
||||
samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache
|
||||
file (if caching on disk), and optionally the loaded image array (if caching in RAM).
|
||||
torch_transforms (callable): PyTorch transforms to be applied to the images.
|
||||
"""
|
||||
|
||||
def __init__(self, root, args, augment=False, prefix=""):
|
||||
"""
|
||||
Initialize YOLO object with root, image size, augmentations, and cache settings.
|
||||
|
||||
Args:
|
||||
root (str): Path to the dataset directory where images are stored in a class-specific folder structure.
|
||||
args (Namespace): Configuration containing dataset-related settings such as image size, augmentation
|
||||
parameters, and cache settings. It includes attributes like `imgsz` (image size), `fraction` (fraction
|
||||
of data to use), `scale`, `fliplr`, `flipud`, `cache` (disk or RAM caching for faster training),
|
||||
`auto_augment`, `hsv_h`, `hsv_s`, `hsv_v`, and `crop_fraction`.
|
||||
augment (bool, optional): Whether to apply augmentations to the dataset. Default is False.
|
||||
prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification and
|
||||
debugging. Default is an empty string.
|
||||
"""
|
||||
super().__init__(root=root)
|
||||
if augment and args.fraction < 1.0: # reduce training fraction
|
||||
self.samples = self.samples[: round(len(self.samples) * args.fraction)]
|
||||
self.prefix = colorstr(f"{prefix}: ") if prefix else ""
|
||||
self.cache_ram = args.cache is True or str(args.cache).lower() == "ram" # cache images into RAM
|
||||
self.cache_disk = str(args.cache).lower() == "disk" # cache images on hard drive as uncompressed *.npy files
|
||||
self.samples = self.verify_images() # filter out bad images
|
||||
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
|
||||
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
|
||||
self.torch_transforms = (
|
||||
classify_augmentations(
|
||||
size=args.imgsz,
|
||||
scale=scale,
|
||||
hflip=args.fliplr,
|
||||
vflip=args.flipud,
|
||||
erasing=args.erasing,
|
||||
auto_augment=args.auto_augment,
|
||||
hsv_h=args.hsv_h,
|
||||
hsv_s=args.hsv_s,
|
||||
hsv_v=args.hsv_v,
|
||||
)
|
||||
if augment
|
||||
else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction)
|
||||
)
|
||||
|
||||
def __getitem__(self, i):
|
||||
"""Returns subset of data and targets corresponding to given indices."""
|
||||
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
|
||||
if self.cache_ram:
|
||||
if im is None: # Warning: two separate if statements required here, do not combine this with previous line
|
||||
im = self.samples[i][3] = cv2.imread(f)
|
||||
elif self.cache_disk:
|
||||
if not fn.exists(): # load npy
|
||||
np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)
|
||||
im = np.load(fn)
|
||||
else: # read image
|
||||
im = cv2.imread(f) # BGR
|
||||
# Convert NumPy array to PIL image
|
||||
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
|
||||
sample = self.torch_transforms(im)
|
||||
return {"img": sample, "cls": j}
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the total number of samples in the dataset."""
|
||||
return len(self.samples)
|
||||
|
||||
def verify_images(self):
|
||||
"""Verify all images in dataset."""
|
||||
desc = f"{self.prefix}Scanning {self.root}..."
|
||||
path = Path(self.root).with_suffix(".cache") # *.cache file path
|
||||
|
||||
with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
|
||||
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
|
||||
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
||||
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
|
||||
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
|
||||
if LOCAL_RANK in {-1, 0}:
|
||||
d = f"{desc} {nf} images, {nc} corrupt"
|
||||
TQDM(None, desc=d, total=n, initial=n)
|
||||
if cache["msgs"]:
|
||||
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
||||
return samples
|
||||
|
||||
# Run scan if *.cache retrieval failed
|
||||
nf, nc, msgs, samples, x = 0, 0, [], [], {}
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
|
||||
pbar = TQDM(results, desc=desc, total=len(self.samples))
|
||||
for sample, nf_f, nc_f, msg in pbar:
|
||||
if nf_f:
|
||||
samples.append(sample)
|
||||
if msg:
|
||||
msgs.append(msg)
|
||||
nf += nf_f
|
||||
nc += nc_f
|
||||
pbar.desc = f"{desc} {nf} images, {nc} corrupt"
|
||||
pbar.close()
|
||||
if msgs:
|
||||
LOGGER.info("\n".join(msgs))
|
||||
x["hash"] = get_hash([x[0] for x in self.samples])
|
||||
x["results"] = nf, nc, len(samples), samples
|
||||
x["msgs"] = msgs # warnings
|
||||
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
||||
return samples
|
||||
|
||||
|
||||
class YOLOMultiModalDataset(YOLODataset):
|
||||
"""
|
||||
Dataset class for loading object detection and/or segmentation labels in YOLO format.
|
||||
|
|
@ -447,15 +327,15 @@ class GroundingDataset(YOLODataset):
|
|||
bboxes.append(box)
|
||||
lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
|
||||
labels.append(
|
||||
dict(
|
||||
im_file=im_file,
|
||||
shape=(h, w),
|
||||
cls=lb[:, 0:1], # n, 1
|
||||
bboxes=lb[:, 1:], # n, 4
|
||||
normalized=True,
|
||||
bbox_format="xywh",
|
||||
texts=texts,
|
||||
)
|
||||
{
|
||||
"im_file": im_file,
|
||||
"shape": (h, w),
|
||||
"cls": lb[:, 0:1], # n, 1
|
||||
"bboxes": lb[:, 1:], # n, 4
|
||||
"normalized": True,
|
||||
"bbox_format": "xywh",
|
||||
"texts": texts,
|
||||
}
|
||||
)
|
||||
return labels
|
||||
|
||||
|
|
@ -497,3 +377,128 @@ class SemanticDataset(BaseDataset):
|
|||
def __init__(self):
|
||||
"""Initialize a SemanticDataset object."""
|
||||
super().__init__()
|
||||
|
||||
|
||||
class ClassificationDataset:
|
||||
"""
|
||||
Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like image
|
||||
augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep
|
||||
learning models, with optional image transformations and caching mechanisms to speed up training.
|
||||
|
||||
This class allows for augmentations using both torchvision and Albumentations libraries, and supports caching images
|
||||
in RAM or on disk to reduce IO overhead during training. Additionally, it implements a robust verification process
|
||||
to ensure data integrity and consistency.
|
||||
|
||||
Attributes:
|
||||
cache_ram (bool): Indicates if caching in RAM is enabled.
|
||||
cache_disk (bool): Indicates if caching on disk is enabled.
|
||||
samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache
|
||||
file (if caching on disk), and optionally the loaded image array (if caching in RAM).
|
||||
torch_transforms (callable): PyTorch transforms to be applied to the images.
|
||||
"""
|
||||
|
||||
def __init__(self, root, args, augment=False, prefix=""):
|
||||
"""
|
||||
Initialize YOLO object with root, image size, augmentations, and cache settings.
|
||||
|
||||
Args:
|
||||
root (str): Path to the dataset directory where images are stored in a class-specific folder structure.
|
||||
args (Namespace): Configuration containing dataset-related settings such as image size, augmentation
|
||||
parameters, and cache settings. It includes attributes like `imgsz` (image size), `fraction` (fraction
|
||||
of data to use), `scale`, `fliplr`, `flipud`, `cache` (disk or RAM caching for faster training),
|
||||
`auto_augment`, `hsv_h`, `hsv_s`, `hsv_v`, and `crop_fraction`.
|
||||
augment (bool, optional): Whether to apply augmentations to the dataset. Default is False.
|
||||
prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification and
|
||||
debugging. Default is an empty string.
|
||||
"""
|
||||
import torchvision # scope for faster 'import ultralytics'
|
||||
|
||||
# Base class assigned as attribute rather than used as base class to allow for scoping slow torchvision import
|
||||
self.base = torchvision.datasets.ImageFolder(root=root)
|
||||
self.samples = self.base.samples
|
||||
self.root = self.base.root
|
||||
|
||||
# Initialize attributes
|
||||
if augment and args.fraction < 1.0: # reduce training fraction
|
||||
self.samples = self.samples[: round(len(self.samples) * args.fraction)]
|
||||
self.prefix = colorstr(f"{prefix}: ") if prefix else ""
|
||||
self.cache_ram = args.cache is True or str(args.cache).lower() == "ram" # cache images into RAM
|
||||
self.cache_disk = str(args.cache).lower() == "disk" # cache images on hard drive as uncompressed *.npy files
|
||||
self.samples = self.verify_images() # filter out bad images
|
||||
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
|
||||
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
|
||||
self.torch_transforms = (
|
||||
classify_augmentations(
|
||||
size=args.imgsz,
|
||||
scale=scale,
|
||||
hflip=args.fliplr,
|
||||
vflip=args.flipud,
|
||||
erasing=args.erasing,
|
||||
auto_augment=args.auto_augment,
|
||||
hsv_h=args.hsv_h,
|
||||
hsv_s=args.hsv_s,
|
||||
hsv_v=args.hsv_v,
|
||||
)
|
||||
if augment
|
||||
else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction)
|
||||
)
|
||||
|
||||
def __getitem__(self, i):
|
||||
"""Returns subset of data and targets corresponding to given indices."""
|
||||
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
|
||||
if self.cache_ram:
|
||||
if im is None: # Warning: two separate if statements required here, do not combine this with previous line
|
||||
im = self.samples[i][3] = cv2.imread(f)
|
||||
elif self.cache_disk:
|
||||
if not fn.exists(): # load npy
|
||||
np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)
|
||||
im = np.load(fn)
|
||||
else: # read image
|
||||
im = cv2.imread(f) # BGR
|
||||
# Convert NumPy array to PIL image
|
||||
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
|
||||
sample = self.torch_transforms(im)
|
||||
return {"img": sample, "cls": j}
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the total number of samples in the dataset."""
|
||||
return len(self.samples)
|
||||
|
||||
def verify_images(self):
|
||||
"""Verify all images in dataset."""
|
||||
desc = f"{self.prefix}Scanning {self.root}..."
|
||||
path = Path(self.root).with_suffix(".cache") # *.cache file path
|
||||
|
||||
with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
|
||||
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
|
||||
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
||||
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
|
||||
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
|
||||
if LOCAL_RANK in {-1, 0}:
|
||||
d = f"{desc} {nf} images, {nc} corrupt"
|
||||
TQDM(None, desc=d, total=n, initial=n)
|
||||
if cache["msgs"]:
|
||||
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
||||
return samples
|
||||
|
||||
# Run scan if *.cache retrieval failed
|
||||
nf, nc, msgs, samples, x = 0, 0, [], [], {}
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
|
||||
pbar = TQDM(results, desc=desc, total=len(self.samples))
|
||||
for sample, nf_f, nc_f, msg in pbar:
|
||||
if nf_f:
|
||||
samples.append(sample)
|
||||
if msg:
|
||||
msgs.append(msg)
|
||||
nf += nf_f
|
||||
nc += nc_f
|
||||
pbar.desc = f"{desc} {nf} images, {nc} corrupt"
|
||||
pbar.close()
|
||||
if msgs:
|
||||
LOGGER.info("\n".join(msgs))
|
||||
x["hash"] = get_hash([x[0] for x in self.samples])
|
||||
x["results"] = nf, nc, len(samples), samples
|
||||
x["msgs"] = msgs # warnings
|
||||
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
||||
return samples
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ import numpy as np
|
|||
import torch
|
||||
from PIL import Image
|
||||
from matplotlib import pyplot as plt
|
||||
from pandas import DataFrame
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.data.augment import Format
|
||||
|
|
@ -172,7 +171,7 @@ class Explorer:
|
|||
|
||||
def sql_query(
|
||||
self, query: str, return_type: str = "pandas"
|
||||
) -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table
|
||||
) -> Union[Any, None]: # pandas.DataFrame or pyarrow.Table
|
||||
"""
|
||||
Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
|
||||
|
||||
|
|
@ -247,7 +246,7 @@ class Explorer:
|
|||
idx: Union[int, List[int]] = None,
|
||||
limit: int = 25,
|
||||
return_type: str = "pandas",
|
||||
) -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table
|
||||
) -> Any: # pandas.DataFrame or pyarrow.Table
|
||||
"""
|
||||
Query the table for similar images. Accepts a single image or a list of images.
|
||||
|
||||
|
|
@ -312,7 +311,7 @@ class Explorer:
|
|||
img = plot_query_result(similar, plot_labels=labels)
|
||||
return Image.fromarray(img)
|
||||
|
||||
def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame:
|
||||
def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Any: # pd.DataFrame
|
||||
"""
|
||||
Calculate the similarity index of all the images in the table. Here, the index will contain the data points that
|
||||
are max_dist or closer to the image in the embedding space at a given index.
|
||||
|
|
@ -447,12 +446,11 @@ class Explorer:
|
|||
"""
|
||||
result = prompt_sql_query(query)
|
||||
try:
|
||||
df = self.sql_query(result)
|
||||
return self.sql_query(result)
|
||||
except Exception as e:
|
||||
LOGGER.error("AI generated query is not valid. Please try again with a different prompt")
|
||||
LOGGER.error(e)
|
||||
return None
|
||||
return df
|
||||
|
||||
def visualize(self, result):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -3,8 +3,6 @@
|
|||
import time
|
||||
from threading import Thread
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from ultralytics import Explorer
|
||||
from ultralytics.utils import ROOT, SETTINGS
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
|
|
@ -148,12 +146,14 @@ def run_ai_query():
|
|||
'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
|
||||
)
|
||||
return
|
||||
import pandas # scope for faster 'import ultralytics'
|
||||
|
||||
st.session_state["error"] = None
|
||||
query = st.session_state.get("ai_query")
|
||||
if query.rstrip().lstrip():
|
||||
exp = st.session_state["explorer"]
|
||||
res = exp.ask_ai(query)
|
||||
if not isinstance(res, pd.DataFrame) or res.empty:
|
||||
if not isinstance(res, pandas.DataFrame) or res.empty:
|
||||
st.session_state["error"] = "No results found using AI generated query. Try another query or rerun it."
|
||||
return
|
||||
st.session_state["imgs"] = res["im_file"].to_list()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from typing import List
|
|||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ultralytics.data.augment import LetterBox
|
||||
from ultralytics.utils import LOGGER as logger
|
||||
|
|
@ -64,8 +63,10 @@ def plot_query_result(similar_set, plot_labels=True):
|
|||
similar_set (list): Pyarrow or pandas object containing the similar data points
|
||||
plot_labels (bool): Whether to plot labels or not
|
||||
"""
|
||||
import pandas # scope for faster 'import ultralytics'
|
||||
|
||||
similar_set = (
|
||||
similar_set.to_dict(orient="list") if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict()
|
||||
similar_set.to_dict(orient="list") if isinstance(similar_set, pandas.DataFrame) else similar_set.to_pydict()
|
||||
)
|
||||
empty_masks = [[[]]]
|
||||
empty_boxes = [[]]
|
||||
|
|
|
|||
|
|
@ -75,6 +75,7 @@ from ultralytics.utils import (
|
|||
LINUX,
|
||||
LOGGER,
|
||||
MACOS,
|
||||
PYTHON_VERSION,
|
||||
ROOT,
|
||||
WINDOWS,
|
||||
__version__,
|
||||
|
|
@ -83,7 +84,7 @@ from ultralytics.utils import (
|
|||
get_default_args,
|
||||
yaml_save,
|
||||
)
|
||||
from ultralytics.utils.checks import PYTHON_VERSION, check_imgsz, check_is_path_safe, check_requirements, check_version
|
||||
from ultralytics.utils.checks import check_imgsz, check_is_path_safe, check_requirements, check_version
|
||||
from ultralytics.utils.downloads import attempt_download_asset, get_github_assets
|
||||
from ultralytics.utils.files import file_size, spaces_in_path
|
||||
from ultralytics.utils.ops import Profile
|
||||
|
|
@ -92,7 +93,7 @@ from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_d
|
|||
|
||||
def export_formats():
|
||||
"""YOLOv8 export formats."""
|
||||
import pandas
|
||||
import pandas # scope for faster 'import ultralytics'
|
||||
|
||||
x = [
|
||||
["PyTorch", "-", ".pt", True, True],
|
||||
|
|
|
|||
|
|
@ -464,7 +464,7 @@ class BaseTrainer:
|
|||
def save_model(self):
|
||||
"""Save model training checkpoints with additional metadata."""
|
||||
import io
|
||||
import pandas as pd # scope for faster startup
|
||||
import pandas as pd # scope for faster 'import ultralytics'
|
||||
|
||||
# Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
|
||||
buffer = io.BytesIO()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import os
|
|||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
|
@ -118,6 +117,8 @@ class FastSAMPrompt:
|
|||
retina (bool, optional): Whether to use retina mask. Defaults to False.
|
||||
with_contours (bool, optional): Whether to plot contours. Defaults to True.
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
pbar = TQDM(annotations, total=len(annotations))
|
||||
for ann in pbar:
|
||||
result_name = os.path.basename(ann.path)
|
||||
|
|
@ -202,6 +203,8 @@ class FastSAMPrompt:
|
|||
target_height (int, optional): Target height for resizing. Defaults to 960.
|
||||
target_width (int, optional): Target width for resizing. Defaults to 960.
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
n, h, w = annotation.shape # batch, height, width
|
||||
|
||||
areas = np.sum(annotation, axis=(1, 2))
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ segmentation tasks.
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
from ultralytics.data.augment import LetterBox
|
||||
from ultralytics.engine.predictor import BasePredictor
|
||||
|
|
@ -246,6 +245,8 @@ class Predictor(BasePredictor):
|
|||
Returns:
|
||||
(tuple): A tuple containing segmented masks, confidence scores, and bounding boxes.
|
||||
"""
|
||||
import torchvision # scope for faster 'import ultralytics'
|
||||
|
||||
self.segment_all = True
|
||||
ih, iw = im.shape[2:]
|
||||
crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio)
|
||||
|
|
@ -449,6 +450,8 @@ class Predictor(BasePredictor):
|
|||
- new_masks (torch.Tensor): The processed masks with small regions removed. Shape is (N, H, W).
|
||||
- keep (List[int]): The indices of the remaining masks post-NMS, which can be used to filter the boxes.
|
||||
"""
|
||||
import torchvision # scope for faster 'import ultralytics'
|
||||
|
||||
if len(masks) == 0:
|
||||
return masks
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from ultralytics.data import ClassificationDataset, build_dataloader
|
||||
from ultralytics.engine.trainer import BaseTrainer
|
||||
|
|
@ -59,6 +58,8 @@ class ClassificationTrainer(BaseTrainer):
|
|||
|
||||
def setup_model(self):
|
||||
"""Load, create or download model for any task."""
|
||||
import torchvision # scope for faster 'import ultralytics'
|
||||
|
||||
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import contextlib
|
||||
import importlib.metadata
|
||||
import inspect
|
||||
import logging.config
|
||||
import os
|
||||
|
|
@ -42,6 +43,8 @@ TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}" if VERBOSE else None # tqdm bar form
|
|||
LOGGING_NAME = "ultralytics"
|
||||
MACOS, LINUX, WINDOWS = (platform.system() == x for x in ["Darwin", "Linux", "Windows"]) # environment booleans
|
||||
ARM64 = platform.machine() in {"arm64", "aarch64"} # ARM64 booleans
|
||||
PYTHON_VERSION = platform.python_version()
|
||||
TORCHVISION_VERSION = importlib.metadata.version("torchvision") # faster than importing torchvision
|
||||
HELP_MSG = """
|
||||
Usage examples for running YOLOv8:
|
||||
|
||||
|
|
@ -476,7 +479,7 @@ def is_online() -> bool:
|
|||
|
||||
for host in "1.1.1.1", "8.8.8.8", "223.5.5.5": # Cloudflare, Google, AliDNS:
|
||||
try:
|
||||
test_connection = socket.create_connection(address=(host, 53), timeout=2)
|
||||
test_connection = socket.create_connection(address=(host, 80), timeout=2)
|
||||
except (socket.timeout, socket.gaierror, OSError):
|
||||
continue
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -69,8 +69,7 @@ def benchmark(
|
|||
benchmark(model='yolov8n.pt', imgsz=640)
|
||||
```
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import pandas as pd # scope for faster 'import ultralytics'
|
||||
|
||||
pd.options.display.max_columns = 10
|
||||
pd.options.display.width = 120
|
||||
|
|
|
|||
|
|
@ -7,8 +7,6 @@ try:
|
|||
assert SETTINGS["clearml"] is True # verify integration is enabled
|
||||
import clearml
|
||||
from clearml import Task
|
||||
from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
||||
from clearml.binding.matplotlib_bind import PatchedMatplotlib
|
||||
|
||||
assert hasattr(clearml, "__version__") # verify package is not directory
|
||||
|
||||
|
|
@ -61,8 +59,11 @@ def on_pretrain_routine_start(trainer):
|
|||
"""Runs at start of pretraining routine; initializes and connects/ logs task to ClearML."""
|
||||
try:
|
||||
if task := Task.current_task():
|
||||
# Make sure the automatic pytorch and matplotlib bindings are disabled!
|
||||
# WARNING: make sure the automatic pytorch and matplotlib bindings are disabled!
|
||||
# We are logging these plots and model files manually in the integration
|
||||
from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
||||
from clearml.binding.matplotlib_bind import PatchedMatplotlib
|
||||
|
||||
PatchPyTorchModelIO.update_current_task(None)
|
||||
PatchedMatplotlib.update_current_task(None)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -9,10 +9,6 @@ try:
|
|||
import wandb as wb
|
||||
|
||||
assert hasattr(wb, "__version__") # verify package is not directory
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
_processed_plots = {}
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
|
|
@ -38,7 +34,9 @@ def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall
|
|||
Returns:
|
||||
(wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
|
||||
"""
|
||||
df = pd.DataFrame({"class": classes, "y": y, "x": x}).round(3)
|
||||
import pandas # scope for faster 'import ultralytics'
|
||||
|
||||
df = pandas.DataFrame({"class": classes, "y": y, "x": x}).round(3)
|
||||
fields = {"x": "x", "y": "y", "class": "class"}
|
||||
string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
|
||||
return wb.plot_table(
|
||||
|
|
@ -77,6 +75,8 @@ def _plot_curve(
|
|||
Note:
|
||||
The function leverages the '_custom_table' function to generate the actual visualization.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# Create new x
|
||||
if names is None:
|
||||
names = []
|
||||
|
|
|
|||
|
|
@ -18,15 +18,16 @@ import cv2
|
|||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from matplotlib import font_manager
|
||||
|
||||
from ultralytics.utils import (
|
||||
ASSETS,
|
||||
AUTOINSTALL,
|
||||
LINUX,
|
||||
LOGGER,
|
||||
PYTHON_VERSION,
|
||||
ONLINE,
|
||||
ROOT,
|
||||
TORCHVISION_VERSION,
|
||||
USER_CONFIG_DIR,
|
||||
Retry,
|
||||
SimpleNamespace,
|
||||
|
|
@ -41,13 +42,10 @@ from ultralytics.utils import (
|
|||
is_github_action_running,
|
||||
is_jupyter,
|
||||
is_kaggle,
|
||||
is_online,
|
||||
is_pip_package,
|
||||
url2file,
|
||||
)
|
||||
|
||||
PYTHON_VERSION = platform.python_version()
|
||||
|
||||
|
||||
def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
|
||||
"""
|
||||
|
|
@ -304,9 +302,10 @@ def check_font(font="Arial.ttf"):
|
|||
Returns:
|
||||
file (Path): Resolved font file path.
|
||||
"""
|
||||
name = Path(font).name
|
||||
from matplotlib import font_manager
|
||||
|
||||
# Check USER_CONFIG_DIR
|
||||
name = Path(font).name
|
||||
file = USER_CONFIG_DIR / name
|
||||
if file.exists():
|
||||
return file
|
||||
|
|
@ -390,7 +389,7 @@ def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=()
|
|||
LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...")
|
||||
try:
|
||||
t = time.time()
|
||||
assert is_online(), "AutoUpdate skipped (offline)"
|
||||
assert ONLINE, "AutoUpdate skipped (offline)"
|
||||
with Retry(times=2, delay=1): # run up to 2 times with 1-second retry delay
|
||||
LOGGER.info(subprocess.check_output(f"pip install --no-cache {s} {cmds}", shell=True).decode())
|
||||
dt = time.time() - t
|
||||
|
|
@ -419,14 +418,12 @@ def check_torchvision():
|
|||
Torchvision versions.
|
||||
"""
|
||||
|
||||
import torchvision
|
||||
|
||||
# Compatibility table
|
||||
compatibility_table = {"2.0": ["0.15"], "1.13": ["0.14"], "1.12": ["0.13"]}
|
||||
|
||||
# Extract only the major and minor versions
|
||||
v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2])
|
||||
v_torchvision = ".".join(torchvision.__version__.split("+")[0].split(".")[:2])
|
||||
v_torchvision = ".".join(TORCHVISION_VERSION.split("+")[0].split(".")[:2])
|
||||
|
||||
if v_torch in compatibility_table:
|
||||
compatible_versions = compatibility_table[v_torch]
|
||||
|
|
|
|||
|
|
@ -395,19 +395,19 @@ class ConfusionMatrix:
|
|||
names (tuple): Names of classes, used as labels on the plot.
|
||||
on_plot (func): An optional callback to pass plots path and data when they are rendered.
|
||||
"""
|
||||
import seaborn as sn
|
||||
import seaborn # scope for faster 'import ultralytics'
|
||||
|
||||
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns
|
||||
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
|
||||
|
||||
fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
|
||||
nc, nn = self.nc, len(names) # number of classes, names
|
||||
sn.set_theme(font_scale=1.0 if nc < 50 else 0.8) # for label size
|
||||
seaborn.set_theme(font_scale=1.0 if nc < 50 else 0.8) # for label size
|
||||
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
|
||||
ticklabels = (list(names) + ["background"]) if labels else "auto"
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore") # suppress empty matrix RuntimeWarning: All-NaN slice encountered
|
||||
sn.heatmap(
|
||||
seaborn.heatmap(
|
||||
array,
|
||||
ax=ax,
|
||||
annot=nc < 30,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ import cv2
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
from ultralytics.utils import LOGGER
|
||||
from ultralytics.utils.metrics import batch_probiou
|
||||
|
|
@ -206,6 +205,7 @@ def non_max_suppression(
|
|||
shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
|
||||
(x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
|
||||
"""
|
||||
import torchvision # scope for faster 'import ultralytics'
|
||||
|
||||
# Checks
|
||||
assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
|
||||
|
|
|
|||
|
|
@ -671,8 +671,8 @@ class Annotator:
|
|||
@plt_settings()
|
||||
def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
||||
"""Plot training labels including class histograms and box statistics."""
|
||||
import pandas as pd
|
||||
import seaborn as sn
|
||||
import pandas # scope for faster 'import ultralytics'
|
||||
import seaborn # scope for faster 'import ultralytics'
|
||||
|
||||
# Filter matplotlib>=3.7.2 warning and Seaborn use_inf and is_categorical FutureWarnings
|
||||
warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight")
|
||||
|
|
@ -682,10 +682,10 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
|||
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
|
||||
nc = int(cls.max() + 1) # number of classes
|
||||
boxes = boxes[:1000000] # limit to 1M boxes
|
||||
x = pd.DataFrame(boxes, columns=["x", "y", "width", "height"])
|
||||
x = pandas.DataFrame(boxes, columns=["x", "y", "width", "height"])
|
||||
|
||||
# Seaborn correlogram
|
||||
sn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
|
||||
seaborn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
|
||||
plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200)
|
||||
plt.close()
|
||||
|
||||
|
|
@ -700,8 +700,8 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
|||
ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
|
||||
else:
|
||||
ax[0].set_xlabel("classes")
|
||||
sn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9)
|
||||
sn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9)
|
||||
seaborn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9)
|
||||
seaborn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9)
|
||||
|
||||
# Rectangles
|
||||
boxes[:, 0:2] = 0.5 # center
|
||||
|
|
@ -933,7 +933,7 @@ def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False,
|
|||
plot_results('path/to/results.csv', segment=True)
|
||||
```
|
||||
"""
|
||||
import pandas as pd
|
||||
import pandas as pd # scope for faster 'import ultralytics'
|
||||
from scipy.ndimage import gaussian_filter1d
|
||||
|
||||
save_dir = Path(file).parent if file else Path(dir)
|
||||
|
|
@ -1019,7 +1019,7 @@ def plot_tune_results(csv_file="tune_results.csv"):
|
|||
>>> plot_tune_results('path/to/tune_results.csv')
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import pandas as pd # scope for faster 'import ultralytics'
|
||||
from scipy.ndimage import gaussian_filter1d
|
||||
|
||||
# Scatter plots for each hyperparameter
|
||||
|
|
|
|||
|
|
@ -14,10 +14,17 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, __version__
|
||||
from ultralytics.utils.checks import PYTHON_VERSION, check_version
|
||||
from ultralytics.utils import (
|
||||
DEFAULT_CFG_DICT,
|
||||
DEFAULT_CFG_KEYS,
|
||||
LOGGER,
|
||||
PYTHON_VERSION,
|
||||
TORCHVISION_VERSION,
|
||||
colorstr,
|
||||
__version__,
|
||||
)
|
||||
from ultralytics.utils.checks import check_version
|
||||
|
||||
try:
|
||||
import thop
|
||||
|
|
@ -28,9 +35,9 @@ except ImportError:
|
|||
TORCH_1_9 = check_version(torch.__version__, "1.9.0")
|
||||
TORCH_1_13 = check_version(torch.__version__, "1.13.0")
|
||||
TORCH_2_0 = check_version(torch.__version__, "2.0.0")
|
||||
TORCHVISION_0_10 = check_version(torchvision.__version__, "0.10.0")
|
||||
TORCHVISION_0_11 = check_version(torchvision.__version__, "0.11.0")
|
||||
TORCHVISION_0_13 = check_version(torchvision.__version__, "0.13.0")
|
||||
TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0")
|
||||
TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0")
|
||||
TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0")
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue