Model coverage cleanup (#4585)
This commit is contained in:
parent
c635418a27
commit
deac7575b1
12 changed files with 132 additions and 175 deletions
|
|
@ -51,8 +51,7 @@ class TransformerEncoderLayer(nn.Module):
|
|||
src = self.norm1(src)
|
||||
src2 = self.fc2(self.dropout(self.act(self.fc1(src))))
|
||||
src = src + self.dropout2(src2)
|
||||
src = self.norm2(src)
|
||||
return src
|
||||
return self.norm2(src)
|
||||
|
||||
def forward_pre(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
|
||||
src2 = self.norm1(src)
|
||||
|
|
@ -61,8 +60,7 @@ class TransformerEncoderLayer(nn.Module):
|
|||
src = src + self.dropout1(src2)
|
||||
src2 = self.norm2(src)
|
||||
src2 = self.fc2(self.dropout(self.act(self.fc1(src2))))
|
||||
src = src + self.dropout2(src2)
|
||||
return src
|
||||
return src + self.dropout2(src2)
|
||||
|
||||
def forward(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
|
||||
"""Forward propagates the input through the encoder module."""
|
||||
|
|
@ -116,8 +114,7 @@ class TransformerLayer(nn.Module):
|
|||
def forward(self, x):
|
||||
"""Apply a transformer block to the input x and return the output."""
|
||||
x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
|
||||
x = self.fc2(self.fc1(x)) + x
|
||||
return x
|
||||
return self.fc2(self.fc1(x)) + x
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
|
|
@ -185,8 +182,7 @@ class LayerNorm2d(nn.Module):
|
|||
u = x.mean(1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
||||
return x
|
||||
return self.weight[:, None, None] * x + self.bias[:, None, None]
|
||||
|
||||
|
||||
class MSDeformAttn(nn.Module):
|
||||
|
|
@ -271,8 +267,7 @@ class MSDeformAttn(nn.Module):
|
|||
else:
|
||||
raise ValueError(f'Last dim of reference_points must be 2 or 4, but got {num_points}.')
|
||||
output = multi_scale_deformable_attn_pytorch(value, value_shapes, sampling_locations, attention_weights)
|
||||
output = self.output_proj(output)
|
||||
return output
|
||||
return self.output_proj(output)
|
||||
|
||||
|
||||
class DeformableTransformerDecoderLayer(nn.Module):
|
||||
|
|
@ -309,8 +304,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
|
|||
def forward_ffn(self, tgt):
|
||||
tgt2 = self.linear2(self.dropout3(self.act(self.linear1(tgt))))
|
||||
tgt = tgt + self.dropout4(tgt2)
|
||||
tgt = self.norm3(tgt)
|
||||
return tgt
|
||||
return self.norm3(tgt)
|
||||
|
||||
def forward(self, embed, refer_bbox, feats, shapes, padding_mask=None, attn_mask=None, query_pos=None):
|
||||
# self attention
|
||||
|
|
@ -327,9 +321,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
|
|||
embed = self.norm2(embed)
|
||||
|
||||
# ffn
|
||||
embed = self.forward_ffn(embed)
|
||||
|
||||
return embed
|
||||
return self.forward_ffn(embed)
|
||||
|
||||
|
||||
class DeformableTransformerDecoder(nn.Module):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue