ultralytics 8.0.170 apply is_list fixes for torch.Tensor inputs (#4713)

Co-authored-by: Gezhi Zhang <765724965@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-09-04 02:22:05 +02:00 committed by GitHub
parent a1c1d6b483
commit aa9133bb88
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 74 additions and 36 deletions

View file

@ -771,6 +771,19 @@ def masks2segments(masks, strategy='largest'):
return segments
def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray:
"""
Convert a batch of FP32 torch tensors (0.0-1.0) to a NumPy uint8 array (0-255), changing from BCHW to BHWC layout.
Args:
batch (torch.Tensor): Input tensor batch of shape (Batch, Channels, Height, Width) and dtype torch.float32.
Returns:
(np.ndarray): Output NumPy array batch of shape (Batch, Height, Width, Channels) and dtype uint8.
"""
return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
def clean_str(s):
"""
Cleans a string by replacing special characters with underscore _