Fix layer count; show layers with no params in detailed log (#19202)
Signed-off-by: Mohammed Yasin <32206511+Y-T-G@users.noreply.github.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
0ae4670da6
commit
6e7d888703
1 changed files with 15 additions and 8 deletions
|
|
@ -302,15 +302,22 @@ def model_info(model, detailed=False, verbose=True, imgsz=640):
|
||||||
return
|
return
|
||||||
n_p = get_num_params(model) # number of parameters
|
n_p = get_num_params(model) # number of parameters
|
||||||
n_g = get_num_gradients(model) # number of gradients
|
n_g = get_num_gradients(model) # number of gradients
|
||||||
n_l = len(list(model.modules())) # number of layers
|
layers = __import__("collections").OrderedDict((n, m) for n, m in model.named_modules() if len(m._modules) == 0)
|
||||||
|
n_l = len(layers) # number of layers
|
||||||
if detailed:
|
if detailed:
|
||||||
LOGGER.info(f"{'layer':>5}{'name':>40}{'gradient':>10}{'parameters':>12}{'shape':>20}{'mu':>10}{'sigma':>10}")
|
h = f"{'layer':>5}{'name':>40}{'type':>20}{'gradient':>10}{'parameters':>12}{'shape':>20}{'mu':>10}{'sigma':>10}"
|
||||||
for i, (name, p) in enumerate(model.named_parameters()):
|
LOGGER.info(h)
|
||||||
name = name.replace("module_list.", "")
|
for i, (mn, m) in enumerate(layers.items()):
|
||||||
LOGGER.info(
|
mn = mn.replace("module_list.", "")
|
||||||
f"{i:>5g}{name:>40s}{p.requires_grad!r:>10}{p.numel():>12g}{str(list(p.shape)):>20s}"
|
mt = m.__class__.__name__
|
||||||
f"{p.mean():>10.3g}{p.std():>10.3g}{str(p.dtype):>15s}"
|
if len(m._parameters):
|
||||||
)
|
for pn, p in m.named_parameters():
|
||||||
|
LOGGER.info(
|
||||||
|
f"{i:>5g}{mn + '.' + pn:>40}{mt:>20}{p.requires_grad!r:>10}{p.numel():>12g}"
|
||||||
|
f"{str(list(p.shape)):>20}{p.mean():>10.3g}{p.std():>10.3g}{str(p.dtype).replace('torch.', ''):>15}"
|
||||||
|
)
|
||||||
|
else: # layers with no learnable params
|
||||||
|
LOGGER.info(f"{i:>5g}{mn:>40}{mt:>20}{False!r:>10}{0:>12g}{str([]):>20}{'-':>10}{'-':>10}{'-':>15}")
|
||||||
|
|
||||||
flops = get_flops(model, imgsz) # imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]
|
flops = get_flops(model, imgsz) # imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]
|
||||||
fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else ""
|
fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else ""
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue