ultralytics 8.0.153 YOLO Tasks Cleanup (#4314)
This commit is contained in:
parent
39395aedc8
commit
822608986c
22 changed files with 87 additions and 55 deletions
|
|
@ -51,9 +51,18 @@ class BaseValidator:
|
|||
device (torch.device): Device to use for validation.
|
||||
batch_i (int): Current batch index.
|
||||
training (bool): Whether the model is in training mode.
|
||||
speed (float): Batch processing speed in seconds.
|
||||
jdict (dict): Dictionary to store validation results.
|
||||
names (dict): Class names.
|
||||
seen: Records the number of images seen so far during validation.
|
||||
stats: Placeholder for statistics during validation.
|
||||
confusion_matrix: Placeholder for a confusion matrix.
|
||||
nc: Number of classes.
|
||||
iouv: (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
|
||||
jdict (dict): Dictionary to store JSON validation results.
|
||||
speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
|
||||
batch processing times in milliseconds.
|
||||
save_dir (Path): Directory to save results.
|
||||
plots (dict): Dictionary to store plots for visualization.
|
||||
callbacks (dict): Dictionary to store various callback functions.
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||
|
|
@ -65,6 +74,7 @@ class BaseValidator:
|
|||
save_dir (Path): Directory to save results.
|
||||
pbar (tqdm.tqdm): Progress bar for displaying progress.
|
||||
args (SimpleNamespace): Configuration for the validator.
|
||||
_callbacks (dict): Dictionary to store various callback functions.
|
||||
"""
|
||||
self.dataloader = dataloader
|
||||
self.pbar = pbar
|
||||
|
|
@ -74,8 +84,14 @@ class BaseValidator:
|
|||
self.device = None
|
||||
self.batch_i = None
|
||||
self.training = True
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
self.names = None
|
||||
self.seen = None
|
||||
self.stats = None
|
||||
self.confusion_matrix = None
|
||||
self.nc = None
|
||||
self.iouv = None
|
||||
self.jdict = None
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
|
||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||
name = self.args.name or f'{self.args.mode}'
|
||||
|
|
@ -200,26 +216,26 @@ class BaseValidator:
|
|||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
|
||||
return stats
|
||||
|
||||
def match_predictions(self, pred_classes: torch.Tensor, true_classes: torch.Tensor,
|
||||
iou: torch.Tensor) -> torch.Tensor:
|
||||
def match_predictions(self, pred_classes, true_classes, iou):
|
||||
"""
|
||||
Matches predictions to ground truth objects (pred_classes, true_classes) using IoU.
|
||||
|
||||
Args:
|
||||
pred_classes (torch.Tensor): Predicted class indices of shape(N,).
|
||||
true_classes (torch.Tensor): Target class indices of shape(M,).
|
||||
iou (torch.Tensor): IoU thresholds from 0.50 to 0.95 in space of 0.05.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds.
|
||||
"""
|
||||
correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
|
||||
correct_class = true_classes[:, None] == pred_classes
|
||||
for i in range(len(self.iouv)):
|
||||
x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match
|
||||
if x[0].shape[0]:
|
||||
for i, iouv in enumerate(self.iouv):
|
||||
x = torch.nonzero(iou.ge(iouv) & correct_class) # IoU > threshold and classes match
|
||||
if x.shape[0]:
|
||||
# Concatenate [label, detect, iou]
|
||||
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
|
||||
if x[0].shape[0] > 1:
|
||||
matches = torch.cat((x, iou[x[:, 0], x[:, 1]].unsqueeze(1)), 1).cpu().numpy()
|
||||
if x.shape[0] > 1:
|
||||
matches = matches[matches[:, 2].argsort()[::-1]]
|
||||
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
||||
# matches = matches[matches[:, 2].argsort()[::-1]]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue