ultralytics 8.2.98 faster fuse() operations (#16375)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
641d09164c
commit
07a5ff9ddc
2 changed files with 3 additions and 3 deletions
|
|
@ -1,6 +1,6 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.2.97"
|
__version__ = "8.2.98"
|
||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
|
||||||
|
|
@ -251,7 +251,7 @@ def fuse_conv_and_bn(conv, bn):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare filters
|
# Prepare filters
|
||||||
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
w_conv = conv.weight.view(conv.out_channels, -1)
|
||||||
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
||||||
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
|
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
|
||||||
|
|
||||||
|
|
@ -282,7 +282,7 @@ def fuse_deconv_and_bn(deconv, bn):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare filters
|
# Prepare filters
|
||||||
w_deconv = deconv.weight.clone().view(deconv.out_channels, -1)
|
w_deconv = deconv.weight.view(deconv.out_channels, -1)
|
||||||
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
||||||
fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))
|
fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue