ultralytics 8.0.49 task, exports and metadata updates (#1197)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Mehran Ghandehari <mehran.maps@gmail.com>
Co-authored-by: Paul Guerrie <97041392+paulguerrie@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-03-01 21:16:09 -08:00 committed by GitHub
parent 74e4c94806
commit 3861e6c82a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 111 additions and 101 deletions

View file

@ -215,7 +215,7 @@ class Exporter:
self.model = model
self.file = file
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y)
self.pretty_name = self.file.stem.replace('yolo', 'YOLO')
self.pretty_name = Path(self.model.yaml.get('yaml_file', self.file)).stem.replace('yolo', 'YOLO')
description = f'Ultralytics {self.pretty_name} model ' + f'trained on {Path(self.args.data).name}' \
if self.args.data else '(untrained)'
self.metadata = {
@ -225,6 +225,8 @@ class Exporter:
'version': __version__,
'stride': int(max(model.stride)),
'task': model.task,
'batch': self.args.batch,
'imgsz': self.imgsz,
'names': model.names} # model metadata
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} BCHW and "
@ -283,8 +285,7 @@ class Exporter:
f = self.file.with_suffix('.torchscript')
ts = torch.jit.trace(self.model, self.im, strict=False)
d = {'shape': self.im.shape, 'stride': int(max(self.model.stride)), 'names': self.model.names}
extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
extra_files = {'config.txt': json.dumps(self.metadata)} # torch._C.ExtraFilesMap()
if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
LOGGER.info(f'{prefix} optimizing for mobile...')
from torch.utils.mobile_optimizer import optimize_for_mobile
@ -429,16 +430,18 @@ class Exporter:
classifier_config=classifier_config)
bits, mode = (8, 'kmeans_lut') if self.args.int8 else (16, 'linear') if self.args.half else (32, None)
if bits < 32:
if 'kmeans' in mode:
check_requirements('scikit-learn') # scikit-learn package required for k-means quantization
ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
if self.args.nms and self.model.task == 'detect':
ct_model = self._pipeline_coreml(ct_model)
m = self.metadata # metadata dict
ct_model.short_description = m['description']
ct_model.author = m['author']
ct_model.license = m['license']
ct_model.version = m['version']
ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items() if k in ('stride', 'task', 'names')})
ct_model.short_description = m.pop('description')
ct_model.author = m.pop('author')
ct_model.license = m.pop('license')
ct_model.version = m.pop('version')
ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items()})
ct_model.save(str(f))
return f, ct_model