Fix wrong parameter count while profiling (#16790)
This commit is contained in:
parent
6babd8fedc
commit
7680a16257
1 changed files with 2 additions and 2 deletions
|
|
@ -1061,10 +1061,10 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||||
|
|
||||||
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
|
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
|
||||||
t = str(m)[8:-2].replace("__main__.", "") # module type
|
t = str(m)[8:-2].replace("__main__.", "") # module type
|
||||||
m.np = sum(x.numel() for x in m_.parameters()) # number params
|
m_.np = sum(x.numel() for x in m_.parameters()) # number params
|
||||||
m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
|
m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
|
||||||
if verbose:
|
if verbose:
|
||||||
LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}") # print
|
LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m_.np:10.0f} {t:<45}{str(args):<30}") # print
|
||||||
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
|
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
|
||||||
layers.append(m_)
|
layers.append(m_)
|
||||||
if i == 0:
|
if i == 0:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue