Improve tests coverage and speed (#4340)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-08-13 22:24:01 +02:00 committed by GitHub
parent d704507217
commit 9f6d48d3cf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 183 additions and 347 deletions

View file

@ -9,7 +9,7 @@ import torch
import torch.nn as nn
from torch.nn.init import constant_, xavier_uniform_
from ultralytics.utils.tal import dist2bbox, make_anchors
from ultralytics.utils.tal import TORCH_1_10, dist2bbox, make_anchors
from .block import DFL, Proto
from .conv import Conv
@ -267,9 +267,9 @@ class RTDETRDecoder(nn.Module):
def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2):
anchors = []
for i, (h, w) in enumerate(shapes):
grid_y, grid_x = torch.meshgrid(torch.arange(end=h, dtype=dtype, device=device),
torch.arange(end=w, dtype=dtype, device=device),
indexing='ij')
sy = torch.arange(end=h, dtype=dtype, device=device)
sx = torch.arange(end=w, dtype=dtype, device=device)
grid_y, grid_x = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx)
grid_xy = torch.stack([grid_x, grid_y], -1) # (h, w, 2)
valid_WH = torch.tensor([h, w], dtype=dtype, device=device)