Simplify Results() class (#4579)
This commit is contained in:
parent
e9f596430f
commit
2db35afad5
4 changed files with 30 additions and 39 deletions
|
|
@ -101,19 +101,18 @@ class Results(SimpleClass):
|
|||
self.names = names
|
||||
self.path = path
|
||||
self.save_dir = None
|
||||
self._keys = ('boxes', 'masks', 'probs', 'keypoints')
|
||||
self._keys = 'boxes', 'masks', 'probs', 'keypoints'
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Return a Results object for the specified index."""
|
||||
r = self.new()
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k)[idx])
|
||||
return r
|
||||
return self._apply('__getitem__', idx)
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of detections in the Results object."""
|
||||
for k in self.keys:
|
||||
return len(getattr(self, k))
|
||||
for k in self._keys:
|
||||
v = getattr(self, k)
|
||||
if v is not None:
|
||||
return len(v)
|
||||
|
||||
def update(self, boxes=None, masks=None, probs=None):
|
||||
"""Update the boxes, masks, and probs attributes of the Results object."""
|
||||
|
|
@ -125,43 +124,34 @@ class Results(SimpleClass):
|
|||
if probs is not None:
|
||||
self.probs = probs
|
||||
|
||||
def _apply(self, fn, *args, **kwargs):
|
||||
r = self.new()
|
||||
for k in self._keys:
|
||||
v = getattr(self, k)
|
||||
if v is not None:
|
||||
setattr(r, k, getattr(v, fn)(*args, **kwargs))
|
||||
return r
|
||||
|
||||
def cpu(self):
|
||||
"""Return a copy of the Results object with all tensors on CPU memory."""
|
||||
r = self.new()
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k).cpu())
|
||||
return r
|
||||
return self._apply('cpu')
|
||||
|
||||
def numpy(self):
|
||||
"""Return a copy of the Results object with all tensors as numpy arrays."""
|
||||
r = self.new()
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k).numpy())
|
||||
return r
|
||||
return self._apply('numpy')
|
||||
|
||||
def cuda(self):
|
||||
"""Return a copy of the Results object with all tensors on GPU memory."""
|
||||
r = self.new()
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k).cuda())
|
||||
return r
|
||||
return self._apply('cuda')
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
"""Return a copy of the Results object with tensors on the specified device and dtype."""
|
||||
r = self.new()
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k).to(*args, **kwargs))
|
||||
return r
|
||||
return self._apply('to', *args, **kwargs)
|
||||
|
||||
def new(self):
|
||||
"""Return a new Results object with the same image, path, and names."""
|
||||
return Results(orig_img=self.orig_img, path=self.path, names=self.names)
|
||||
|
||||
@property
|
||||
def keys(self):
|
||||
"""Return a list of non-empty attribute names."""
|
||||
return [k for k in self._keys if getattr(self, k) is not None]
|
||||
|
||||
def plot(
|
||||
self,
|
||||
conf=True,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue