Predictor support (#65)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
479992093c
commit
e6737f1207
22 changed files with 916 additions and 48 deletions
|
|
@ -46,8 +46,8 @@ class BaseValidator:
|
|||
self.args.half &= self.device.type != 'cpu'
|
||||
model = model.half() if self.args.half else model.float()
|
||||
self.model = model
|
||||
loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
|
||||
else: # TODO: handle this when detectMultiBackend is supported
|
||||
self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
|
||||
else:
|
||||
assert model is not None, "Either trainer or model is needed for validation"
|
||||
self.device = select_device(self.args.device, self.args.batch_size)
|
||||
self.args.half &= self.device.type != 'cpu'
|
||||
|
|
@ -90,13 +90,11 @@ class BaseValidator:
|
|||
# inference
|
||||
with dt[1]:
|
||||
preds = model(batch["img"])
|
||||
# TODO: remember to add native augmentation support when implementing model, like:
|
||||
# preds, train_out = model(im, augment=augment)
|
||||
|
||||
# loss
|
||||
with dt[2]:
|
||||
if self.training:
|
||||
loss += trainer.criterion(preds, batch)[1]
|
||||
self.loss += trainer.criterion(preds, batch)[1]
|
||||
|
||||
# pre-process predictions
|
||||
with dt[3]:
|
||||
|
|
@ -123,7 +121,7 @@ class BaseValidator:
|
|||
model.float()
|
||||
# TODO: implement save json
|
||||
|
||||
return stats | trainer.label_loss_items(loss.cpu() / len(self.dataloader), prefix="val") \
|
||||
return stats | trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val") \
|
||||
if self.training else stats
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue