From 6e7d888703367764a60ea9791ede82e7954b3113 Mon Sep 17 00:00:00 2001 From: Mohammed Yasin <32206511+Y-T-G@users.noreply.github.com> Date: Mon, 17 Feb 2025 15:20:04 +0800 Subject: [PATCH] Fix layer count; show layers with no params in detailed log (#19202) Signed-off-by: Mohammed Yasin <32206511+Y-T-G@users.noreply.github.com> Co-authored-by: UltralyticsAssistant Co-authored-by: Glenn Jocher --- ultralytics/utils/torch_utils.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py index 00aa8df5..a589b7d6 100644 --- a/ultralytics/utils/torch_utils.py +++ b/ultralytics/utils/torch_utils.py @@ -302,15 +302,22 @@ def model_info(model, detailed=False, verbose=True, imgsz=640): return n_p = get_num_params(model) # number of parameters n_g = get_num_gradients(model) # number of gradients - n_l = len(list(model.modules())) # number of layers + layers = __import__("collections").OrderedDict((n, m) for n, m in model.named_modules() if len(m._modules) == 0) + n_l = len(layers) # number of layers if detailed: - LOGGER.info(f"{'layer':>5}{'name':>40}{'gradient':>10}{'parameters':>12}{'shape':>20}{'mu':>10}{'sigma':>10}") - for i, (name, p) in enumerate(model.named_parameters()): - name = name.replace("module_list.", "") - LOGGER.info( - f"{i:>5g}{name:>40s}{p.requires_grad!r:>10}{p.numel():>12g}{str(list(p.shape)):>20s}" - f"{p.mean():>10.3g}{p.std():>10.3g}{str(p.dtype):>15s}" - ) + h = f"{'layer':>5}{'name':>40}{'type':>20}{'gradient':>10}{'parameters':>12}{'shape':>20}{'mu':>10}{'sigma':>10}" + LOGGER.info(h) + for i, (mn, m) in enumerate(layers.items()): + mn = mn.replace("module_list.", "") + mt = m.__class__.__name__ + if len(m._parameters): + for pn, p in m.named_parameters(): + LOGGER.info( + f"{i:>5g}{mn + '.' + pn:>40}{mt:>20}{p.requires_grad!r:>10}{p.numel():>12g}" + f"{str(list(p.shape)):>20}{p.mean():>10.3g}{p.std():>10.3g}{str(p.dtype).replace('torch.', ''):>15}" + ) + else: # layers with no learnable params + LOGGER.info(f"{i:>5g}{mn:>40}{mt:>20}{False!r:>10}{0:>12g}{str([]):>20}{'-':>10}{'-':>10}{'-':>15}") flops = get_flops(model, imgsz) # imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320] fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else ""