ultralytics 8.0.239 Ultralytics Actions and hub-sdk adoption (#7431)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com>
Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2024-01-10 03:16:08 +01:00 committed by GitHub
parent e795277391
commit fe27db2f6e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
139 changed files with 6870 additions and 5125 deletions

View file

@ -10,7 +10,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import uniform_
__all__ = 'multi_scale_deformable_attn_pytorch', 'inverse_sigmoid'
__all__ = "multi_scale_deformable_attn_pytorch", "inverse_sigmoid"
def _get_clones(module, n):
@ -27,7 +27,7 @@ def linear_init_(module):
"""Initialize the weights and biases of a linear module."""
bound = 1 / math.sqrt(module.weight.shape[0])
uniform_(module.weight, -bound, bound)
if hasattr(module, 'bias') and module.bias is not None:
if hasattr(module, "bias") and module.bias is not None:
uniform_(module.bias, -bound, bound)
@ -39,9 +39,12 @@ def inverse_sigmoid(x, eps=1e-5):
return torch.log(x1 / x2)
def multi_scale_deformable_attn_pytorch(value: torch.Tensor, value_spatial_shapes: torch.Tensor,
sampling_locations: torch.Tensor,
attention_weights: torch.Tensor) -> torch.Tensor:
def multi_scale_deformable_attn_pytorch(
value: torch.Tensor,
value_spatial_shapes: torch.Tensor,
sampling_locations: torch.Tensor,
attention_weights: torch.Tensor,
) -> torch.Tensor:
"""
Multi-scale deformable attention.
@ -58,23 +61,25 @@ def multi_scale_deformable_attn_pytorch(value: torch.Tensor, value_spatial_shape
# bs, H_*W_, num_heads*embed_dims ->
# bs, num_heads*embed_dims, H_*W_ ->
# bs*num_heads, embed_dims, H_, W_
value_l_ = (value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_))
value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)
# bs, num_queries, num_heads, num_points, 2 ->
# bs, num_heads, num_queries, num_points, 2 ->
# bs*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
# bs*num_heads, embed_dims, num_queries, num_points
sampling_value_l_ = F.grid_sample(value_l_,
sampling_grid_l_,
mode='bilinear',
padding_mode='zeros',
align_corners=False)
sampling_value_l_ = F.grid_sample(
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
)
sampling_value_list.append(sampling_value_l_)
# (bs, num_queries, num_heads, num_levels, num_points) ->
# (bs, num_heads, num_queries, num_levels, num_points) ->
# (bs, num_heads, 1, num_queries, num_levels*num_points)
attention_weights = attention_weights.transpose(1, 2).reshape(bs * num_heads, 1, num_queries,
num_levels * num_points)
output = ((torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(
bs, num_heads * embed_dims, num_queries))
attention_weights = attention_weights.transpose(1, 2).reshape(
bs * num_heads, 1, num_queries, num_levels * num_points
)
output = (
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
.sum(-1)
.view(bs, num_heads * embed_dims, num_queries)
)
return output.transpose(1, 2).contiguous()