fix double zero_grad call messing up gradient accumulation (#11217)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
d5458f27cd
commit
2f9a604387
1 changed files with 1 additions and 1 deletions
|
|
@ -329,6 +329,7 @@ class BaseTrainer:
|
|||
base_idx = (self.epochs - self.args.close_mosaic) * nb
|
||||
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
|
||||
epoch = self.start_epoch
|
||||
self.optimizer.zero_grad() # zero any resumed gradients to ensure stability on train start
|
||||
while True:
|
||||
self.epoch = epoch
|
||||
self.run_callbacks("on_train_epoch_start")
|
||||
|
|
@ -349,7 +350,6 @@ class BaseTrainer:
|
|||
LOGGER.info(self.progress_string())
|
||||
pbar = TQDM(enumerate(self.train_loader), total=nb)
|
||||
self.tloss = None
|
||||
self.optimizer.zero_grad()
|
||||
for i, batch in pbar:
|
||||
self.run_callbacks("on_train_batch_start")
|
||||
# Warmup
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue