ultralytics 8.1.37 fix empty sys.argv bug (#9390)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Joshua Harrison <108553645+TurbineJoshua@users.noreply.github.com>
Co-authored-by: Joshua Harrison <joshua@fido.investments>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-03-29 03:20:39 +01:00 committed by GitHub
parent ed2250cf1c
commit 4a7ccba0af
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 31 additions and 15 deletions

View file

@ -1,5 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
import sys
from unittest import mock
from ultralytics import YOLO from ultralytics import YOLO
from ultralytics.cfg import get_cfg from ultralytics.cfg import get_cfg
from ultralytics.engine.exporter import Exporter from ultralytics.engine.exporter import Exporter
@ -49,8 +51,10 @@ def test_detect():
pred = detect.DetectionPredictor(overrides={"imgsz": [64, 64]}) pred = detect.DetectionPredictor(overrides={"imgsz": [64, 64]})
pred.add_callback("on_predict_start", test_func) pred.add_callback("on_predict_start", test_func)
assert test_func in pred.callbacks["on_predict_start"], "callback test failed" assert test_func in pred.callbacks["on_predict_start"], "callback test failed"
result = pred(source=ASSETS, model=f"{MODEL}.pt") # Confirm there is no issue with sys.argv being empty.
assert len(result), "predictor test failed" with mock.patch.object(sys, 'argv', []):
result = pred(source=ASSETS, model=f"{MODEL}.pt")
assert len(result), "predictor test failed"
overrides["resume"] = trainer.last overrides["resume"] = trainer.last
trainer = detect.DetectionTrainer(overrides=overrides) trainer = detect.DetectionTrainer(overrides=overrides)

View file

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.1.36" __version__ = "8.1.37"
from ultralytics.data.explorer.explorer import Explorer from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld

View file

@ -54,8 +54,9 @@ TASK2METRIC = {
"obb": "metrics/mAP50-95(B)", "obb": "metrics/mAP50-95(B)",
} }
ARGV = sys.argv or ["", ""] # sometimes sys.argv = []
CLI_HELP_MSG = f""" CLI_HELP_MSG = f"""
Arguments received: {str(['yolo'] + sys.argv[1:])}. Ultralytics 'yolo' commands use the following syntax: Arguments received: {str(['yolo'] + ARGV[1:])}. Ultralytics 'yolo' commands use the following syntax:
yolo TASK MODE ARGS yolo TASK MODE ARGS
@ -452,7 +453,7 @@ def entrypoint(debug=""):
It uses the package's default cfg and initializes it using the passed overrides. It uses the package's default cfg and initializes it using the passed overrides.
Then it calls the CLI function with the composed cfg Then it calls the CLI function with the composed cfg
""" """
args = (debug.split(" ") if debug else sys.argv)[1:] args = (debug.split(" ") if debug else ARGV)[1:]
if not args: # no arguments passed if not args: # no arguments passed
LOGGER.info(CLI_HELP_MSG) LOGGER.info(CLI_HELP_MSG)
return return

View file

@ -1,7 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
import inspect import inspect
import sys
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
@ -11,7 +10,18 @@ import torch
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
from ultralytics.hub.utils import HUB_WEB_ROOT from ultralytics.hub.utils import HUB_WEB_ROOT
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, SETTINGS, callbacks, checks, emojis, yaml_load from ultralytics.utils import (
ARGV,
ASSETS,
DEFAULT_CFG_DICT,
LOGGER,
RANK,
SETTINGS,
callbacks,
checks,
emojis,
yaml_load,
)
class Model(nn.Module): class Model(nn.Module):
@ -421,8 +431,8 @@ class Model(nn.Module):
source = ASSETS source = ASSETS
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
is_cli = (sys.argv[0].endswith("yolo") or sys.argv[0].endswith("ultralytics")) and any( is_cli = (ARGV[0].endswith("yolo") or ARGV[0].endswith("ultralytics")) and any(
x in sys.argv for x in ("predict", "track", "mode=predict", "mode=track") x in ARGV for x in ("predict", "track", "mode=predict", "mode=track")
) )
custom = {"conf": 0.25, "batch": 1, "save": is_cli, "mode": "predict"} # method defaults custom = {"conf": 0.25, "batch": 1, "save": is_cli, "mode": "predict"} # method defaults

View file

@ -3,7 +3,6 @@
import os import os
import platform import platform
import random import random
import sys
import threading import threading
import time import time
from pathlib import Path from pathlib import Path
@ -11,6 +10,7 @@ from pathlib import Path
import requests import requests
from ultralytics.utils import ( from ultralytics.utils import (
ARGV,
ENVIRONMENT, ENVIRONMENT,
LOGGER, LOGGER,
ONLINE, ONLINE,
@ -188,7 +188,7 @@ class Events:
self.rate_limit = 60.0 # rate limit (seconds) self.rate_limit = 60.0 # rate limit (seconds)
self.t = 0.0 # rate limit timer (seconds) self.t = 0.0 # rate limit timer (seconds)
self.metadata = { self.metadata = {
"cli": Path(sys.argv[0]).name == "yolo", "cli": Path(ARGV[0]).name == "yolo",
"install": "git" if is_git_dir() else "pip" if is_pip_package() else "other", "install": "git" if is_git_dir() else "pip" if is_pip_package() else "other",
"python": ".".join(platform.python_version_tuple()[:2]), # i.e. 3.10 "python": ".".join(platform.python_version_tuple()[:2]), # i.e. 3.10
"version": __version__, "version": __version__,

View file

@ -30,6 +30,7 @@ RANK = int(os.getenv("RANK", -1))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
# Other Constants # Other Constants
ARGV = sys.argv or ["", ""] # sometimes sys.argv = []
FILE = Path(__file__).resolve() FILE = Path(__file__).resolve()
ROOT = FILE.parents[1] # YOLO ROOT = FILE.parents[1] # YOLO
ASSETS = ROOT / "assets" # default images ASSETS = ROOT / "assets" # default images
@ -522,7 +523,7 @@ def is_pytest_running():
Returns: Returns:
(bool): True if pytest is running, False otherwise. (bool): True if pytest is running, False otherwise.
""" """
return ("PYTEST_CURRENT_TEST" in os.environ) or ("pytest" in sys.modules) or ("pytest" in Path(sys.argv[0]).stem) return ("PYTEST_CURRENT_TEST" in os.environ) or ("pytest" in sys.modules) or ("pytest" in Path(ARGV[0]).stem)
def is_github_action_running() -> bool: def is_github_action_running() -> bool:
@ -869,8 +870,8 @@ def set_sentry():
return None # do not send event return None # do not send event
event["tags"] = { event["tags"] = {
"sys_argv": sys.argv[0], "sys_argv": ARGV[0],
"sys_argv_name": Path(sys.argv[0]).name, "sys_argv_name": Path(ARGV[0]).name,
"install": "git" if is_git_dir() else "pip" if is_pip_package() else "other", "install": "git" if is_git_dir() else "pip" if is_pip_package() else "other",
"os": ENVIRONMENT, "os": ENVIRONMENT,
} }
@ -879,7 +880,7 @@ def set_sentry():
if ( if (
SETTINGS["sync"] SETTINGS["sync"]
and RANK in (-1, 0) and RANK in (-1, 0)
and Path(sys.argv[0]).name == "yolo" and Path(ARGV[0]).name == "yolo"
and not TESTS_RUNNING and not TESTS_RUNNING
and ONLINE and ONLINE
and is_pip_package() and is_pip_package()