Tests and docstrings improvements (#4475)
This commit is contained in:
parent
c659c0fa7b
commit
615ddc9d97
22 changed files with 107 additions and 186 deletions
|
|
@ -341,10 +341,10 @@ def yaml_load(file='data.yaml', append_filename=False):
|
|||
|
||||
def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
|
||||
"""
|
||||
Pretty prints a yaml file or a yaml-formatted dictionary.
|
||||
Pretty prints a YAML file or a YAML-formatted dictionary.
|
||||
|
||||
Args:
|
||||
yaml_file: The file path of the yaml file or a yaml-formatted dictionary.
|
||||
yaml_file: The file path of the YAML file or a YAML-formatted dictionary.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
|
|
|||
|
|
@ -29,8 +29,7 @@ def _log_debug_samples(files, title='Debug Samples') -> None:
|
|||
files (list): A list of file paths in PosixPath format.
|
||||
title (str): A title that groups together images with the same values.
|
||||
"""
|
||||
task = Task.current_task()
|
||||
if task:
|
||||
if task := Task.current_task():
|
||||
for f in files:
|
||||
if f.exists():
|
||||
it = re.search(r'_batch(\d+)', f.name)
|
||||
|
|
@ -63,8 +62,7 @@ def _log_plot(title, plot_path) -> None:
|
|||
def on_pretrain_routine_start(trainer):
|
||||
"""Runs at start of pretraining routine; initializes and connects/ logs task to ClearML."""
|
||||
try:
|
||||
task = Task.current_task()
|
||||
if task:
|
||||
if task := Task.current_task():
|
||||
# Make sure the automatic pytorch and matplotlib bindings are disabled!
|
||||
# We are logging these plots and model files manually in the integration
|
||||
PatchPyTorchModelIO.update_current_task(None)
|
||||
|
|
@ -86,21 +84,19 @@ def on_pretrain_routine_start(trainer):
|
|||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
task = Task.current_task()
|
||||
|
||||
if task:
|
||||
"""Logs debug samples for the first epoch of YOLO training."""
|
||||
"""Logs debug samples for the first epoch of YOLO training and report current training progress."""
|
||||
if task := Task.current_task():
|
||||
# Log debug samples
|
||||
if trainer.epoch == 1:
|
||||
_log_debug_samples(sorted(trainer.save_dir.glob('train_batch*.jpg')), 'Mosaic')
|
||||
"""Report the current training progress."""
|
||||
# Report the current training progress
|
||||
for k, v in trainer.validator.metrics.results_dict.items():
|
||||
task.get_logger().report_scalar('train', k, v, iteration=trainer.epoch)
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Reports model information to logger at the end of an epoch."""
|
||||
task = Task.current_task()
|
||||
if task:
|
||||
if task := Task.current_task():
|
||||
# You should have access to the validation bboxes under jdict
|
||||
task.get_logger().report_scalar(title='Epoch Time',
|
||||
series='Epoch Time',
|
||||
|
|
@ -120,8 +116,7 @@ def on_val_end(validator):
|
|||
|
||||
def on_train_end(trainer):
|
||||
"""Logs final model and its name on training completion."""
|
||||
task = Task.current_task()
|
||||
if task:
|
||||
if task := Task.current_task():
|
||||
# Log final results, CM matrix + PR plots
|
||||
files = [
|
||||
'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png',
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ def _log_images(path, prefix=''):
|
|||
|
||||
# Group images by batch to enable sliders in UI
|
||||
if m := re.search(r'_batch(\d+)', name):
|
||||
ni = m.group(1)
|
||||
ni = m[1]
|
||||
new_stem = re.sub(r'_batch(\d+)', '_batch', path.stem)
|
||||
name = (Path(new_stem) / ni).with_suffix(path.suffix)
|
||||
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ def zip_directory(directory, compress=True, exclude=('.DS_Store', '__MACOSX'), p
|
|||
raise FileNotFoundError(f"Directory '{directory}' does not exist.")
|
||||
|
||||
# Unzip with progress bar
|
||||
files_to_zip = [f for f in directory.rglob('*') if f.is_file() and not any(x in f.name for x in exclude)]
|
||||
files_to_zip = [f for f in directory.rglob('*') if f.is_file() and all(x not in f.name for x in exclude)]
|
||||
zip_file = directory.with_suffix('.zip')
|
||||
compression = ZIP_DEFLATED if compress else ZIP_STORED
|
||||
with ZipFile(zip_file, 'w', compression) as f:
|
||||
|
|
@ -185,11 +185,9 @@ def check_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=1.5, h
|
|||
f'Please free {data * sf - free:.1f} GB additional disk space and try again.')
|
||||
if hard:
|
||||
raise MemoryError(text)
|
||||
else:
|
||||
LOGGER.warning(text)
|
||||
return False
|
||||
LOGGER.warning(text)
|
||||
return False
|
||||
|
||||
# Pass if error
|
||||
return True
|
||||
|
||||
|
||||
|
|
@ -332,6 +330,9 @@ def get_github_assets(repo='ultralytics/assets', version='latest', retry=False):
|
|||
r = requests.get(url) # github api
|
||||
if r.status_code != 200 and retry:
|
||||
r = requests.get(url) # try again
|
||||
if r.status_code != 200:
|
||||
LOGGER.warning(f'⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}')
|
||||
return '', []
|
||||
data = r.json()
|
||||
return data['tag_name'], [x['name'] for x in data['assets']] # tag, assets
|
||||
|
||||
|
|
|
|||
|
|
@ -382,7 +382,7 @@ def compute_ap(recall, precision):
|
|||
"""
|
||||
Compute the average precision (AP) given the recall and precision curves.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
recall (list): The recall curve.
|
||||
precision (list): The precision curve.
|
||||
|
||||
|
|
|
|||
|
|
@ -140,7 +140,7 @@ def non_max_suppression(
|
|||
"""
|
||||
Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes)
|
||||
containing the predicted boxes, classes, and masks. The tensor should be in the format
|
||||
output by a model, such as YOLO.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue