ultralytics 8.2.98 faster fuse() operations (#16375)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-09-19 21:22:55 +02:00 committed by GitHub
parent 641d09164c
commit 07a5ff9ddc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 3 additions and 3 deletions

View file

@ -251,7 +251,7 @@ def fuse_conv_and_bn(conv, bn):
)
# Prepare filters
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_conv = conv.weight.view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
@ -282,7 +282,7 @@ def fuse_deconv_and_bn(deconv, bn):
)
# Prepare filters
w_deconv = deconv.weight.clone().view(deconv.out_channels, -1)
w_deconv = deconv.weight.view(deconv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))