ultralytics 8.1.37 fix empty sys.argv bug (#9390)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Joshua Harrison <108553645+TurbineJoshua@users.noreply.github.com> Co-authored-by: Joshua Harrison <joshua@fido.investments> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
ed2250cf1c
commit
4a7ccba0af
6 changed files with 31 additions and 15 deletions
|
|
@ -1,5 +1,7 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from unittest import mock
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
from ultralytics.cfg import get_cfg
|
from ultralytics.cfg import get_cfg
|
||||||
from ultralytics.engine.exporter import Exporter
|
from ultralytics.engine.exporter import Exporter
|
||||||
|
|
@ -49,8 +51,10 @@ def test_detect():
|
||||||
pred = detect.DetectionPredictor(overrides={"imgsz": [64, 64]})
|
pred = detect.DetectionPredictor(overrides={"imgsz": [64, 64]})
|
||||||
pred.add_callback("on_predict_start", test_func)
|
pred.add_callback("on_predict_start", test_func)
|
||||||
assert test_func in pred.callbacks["on_predict_start"], "callback test failed"
|
assert test_func in pred.callbacks["on_predict_start"], "callback test failed"
|
||||||
result = pred(source=ASSETS, model=f"{MODEL}.pt")
|
# Confirm there is no issue with sys.argv being empty.
|
||||||
assert len(result), "predictor test failed"
|
with mock.patch.object(sys, 'argv', []):
|
||||||
|
result = pred(source=ASSETS, model=f"{MODEL}.pt")
|
||||||
|
assert len(result), "predictor test failed"
|
||||||
|
|
||||||
overrides["resume"] = trainer.last
|
overrides["resume"] = trainer.last
|
||||||
trainer = detect.DetectionTrainer(overrides=overrides)
|
trainer = detect.DetectionTrainer(overrides=overrides)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.1.36"
|
__version__ = "8.1.37"
|
||||||
|
|
||||||
from ultralytics.data.explorer.explorer import Explorer
|
from ultralytics.data.explorer.explorer import Explorer
|
||||||
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld
|
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld
|
||||||
|
|
|
||||||
|
|
@ -54,8 +54,9 @@ TASK2METRIC = {
|
||||||
"obb": "metrics/mAP50-95(B)",
|
"obb": "metrics/mAP50-95(B)",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ARGV = sys.argv or ["", ""] # sometimes sys.argv = []
|
||||||
CLI_HELP_MSG = f"""
|
CLI_HELP_MSG = f"""
|
||||||
Arguments received: {str(['yolo'] + sys.argv[1:])}. Ultralytics 'yolo' commands use the following syntax:
|
Arguments received: {str(['yolo'] + ARGV[1:])}. Ultralytics 'yolo' commands use the following syntax:
|
||||||
|
|
||||||
yolo TASK MODE ARGS
|
yolo TASK MODE ARGS
|
||||||
|
|
||||||
|
|
@ -452,7 +453,7 @@ def entrypoint(debug=""):
|
||||||
It uses the package's default cfg and initializes it using the passed overrides.
|
It uses the package's default cfg and initializes it using the passed overrides.
|
||||||
Then it calls the CLI function with the composed cfg
|
Then it calls the CLI function with the composed cfg
|
||||||
"""
|
"""
|
||||||
args = (debug.split(" ") if debug else sys.argv)[1:]
|
args = (debug.split(" ") if debug else ARGV)[1:]
|
||||||
if not args: # no arguments passed
|
if not args: # no arguments passed
|
||||||
LOGGER.info(CLI_HELP_MSG)
|
LOGGER.info(CLI_HELP_MSG)
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
|
@ -11,7 +10,18 @@ import torch
|
||||||
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
|
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
|
||||||
from ultralytics.hub.utils import HUB_WEB_ROOT
|
from ultralytics.hub.utils import HUB_WEB_ROOT
|
||||||
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
|
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
|
||||||
from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, SETTINGS, callbacks, checks, emojis, yaml_load
|
from ultralytics.utils import (
|
||||||
|
ARGV,
|
||||||
|
ASSETS,
|
||||||
|
DEFAULT_CFG_DICT,
|
||||||
|
LOGGER,
|
||||||
|
RANK,
|
||||||
|
SETTINGS,
|
||||||
|
callbacks,
|
||||||
|
checks,
|
||||||
|
emojis,
|
||||||
|
yaml_load,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
|
|
@ -421,8 +431,8 @@ class Model(nn.Module):
|
||||||
source = ASSETS
|
source = ASSETS
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
||||||
|
|
||||||
is_cli = (sys.argv[0].endswith("yolo") or sys.argv[0].endswith("ultralytics")) and any(
|
is_cli = (ARGV[0].endswith("yolo") or ARGV[0].endswith("ultralytics")) and any(
|
||||||
x in sys.argv for x in ("predict", "track", "mode=predict", "mode=track")
|
x in ARGV for x in ("predict", "track", "mode=predict", "mode=track")
|
||||||
)
|
)
|
||||||
|
|
||||||
custom = {"conf": 0.25, "batch": 1, "save": is_cli, "mode": "predict"} # method defaults
|
custom = {"conf": 0.25, "batch": 1, "save": is_cli, "mode": "predict"} # method defaults
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import random
|
import random
|
||||||
import sys
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
@ -11,6 +10,7 @@ from pathlib import Path
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from ultralytics.utils import (
|
from ultralytics.utils import (
|
||||||
|
ARGV,
|
||||||
ENVIRONMENT,
|
ENVIRONMENT,
|
||||||
LOGGER,
|
LOGGER,
|
||||||
ONLINE,
|
ONLINE,
|
||||||
|
|
@ -188,7 +188,7 @@ class Events:
|
||||||
self.rate_limit = 60.0 # rate limit (seconds)
|
self.rate_limit = 60.0 # rate limit (seconds)
|
||||||
self.t = 0.0 # rate limit timer (seconds)
|
self.t = 0.0 # rate limit timer (seconds)
|
||||||
self.metadata = {
|
self.metadata = {
|
||||||
"cli": Path(sys.argv[0]).name == "yolo",
|
"cli": Path(ARGV[0]).name == "yolo",
|
||||||
"install": "git" if is_git_dir() else "pip" if is_pip_package() else "other",
|
"install": "git" if is_git_dir() else "pip" if is_pip_package() else "other",
|
||||||
"python": ".".join(platform.python_version_tuple()[:2]), # i.e. 3.10
|
"python": ".".join(platform.python_version_tuple()[:2]), # i.e. 3.10
|
||||||
"version": __version__,
|
"version": __version__,
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@ RANK = int(os.getenv("RANK", -1))
|
||||||
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||||
|
|
||||||
# Other Constants
|
# Other Constants
|
||||||
|
ARGV = sys.argv or ["", ""] # sometimes sys.argv = []
|
||||||
FILE = Path(__file__).resolve()
|
FILE = Path(__file__).resolve()
|
||||||
ROOT = FILE.parents[1] # YOLO
|
ROOT = FILE.parents[1] # YOLO
|
||||||
ASSETS = ROOT / "assets" # default images
|
ASSETS = ROOT / "assets" # default images
|
||||||
|
|
@ -522,7 +523,7 @@ def is_pytest_running():
|
||||||
Returns:
|
Returns:
|
||||||
(bool): True if pytest is running, False otherwise.
|
(bool): True if pytest is running, False otherwise.
|
||||||
"""
|
"""
|
||||||
return ("PYTEST_CURRENT_TEST" in os.environ) or ("pytest" in sys.modules) or ("pytest" in Path(sys.argv[0]).stem)
|
return ("PYTEST_CURRENT_TEST" in os.environ) or ("pytest" in sys.modules) or ("pytest" in Path(ARGV[0]).stem)
|
||||||
|
|
||||||
|
|
||||||
def is_github_action_running() -> bool:
|
def is_github_action_running() -> bool:
|
||||||
|
|
@ -869,8 +870,8 @@ def set_sentry():
|
||||||
return None # do not send event
|
return None # do not send event
|
||||||
|
|
||||||
event["tags"] = {
|
event["tags"] = {
|
||||||
"sys_argv": sys.argv[0],
|
"sys_argv": ARGV[0],
|
||||||
"sys_argv_name": Path(sys.argv[0]).name,
|
"sys_argv_name": Path(ARGV[0]).name,
|
||||||
"install": "git" if is_git_dir() else "pip" if is_pip_package() else "other",
|
"install": "git" if is_git_dir() else "pip" if is_pip_package() else "other",
|
||||||
"os": ENVIRONMENT,
|
"os": ENVIRONMENT,
|
||||||
}
|
}
|
||||||
|
|
@ -879,7 +880,7 @@ def set_sentry():
|
||||||
if (
|
if (
|
||||||
SETTINGS["sync"]
|
SETTINGS["sync"]
|
||||||
and RANK in (-1, 0)
|
and RANK in (-1, 0)
|
||||||
and Path(sys.argv[0]).name == "yolo"
|
and Path(ARGV[0]).name == "yolo"
|
||||||
and not TESTS_RUNNING
|
and not TESTS_RUNNING
|
||||||
and ONLINE
|
and ONLINE
|
||||||
and is_pip_package()
|
and is_pip_package()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue