Empty index [0,,1] robust device selection (#16631)
This commit is contained in:
parent
7cc40f6847
commit
5c76601d21
1 changed files with 3 additions and 1 deletions
|
|
@ -170,6 +170,8 @@ def select_device(device="", batch=0, newline=False, verbose=True):
|
|||
elif device: # non-cpu device requested
|
||||
if device == "cuda":
|
||||
device = "0"
|
||||
if "," in device:
|
||||
device = ",".join([x for x in device.split(",") if x]) # remove sequential commas, i.e. "0,,1" -> "0,1"
|
||||
visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available()
|
||||
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.split(","))):
|
||||
|
|
@ -191,7 +193,7 @@ def select_device(device="", batch=0, newline=False, verbose=True):
|
|||
)
|
||||
|
||||
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
|
||||
devices = device.split(",") if device else "0" # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
||||
devices = device.split(",") if device else "0" # i.e. "0,1" -> ["0", "1"]
|
||||
n = len(devices) # device count
|
||||
if n > 1: # multi-GPU
|
||||
if batch < 1:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue