Add utils.ops and nn.modules to tests (#4484)
This commit is contained in:
parent
1cec0185a1
commit
6da8f7f51e
14 changed files with 246 additions and 330 deletions
|
|
@ -9,7 +9,7 @@ import numpy as np
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
__all__ = ('Conv', 'LightConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv',
|
||||
__all__ = ('Conv', 'Conv2', 'LightConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv',
|
||||
'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'RepConv')
|
||||
|
||||
|
||||
|
|
@ -54,6 +54,10 @@ class Conv2(Conv):
|
|||
"""Apply convolution, batch normalization and activation to input tensor."""
|
||||
return self.act(self.bn(self.conv(x) + self.cv2(x)))
|
||||
|
||||
def forward_fuse(self, x):
|
||||
"""Apply fused convolution, batch normalization and activation to input tensor."""
|
||||
return self.act(self.bn(self.conv(x)))
|
||||
|
||||
def fuse_convs(self):
|
||||
"""Fuse parallel convolutions."""
|
||||
w = torch.zeros_like(self.conv.weight.data)
|
||||
|
|
@ -61,6 +65,7 @@ class Conv2(Conv):
|
|||
w[:, :, i[0]:i[0] + 1, i[1]:i[1] + 1] = self.cv2.weight.data.clone()
|
||||
self.conv.weight.data += w
|
||||
self.__delattr__('cv2')
|
||||
self.forward = self.forward_fuse
|
||||
|
||||
|
||||
class LightConv(nn.Module):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue