ultralytics 8.0.172 faster LetterBox() and Classify Tune fix (#4766)
Co-authored-by: BardJun <70683507+jy1002@users.noreply.github.com>
This commit is contained in:
parent
577d066fb2
commit
8fd9a1a048
8 changed files with 68 additions and 26 deletions
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = '8.0.171'
|
||||
__version__ = '8.0.172'
|
||||
|
||||
from ultralytics.models import RTDETR, SAM, YOLO
|
||||
from ultralytics.models.fastsam import FastSAM
|
||||
|
|
|
|||
|
|
@ -143,8 +143,8 @@ class BasePredictor:
|
|||
(list): A list of transformed images.
|
||||
"""
|
||||
same_shapes = all(x.shape == im[0].shape for x in im)
|
||||
auto = same_shapes and self.model.pt
|
||||
return [LetterBox(self.imgsz, auto=auto, stride=self.model.stride)(image=x) for x in im]
|
||||
letterbox = LetterBox(self.imgsz, auto=same_shapes and self.model.pt, stride=self.model.stride)
|
||||
return [letterbox(image=x) for x in im]
|
||||
|
||||
def write_results(self, idx, results, batch):
|
||||
"""Write inference results to a file or directory."""
|
||||
|
|
|
|||
|
|
@ -58,4 +58,5 @@ class RTDETRPredictor(BasePredictor):
|
|||
Returns:
|
||||
(list): A list of transformed imgs.
|
||||
"""
|
||||
return [LetterBox(self.imgsz, auto=False, scaleFill=True)(image=x) for x in im]
|
||||
letterbox = LetterBox(self.imgsz, auto=False, scaleFill=True)
|
||||
return [letterbox(image=x) for x in im]
|
||||
|
|
|
|||
|
|
@ -48,11 +48,11 @@ class Predictor(BasePredictor):
|
|||
im = np.ascontiguousarray(im) # contiguous
|
||||
im = torch.from_numpy(im)
|
||||
|
||||
img = im.to(self.device)
|
||||
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
|
||||
im = im.to(self.device)
|
||||
im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32
|
||||
if not_tensor:
|
||||
img = (img - self.mean) / self.std
|
||||
return img
|
||||
im = (im - self.mean) / self.std
|
||||
return im
|
||||
|
||||
def pre_transform(self, im):
|
||||
"""
|
||||
|
|
@ -64,8 +64,9 @@ class Predictor(BasePredictor):
|
|||
Returns:
|
||||
(list): A list of transformed images.
|
||||
"""
|
||||
assert len(im) == 1, 'SAM model has not supported batch inference yet!'
|
||||
return [LetterBox(self.args.imgsz, auto=False, center=False)(image=x) for x in im]
|
||||
assert len(im) == 1, 'SAM model does not currently support batched inference'
|
||||
letterbox = LetterBox(self.args.imgsz, auto=False, center=False)
|
||||
return [letterbox(image=x) for x in im]
|
||||
|
||||
def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -70,6 +70,7 @@ class ClassificationValidator(BaseValidator):
|
|||
on_plot=self.on_plot)
|
||||
self.metrics.speed = self.speed
|
||||
self.metrics.confusion_matrix = self.confusion_matrix
|
||||
self.metrics.save_dir = self.save_dir
|
||||
|
||||
def get_stats(self):
|
||||
"""Returns a dictionary of metrics obtained by processing targets and predictions."""
|
||||
|
|
|
|||
|
|
@ -214,7 +214,7 @@ def add_integration_callbacks(instance):
|
|||
callbacks_list.extend([clear_cb, comet_cb, dvc_cb, mlflow_cb, neptune_cb, tune_cb, tb_cb, wb_cb])
|
||||
|
||||
# Load export callbacks (patch to avoid CoreML protobuf error)
|
||||
if 'Exporter' in instance.__class__.__name__:
|
||||
if 'Exporter' in instance.__class__.__name__ and instance.args.format in ('coreml', 'mlmodel'):
|
||||
from .tensorboard import callbacks as tb_cb
|
||||
callbacks_list.append(tb_cb)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue