Start export implementation (#110)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
c1b38428bc
commit
92dad1c1b5
32 changed files with 827 additions and 222 deletions
|
|
@ -66,7 +66,7 @@ class BaseValidator:
|
|||
self.args.batch_size = model.batch_size
|
||||
else:
|
||||
self.device = model.device
|
||||
if not (pt or jit):
|
||||
if not pt and not jit:
|
||||
self.args.batch_size = 1 # export.py models default to batch-size 1
|
||||
self.logger.info(
|
||||
f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
|
||||
|
|
@ -75,8 +75,8 @@ class BaseValidator:
|
|||
data = check_dataset_yaml(self.args.data)
|
||||
else:
|
||||
data = check_dataset(self.args.data)
|
||||
self.dataloader = self.get_dataloader(data.get("val") or data.set("test"),
|
||||
self.args.batch_size) if not self.dataloader else self.dataloader
|
||||
self.dataloader = self.dataloader or \
|
||||
self.get_dataloader(data.get("val") or data.set("test"), self.args.batch_size)
|
||||
|
||||
model.eval()
|
||||
|
||||
|
|
@ -139,7 +139,7 @@ class BaseValidator:
|
|||
def postprocess(self, preds):
|
||||
return preds
|
||||
|
||||
def init_metrics(self):
|
||||
def init_metrics(self, model):
|
||||
pass
|
||||
|
||||
def update_metrics(self, preds, batch):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue