ultralytics 8.0.167 Tuner updates and HUB Pose and Classify fixes (#4656)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
8596ee241f
commit
d2cf7acce0
21 changed files with 174 additions and 144 deletions
|
|
@ -202,6 +202,28 @@ def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
|
|||
return masks, index
|
||||
|
||||
|
||||
def find_dataset_yaml(path: Path) -> Path:
|
||||
"""
|
||||
Find and return the YAML file associated with a Detect, Segment or Pose dataset.
|
||||
|
||||
This function searches for a YAML file at the root level of the provided directory first, and if not found, it
|
||||
performs a recursive search. It prefers YAML files that have the samestem as the provided path. An AssertionError
|
||||
is raised if no YAML file is found or if multiple YAML files are found.
|
||||
|
||||
Args:
|
||||
path (Path): The directory path to search for the YAML file.
|
||||
|
||||
Returns:
|
||||
(Path): The path of the found YAML file.
|
||||
"""
|
||||
files = list(path.glob('*.yaml')) or list(path.rglob('*.yaml')) # try root level first and then recursive
|
||||
assert files, f"No YAML file found in '{path.resolve()}'"
|
||||
if len(files) > 1:
|
||||
files = [f for f in files if f.stem == path.stem] # prefer *.yaml files that match
|
||||
assert len(files) == 1, f"Expected 1 YAML file in '{path.resolve()}', but found {len(files)}.\n{files}"
|
||||
return files[0]
|
||||
|
||||
|
||||
def check_det_dataset(dataset, autodownload=True):
|
||||
"""
|
||||
Download, verify, and/or unzip a dataset if not found locally.
|
||||
|
|
@ -223,8 +245,8 @@ def check_det_dataset(dataset, autodownload=True):
|
|||
# Download (optional)
|
||||
extract_dir = ''
|
||||
if isinstance(data, (str, Path)) and (zipfile.is_zipfile(data) or is_tarfile(data)):
|
||||
new_dir = safe_download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False)
|
||||
data = next((DATASETS_DIR / new_dir).rglob('*.yaml'))
|
||||
new_dir = safe_download(data, dir=DATASETS_DIR, unzip=True, delete=False)
|
||||
data = find_dataset_yaml(DATASETS_DIR / new_dir)
|
||||
extract_dir, autodownload = data.parent, False
|
||||
|
||||
# Read YAML (optional)
|
||||
|
|
@ -316,6 +338,10 @@ def check_cls_dataset(dataset, split=''):
|
|||
- 'names' (dict): A dictionary of class names in the dataset.
|
||||
"""
|
||||
|
||||
# Download (optional if dataset=https://file.zip is passed directly)
|
||||
if str(dataset).startswith(('http:/', 'https:/')):
|
||||
dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False)
|
||||
|
||||
dataset = Path(dataset)
|
||||
data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
|
||||
if not data_dir.is_dir():
|
||||
|
|
@ -329,8 +355,8 @@ def check_cls_dataset(dataset, split=''):
|
|||
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
|
||||
LOGGER.info(s)
|
||||
train_set = data_dir / 'train'
|
||||
val_set = data_dir / 'val' if (data_dir / 'val').exists() else data_dir / 'validation' if (
|
||||
data_dir / 'validation').exists() else None # data/test or data/val
|
||||
val_set = data_dir / 'val' if (data_dir / 'val').exists() else data_dir / 'validation' if \
|
||||
(data_dir / 'validation').exists() else None # data/test or data/val
|
||||
test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test
|
||||
if split == 'val' and not val_set:
|
||||
LOGGER.warning("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
|
||||
|
|
@ -414,16 +440,6 @@ class HUBDatasetStats:
|
|||
self.stats = {'nc': len(data['names']), 'names': list(data['names'].values())} # statistics dictionary
|
||||
self.data = data
|
||||
|
||||
@staticmethod
|
||||
def _find_yaml(dir):
|
||||
"""Return data.yaml file."""
|
||||
files = list(dir.glob('*.yaml')) or list(dir.rglob('*.yaml')) # try root level first and then recursive
|
||||
assert files, f"No *.yaml file found in '{dir.resolve()}'"
|
||||
if len(files) > 1:
|
||||
files = [f for f in files if f.stem == dir.stem] # prefer *.yaml files that match dir name
|
||||
assert len(files) == 1, f"Expected 1 *.yaml file in '{dir.resolve()}', but found {len(files)}.\n{files}"
|
||||
return files[0]
|
||||
|
||||
def _unzip(self, path):
|
||||
"""Unzip data.zip."""
|
||||
if not str(path).endswith('.zip'): # path is data.yaml
|
||||
|
|
@ -431,7 +447,7 @@ class HUBDatasetStats:
|
|||
unzip_dir = unzip_file(path, path=path.parent)
|
||||
assert unzip_dir.is_dir(), f'Error unzipping {path}, {unzip_dir} not found. ' \
|
||||
f'path/to/abc.zip MUST unzip to path/to/abc/'
|
||||
return True, str(unzip_dir), self._find_yaml(unzip_dir) # zipped, data_dir, yaml_path
|
||||
return True, str(unzip_dir), find_dataset_yaml(unzip_dir) # zipped, data_dir, yaml_path
|
||||
|
||||
def _hub_ops(self, f):
|
||||
"""Saves a compressed image for HUB previews."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue