ultralytics 8.0.49 task, exports and metadata updates (#1197)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Mehran Ghandehari <mehran.maps@gmail.com> Co-authored-by: Paul Guerrie <97041392+paulguerrie@users.noreply.github.com>
This commit is contained in:
parent
74e4c94806
commit
3861e6c82a
20 changed files with 111 additions and 101 deletions
|
|
@ -215,7 +215,7 @@ class Exporter:
|
|||
self.model = model
|
||||
self.file = file
|
||||
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y)
|
||||
self.pretty_name = self.file.stem.replace('yolo', 'YOLO')
|
||||
self.pretty_name = Path(self.model.yaml.get('yaml_file', self.file)).stem.replace('yolo', 'YOLO')
|
||||
description = f'Ultralytics {self.pretty_name} model ' + f'trained on {Path(self.args.data).name}' \
|
||||
if self.args.data else '(untrained)'
|
||||
self.metadata = {
|
||||
|
|
@ -225,6 +225,8 @@ class Exporter:
|
|||
'version': __version__,
|
||||
'stride': int(max(model.stride)),
|
||||
'task': model.task,
|
||||
'batch': self.args.batch,
|
||||
'imgsz': self.imgsz,
|
||||
'names': model.names} # model metadata
|
||||
|
||||
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} BCHW and "
|
||||
|
|
@ -283,8 +285,7 @@ class Exporter:
|
|||
f = self.file.with_suffix('.torchscript')
|
||||
|
||||
ts = torch.jit.trace(self.model, self.im, strict=False)
|
||||
d = {'shape': self.im.shape, 'stride': int(max(self.model.stride)), 'names': self.model.names}
|
||||
extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
|
||||
extra_files = {'config.txt': json.dumps(self.metadata)} # torch._C.ExtraFilesMap()
|
||||
if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
|
||||
LOGGER.info(f'{prefix} optimizing for mobile...')
|
||||
from torch.utils.mobile_optimizer import optimize_for_mobile
|
||||
|
|
@ -429,16 +430,18 @@ class Exporter:
|
|||
classifier_config=classifier_config)
|
||||
bits, mode = (8, 'kmeans_lut') if self.args.int8 else (16, 'linear') if self.args.half else (32, None)
|
||||
if bits < 32:
|
||||
if 'kmeans' in mode:
|
||||
check_requirements('scikit-learn') # scikit-learn package required for k-means quantization
|
||||
ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
|
||||
if self.args.nms and self.model.task == 'detect':
|
||||
ct_model = self._pipeline_coreml(ct_model)
|
||||
|
||||
m = self.metadata # metadata dict
|
||||
ct_model.short_description = m['description']
|
||||
ct_model.author = m['author']
|
||||
ct_model.license = m['license']
|
||||
ct_model.version = m['version']
|
||||
ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items() if k in ('stride', 'task', 'names')})
|
||||
ct_model.short_description = m.pop('description')
|
||||
ct_model.author = m.pop('author')
|
||||
ct_model.license = m.pop('license')
|
||||
ct_model.version = m.pop('version')
|
||||
ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items()})
|
||||
ct_model.save(str(f))
|
||||
return f, ct_model
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue