diff --git a/ultralytics/data/dataset.py b/ultralytics/data/dataset.py index c10adef8..3ed6a219 100644 --- a/ultralytics/data/dataset.py +++ b/ultralytics/data/dataset.py @@ -15,7 +15,7 @@ from torch.utils.data import ConcatDataset from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr from ultralytics.utils.ops import resample_segments -from ultralytics.utils.torch_utils import TORCH_1_13 +from ultralytics.utils.torch_utils import TORCHVISION_0_18 from .augment import ( Compose, @@ -417,7 +417,7 @@ class ClassificationDataset: import torchvision # scope for faster 'import ultralytics' # Base class assigned as attribute rather than used as base class to allow for scoping slow torchvision import - if TORCH_1_13: # 'allow_empty' argument first introduced in torch 1.13 + if TORCHVISION_0_18: # 'allow_empty' argument first introduced in torchvision 0.18 self.base = torchvision.datasets.ImageFolder(root=root, allow_empty=True) else: self.base = torchvision.datasets.ImageFolder(root=root) diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py index ca814b60..db17813f 100644 --- a/ultralytics/utils/torch_utils.py +++ b/ultralytics/utils/torch_utils.py @@ -40,6 +40,7 @@ TORCH_2_0 = check_version(torch.__version__, "2.0.0") TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0") TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0") TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0") +TORCHVISION_0_18 = check_version(TORCHVISION_VERSION, "0.18.0") @contextmanager