ultralytics 8.3.54 New Streamlit inference Solution (#18316)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Signed-off-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Muhammad Rizwan Munawar 2024-12-24 16:26:56 +05:00 committed by GitHub
parent 5b76bed7d0
commit 51026a9a4a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 251 additions and 188 deletions

View file

@ -55,7 +55,7 @@ See below for a quickstart install and usage examples, and see our [Docs](https:
<details open> <details open>
<summary>Install</summary> <summary>Install</summary>
Pip install the ultralytics package including all [requirements](https://github.com/ultralytics/ultralytics/blob/main/pyproject.toml) in a [**Python>=3.8**](https://www.python.org/) environment with [**PyTorch>=1.8**](https://pytorch.org/get-started/locally/). Pip install the Ultralytics package including all [requirements](https://github.com/ultralytics/ultralytics/blob/main/pyproject.toml) in a [**Python>=3.8**](https://www.python.org/) environment with [**PyTorch>=1.8**](https://pytorch.org/get-started/locally/).
[![PyPI - Version](https://img.shields.io/pypi/v/ultralytics?logo=pypi&logoColor=white)](https://pypi.org/project/ultralytics/) [![Ultralytics Downloads](https://static.pepy.tech/badge/ultralytics)](https://www.pepy.tech/projects/ultralytics) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/ultralytics?logo=python&logoColor=gold)](https://pypi.org/project/ultralytics/) [![PyPI - Version](https://img.shields.io/pypi/v/ultralytics?logo=pypi&logoColor=white)](https://pypi.org/project/ultralytics/) [![Ultralytics Downloads](https://static.pepy.tech/badge/ultralytics)](https://www.pepy.tech/projects/ultralytics) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/ultralytics?logo=python&logoColor=gold)](https://pypi.org/project/ultralytics/)

View file

@ -18,7 +18,7 @@
[![Downloads](https://static.pepy.tech/badge/ultralytics)](https://www.pepy.tech/projects/ultralytics) [![Downloads](https://static.pepy.tech/badge/ultralytics)](https://www.pepy.tech/projects/ultralytics)
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/ultralytics?logo=python&logoColor=gold)](https://pypi.org/project/ultralytics/) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/ultralytics?logo=python&logoColor=gold)](https://pypi.org/project/ultralytics/)
To install the ultralytics package in developer mode, ensure you have Git and Python 3 installed on your system. Then, follow these steps: To install the Ultralytics package in developer mode, ensure you have Git and Python 3 installed on your system. Then, follow these steps:
1. Clone the ultralytics repository to your local machine using Git: 1. Clone the ultralytics repository to your local machine using Git:
@ -38,7 +38,7 @@ To install the ultralytics package in developer mode, ensure you have Git and Py
pip install -e '.[dev]' pip install -e '.[dev]'
``` ```
- This command installs the ultralytics package along with all development dependencies, allowing you to modify the package code and have the changes immediately reflected in your Python environment. - This command installs the Ultralytics package along with all development dependencies, allowing you to modify the package code and have the changes immediately reflected in your Python environment.
## 🚀 Building and Serving Locally ## 🚀 Building and Serving Locally

View file

@ -43,7 +43,9 @@ Streamlit makes it simple to build and deploy interactive web applications. Comb
=== "CLI" === "CLI"
```bash ```bash
yolo streamlit-predict yolo solutions inference
yolo solutions inference model="path/to/model/file.pt"
``` ```
=== "Python" === "Python"
@ -51,7 +53,11 @@ Streamlit makes it simple to build and deploy interactive web applications. Comb
```python ```python
from ultralytics import solutions from ultralytics import solutions
solutions.inference() inf = solutions.Inference(
model="yolo11n.pt", # You can use any model that Ultralytics support, i.e. YOLO11, or custom trained model
)
inf.inference()
### Make sure to run the file using command `streamlit run <file-name.py>` ### Make sure to run the file using command `streamlit run <file-name.py>`
``` ```
@ -67,8 +73,11 @@ You can optionally supply a specific model in Python:
```python ```python
from ultralytics import solutions from ultralytics import solutions
# Pass a model as an argument inf = solutions.Inference(
solutions.inference(model="path/to/model.pt") model="yolo11n.pt", # You can use any model that Ultralytics support, i.e. YOLO11, YOLOv10
)
inf.inference()
### Make sure to run the file using command `streamlit run <file-name.py>` ### Make sure to run the file using command `streamlit run <file-name.py>`
``` ```
@ -111,7 +120,11 @@ Then, you can create a basic Streamlit application to run live inference:
```python ```python
from ultralytics import solutions from ultralytics import solutions
solutions.inference() inf = solutions.Inference(
model="yolo11n.pt", # You can use any model that Ultralytics support, i.e. YOLO11, YOLOv10
)
inf.inference()
### Make sure to run the file using command `streamlit run <file-name.py>` ### Make sure to run the file using command `streamlit run <file-name.py>`
``` ```
@ -119,7 +132,7 @@ Then, you can create a basic Streamlit application to run live inference:
=== "CLI" === "CLI"
```bash ```bash
yolo streamlit-predict yolo solutions inference
``` ```
For more details on the practical setup, refer to the [Streamlit Application Code section](#streamlit-application-code) of the documentation. For more details on the practical setup, refer to the [Streamlit Application Code section](#streamlit-application-code) of the documentation.

View file

@ -51,10 +51,6 @@ keywords: Ultralytics, YOLO, configuration, cfg2dict, get_cfg, check_cfg, save_d
<br><br><hr><br> <br><br><hr><br>
## ::: ultralytics.cfg.handle_streamlit_inference
<br><br><hr><br>
## ::: ultralytics.cfg.parse_key_value_pair ## ::: ultralytics.cfg.parse_key_value_pair
<br><br><hr><br> <br><br><hr><br>

View file

@ -11,6 +11,6 @@ keywords: Ultralytics, YOLOv8, live inference, real-time object detection, Strea
<br> <br>
## ::: ultralytics.solutions.streamlit_inference.inference ## ::: ultralytics.solutions.streamlit_inference.Inference
<br><br> <br><br>

View file

@ -8,6 +8,7 @@ import torch
from ultralytics.utils import ASSETS, yaml_load from ultralytics.utils import ASSETS, yaml_load
from ultralytics.utils.checks import check_requirements, check_yaml from ultralytics.utils.checks import check_requirements, check_yaml
class RTDETR: class RTDETR:
"""RTDETR object detection model class for handling inference and visualization.""" """RTDETR object detection model class for handling inference and visualization."""
@ -71,11 +72,17 @@ class RTDETR:
# Draw a filled rectangle as the background for the label text # Draw a filled rectangle as the background for the label text
cv2.rectangle( cv2.rectangle(
self.img, (int(label_x), int(label_y - label_height)), (int(label_x + label_width), int(label_y + label_height)), color, cv2.FILLED self.img,
(int(label_x), int(label_y - label_height)),
(int(label_x + label_width), int(label_y + label_height)),
color,
cv2.FILLED,
) )
# Draw the label text on the image # Draw the label text on the image
cv2.putText(self.img, label, (int(label_x), int(label_y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA) cv2.putText(
self.img, label, (int(label_x), int(label_y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA
)
def preprocess(self): def preprocess(self):
""" """
@ -110,8 +117,7 @@ class RTDETR:
def bbox_cxcywh_to_xyxy(self, boxes): def bbox_cxcywh_to_xyxy(self, boxes):
""" """
Converts bounding boxes from (center x, center y, width, height) format Converts bounding boxes from (center x, center y, width, height) format to (x_min, y_min, x_max, y_max) format.
to (x_min, y_min, x_max, y_max) format.
Args: Args:
boxes (numpy.ndarray): An array of shape (N, 4) where each row represents boxes (numpy.ndarray): An array of shape (N, 4) where each row represents
@ -189,6 +195,7 @@ class RTDETR:
# Process and return the model output # Process and return the model output
return self.postprocess(model_output) return self.postprocess(model_output)
if __name__ == "__main__": if __name__ == "__main__":
# Set up argument parser for command-line arguments # Set up argument parser for command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View file

@ -82,4 +82,4 @@ def test_instance_segmentation():
@pytest.mark.slow @pytest.mark.slow
def test_streamlit_predict(): def test_streamlit_predict():
"""Test streamlit predict live inference solution.""" """Test streamlit predict live inference solution."""
solutions.inference() solutions.Inference().inference()

View file

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.3.53" __version__ = "8.3.54"
import os import os

View file

@ -42,6 +42,7 @@ SOLUTION_MAP = {
"workout": ("AIGym", "monitor"), "workout": ("AIGym", "monitor"),
"analytics": ("Analytics", "process_data"), "analytics": ("Analytics", "process_data"),
"trackzone": ("TrackZone", "trackzone"), "trackzone": ("TrackZone", "trackzone"),
"inference": ("Inference", "inference"),
"help": None, "help": None,
} }
@ -98,6 +99,9 @@ SOLUTIONS_HELP_MSG = f"""
6. Track objects within specific zones 6. Track objects within specific zones
yolo solutions trackzone source="path/to/video/file.mp4" region=[(150, 150), (1130, 150), (1130, 570), (150, 570)] yolo solutions trackzone source="path/to/video/file.mp4" region=[(150, 150), (1130, 150), (1130, 570), (150, 570)]
7. Streamlit real-time webcam inference GUI
yolo streamlit-predict
""" """
CLI_HELP_MSG = f""" CLI_HELP_MSG = f"""
Arguments received: {str(['yolo'] + ARGV[1:])}. Ultralytics 'yolo' commands use the following syntax: Arguments received: {str(['yolo'] + ARGV[1:])}. Ultralytics 'yolo' commands use the following syntax:
@ -121,13 +125,10 @@ CLI_HELP_MSG = f"""
4. Export a YOLO11n classification model to ONNX format at image size 224 by 128 (no TASK required) 4. Export a YOLO11n classification model to ONNX format at image size 224 by 128 (no TASK required)
yolo export model=yolo11n-cls.pt format=onnx imgsz=224,128 yolo export model=yolo11n-cls.pt format=onnx imgsz=224,128
5. Streamlit real-time webcam inference GUI 5. Ultralytics solutions usage
yolo streamlit-predict
6. Ultralytics solutions usage
yolo solutions count or in {list(SOLUTION_MAP.keys())[1:-1]} source="path/to/video/file.mp4" yolo solutions count or in {list(SOLUTION_MAP.keys())[1:-1]} source="path/to/video/file.mp4"
7. Run special commands: 6. Run special commands:
yolo help yolo help
yolo checks yolo checks
yolo version yolo version
@ -636,6 +637,9 @@ def handle_yolo_solutions(args: List[str]) -> None:
Run analytics with custom configuration: Run analytics with custom configuration:
>>> handle_yolo_solutions(["analytics", "conf=0.25", "source=path/to/video/file.mp4"]) >>> handle_yolo_solutions(["analytics", "conf=0.25", "source=path/to/video/file.mp4"])
Run inference with custom configuration, requires Streamlit version 1.29.0 or higher.
>>> handle_yolo_solutions(["inference", "model=yolo11n.pt"])
Notes: Notes:
- Default configurations are merged from DEFAULT_SOL_DICT and DEFAULT_CFG_DICT - Default configurations are merged from DEFAULT_SOL_DICT and DEFAULT_CFG_DICT
- Arguments can be provided in the format 'key=value' or as boolean flags - Arguments can be provided in the format 'key=value' or as boolean flags
@ -645,7 +649,9 @@ def handle_yolo_solutions(args: List[str]) -> None:
- For 'analytics' solution, frame numbers are tracked for generating analytical graphs - For 'analytics' solution, frame numbers are tracked for generating analytical graphs
- Video processing can be interrupted by pressing 'q' - Video processing can be interrupted by pressing 'q'
- Processes video frames sequentially and saves output in .avi format - Processes video frames sequentially and saves output in .avi format
- If no source is specified, downloads and uses a default sample video - If no source is specified, downloads and uses a default sample video\
- The inference solution will be launched using the 'streamlit run' command.
- The Streamlit app file is located in the Ultralytics package directory.
""" """
full_args_dict = {**DEFAULT_SOL_DICT, **DEFAULT_CFG_DICT} # arguments dictionary full_args_dict = {**DEFAULT_SOL_DICT, **DEFAULT_CFG_DICT} # arguments dictionary
overrides = {} overrides = {}
@ -678,60 +684,56 @@ def handle_yolo_solutions(args: List[str]) -> None:
if args and args[0] == "help": # Add check for return if user call `yolo solutions help` if args and args[0] == "help": # Add check for return if user call `yolo solutions help`
return return
cls, method = SOLUTION_MAP[s_n] # solution class name, method name and default source if s_n == "inference":
checks.check_requirements("streamlit>=1.29.0")
LOGGER.info("💡 Loading Ultralytics live inference app...")
subprocess.run(
[ # Run subprocess with Streamlit custom argument
"streamlit",
"run",
str(ROOT / "solutions/streamlit_inference.py"),
"--server.headless",
"true",
overrides["model"],
]
)
else:
cls, method = SOLUTION_MAP[s_n] # solution class name, method name and default source
from ultralytics import solutions # import ultralytics solutions from ultralytics import solutions # import ultralytics solutions
solution = getattr(solutions, cls)(IS_CLI=True, **overrides) # get solution class i.e ObjectCounter solution = getattr(solutions, cls)(IS_CLI=True, **overrides) # get solution class i.e ObjectCounter
process = getattr(solution, method) # get specific function of class for processing i.e, count from ObjectCounter process = getattr(
solution, method
) # get specific function of class for processing i.e, count from ObjectCounter
cap = cv2.VideoCapture(solution.CFG["source"]) # read the video file cap = cv2.VideoCapture(solution.CFG["source"]) # read the video file
# extract width, height and fps of the video file, create save directory and initialize video writer # extract width, height and fps of the video file, create save directory and initialize video writer
import os # for directory creation import os # for directory creation
from pathlib import Path from pathlib import Path
from ultralytics.utils.files import increment_path # for output directory path update from ultralytics.utils.files import increment_path # for output directory path update
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
if s_n == "analytics": # analytical graphs follow fixed shape for output i.e w=1920, h=1080 if s_n == "analytics": # analytical graphs follow fixed shape for output i.e w=1920, h=1080
w, h = 1920, 1080 w, h = 1920, 1080
save_dir = increment_path(Path("runs") / "solutions" / "exp", exist_ok=False) save_dir = increment_path(Path("runs") / "solutions" / "exp", exist_ok=False)
save_dir.mkdir(parents=True, exist_ok=True) # create the output directory save_dir.mkdir(parents=True, exist_ok=True) # create the output directory
vw = cv2.VideoWriter(os.path.join(save_dir, "solution.avi"), cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) vw = cv2.VideoWriter(os.path.join(save_dir, "solution.avi"), cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
try: # Process video frames try: # Process video frames
f_n = 0 # frame number, required for analytical graphs f_n = 0 # frame number, required for analytical graphs
while cap.isOpened(): while cap.isOpened():
success, frame = cap.read() success, frame = cap.read()
if not success: if not success:
break break
frame = process(frame, f_n := f_n + 1) if s_n == "analytics" else process(frame) frame = process(frame, f_n := f_n + 1) if s_n == "analytics" else process(frame)
vw.write(frame) vw.write(frame)
if cv2.waitKey(1) & 0xFF == ord("q"): if cv2.waitKey(1) & 0xFF == ord("q"):
break break
finally: finally:
cap.release() cap.release()
def handle_streamlit_inference():
"""
Open the Ultralytics Live Inference Streamlit app for real-time object detection.
This function initializes and runs a Streamlit application designed for performing live object detection using
Ultralytics models. It checks for the required Streamlit package and launches the app.
Examples:
>>> handle_streamlit_inference()
Notes:
- Requires Streamlit version 1.29.0 or higher.
- The app is launched using the 'streamlit run' command.
- The Streamlit app file is located in the Ultralytics package directory.
"""
checks.check_requirements("streamlit>=1.29.0")
LOGGER.info("💡 Loading Ultralytics Live Inference app...")
subprocess.run(["streamlit", "run", ROOT / "solutions/streamlit_inference.py", "--server.headless", "true"])
def parse_key_value_pair(pair: str = "key=value"): def parse_key_value_pair(pair: str = "key=value"):
@ -853,7 +855,6 @@ def entrypoint(debug=""):
"login": lambda: handle_yolo_hub(args), "login": lambda: handle_yolo_hub(args),
"logout": lambda: handle_yolo_hub(args), "logout": lambda: handle_yolo_hub(args),
"copy-cfg": copy_default_cfg, "copy-cfg": copy_default_cfg,
"streamlit-predict": lambda: handle_streamlit_inference(),
"solutions": lambda: handle_yolo_solutions(args[1:]), "solutions": lambda: handle_yolo_solutions(args[1:]),
} }
full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special} full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special}

View file

@ -10,7 +10,7 @@ from .queue_management import QueueManager
from .region_counter import RegionCounter from .region_counter import RegionCounter
from .security_alarm import SecurityAlarm from .security_alarm import SecurityAlarm
from .speed_estimation import SpeedEstimator from .speed_estimation import SpeedEstimator
from .streamlit_inference import inference from .streamlit_inference import Inference
from .trackzone import TrackZone from .trackzone import TrackZone
__all__ = ( __all__ = (
@ -23,7 +23,7 @@ __all__ = (
"QueueManager", "QueueManager",
"SpeedEstimator", "SpeedEstimator",
"Analytics", "Analytics",
"inference", "Inference",
"RegionCounter", "RegionCounter",
"TrackZone", "TrackZone",
"SecurityAlarm", "SecurityAlarm",

View file

@ -5,7 +5,9 @@ import json
import cv2 import cv2
import numpy as np import numpy as np
from ultralytics.solutions.solutions import LOGGER, BaseSolution, check_requirements from ultralytics.solutions.solutions import BaseSolution
from ultralytics.utils import LOGGER
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.plotting import Annotator from ultralytics.utils.plotting import Annotator

View file

@ -1,6 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.solutions.solutions import LOGGER, BaseSolution from ultralytics.solutions.solutions import BaseSolution
from ultralytics.utils import LOGGER
from ultralytics.utils.plotting import Annotator, colors from ultralytics.utils.plotting import Annotator, colors

View file

@ -4,145 +4,188 @@ import io
import time import time
import cv2 import cv2
import torch
from ultralytics import YOLO
from ultralytics.utils import LOGGER
from ultralytics.utils.checks import check_requirements from ultralytics.utils.checks import check_requirements
from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS
def inference(model=None): class Inference:
"""Performs real-time object detection on video input using YOLO in a Streamlit web application.""" """
check_requirements("streamlit>=1.29.0") # scope imports for faster ultralytics package load speeds A class to perform object detection, image classification, image segmentation and pose estimation inference using
import streamlit as st Streamlit and Ultralytics YOLO models. It provides the functionalities such as loading models, configuring settings,
uploading video files, and performing real-time inference.
from ultralytics import YOLO Attributes:
st (module): Streamlit module for UI creation.
temp_dict (dict): Temporary dictionary to store the model path.
model_path (str): Path to the loaded model.
model (YOLO): The YOLO model instance.
source (str): Selected video source.
enable_trk (str): Enable tracking option.
conf (float): Confidence threshold.
iou (float): IoU threshold for non-max suppression.
vid_file_name (str): Name of the uploaded video file.
selected_ind (list): List of selected class indices.
# Hide main menu style Methods:
menu_style_cfg = """<style>MainMenu {visibility: hidden;}</style>""" web_ui: Sets up the Streamlit web interface with custom HTML elements.
sidebar: Configures the Streamlit sidebar for model and inference settings.
source_upload: Handles video file uploads through the Streamlit interface.
configure: Configures the model and loads selected classes for inference.
inference: Performs real-time object detection inference.
# Main title of streamlit application Examples:
main_title_cfg = """<div><h1 style="color:#FF64DA; text-align:center; font-size:40px; >>> inf = solutions.Inference(model="path/to/model/file.pt") # Model is not necessary argument.
font-family: 'Archivo', sans-serif; margin-top:-50px;margin-bottom:20px;"> >>> inf.inference()
Ultralytics YOLO Streamlit Application """
</h1></div>"""
# Subtitle of streamlit application def __init__(self, **kwargs):
sub_title_cfg = """<div><h4 style="color:#042AFF; text-align:center; """
font-family: 'Archivo', sans-serif; margin-top:-15px; margin-bottom:50px;"> Initializes the Inference class, checking Streamlit requirements and setting up the model path.
Experience real-time object detection on your webcam with the power of Ultralytics YOLO! 🚀</h4>
</div>"""
# Set html page configuration Args:
st.set_page_config(page_title="Ultralytics Streamlit App", layout="wide", initial_sidebar_state="auto") **kwargs (Dict): Additional keyword arguments for model configuration.
"""
check_requirements("streamlit>=1.29.0") # scope imports for faster ultralytics package load speeds
import streamlit as st
# Append the custom HTML self.st = st
st.markdown(menu_style_cfg, unsafe_allow_html=True)
st.markdown(main_title_cfg, unsafe_allow_html=True)
st.markdown(sub_title_cfg, unsafe_allow_html=True)
# Add ultralytics logo in sidebar self.temp_dict = {"model": None} # Temporary dict to store the model path
with st.sidebar: self.temp_dict.update(kwargs)
logo = "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg"
st.image(logo, width=250)
# Add elements to vertical setting menu self.model_path = None # Store model file name with path
st.sidebar.title("User Configuration") if self.temp_dict["model"] is not None:
self.model_path = self.temp_dict["model"]
# Add video source selection dropdown LOGGER.info(f"Ultralytics Solutions: ✅ {self.temp_dict}")
source = st.sidebar.selectbox(
"Video",
("webcam", "video"),
)
vid_file_name = "" def web_ui(self):
if source == "video": """Sets up the Streamlit web interface with custom HTML elements."""
vid_file = st.sidebar.file_uploader("Upload Video File", type=["mp4", "mov", "avi", "mkv"]) menu_style_cfg = """<style>MainMenu {visibility: hidden;}</style>""" # Hide main menu style
if vid_file is not None:
g = io.BytesIO(vid_file.read()) # BytesIO Object
vid_location = "ultralytics.mp4"
with open(vid_location, "wb") as out: # Open temporary file as bytes
out.write(g.read()) # Read bytes into file
vid_file_name = "ultralytics.mp4"
elif source == "webcam":
vid_file_name = 0
# Add dropdown menu for model selection # Main title of streamlit application
available_models = [x.replace("yolo", "YOLO") for x in GITHUB_ASSETS_STEMS if x.startswith("yolo11")] main_title_cfg = """<div><h1 style="color:#FF64DA; text-align:center; font-size:40px; margin-top:-50px;
if model: font-family: 'Archivo', sans-serif; margin-bottom:20px;">Ultralytics YOLO Streamlit Application</h1></div>"""
available_models.insert(0, model.split(".pt")[0]) # insert model without suffix as *.pt is added later
selected_model = st.sidebar.selectbox("Model", available_models) # Subtitle of streamlit application
with st.spinner("Model is downloading..."): sub_title_cfg = """<div><h4 style="color:#042AFF; text-align:center; font-family: 'Archivo', sans-serif;
model = YOLO(f"{selected_model.lower()}.pt") # Load the YOLO model margin-top:-15px; margin-bottom:50px;">Experience real-time object detection on your webcam with the power
class_names = list(model.names.values()) # Convert dictionary to list of class names of Ultralytics YOLO! 🚀</h4></div>"""
st.success("Model loaded successfully!")
# Multiselect box with class names and get indices of selected classes # Set html page configuration and append custom HTML
selected_classes = st.sidebar.multiselect("Classes", class_names, default=class_names[:3]) self.st.set_page_config(page_title="Ultralytics Streamlit App", layout="wide", initial_sidebar_state="auto")
selected_ind = [class_names.index(option) for option in selected_classes] self.st.markdown(menu_style_cfg, unsafe_allow_html=True)
self.st.markdown(main_title_cfg, unsafe_allow_html=True)
self.st.markdown(sub_title_cfg, unsafe_allow_html=True)
if not isinstance(selected_ind, list): # Ensure selected_options is a list def sidebar(self):
selected_ind = list(selected_ind) """Configures the Streamlit sidebar for model and inference settings."""
with self.st.sidebar: # Add Ultralytics LOGO
logo = "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg"
self.st.image(logo, width=250)
enable_trk = st.sidebar.radio("Enable Tracking", ("Yes", "No")) self.st.sidebar.title("User Configuration") # Add elements to vertical setting menu
conf = float(st.sidebar.slider("Confidence Threshold", 0.0, 1.0, 0.25, 0.01)) self.source = self.st.sidebar.selectbox(
iou = float(st.sidebar.slider("IoU Threshold", 0.0, 1.0, 0.45, 0.01)) "Video",
("webcam", "video"),
) # Add source selection dropdown
self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) # Enable object tracking
self.conf = float(self.st.sidebar.slider("Confidence Threshold", 0.0, 1.0, 0.25, 0.01)) # Slider for confidence
self.iou = float(self.st.sidebar.slider("IoU Threshold", 0.0, 1.0, 0.45, 0.01)) # Slider for NMS threshold
col1, col2 = st.columns(2) col1, col2 = self.st.columns(2)
org_frame = col1.empty() self.org_frame = col1.empty()
ann_frame = col2.empty() self.ann_frame = col2.empty()
self.fps_display = self.st.sidebar.empty() # Placeholder for FPS display
fps_display = st.sidebar.empty() # Placeholder for FPS display def source_upload(self):
"""Handles video file uploads through the Streamlit interface."""
self.vid_file_name = ""
if self.source == "video":
vid_file = self.st.sidebar.file_uploader("Upload Video File", type=["mp4", "mov", "avi", "mkv"])
if vid_file is not None:
g = io.BytesIO(vid_file.read()) # BytesIO Object
with open("ultralytics.mp4", "wb") as out: # Open temporary file as bytes
out.write(g.read()) # Read bytes into file
self.vid_file_name = "ultralytics.mp4"
elif self.source == "webcam":
self.vid_file_name = 0
if st.sidebar.button("Start"): def configure(self):
videocapture = cv2.VideoCapture(vid_file_name) # Capture the video """Configures the model and loads selected classes for inference."""
# Add dropdown menu for model selection
available_models = [x.replace("yolo", "YOLO") for x in GITHUB_ASSETS_STEMS if x.startswith("yolo11")]
if self.model_path: # If user provided the custom model, insert model without suffix as *.pt is added later
available_models.insert(0, self.model_path.split(".pt")[0])
selected_model = self.st.sidebar.selectbox("Model", available_models)
if not videocapture.isOpened(): with self.st.spinner("Model is downloading..."):
st.error("Could not open webcam.") self.model = YOLO(f"{selected_model.lower()}.pt") # Load the YOLO model
class_names = list(self.model.names.values()) # Convert dictionary to list of class names
self.st.success("Model loaded successfully!")
stop_button = st.button("Stop") # Button to stop the inference # Multiselect box with class names and get indices of selected classes
selected_classes = self.st.sidebar.multiselect("Classes", class_names, default=class_names[:3])
self.selected_ind = [class_names.index(option) for option in selected_classes]
while videocapture.isOpened(): if not isinstance(self.selected_ind, list): # Ensure selected_options is a list
success, frame = videocapture.read() self.selected_ind = list(self.selected_ind)
if not success:
st.warning("Failed to read frame from webcam. Please make sure the webcam is connected properly.")
break
prev_time = time.time() # Store initial time for FPS calculation def inference(self):
"""Performs real-time object detection inference."""
self.web_ui() # Initialize the web interface
self.sidebar() # Create the sidebar
self.source_upload() # Upload the video source
self.configure() # Configure the app
# Store model predictions if self.st.sidebar.button("Start"):
if enable_trk == "Yes": stop_button = self.st.button("Stop") # Button to stop the inference
results = model.track(frame, conf=conf, iou=iou, classes=selected_ind, persist=True) cap = cv2.VideoCapture(self.vid_file_name) # Capture the video
else: if not cap.isOpened():
results = model(frame, conf=conf, iou=iou, classes=selected_ind) self.st.error("Could not open webcam.")
annotated_frame = results[0].plot() # Add annotations on frame while cap.isOpened():
success, frame = cap.read()
if not success:
st.warning("Failed to read frame from webcam. Please make sure the webcam is connected properly.")
break
# Calculate model FPS prev_time = time.time() # Store initial time for FPS calculation
curr_time = time.time()
fps = 1 / (curr_time - prev_time)
# display frame # Store model predictions
org_frame.image(frame, channels="BGR") if self.enable_trk == "Yes":
ann_frame.image(annotated_frame, channels="BGR") results = self.model.track(
frame, conf=self.conf, iou=self.iou, classes=self.selected_ind, persist=True
)
else:
results = self.model(frame, conf=self.conf, iou=self.iou, classes=self.selected_ind)
annotated_frame = results[0].plot() # Add annotations on frame
if stop_button: fps = 1 / (time.time() - prev_time) # Calculate model FPS
videocapture.release() # Release the capture
torch.cuda.empty_cache() # Clear CUDA memory
st.stop() # Stop streamlit app
# Display FPS in sidebar if stop_button:
fps_display.metric("FPS", f"{fps:.2f}") cap.release() # Release the capture
self.st.stop() # Stop streamlit app
# Release the capture self.fps_display.metric("FPS", f"{fps:.2f}") # Display FPS in sidebar
videocapture.release() self.org_frame.image(frame, channels="BGR") # Display original frame
self.ann_frame.image(annotated_frame, channels="BGR") # Display processed frame
# Clear CUDA memory cap.release() # Release the capture
torch.cuda.empty_cache() cv2.destroyAllWindows() # Destroy window
# Destroy window
cv2.destroyAllWindows()
# Main function call
if __name__ == "__main__": if __name__ == "__main__":
inference() import sys # Import the sys module for accessing command-line arguments
model = None # Initialize the model variable as None
# Check if a model name is provided as a command-line argument
args = len(sys.argv)
if args > 1:
model = args # Assign the first argument as the model name
# Create an instance of the Inference class and run inference
Inference(model=model).inference()