Model coverage cleanup (#4585)
This commit is contained in:
parent
c635418a27
commit
deac7575b1
12 changed files with 132 additions and 175 deletions
|
|
@ -30,40 +30,6 @@ class Conv2d_BN(torch.nn.Sequential):
|
|||
torch.nn.init.constant_(bn.bias, 0)
|
||||
self.add_module('bn', bn)
|
||||
|
||||
@torch.no_grad()
|
||||
def fuse(self):
|
||||
c, bn = self._modules.values()
|
||||
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
|
||||
w = c.weight * w[:, None, None, None]
|
||||
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
|
||||
m = torch.nn.Conv2d(w.size(1) * self.c.groups,
|
||||
w.size(0),
|
||||
w.shape[2:],
|
||||
stride=self.c.stride,
|
||||
padding=self.c.padding,
|
||||
dilation=self.c.dilation,
|
||||
groups=self.c.groups)
|
||||
m.weight.data.copy_(w)
|
||||
m.bias.data.copy_(b)
|
||||
return m
|
||||
|
||||
|
||||
# NOTE: This module and timm package is needed only for training.
|
||||
# from ultralytics.utils.checks import check_requirements
|
||||
# check_requirements('timm')
|
||||
# from timm.models.layers import DropPath as TimmDropPath
|
||||
# from timm.models.layers import trunc_normal_
|
||||
# class DropPath(TimmDropPath):
|
||||
#
|
||||
# def __init__(self, drop_prob=None):
|
||||
# super().__init__(drop_prob=drop_prob)
|
||||
# self.drop_prob = drop_prob
|
||||
#
|
||||
# def __repr__(self):
|
||||
# msg = super().__repr__()
|
||||
# msg += f'(drop_prob={self.drop_prob})'
|
||||
# return msg
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue