Fix DDP when device is a list (#4600)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2023-08-28 19:04:30 +08:00 committed by GitHub
parent 23b4f697c9
commit 53b4f8c713
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 2 deletions

View file

@ -28,7 +28,14 @@ def test_checks():
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
def test_train():
device = 0 if CUDA_DEVICE_COUNT == 1 else [0, 1]
YOLO(MODEL).train(data=DATA, imgsz=64, epochs=1, batch=-1, device=device) # also test AutoBatch, requires imgsz>=64
YOLO(MODEL).train(data=DATA, imgsz=64, epochs=1, device=device) # requires imgsz>=64
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
def test_autobatch():
from ultralytics.utils.autobatch import check_train_batch_size
check_train_batch_size(YOLO(MODEL).model.cuda(), imgsz=128, amp=True)
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')