ultralytics 8.0.44 export and task fixes (#1088)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Mehran Ghandehari <mehran.maps@gmail.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-02-24 03:11:25 +01:00 committed by GitHub
parent fe61018975
commit 3ea659411b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
32 changed files with 439 additions and 480 deletions

View file

@ -44,7 +44,6 @@ class BaseValidator:
Attributes:
dataloader (DataLoader): Dataloader to use for validation.
pbar (tqdm): Progress bar to update during validation.
logger (logging.Logger): Logger to use for validation.
args (SimpleNamespace): Configuration for the validator.
model (nn.Module): Model to validate.
data (dict): Data dictionary.
@ -56,7 +55,7 @@ class BaseValidator:
save_dir (Path): Directory to save results.
"""
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None):
"""
Initializes a BaseValidator instance.
@ -69,14 +68,13 @@ class BaseValidator:
"""
self.dataloader = dataloader
self.pbar = pbar
self.logger = logger or LOGGER
self.args = args or get_cfg(DEFAULT_CFG)
self.model = None
self.data = None
self.device = None
self.batch_i = None
self.training = True
self.speed = None
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
self.jdict = None
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
@ -123,14 +121,14 @@ class BaseValidator:
self.device = model.device
if not pt and not jit:
self.args.batch = 1 # export.py models default to batch-size 1
self.logger.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
if isinstance(self.args.data, str) and self.args.data.endswith('.yaml'):
self.data = check_det_dataset(self.args.data)
elif self.args.task == 'classify':
self.data = check_cls_dataset(self.args.data)
else:
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' not found ❌"))
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
if self.device.type == 'cpu':
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
@ -179,7 +177,7 @@ class BaseValidator:
stats = self.get_stats()
self.check_stats(stats)
self.print_results()
self.speed = tuple(x.t / len(self.dataloader.dataset) * 1E3 for x in dt) # speeds per image
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1E3 for x in dt)))
self.finalize_metrics()
self.run_callbacks('on_val_end')
if self.training:
@ -187,11 +185,11 @@ class BaseValidator:
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')}
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
else:
self.logger.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' %
self.speed)
LOGGER.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' %
tuple(self.speed.values()))
if self.args.save_json and self.jdict:
with open(str(self.save_dir / 'predictions.json'), 'w') as f:
self.logger.info(f'Saving {f.name}...')
LOGGER.info(f'Saving {f.name}...')
json.dump(self.jdict, f) # flatten and save
stats = self.eval_json(stats) # update stats
if self.args.plots or self.args.save_json: