Add Conv2() module (#2820)

This commit is contained in:
Glenn Jocher 2023-05-25 18:12:50 +02:00 committed by GitHub
parent d19c5b6ce8
commit 441e67d330
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 92 additions and 45 deletions

View file

@ -43,6 +43,27 @@ class Conv(nn.Module):
return self.act(self.conv(x))
class Conv2(Conv):
"""Simplified RepConv module with Conv fusing."""
def __init__(self, c1, c2, k=3, s=1, p=None, g=1, d=1, act=True):
"""Initialize Conv layer with given arguments including activation."""
super().__init__(c1, c2, k, s, p, g=g, d=d, act=act)
self.cv2 = nn.Conv2d(c1, c2, 1, s, autopad(1, p, d), groups=g, dilation=d, bias=False) # add 1x1 conv
def forward(self, x):
"""Apply convolution, batch normalization and activation to input tensor."""
return self.act(self.bn(self.conv(x) + self.cv2(x)))
def fuse_convs(self):
"""Fuse parallel convolutions."""
w = torch.zeros_like(self.conv.weight.data)
i = [x // 2 for x in w.shape[2:]]
w[:, :, i[0] - 1:i[0], i[1] - 1:i[1]] = self.cv2.weight.data.clone()
self.conv.weight.data += w
self.__delattr__('cv2')
class LightConv(nn.Module):
"""Light convolution with args(ch_in, ch_out, kernel).
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py