Model coverage cleanup (#4585)
This commit is contained in:
parent
c635418a27
commit
deac7575b1
12 changed files with 132 additions and 175 deletions
|
|
@ -51,8 +51,7 @@ class TransformerEncoderLayer(nn.Module):
|
|||
src = self.norm1(src)
|
||||
src2 = self.fc2(self.dropout(self.act(self.fc1(src))))
|
||||
src = src + self.dropout2(src2)
|
||||
src = self.norm2(src)
|
||||
return src
|
||||
return self.norm2(src)
|
||||
|
||||
def forward_pre(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
|
||||
src2 = self.norm1(src)
|
||||
|
|
@ -61,8 +60,7 @@ class TransformerEncoderLayer(nn.Module):
|
|||
src = src + self.dropout1(src2)
|
||||
src2 = self.norm2(src)
|
||||
src2 = self.fc2(self.dropout(self.act(self.fc1(src2))))
|
||||
src = src + self.dropout2(src2)
|
||||
return src
|
||||
return src + self.dropout2(src2)
|
||||
|
||||
def forward(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
|
||||
"""Forward propagates the input through the encoder module."""
|
||||
|
|
@ -116,8 +114,7 @@ class TransformerLayer(nn.Module):
|
|||
def forward(self, x):
|
||||
"""Apply a transformer block to the input x and return the output."""
|
||||
x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
|
||||
x = self.fc2(self.fc1(x)) + x
|
||||
return x
|
||||
return self.fc2(self.fc1(x)) + x
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
|
|
@ -185,8 +182,7 @@ class LayerNorm2d(nn.Module):
|
|||
u = x.mean(1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
||||
return x
|
||||
return self.weight[:, None, None] * x + self.bias[:, None, None]
|
||||
|
||||
|
||||
class MSDeformAttn(nn.Module):
|
||||
|
|
@ -271,8 +267,7 @@ class MSDeformAttn(nn.Module):
|
|||
else:
|
||||
raise ValueError(f'Last dim of reference_points must be 2 or 4, but got {num_points}.')
|
||||
output = multi_scale_deformable_attn_pytorch(value, value_shapes, sampling_locations, attention_weights)
|
||||
output = self.output_proj(output)
|
||||
return output
|
||||
return self.output_proj(output)
|
||||
|
||||
|
||||
class DeformableTransformerDecoderLayer(nn.Module):
|
||||
|
|
@ -309,8 +304,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
|
|||
def forward_ffn(self, tgt):
|
||||
tgt2 = self.linear2(self.dropout3(self.act(self.linear1(tgt))))
|
||||
tgt = tgt + self.dropout4(tgt2)
|
||||
tgt = self.norm3(tgt)
|
||||
return tgt
|
||||
return self.norm3(tgt)
|
||||
|
||||
def forward(self, embed, refer_bbox, feats, shapes, padding_mask=None, attn_mask=None, query_pos=None):
|
||||
# self attention
|
||||
|
|
@ -327,9 +321,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
|
|||
embed = self.norm2(embed)
|
||||
|
||||
# ffn
|
||||
embed = self.forward_ffn(embed)
|
||||
|
||||
return embed
|
||||
return self.forward_ffn(embed)
|
||||
|
||||
|
||||
class DeformableTransformerDecoder(nn.Module):
|
||||
|
|
|
|||
|
|
@ -322,31 +322,10 @@ class PoseModel(DetectionModel):
|
|||
class ClassificationModel(BaseModel):
|
||||
"""YOLOv8 classification model."""
|
||||
|
||||
def __init__(self,
|
||||
cfg='yolov8n-cls.yaml',
|
||||
model=None,
|
||||
ch=3,
|
||||
nc=None,
|
||||
cutoff=10,
|
||||
verbose=True): # YAML, model, channels, number of classes, cutoff index, verbose flag
|
||||
def __init__(self, cfg='yolov8n-cls.yaml', ch=3, nc=None, verbose=True):
|
||||
"""Init ClassificationModel with YAML, channels, number of classes, verbose flag."""
|
||||
super().__init__()
|
||||
self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg, ch, nc, verbose)
|
||||
|
||||
def _from_detection_model(self, model, nc=1000, cutoff=10):
|
||||
"""Create a YOLOv5 classification model from a YOLOv5 detection model."""
|
||||
from ultralytics.nn.autobackend import AutoBackend
|
||||
if isinstance(model, AutoBackend):
|
||||
model = model.model # unwrap DetectMultiBackend
|
||||
model.model = model.model[:cutoff] # backbone
|
||||
m = model.model[-1] # last layer
|
||||
ch = m.conv.in_channels if hasattr(m, 'conv') else m.cv1.conv.in_channels # ch into module
|
||||
c = Classify(ch, nc) # Classify()
|
||||
c.i, c.f, c.type = m.i, m.f, 'models.common.Classify' # index, from, type
|
||||
model.model[-1] = c # replace
|
||||
self.model = model.model
|
||||
self.stride = model.stride
|
||||
self.save = []
|
||||
self.nc = nc
|
||||
self._from_yaml(cfg, ch, nc, verbose)
|
||||
|
||||
def _from_yaml(self, cfg, ch, nc, verbose):
|
||||
"""Set YOLOv8 model configurations and define the model architecture."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue