ultralytics 8.0.159 add Classify training resume feature (#4482)
This commit is contained in:
parent
b2f279ffdd
commit
c0a9660310
3 changed files with 5 additions and 9 deletions
|
|
@ -62,10 +62,10 @@ class ClassificationTrainer(BaseTrainer):
|
|||
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
||||
return
|
||||
|
||||
model = str(self.model)
|
||||
model, ckpt = str(self.model), None
|
||||
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
|
||||
if model.endswith('.pt'):
|
||||
self.model, _ = attempt_load_one_weight(model, device='cpu')
|
||||
self.model, ckpt = attempt_load_one_weight(model, device='cpu')
|
||||
for p in self.model.parameters():
|
||||
p.requires_grad = True # for training
|
||||
elif model.split('.')[-1] in ('yaml', 'yml'):
|
||||
|
|
@ -76,7 +76,7 @@ class ClassificationTrainer(BaseTrainer):
|
|||
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
|
||||
ClassificationModel.reshape_outputs(self.model, self.data['nc'])
|
||||
|
||||
return # do not return ckpt. Classification doesn't support resume
|
||||
return ckpt
|
||||
|
||||
def build_dataset(self, img_path, mode='train', batch=None):
|
||||
return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train')
|
||||
|
|
@ -122,10 +122,6 @@ class ClassificationTrainer(BaseTrainer):
|
|||
loss_items = [round(float(loss_items), 5)]
|
||||
return dict(zip(keys, loss_items))
|
||||
|
||||
def resume_training(self, ckpt):
|
||||
"""Resumes training from a given checkpoint."""
|
||||
pass
|
||||
|
||||
def plot_metrics(self):
|
||||
"""Plots metrics from a CSV file."""
|
||||
plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue