ultralytics 8.0.170 apply is_list fixes for torch.Tensor inputs (#4713)
Co-authored-by: Gezhi Zhang <765724965@qq.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
a1c1d6b483
commit
aa9133bb88
15 changed files with 74 additions and 36 deletions
|
|
@ -771,6 +771,19 @@ def masks2segments(masks, strategy='largest'):
|
|||
return segments
|
||||
|
||||
|
||||
def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray:
|
||||
"""
|
||||
Convert a batch of FP32 torch tensors (0.0-1.0) to a NumPy uint8 array (0-255), changing from BCHW to BHWC layout.
|
||||
|
||||
Args:
|
||||
batch (torch.Tensor): Input tensor batch of shape (Batch, Channels, Height, Width) and dtype torch.float32.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Output NumPy array batch of shape (Batch, Height, Width, Channels) and dtype uint8.
|
||||
"""
|
||||
return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
|
||||
|
||||
|
||||
def clean_str(s):
|
||||
"""
|
||||
Cleans a string by replacing special characters with underscore _
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue