diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index cf3816fa..52e999c9 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -425,11 +425,11 @@ class ClassificationModel(BaseModel): elif isinstance(m, nn.Sequential): types = [type(x) for x in m] if nn.Linear in types: - i = types.index(nn.Linear) # nn.Linear index + i = len(types) - 1 - types[::-1].index(nn.Linear) # last nn.Linear index if m[i].out_features != nc: m[i] = nn.Linear(m[i].in_features, nc) elif nn.Conv2d in types: - i = types.index(nn.Conv2d) # nn.Conv2d index + i = len(types) - 1 - types[::-1].index(nn.Conv2d) # last nn.Conv2d index if m[i].out_channels != nc: m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)