From 1747efba721a1d874ffa7c8d8611a21788a5f55c Mon Sep 17 00:00:00 2001 From: AdamP <7806910+adamp87@users.noreply.github.com> Date: Wed, 11 Dec 2024 21:06:42 +0100 Subject: [PATCH] Fix SAM CUDA hard-code (#18153) Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Glenn Jocher --- ultralytics/models/sam/modules/sam.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ultralytics/models/sam/modules/sam.py b/ultralytics/models/sam/modules/sam.py index 5d48ed1f..5d3326a2 100644 --- a/ultralytics/models/sam/modules/sam.py +++ b/ultralytics/models/sam/modules/sam.py @@ -685,11 +685,11 @@ class SAM2Model(torch.nn.Module): if prev is None: continue # skip padding frames # "maskmem_features" might have been offloaded to CPU in demo use cases, - # so we load it back to GPU (it's a no-op if it's already on GPU). - feats = prev["maskmem_features"].cuda(non_blocking=True) + # so we load it back to inference device (it's a no-op if it's already on device). + feats = prev["maskmem_features"].to(device=device, non_blocking=True) to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) # Spatial positional encoding (it might have been offloaded to CPU in eval) - maskmem_enc = prev["maskmem_pos_enc"][-1].cuda() + maskmem_enc = prev["maskmem_pos_enc"][-1].to(device=device) maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) # Temporal positional encoding maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]