From c943a3b747cd445a31c116690823d17a5996b0b9 Mon Sep 17 00:00:00 2001 From: Muhammad Rizwan Munawar Date: Thu, 31 Oct 2024 16:52:14 +0500 Subject: [PATCH] Case-insensitive optimizer name (#17287) Co-authored-by: UltralyticsAssistant Co-authored-by: Glenn Jocher --- ultralytics/engine/trainer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 35206739..e82aed9e 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -791,6 +791,8 @@ class BaseTrainer: else: # weight (with decay) g[0].append(param) + optimizers = {"Adam", "Adamax", "AdamW", "NAdam", "RAdam", "RMSProp", "SGD", "auto"} + name = {x.lower(): x for x in optimizers}.get(name.lower(), None) if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}: optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0) elif name == "RMSProp": @@ -799,9 +801,8 @@ class BaseTrainer: optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True) else: raise NotImplementedError( - f"Optimizer '{name}' not found in list of available optimizers " - f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]." - "To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics." + f"Optimizer '{name}' not found in list of available optimizers {optimizers}. " + "Request support for addition optimizers at https://github.com/ultralytics/ultralytics." ) optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay