Tests and docstrings improvements (#4475)

This commit is contained in:
Glenn Jocher 2023-08-21 17:02:14 +02:00 committed by GitHub
parent c659c0fa7b
commit 615ddc9d97
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 107 additions and 186 deletions

View file

@ -341,10 +341,10 @@ def yaml_load(file='data.yaml', append_filename=False):
def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
"""
Pretty prints a yaml file or a yaml-formatted dictionary.
Pretty prints a YAML file or a YAML-formatted dictionary.
Args:
yaml_file: The file path of the yaml file or a yaml-formatted dictionary.
yaml_file: The file path of the YAML file or a YAML-formatted dictionary.
Returns:
None

View file

@ -29,8 +29,7 @@ def _log_debug_samples(files, title='Debug Samples') -> None:
files (list): A list of file paths in PosixPath format.
title (str): A title that groups together images with the same values.
"""
task = Task.current_task()
if task:
if task := Task.current_task():
for f in files:
if f.exists():
it = re.search(r'_batch(\d+)', f.name)
@ -63,8 +62,7 @@ def _log_plot(title, plot_path) -> None:
def on_pretrain_routine_start(trainer):
"""Runs at start of pretraining routine; initializes and connects/ logs task to ClearML."""
try:
task = Task.current_task()
if task:
if task := Task.current_task():
# Make sure the automatic pytorch and matplotlib bindings are disabled!
# We are logging these plots and model files manually in the integration
PatchPyTorchModelIO.update_current_task(None)
@ -86,21 +84,19 @@ def on_pretrain_routine_start(trainer):
def on_train_epoch_end(trainer):
task = Task.current_task()
if task:
"""Logs debug samples for the first epoch of YOLO training."""
"""Logs debug samples for the first epoch of YOLO training and report current training progress."""
if task := Task.current_task():
# Log debug samples
if trainer.epoch == 1:
_log_debug_samples(sorted(trainer.save_dir.glob('train_batch*.jpg')), 'Mosaic')
"""Report the current training progress."""
# Report the current training progress
for k, v in trainer.validator.metrics.results_dict.items():
task.get_logger().report_scalar('train', k, v, iteration=trainer.epoch)
def on_fit_epoch_end(trainer):
"""Reports model information to logger at the end of an epoch."""
task = Task.current_task()
if task:
if task := Task.current_task():
# You should have access to the validation bboxes under jdict
task.get_logger().report_scalar(title='Epoch Time',
series='Epoch Time',
@ -120,8 +116,7 @@ def on_val_end(validator):
def on_train_end(trainer):
"""Logs final model and its name on training completion."""
task = Task.current_task()
if task:
if task := Task.current_task():
# Log final results, CM matrix + PR plots
files = [
'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png',

View file

@ -40,7 +40,7 @@ def _log_images(path, prefix=''):
# Group images by batch to enable sliders in UI
if m := re.search(r'_batch(\d+)', name):
ni = m.group(1)
ni = m[1]
new_stem = re.sub(r'_batch(\d+)', '_batch', path.stem)
name = (Path(new_stem) / ni).with_suffix(path.suffix)

View file

@ -93,7 +93,7 @@ def zip_directory(directory, compress=True, exclude=('.DS_Store', '__MACOSX'), p
raise FileNotFoundError(f"Directory '{directory}' does not exist.")
# Unzip with progress bar
files_to_zip = [f for f in directory.rglob('*') if f.is_file() and not any(x in f.name for x in exclude)]
files_to_zip = [f for f in directory.rglob('*') if f.is_file() and all(x not in f.name for x in exclude)]
zip_file = directory.with_suffix('.zip')
compression = ZIP_DEFLATED if compress else ZIP_STORED
with ZipFile(zip_file, 'w', compression) as f:
@ -185,11 +185,9 @@ def check_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=1.5, h
f'Please free {data * sf - free:.1f} GB additional disk space and try again.')
if hard:
raise MemoryError(text)
else:
LOGGER.warning(text)
return False
LOGGER.warning(text)
return False
# Pass if error
return True
@ -332,6 +330,9 @@ def get_github_assets(repo='ultralytics/assets', version='latest', retry=False):
r = requests.get(url) # github api
if r.status_code != 200 and retry:
r = requests.get(url) # try again
if r.status_code != 200:
LOGGER.warning(f'⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}')
return '', []
data = r.json()
return data['tag_name'], [x['name'] for x in data['assets']] # tag, assets

View file

@ -382,7 +382,7 @@ def compute_ap(recall, precision):
"""
Compute the average precision (AP) given the recall and precision curves.
Arguments:
Args:
recall (list): The recall curve.
precision (list): The precision curve.

View file

@ -140,7 +140,7 @@ def non_max_suppression(
"""
Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
Arguments:
Args:
prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes)
containing the predicted boxes, classes, and masks. The tensor should be in the format
output by a model, such as YOLO.