Unified model loading with backwards compatibility (#132)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
8996c5c6cf
commit
c3d961fb03
10 changed files with 65 additions and 50 deletions
|
|
@ -164,8 +164,8 @@ class Exporter:
|
|||
assert not self.args.dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic'
|
||||
|
||||
# Checks
|
||||
if self.args.batch_size == 16:
|
||||
self.args.batch_size = 1 # TODO: resolve batch_size 16 default in config.yaml
|
||||
# if self.args.batch_size == model.args['batch_size']: # user has not modified training batch_size
|
||||
self.args.batch_size = 1
|
||||
self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
|
||||
if self.args.optimize:
|
||||
assert self.device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'
|
||||
|
|
@ -778,7 +778,7 @@ def export(cfg):
|
|||
if Path(cfg.model).suffix == '.yaml':
|
||||
model = DetectionModel(cfg.model)
|
||||
elif Path(cfg.model).suffix == '.pt':
|
||||
model = attempt_load_weights(cfg.model)
|
||||
model = attempt_load_weights(cfg.model, fuse=True)
|
||||
else:
|
||||
TypeError(f'Unsupported model type {cfg.model}')
|
||||
exporter(model=model)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue