Predictor support (#65)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Ayush Chaurasia 2022-12-07 10:33:10 +05:30 committed by GitHub
parent 479992093c
commit e6737f1207
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 916 additions and 48 deletions

View file

@ -46,8 +46,8 @@ class BaseValidator:
self.args.half &= self.device.type != 'cpu'
model = model.half() if self.args.half else model.float()
self.model = model
loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
else: # TODO: handle this when detectMultiBackend is supported
self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
else:
assert model is not None, "Either trainer or model is needed for validation"
self.device = select_device(self.args.device, self.args.batch_size)
self.args.half &= self.device.type != 'cpu'
@ -90,13 +90,11 @@ class BaseValidator:
# inference
with dt[1]:
preds = model(batch["img"])
# TODO: remember to add native augmentation support when implementing model, like:
# preds, train_out = model(im, augment=augment)
# loss
with dt[2]:
if self.training:
loss += trainer.criterion(preds, batch)[1]
self.loss += trainer.criterion(preds, batch)[1]
# pre-process predictions
with dt[3]:
@ -123,7 +121,7 @@ class BaseValidator:
model.float()
# TODO: implement save json
return stats | trainer.label_loss_items(loss.cpu() / len(self.dataloader), prefix="val") \
return stats | trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val") \
if self.training else stats
def get_dataloader(self, dataset_path, batch_size):