Fix SAM CUDA hard-code (#18153)
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
8535d36885
commit
1747efba72
1 changed files with 3 additions and 3 deletions
|
|
@ -685,11 +685,11 @@ class SAM2Model(torch.nn.Module):
|
|||
if prev is None:
|
||||
continue # skip padding frames
|
||||
# "maskmem_features" might have been offloaded to CPU in demo use cases,
|
||||
# so we load it back to GPU (it's a no-op if it's already on GPU).
|
||||
feats = prev["maskmem_features"].cuda(non_blocking=True)
|
||||
# so we load it back to inference device (it's a no-op if it's already on device).
|
||||
feats = prev["maskmem_features"].to(device=device, non_blocking=True)
|
||||
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
|
||||
# Spatial positional encoding (it might have been offloaded to CPU in eval)
|
||||
maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
|
||||
maskmem_enc = prev["maskmem_pos_enc"][-1].to(device=device)
|
||||
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
|
||||
# Temporal positional encoding
|
||||
maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue