Model coverage cleanup (#4585)

This commit is contained in:
Glenn Jocher 2023-08-27 04:19:41 +02:00 committed by GitHub
parent c635418a27
commit deac7575b1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 132 additions and 175 deletions

View file

@ -51,8 +51,7 @@ class TransformerEncoderLayer(nn.Module):
src = self.norm1(src)
src2 = self.fc2(self.dropout(self.act(self.fc1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
return self.norm2(src)
def forward_pre(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
src2 = self.norm1(src)
@ -61,8 +60,7 @@ class TransformerEncoderLayer(nn.Module):
src = src + self.dropout1(src2)
src2 = self.norm2(src)
src2 = self.fc2(self.dropout(self.act(self.fc1(src2))))
src = src + self.dropout2(src2)
return src
return src + self.dropout2(src2)
def forward(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
"""Forward propagates the input through the encoder module."""
@ -116,8 +114,7 @@ class TransformerLayer(nn.Module):
def forward(self, x):
"""Apply a transformer block to the input x and return the output."""
x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
x = self.fc2(self.fc1(x)) + x
return x
return self.fc2(self.fc1(x)) + x
class TransformerBlock(nn.Module):
@ -185,8 +182,7 @@ class LayerNorm2d(nn.Module):
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
return self.weight[:, None, None] * x + self.bias[:, None, None]
class MSDeformAttn(nn.Module):
@ -271,8 +267,7 @@ class MSDeformAttn(nn.Module):
else:
raise ValueError(f'Last dim of reference_points must be 2 or 4, but got {num_points}.')
output = multi_scale_deformable_attn_pytorch(value, value_shapes, sampling_locations, attention_weights)
output = self.output_proj(output)
return output
return self.output_proj(output)
class DeformableTransformerDecoderLayer(nn.Module):
@ -309,8 +304,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
def forward_ffn(self, tgt):
tgt2 = self.linear2(self.dropout3(self.act(self.linear1(tgt))))
tgt = tgt + self.dropout4(tgt2)
tgt = self.norm3(tgt)
return tgt
return self.norm3(tgt)
def forward(self, embed, refer_bbox, feats, shapes, padding_mask=None, attn_mask=None, query_pos=None):
# self attention
@ -327,9 +321,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
embed = self.norm2(embed)
# ffn
embed = self.forward_ffn(embed)
return embed
return self.forward_ffn(embed)
class DeformableTransformerDecoder(nn.Module):