Model coverage cleanup (#4585)

This commit is contained in:
Glenn Jocher 2023-08-27 04:19:41 +02:00 committed by GitHub
parent c635418a27
commit deac7575b1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 132 additions and 175 deletions

View file

@ -30,40 +30,6 @@ class Conv2d_BN(torch.nn.Sequential):
torch.nn.init.constant_(bn.bias, 0)
self.add_module('bn', bn)
@torch.no_grad()
def fuse(self):
c, bn = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
w = c.weight * w[:, None, None, None]
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
m = torch.nn.Conv2d(w.size(1) * self.c.groups,
w.size(0),
w.shape[2:],
stride=self.c.stride,
padding=self.c.padding,
dilation=self.c.dilation,
groups=self.c.groups)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
# NOTE: This module and timm package is needed only for training.
# from ultralytics.utils.checks import check_requirements
# check_requirements('timm')
# from timm.models.layers import DropPath as TimmDropPath
# from timm.models.layers import trunc_normal_
# class DropPath(TimmDropPath):
#
# def __init__(self, drop_prob=None):
# super().__init__(drop_prob=drop_prob)
# self.drop_prob = drop_prob
#
# def __repr__(self):
# msg = super().__repr__()
# msg += f'(drop_prob={self.drop_prob})'
# return msg
class PatchEmbed(nn.Module):