From 51026a9a4a0b5e1ea8a1c76b16f8de12b9a71530 Mon Sep 17 00:00:00 2001 From: Muhammad Rizwan Munawar Date: Tue, 24 Dec 2024 16:26:56 +0500 Subject: [PATCH] `ultralytics 8.3.54` New Streamlit inference Solution (#18316) Signed-off-by: Glenn Jocher Signed-off-by: UltralyticsAssistant Co-authored-by: UltralyticsAssistant Co-authored-by: Glenn Jocher --- README.md | 2 +- docs/README.md | 4 +- docs/en/guides/streamlit-live-inference.md | 25 +- docs/en/reference/cfg/__init__.md | 4 - .../solutions/streamlit_inference.md | 2 +- examples/RTDETR-ONNXRuntime-Python/main.py | 23 +- tests/test_solutions.py | 2 +- ultralytics/__init__.py | 2 +- ultralytics/cfg/__init__.py | 111 ++++---- ultralytics/solutions/__init__.py | 4 +- ultralytics/solutions/parking_management.py | 4 +- ultralytics/solutions/security_alarm.py | 3 +- ultralytics/solutions/streamlit_inference.py | 253 ++++++++++-------- 13 files changed, 251 insertions(+), 188 deletions(-) diff --git a/README.md b/README.md index b03e95d0..9807bcab 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,7 @@ See below for a quickstart install and usage examples, and see our [Docs](https:
Install -Pip install the ultralytics package including all [requirements](https://github.com/ultralytics/ultralytics/blob/main/pyproject.toml) in a [**Python>=3.8**](https://www.python.org/) environment with [**PyTorch>=1.8**](https://pytorch.org/get-started/locally/). +Pip install the Ultralytics package including all [requirements](https://github.com/ultralytics/ultralytics/blob/main/pyproject.toml) in a [**Python>=3.8**](https://www.python.org/) environment with [**PyTorch>=1.8**](https://pytorch.org/get-started/locally/). [![PyPI - Version](https://img.shields.io/pypi/v/ultralytics?logo=pypi&logoColor=white)](https://pypi.org/project/ultralytics/) [![Ultralytics Downloads](https://static.pepy.tech/badge/ultralytics)](https://www.pepy.tech/projects/ultralytics) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/ultralytics?logo=python&logoColor=gold)](https://pypi.org/project/ultralytics/) diff --git a/docs/README.md b/docs/README.md index 802352b5..b4eaffcc 100644 --- a/docs/README.md +++ b/docs/README.md @@ -18,7 +18,7 @@ [![Downloads](https://static.pepy.tech/badge/ultralytics)](https://www.pepy.tech/projects/ultralytics) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/ultralytics?logo=python&logoColor=gold)](https://pypi.org/project/ultralytics/) -To install the ultralytics package in developer mode, ensure you have Git and Python 3 installed on your system. Then, follow these steps: +To install the Ultralytics package in developer mode, ensure you have Git and Python 3 installed on your system. Then, follow these steps: 1. Clone the ultralytics repository to your local machine using Git: @@ -38,7 +38,7 @@ To install the ultralytics package in developer mode, ensure you have Git and Py pip install -e '.[dev]' ``` -- This command installs the ultralytics package along with all development dependencies, allowing you to modify the package code and have the changes immediately reflected in your Python environment. +- This command installs the Ultralytics package along with all development dependencies, allowing you to modify the package code and have the changes immediately reflected in your Python environment. ## 🚀 Building and Serving Locally diff --git a/docs/en/guides/streamlit-live-inference.md b/docs/en/guides/streamlit-live-inference.md index 5a822b4c..708e1b20 100644 --- a/docs/en/guides/streamlit-live-inference.md +++ b/docs/en/guides/streamlit-live-inference.md @@ -43,7 +43,9 @@ Streamlit makes it simple to build and deploy interactive web applications. Comb === "CLI" ```bash - yolo streamlit-predict + yolo solutions inference + + yolo solutions inference model="path/to/model/file.pt" ``` === "Python" @@ -51,7 +53,11 @@ Streamlit makes it simple to build and deploy interactive web applications. Comb ```python from ultralytics import solutions - solutions.inference() + inf = solutions.Inference( + model="yolo11n.pt", # You can use any model that Ultralytics support, i.e. YOLO11, or custom trained model + ) + + inf.inference() ### Make sure to run the file using command `streamlit run ` ``` @@ -67,8 +73,11 @@ You can optionally supply a specific model in Python: ```python from ultralytics import solutions - # Pass a model as an argument - solutions.inference(model="path/to/model.pt") + inf = solutions.Inference( + model="yolo11n.pt", # You can use any model that Ultralytics support, i.e. YOLO11, YOLOv10 + ) + + inf.inference() ### Make sure to run the file using command `streamlit run ` ``` @@ -111,7 +120,11 @@ Then, you can create a basic Streamlit application to run live inference: ```python from ultralytics import solutions - solutions.inference() + inf = solutions.Inference( + model="yolo11n.pt", # You can use any model that Ultralytics support, i.e. YOLO11, YOLOv10 + ) + + inf.inference() ### Make sure to run the file using command `streamlit run ` ``` @@ -119,7 +132,7 @@ Then, you can create a basic Streamlit application to run live inference: === "CLI" ```bash - yolo streamlit-predict + yolo solutions inference ``` For more details on the practical setup, refer to the [Streamlit Application Code section](#streamlit-application-code) of the documentation. diff --git a/docs/en/reference/cfg/__init__.md b/docs/en/reference/cfg/__init__.md index 92320b12..6a59b4c4 100644 --- a/docs/en/reference/cfg/__init__.md +++ b/docs/en/reference/cfg/__init__.md @@ -51,10 +51,6 @@ keywords: Ultralytics, YOLO, configuration, cfg2dict, get_cfg, check_cfg, save_d



-## ::: ultralytics.cfg.handle_streamlit_inference - -



- ## ::: ultralytics.cfg.parse_key_value_pair



diff --git a/docs/en/reference/solutions/streamlit_inference.md b/docs/en/reference/solutions/streamlit_inference.md index 368d69e3..92aac750 100644 --- a/docs/en/reference/solutions/streamlit_inference.md +++ b/docs/en/reference/solutions/streamlit_inference.md @@ -11,6 +11,6 @@ keywords: Ultralytics, YOLOv8, live inference, real-time object detection, Strea
-## ::: ultralytics.solutions.streamlit_inference.inference +## ::: ultralytics.solutions.streamlit_inference.Inference

diff --git a/examples/RTDETR-ONNXRuntime-Python/main.py b/examples/RTDETR-ONNXRuntime-Python/main.py index f25ad608..e8d41df0 100644 --- a/examples/RTDETR-ONNXRuntime-Python/main.py +++ b/examples/RTDETR-ONNXRuntime-Python/main.py @@ -8,13 +8,14 @@ import torch from ultralytics.utils import ASSETS, yaml_load from ultralytics.utils.checks import check_requirements, check_yaml + class RTDETR: """RTDETR object detection model class for handling inference and visualization.""" - + def __init__(self, model_path, img_path, conf_thres=0.5, iou_thres=0.5): """ Initializes the RTDETR object with the specified parameters. - + Args: model_path: Path to the ONNX model file. img_path: Path to the input image. @@ -71,11 +72,17 @@ class RTDETR: # Draw a filled rectangle as the background for the label text cv2.rectangle( - self.img, (int(label_x), int(label_y - label_height)), (int(label_x + label_width), int(label_y + label_height)), color, cv2.FILLED + self.img, + (int(label_x), int(label_y - label_height)), + (int(label_x + label_width), int(label_y + label_height)), + color, + cv2.FILLED, ) # Draw the label text on the image - cv2.putText(self.img, label, (int(label_x), int(label_y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA) + cv2.putText( + self.img, label, (int(label_x), int(label_y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA + ) def preprocess(self): """ @@ -110,8 +117,7 @@ class RTDETR: def bbox_cxcywh_to_xyxy(self, boxes): """ - Converts bounding boxes from (center x, center y, width, height) format - to (x_min, y_min, x_max, y_max) format. + Converts bounding boxes from (center x, center y, width, height) format to (x_min, y_min, x_max, y_max) format. Args: boxes (numpy.ndarray): An array of shape (N, 4) where each row represents @@ -176,7 +182,7 @@ class RTDETR: def main(self): """ Executes the detection on the input image using the ONNX model. - + Returns: np.array: Output image with annotations. """ @@ -189,6 +195,7 @@ class RTDETR: # Process and return the model output return self.postprocess(model_output) + if __name__ == "__main__": # Set up argument parser for command-line arguments parser = argparse.ArgumentParser() @@ -210,4 +217,4 @@ if __name__ == "__main__": # Display the annotated output image cv2.namedWindow("Output", cv2.WINDOW_NORMAL) cv2.imshow("Output", output_image) - cv2.waitKey(0) \ No newline at end of file + cv2.waitKey(0) diff --git a/tests/test_solutions.py b/tests/test_solutions.py index fbf6b954..730ba95d 100644 --- a/tests/test_solutions.py +++ b/tests/test_solutions.py @@ -82,4 +82,4 @@ def test_instance_segmentation(): @pytest.mark.slow def test_streamlit_predict(): """Test streamlit predict live inference solution.""" - solutions.inference() + solutions.Inference().inference() diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 824a1aa7..ad0326ec 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.3.53" +__version__ = "8.3.54" import os diff --git a/ultralytics/cfg/__init__.py b/ultralytics/cfg/__init__.py index b36418fb..ca35aff0 100644 --- a/ultralytics/cfg/__init__.py +++ b/ultralytics/cfg/__init__.py @@ -42,6 +42,7 @@ SOLUTION_MAP = { "workout": ("AIGym", "monitor"), "analytics": ("Analytics", "process_data"), "trackzone": ("TrackZone", "trackzone"), + "inference": ("Inference", "inference"), "help": None, } @@ -97,7 +98,10 @@ SOLUTIONS_HELP_MSG = f""" yolo solutions analytics analytics_type="pie" 6. Track objects within specific zones - yolo solutions trackzone source="path/to/video/file.mp4" region=[(150, 150), (1130, 150), (1130, 570), (150, 570)] + yolo solutions trackzone source="path/to/video/file.mp4" region=[(150, 150), (1130, 150), (1130, 570), (150, 570)] + + 7. Streamlit real-time webcam inference GUI + yolo streamlit-predict """ CLI_HELP_MSG = f""" Arguments received: {str(['yolo'] + ARGV[1:])}. Ultralytics 'yolo' commands use the following syntax: @@ -121,13 +125,10 @@ CLI_HELP_MSG = f""" 4. Export a YOLO11n classification model to ONNX format at image size 224 by 128 (no TASK required) yolo export model=yolo11n-cls.pt format=onnx imgsz=224,128 - 5. Streamlit real-time webcam inference GUI - yolo streamlit-predict - - 6. Ultralytics solutions usage + 5. Ultralytics solutions usage yolo solutions count or in {list(SOLUTION_MAP.keys())[1:-1]} source="path/to/video/file.mp4" - 7. Run special commands: + 6. Run special commands: yolo help yolo checks yolo version @@ -636,6 +637,9 @@ def handle_yolo_solutions(args: List[str]) -> None: Run analytics with custom configuration: >>> handle_yolo_solutions(["analytics", "conf=0.25", "source=path/to/video/file.mp4"]) + Run inference with custom configuration, requires Streamlit version 1.29.0 or higher. + >>> handle_yolo_solutions(["inference", "model=yolo11n.pt"]) + Notes: - Default configurations are merged from DEFAULT_SOL_DICT and DEFAULT_CFG_DICT - Arguments can be provided in the format 'key=value' or as boolean flags @@ -645,7 +649,9 @@ def handle_yolo_solutions(args: List[str]) -> None: - For 'analytics' solution, frame numbers are tracked for generating analytical graphs - Video processing can be interrupted by pressing 'q' - Processes video frames sequentially and saves output in .avi format - - If no source is specified, downloads and uses a default sample video + - If no source is specified, downloads and uses a default sample video\ + - The inference solution will be launched using the 'streamlit run' command. + - The Streamlit app file is located in the Ultralytics package directory. """ full_args_dict = {**DEFAULT_SOL_DICT, **DEFAULT_CFG_DICT} # arguments dictionary overrides = {} @@ -678,60 +684,56 @@ def handle_yolo_solutions(args: List[str]) -> None: if args and args[0] == "help": # Add check for return if user call `yolo solutions help` return - cls, method = SOLUTION_MAP[s_n] # solution class name, method name and default source + if s_n == "inference": + checks.check_requirements("streamlit>=1.29.0") + LOGGER.info("💡 Loading Ultralytics live inference app...") + subprocess.run( + [ # Run subprocess with Streamlit custom argument + "streamlit", + "run", + str(ROOT / "solutions/streamlit_inference.py"), + "--server.headless", + "true", + overrides["model"], + ] + ) + else: + cls, method = SOLUTION_MAP[s_n] # solution class name, method name and default source - from ultralytics import solutions # import ultralytics solutions + from ultralytics import solutions # import ultralytics solutions - solution = getattr(solutions, cls)(IS_CLI=True, **overrides) # get solution class i.e ObjectCounter - process = getattr(solution, method) # get specific function of class for processing i.e, count from ObjectCounter + solution = getattr(solutions, cls)(IS_CLI=True, **overrides) # get solution class i.e ObjectCounter + process = getattr( + solution, method + ) # get specific function of class for processing i.e, count from ObjectCounter - cap = cv2.VideoCapture(solution.CFG["source"]) # read the video file + cap = cv2.VideoCapture(solution.CFG["source"]) # read the video file - # extract width, height and fps of the video file, create save directory and initialize video writer - import os # for directory creation - from pathlib import Path + # extract width, height and fps of the video file, create save directory and initialize video writer + import os # for directory creation + from pathlib import Path - from ultralytics.utils.files import increment_path # for output directory path update + from ultralytics.utils.files import increment_path # for output directory path update - w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) - if s_n == "analytics": # analytical graphs follow fixed shape for output i.e w=1920, h=1080 - w, h = 1920, 1080 - save_dir = increment_path(Path("runs") / "solutions" / "exp", exist_ok=False) - save_dir.mkdir(parents=True, exist_ok=True) # create the output directory - vw = cv2.VideoWriter(os.path.join(save_dir, "solution.avi"), cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) + w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) + if s_n == "analytics": # analytical graphs follow fixed shape for output i.e w=1920, h=1080 + w, h = 1920, 1080 + save_dir = increment_path(Path("runs") / "solutions" / "exp", exist_ok=False) + save_dir.mkdir(parents=True, exist_ok=True) # create the output directory + vw = cv2.VideoWriter(os.path.join(save_dir, "solution.avi"), cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) - try: # Process video frames - f_n = 0 # frame number, required for analytical graphs - while cap.isOpened(): - success, frame = cap.read() - if not success: - break - frame = process(frame, f_n := f_n + 1) if s_n == "analytics" else process(frame) - vw.write(frame) - if cv2.waitKey(1) & 0xFF == ord("q"): - break - finally: - cap.release() - - -def handle_streamlit_inference(): - """ - Open the Ultralytics Live Inference Streamlit app for real-time object detection. - - This function initializes and runs a Streamlit application designed for performing live object detection using - Ultralytics models. It checks for the required Streamlit package and launches the app. - - Examples: - >>> handle_streamlit_inference() - - Notes: - - Requires Streamlit version 1.29.0 or higher. - - The app is launched using the 'streamlit run' command. - - The Streamlit app file is located in the Ultralytics package directory. - """ - checks.check_requirements("streamlit>=1.29.0") - LOGGER.info("💡 Loading Ultralytics Live Inference app...") - subprocess.run(["streamlit", "run", ROOT / "solutions/streamlit_inference.py", "--server.headless", "true"]) + try: # Process video frames + f_n = 0 # frame number, required for analytical graphs + while cap.isOpened(): + success, frame = cap.read() + if not success: + break + frame = process(frame, f_n := f_n + 1) if s_n == "analytics" else process(frame) + vw.write(frame) + if cv2.waitKey(1) & 0xFF == ord("q"): + break + finally: + cap.release() def parse_key_value_pair(pair: str = "key=value"): @@ -853,7 +855,6 @@ def entrypoint(debug=""): "login": lambda: handle_yolo_hub(args), "logout": lambda: handle_yolo_hub(args), "copy-cfg": copy_default_cfg, - "streamlit-predict": lambda: handle_streamlit_inference(), "solutions": lambda: handle_yolo_solutions(args[1:]), } full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special} diff --git a/ultralytics/solutions/__init__.py b/ultralytics/solutions/__init__.py index 3e36d167..25b7d8fa 100644 --- a/ultralytics/solutions/__init__.py +++ b/ultralytics/solutions/__init__.py @@ -10,7 +10,7 @@ from .queue_management import QueueManager from .region_counter import RegionCounter from .security_alarm import SecurityAlarm from .speed_estimation import SpeedEstimator -from .streamlit_inference import inference +from .streamlit_inference import Inference from .trackzone import TrackZone __all__ = ( @@ -23,7 +23,7 @@ __all__ = ( "QueueManager", "SpeedEstimator", "Analytics", - "inference", + "Inference", "RegionCounter", "TrackZone", "SecurityAlarm", diff --git a/ultralytics/solutions/parking_management.py b/ultralytics/solutions/parking_management.py index 8b5d4922..5a5c1bdb 100644 --- a/ultralytics/solutions/parking_management.py +++ b/ultralytics/solutions/parking_management.py @@ -5,7 +5,9 @@ import json import cv2 import numpy as np -from ultralytics.solutions.solutions import LOGGER, BaseSolution, check_requirements +from ultralytics.solutions.solutions import BaseSolution +from ultralytics.utils import LOGGER +from ultralytics.utils.checks import check_requirements from ultralytics.utils.plotting import Annotator diff --git a/ultralytics/solutions/security_alarm.py b/ultralytics/solutions/security_alarm.py index 7e014e9f..534cfb9c 100644 --- a/ultralytics/solutions/security_alarm.py +++ b/ultralytics/solutions/security_alarm.py @@ -1,6 +1,7 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -from ultralytics.solutions.solutions import LOGGER, BaseSolution +from ultralytics.solutions.solutions import BaseSolution +from ultralytics.utils import LOGGER from ultralytics.utils.plotting import Annotator, colors diff --git a/ultralytics/solutions/streamlit_inference.py b/ultralytics/solutions/streamlit_inference.py index dcae3add..cf09269c 100644 --- a/ultralytics/solutions/streamlit_inference.py +++ b/ultralytics/solutions/streamlit_inference.py @@ -4,145 +4,188 @@ import io import time import cv2 -import torch +from ultralytics import YOLO +from ultralytics.utils import LOGGER from ultralytics.utils.checks import check_requirements from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS -def inference(model=None): - """Performs real-time object detection on video input using YOLO in a Streamlit web application.""" - check_requirements("streamlit>=1.29.0") # scope imports for faster ultralytics package load speeds - import streamlit as st +class Inference: + """ + A class to perform object detection, image classification, image segmentation and pose estimation inference using + Streamlit and Ultralytics YOLO models. It provides the functionalities such as loading models, configuring settings, + uploading video files, and performing real-time inference. - from ultralytics import YOLO + Attributes: + st (module): Streamlit module for UI creation. + temp_dict (dict): Temporary dictionary to store the model path. + model_path (str): Path to the loaded model. + model (YOLO): The YOLO model instance. + source (str): Selected video source. + enable_trk (str): Enable tracking option. + conf (float): Confidence threshold. + iou (float): IoU threshold for non-max suppression. + vid_file_name (str): Name of the uploaded video file. + selected_ind (list): List of selected class indices. - # Hide main menu style - menu_style_cfg = """""" + Methods: + web_ui: Sets up the Streamlit web interface with custom HTML elements. + sidebar: Configures the Streamlit sidebar for model and inference settings. + source_upload: Handles video file uploads through the Streamlit interface. + configure: Configures the model and loads selected classes for inference. + inference: Performs real-time object detection inference. - # Main title of streamlit application - main_title_cfg = """

- Ultralytics YOLO Streamlit Application -

""" + Examples: + >>> inf = solutions.Inference(model="path/to/model/file.pt") # Model is not necessary argument. + >>> inf.inference() + """ - # Subtitle of streamlit application - sub_title_cfg = """

- Experience real-time object detection on your webcam with the power of Ultralytics YOLO! 🚀

-
""" + def __init__(self, **kwargs): + """ + Initializes the Inference class, checking Streamlit requirements and setting up the model path. - # Set html page configuration - st.set_page_config(page_title="Ultralytics Streamlit App", layout="wide", initial_sidebar_state="auto") + Args: + **kwargs (Dict): Additional keyword arguments for model configuration. + """ + check_requirements("streamlit>=1.29.0") # scope imports for faster ultralytics package load speeds + import streamlit as st - # Append the custom HTML - st.markdown(menu_style_cfg, unsafe_allow_html=True) - st.markdown(main_title_cfg, unsafe_allow_html=True) - st.markdown(sub_title_cfg, unsafe_allow_html=True) + self.st = st - # Add ultralytics logo in sidebar - with st.sidebar: - logo = "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg" - st.image(logo, width=250) + self.temp_dict = {"model": None} # Temporary dict to store the model path + self.temp_dict.update(kwargs) - # Add elements to vertical setting menu - st.sidebar.title("User Configuration") + self.model_path = None # Store model file name with path + if self.temp_dict["model"] is not None: + self.model_path = self.temp_dict["model"] - # Add video source selection dropdown - source = st.sidebar.selectbox( - "Video", - ("webcam", "video"), - ) + LOGGER.info(f"Ultralytics Solutions: ✅ {self.temp_dict}") - vid_file_name = "" - if source == "video": - vid_file = st.sidebar.file_uploader("Upload Video File", type=["mp4", "mov", "avi", "mkv"]) - if vid_file is not None: - g = io.BytesIO(vid_file.read()) # BytesIO Object - vid_location = "ultralytics.mp4" - with open(vid_location, "wb") as out: # Open temporary file as bytes - out.write(g.read()) # Read bytes into file - vid_file_name = "ultralytics.mp4" - elif source == "webcam": - vid_file_name = 0 + def web_ui(self): + """Sets up the Streamlit web interface with custom HTML elements.""" + menu_style_cfg = """""" # Hide main menu style - # Add dropdown menu for model selection - available_models = [x.replace("yolo", "YOLO") for x in GITHUB_ASSETS_STEMS if x.startswith("yolo11")] - if model: - available_models.insert(0, model.split(".pt")[0]) # insert model without suffix as *.pt is added later + # Main title of streamlit application + main_title_cfg = """

Ultralytics YOLO Streamlit Application

""" - selected_model = st.sidebar.selectbox("Model", available_models) - with st.spinner("Model is downloading..."): - model = YOLO(f"{selected_model.lower()}.pt") # Load the YOLO model - class_names = list(model.names.values()) # Convert dictionary to list of class names - st.success("Model loaded successfully!") + # Subtitle of streamlit application + sub_title_cfg = """

Experience real-time object detection on your webcam with the power + of Ultralytics YOLO! 🚀

""" - # Multiselect box with class names and get indices of selected classes - selected_classes = st.sidebar.multiselect("Classes", class_names, default=class_names[:3]) - selected_ind = [class_names.index(option) for option in selected_classes] + # Set html page configuration and append custom HTML + self.st.set_page_config(page_title="Ultralytics Streamlit App", layout="wide", initial_sidebar_state="auto") + self.st.markdown(menu_style_cfg, unsafe_allow_html=True) + self.st.markdown(main_title_cfg, unsafe_allow_html=True) + self.st.markdown(sub_title_cfg, unsafe_allow_html=True) - if not isinstance(selected_ind, list): # Ensure selected_options is a list - selected_ind = list(selected_ind) + def sidebar(self): + """Configures the Streamlit sidebar for model and inference settings.""" + with self.st.sidebar: # Add Ultralytics LOGO + logo = "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg" + self.st.image(logo, width=250) - enable_trk = st.sidebar.radio("Enable Tracking", ("Yes", "No")) - conf = float(st.sidebar.slider("Confidence Threshold", 0.0, 1.0, 0.25, 0.01)) - iou = float(st.sidebar.slider("IoU Threshold", 0.0, 1.0, 0.45, 0.01)) + self.st.sidebar.title("User Configuration") # Add elements to vertical setting menu + self.source = self.st.sidebar.selectbox( + "Video", + ("webcam", "video"), + ) # Add source selection dropdown + self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) # Enable object tracking + self.conf = float(self.st.sidebar.slider("Confidence Threshold", 0.0, 1.0, 0.25, 0.01)) # Slider for confidence + self.iou = float(self.st.sidebar.slider("IoU Threshold", 0.0, 1.0, 0.45, 0.01)) # Slider for NMS threshold - col1, col2 = st.columns(2) - org_frame = col1.empty() - ann_frame = col2.empty() + col1, col2 = self.st.columns(2) + self.org_frame = col1.empty() + self.ann_frame = col2.empty() + self.fps_display = self.st.sidebar.empty() # Placeholder for FPS display - fps_display = st.sidebar.empty() # Placeholder for FPS display + def source_upload(self): + """Handles video file uploads through the Streamlit interface.""" + self.vid_file_name = "" + if self.source == "video": + vid_file = self.st.sidebar.file_uploader("Upload Video File", type=["mp4", "mov", "avi", "mkv"]) + if vid_file is not None: + g = io.BytesIO(vid_file.read()) # BytesIO Object + with open("ultralytics.mp4", "wb") as out: # Open temporary file as bytes + out.write(g.read()) # Read bytes into file + self.vid_file_name = "ultralytics.mp4" + elif self.source == "webcam": + self.vid_file_name = 0 - if st.sidebar.button("Start"): - videocapture = cv2.VideoCapture(vid_file_name) # Capture the video + def configure(self): + """Configures the model and loads selected classes for inference.""" + # Add dropdown menu for model selection + available_models = [x.replace("yolo", "YOLO") for x in GITHUB_ASSETS_STEMS if x.startswith("yolo11")] + if self.model_path: # If user provided the custom model, insert model without suffix as *.pt is added later + available_models.insert(0, self.model_path.split(".pt")[0]) + selected_model = self.st.sidebar.selectbox("Model", available_models) - if not videocapture.isOpened(): - st.error("Could not open webcam.") + with self.st.spinner("Model is downloading..."): + self.model = YOLO(f"{selected_model.lower()}.pt") # Load the YOLO model + class_names = list(self.model.names.values()) # Convert dictionary to list of class names + self.st.success("Model loaded successfully!") - stop_button = st.button("Stop") # Button to stop the inference + # Multiselect box with class names and get indices of selected classes + selected_classes = self.st.sidebar.multiselect("Classes", class_names, default=class_names[:3]) + self.selected_ind = [class_names.index(option) for option in selected_classes] - while videocapture.isOpened(): - success, frame = videocapture.read() - if not success: - st.warning("Failed to read frame from webcam. Please make sure the webcam is connected properly.") - break + if not isinstance(self.selected_ind, list): # Ensure selected_options is a list + self.selected_ind = list(self.selected_ind) - prev_time = time.time() # Store initial time for FPS calculation + def inference(self): + """Performs real-time object detection inference.""" + self.web_ui() # Initialize the web interface + self.sidebar() # Create the sidebar + self.source_upload() # Upload the video source + self.configure() # Configure the app - # Store model predictions - if enable_trk == "Yes": - results = model.track(frame, conf=conf, iou=iou, classes=selected_ind, persist=True) - else: - results = model(frame, conf=conf, iou=iou, classes=selected_ind) - annotated_frame = results[0].plot() # Add annotations on frame + if self.st.sidebar.button("Start"): + stop_button = self.st.button("Stop") # Button to stop the inference + cap = cv2.VideoCapture(self.vid_file_name) # Capture the video + if not cap.isOpened(): + self.st.error("Could not open webcam.") + while cap.isOpened(): + success, frame = cap.read() + if not success: + st.warning("Failed to read frame from webcam. Please make sure the webcam is connected properly.") + break - # Calculate model FPS - curr_time = time.time() - fps = 1 / (curr_time - prev_time) + prev_time = time.time() # Store initial time for FPS calculation - # display frame - org_frame.image(frame, channels="BGR") - ann_frame.image(annotated_frame, channels="BGR") + # Store model predictions + if self.enable_trk == "Yes": + results = self.model.track( + frame, conf=self.conf, iou=self.iou, classes=self.selected_ind, persist=True + ) + else: + results = self.model(frame, conf=self.conf, iou=self.iou, classes=self.selected_ind) + annotated_frame = results[0].plot() # Add annotations on frame - if stop_button: - videocapture.release() # Release the capture - torch.cuda.empty_cache() # Clear CUDA memory - st.stop() # Stop streamlit app + fps = 1 / (time.time() - prev_time) # Calculate model FPS - # Display FPS in sidebar - fps_display.metric("FPS", f"{fps:.2f}") + if stop_button: + cap.release() # Release the capture + self.st.stop() # Stop streamlit app - # Release the capture - videocapture.release() + self.fps_display.metric("FPS", f"{fps:.2f}") # Display FPS in sidebar + self.org_frame.image(frame, channels="BGR") # Display original frame + self.ann_frame.image(annotated_frame, channels="BGR") # Display processed frame - # Clear CUDA memory - torch.cuda.empty_cache() - - # Destroy window - cv2.destroyAllWindows() + cap.release() # Release the capture + cv2.destroyAllWindows() # Destroy window -# Main function call if __name__ == "__main__": - inference() + import sys # Import the sys module for accessing command-line arguments + + model = None # Initialize the model variable as None + + # Check if a model name is provided as a command-line argument + args = len(sys.argv) + if args > 1: + model = args # Assign the first argument as the model name + + # Create an instance of the Inference class and run inference + Inference(model=model).inference()