Simplify Results() class (#4579)

This commit is contained in:
Glenn Jocher 2023-08-26 17:27:18 +02:00 committed by GitHub
parent e9f596430f
commit 2db35afad5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 30 additions and 39 deletions

View file

@ -101,19 +101,18 @@ class Results(SimpleClass):
self.names = names
self.path = path
self.save_dir = None
self._keys = ('boxes', 'masks', 'probs', 'keypoints')
self._keys = 'boxes', 'masks', 'probs', 'keypoints'
def __getitem__(self, idx):
"""Return a Results object for the specified index."""
r = self.new()
for k in self.keys:
setattr(r, k, getattr(self, k)[idx])
return r
return self._apply('__getitem__', idx)
def __len__(self):
"""Return the number of detections in the Results object."""
for k in self.keys:
return len(getattr(self, k))
for k in self._keys:
v = getattr(self, k)
if v is not None:
return len(v)
def update(self, boxes=None, masks=None, probs=None):
"""Update the boxes, masks, and probs attributes of the Results object."""
@ -125,43 +124,34 @@ class Results(SimpleClass):
if probs is not None:
self.probs = probs
def _apply(self, fn, *args, **kwargs):
r = self.new()
for k in self._keys:
v = getattr(self, k)
if v is not None:
setattr(r, k, getattr(v, fn)(*args, **kwargs))
return r
def cpu(self):
"""Return a copy of the Results object with all tensors on CPU memory."""
r = self.new()
for k in self.keys:
setattr(r, k, getattr(self, k).cpu())
return r
return self._apply('cpu')
def numpy(self):
"""Return a copy of the Results object with all tensors as numpy arrays."""
r = self.new()
for k in self.keys:
setattr(r, k, getattr(self, k).numpy())
return r
return self._apply('numpy')
def cuda(self):
"""Return a copy of the Results object with all tensors on GPU memory."""
r = self.new()
for k in self.keys:
setattr(r, k, getattr(self, k).cuda())
return r
return self._apply('cuda')
def to(self, *args, **kwargs):
"""Return a copy of the Results object with tensors on the specified device and dtype."""
r = self.new()
for k in self.keys:
setattr(r, k, getattr(self, k).to(*args, **kwargs))
return r
return self._apply('to', *args, **kwargs)
def new(self):
"""Return a new Results object with the same image, path, and names."""
return Results(orig_img=self.orig_img, path=self.path, names=self.names)
@property
def keys(self):
"""Return a list of non-empty attribute names."""
return [k for k in self._keys if getattr(self, k) is not None]
def plot(
self,
conf=True,