Add Conv2() module (#2820)
This commit is contained in:
parent
d19c5b6ce8
commit
441e67d330
13 changed files with 92 additions and 45 deletions
|
|
@ -8,9 +8,9 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x,
|
||||
Classify, Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Focus,
|
||||
GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv, RTDETRDecoder,
|
||||
Segment)
|
||||
Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, DWConvTranspose2d,
|
||||
Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv,
|
||||
RTDETRDecoder, Segment)
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_yaml
|
||||
from ultralytics.yolo.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss
|
||||
|
|
@ -103,7 +103,9 @@ class BaseModel(nn.Module):
|
|||
"""
|
||||
if not self.is_fused():
|
||||
for m in self.model.modules():
|
||||
if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
|
||||
if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, 'bn'):
|
||||
if isinstance(m, Conv2):
|
||||
m.fuse_convs()
|
||||
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
||||
delattr(m, 'bn') # remove batchnorm
|
||||
m.forward = m.forward_fuse # update forward
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue