ultralytics 8.1.9 replace .size(0) with .shape[0] (#7957)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-02-01 23:33:02 +01:00 committed by GitHub
parent 66b32bb4dd
commit 70d4a3752e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 42 additions and 22 deletions

View file

@ -331,7 +331,7 @@ class ConfusionMatrix:
gt_bboxes (Array[M, 4]): Ground truth bounding boxes with xyxy format.
gt_cls (Array[M]): The class labels.
"""
if gt_cls.size(0) == 0: # Check if labels is empty
if gt_cls.shape[0] == 0: # Check if labels is empty
if detections is not None:
detections = detections[detections[:, 4] > self.conf]
detection_classes = detections[:, 5].int()

View file

@ -56,8 +56,8 @@ class TaskAlignedAssigner(nn.Module):
fg_mask (Tensor): shape(bs, num_total_anchors)
target_gt_idx (Tensor): shape(bs, num_total_anchors)
"""
self.bs = pd_scores.size(0)
self.n_max_boxes = gt_bboxes.size(1)
self.bs = pd_scores.shape[0]
self.n_max_boxes = gt_bboxes.shape[1]
if self.n_max_boxes == 0:
device = gt_bboxes.device

View file

@ -190,7 +190,7 @@ def fuse_conv_and_bn(conv, bn):
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
# Prepare spatial bias
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
b_conv = torch.zeros(conv.weight.shape[0], device=conv.weight.device) if conv.bias is None else conv.bias
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
@ -221,7 +221,7 @@ def fuse_deconv_and_bn(deconv, bn):
fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))
# Prepare spatial bias
b_conv = torch.zeros(deconv.weight.size(1), device=deconv.weight.device) if deconv.bias is None else deconv.bias
b_conv = torch.zeros(deconv.weight.shape[1], device=deconv.weight.device) if deconv.bias is None else deconv.bias
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
fuseddconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)