ultralytics 8.0.167 Tuner updates and HUB Pose and Classify fixes (#4656)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-08-31 01:38:42 +02:00 committed by GitHub
parent 8596ee241f
commit d2cf7acce0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 174 additions and 144 deletions

View file

@ -202,6 +202,28 @@ def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
return masks, index
def find_dataset_yaml(path: Path) -> Path:
"""
Find and return the YAML file associated with a Detect, Segment or Pose dataset.
This function searches for a YAML file at the root level of the provided directory first, and if not found, it
performs a recursive search. It prefers YAML files that have the samestem as the provided path. An AssertionError
is raised if no YAML file is found or if multiple YAML files are found.
Args:
path (Path): The directory path to search for the YAML file.
Returns:
(Path): The path of the found YAML file.
"""
files = list(path.glob('*.yaml')) or list(path.rglob('*.yaml')) # try root level first and then recursive
assert files, f"No YAML file found in '{path.resolve()}'"
if len(files) > 1:
files = [f for f in files if f.stem == path.stem] # prefer *.yaml files that match
assert len(files) == 1, f"Expected 1 YAML file in '{path.resolve()}', but found {len(files)}.\n{files}"
return files[0]
def check_det_dataset(dataset, autodownload=True):
"""
Download, verify, and/or unzip a dataset if not found locally.
@ -223,8 +245,8 @@ def check_det_dataset(dataset, autodownload=True):
# Download (optional)
extract_dir = ''
if isinstance(data, (str, Path)) and (zipfile.is_zipfile(data) or is_tarfile(data)):
new_dir = safe_download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False)
data = next((DATASETS_DIR / new_dir).rglob('*.yaml'))
new_dir = safe_download(data, dir=DATASETS_DIR, unzip=True, delete=False)
data = find_dataset_yaml(DATASETS_DIR / new_dir)
extract_dir, autodownload = data.parent, False
# Read YAML (optional)
@ -316,6 +338,10 @@ def check_cls_dataset(dataset, split=''):
- 'names' (dict): A dictionary of class names in the dataset.
"""
# Download (optional if dataset=https://file.zip is passed directly)
if str(dataset).startswith(('http:/', 'https:/')):
dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False)
dataset = Path(dataset)
data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
if not data_dir.is_dir():
@ -329,8 +355,8 @@ def check_cls_dataset(dataset, split=''):
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
LOGGER.info(s)
train_set = data_dir / 'train'
val_set = data_dir / 'val' if (data_dir / 'val').exists() else data_dir / 'validation' if (
data_dir / 'validation').exists() else None # data/test or data/val
val_set = data_dir / 'val' if (data_dir / 'val').exists() else data_dir / 'validation' if \
(data_dir / 'validation').exists() else None # data/test or data/val
test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test
if split == 'val' and not val_set:
LOGGER.warning("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
@ -414,16 +440,6 @@ class HUBDatasetStats:
self.stats = {'nc': len(data['names']), 'names': list(data['names'].values())} # statistics dictionary
self.data = data
@staticmethod
def _find_yaml(dir):
"""Return data.yaml file."""
files = list(dir.glob('*.yaml')) or list(dir.rglob('*.yaml')) # try root level first and then recursive
assert files, f"No *.yaml file found in '{dir.resolve()}'"
if len(files) > 1:
files = [f for f in files if f.stem == dir.stem] # prefer *.yaml files that match dir name
assert len(files) == 1, f"Expected 1 *.yaml file in '{dir.resolve()}', but found {len(files)}.\n{files}"
return files[0]
def _unzip(self, path):
"""Unzip data.zip."""
if not str(path).endswith('.zip'): # path is data.yaml
@ -431,7 +447,7 @@ class HUBDatasetStats:
unzip_dir = unzip_file(path, path=path.parent)
assert unzip_dir.is_dir(), f'Error unzipping {path}, {unzip_dir} not found. ' \
f'path/to/abc.zip MUST unzip to path/to/abc/'
return True, str(unzip_dir), self._find_yaml(unzip_dir) # zipped, data_dir, yaml_path
return True, str(unzip_dir), find_dataset_yaml(unzip_dir) # zipped, data_dir, yaml_path
def _hub_ops(self, f):
"""Saves a compressed image for HUB previews."""