Fix Classifier last layer indexing (#13219)

This commit is contained in:
Glenn Jocher 2024-05-30 09:55:43 +02:00 committed by GitHub
parent 51c3169e9f
commit 1a40d30dbf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -425,11 +425,11 @@ class ClassificationModel(BaseModel):
elif isinstance(m, nn.Sequential):
types = [type(x) for x in m]
if nn.Linear in types:
i = types.index(nn.Linear) # nn.Linear index
i = len(types) - 1 - types[::-1].index(nn.Linear) # last nn.Linear index
if m[i].out_features != nc:
m[i] = nn.Linear(m[i].in_features, nc)
elif nn.Conv2d in types:
i = types.index(nn.Conv2d) # nn.Conv2d index
i = len(types) - 1 - types[::-1].index(nn.Conv2d) # last nn.Conv2d index
if m[i].out_channels != nc:
m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)