Fix IS_TMP_WRITEABLE order of operations (#16294)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-09-15 21:55:58 +02:00 committed by GitHub
parent fa6362a6f5
commit 2a17462367
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 39 additions and 8 deletions

View file

@ -1,6 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
import csv
import urllib
from copy import copy
from pathlib import Path
@ -12,7 +13,7 @@ import torch
import yaml
from PIL import Image
from tests import CFG, IS_TMP_WRITEABLE, MODEL, SOURCE, TMP
from tests import CFG, MODEL, SOURCE, SOURCES_LIST, TMP
from ultralytics import RTDETR, YOLO
from ultralytics.cfg import MODELS, TASK2DATA, TASKS
from ultralytics.data.build import load_inference_source
@ -26,11 +27,14 @@ from ultralytics.utils import (
WEIGHTS_DIR,
WINDOWS,
checks,
is_dir_writeable,
is_github_action_running,
)
from ultralytics.utils.downloads import download
from ultralytics.utils.torch_utils import TORCH_1_9
IS_TMP_WRITEABLE = is_dir_writeable(TMP) # WARNING: must be run once tests start as TMP does not exist on tests/init
def test_model_forward():
"""Test the forward pass of the YOLO model."""
@ -70,11 +74,37 @@ def test_model_profile():
@pytest.mark.skipif(not IS_TMP_WRITEABLE, reason="directory is not writeable")
def test_predict_txt():
"""Tests YOLO predictions with file, directory, and pattern sources listed in a text file."""
txt_file = TMP / "sources.txt"
with open(txt_file, "w") as f:
for x in [ASSETS / "bus.jpg", ASSETS, ASSETS / "*", ASSETS / "**/*.jpg"]:
f.write(f"{x}\n")
_ = YOLO(MODEL)(source=txt_file, imgsz=32)
file = TMP / "sources_multi_row.txt"
with open(file, "w") as f:
for src in SOURCES_LIST:
f.write(f"{src}\n")
results = YOLO(MODEL)(source=file, imgsz=32)
assert len(results) == 7 # 1 + 2 + 2 + 2 = 7 images
@pytest.mark.skipif(True, reason="disabled for testing")
@pytest.mark.skipif(not IS_TMP_WRITEABLE, reason="directory is not writeable")
def test_predict_csv_multi_row():
"""Tests YOLO predictions with sources listed in multiple rows of a CSV file."""
file = TMP / "sources_multi_row.csv"
with open(file, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["source"])
writer.writerows([[src] for src in SOURCES_LIST])
results = YOLO(MODEL)(source=file, imgsz=32)
assert len(results) == 7 # 1 + 2 + 2 + 2 = 7 images
@pytest.mark.skipif(True, reason="disabled for testing")
@pytest.mark.skipif(not IS_TMP_WRITEABLE, reason="directory is not writeable")
def test_predict_csv_single_row():
"""Tests YOLO predictions with sources listed in a single row of a CSV file."""
file = TMP / "sources_single_row.csv"
with open(file, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(SOURCES_LIST)
results = YOLO(MODEL)(source=file, imgsz=32)
assert len(results) == 7 # 1 + 2 + 2 + 2 = 7 images
@pytest.mark.parametrize("model_name", MODELS)