ultralytics 8.3.54 New Streamlit inference Solution (#18316)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
5b76bed7d0
commit
51026a9a4a
13 changed files with 251 additions and 188 deletions
|
|
@ -55,7 +55,7 @@ See below for a quickstart install and usage examples, and see our [Docs](https:
|
||||||
<details open>
|
<details open>
|
||||||
<summary>Install</summary>
|
<summary>Install</summary>
|
||||||
|
|
||||||
Pip install the ultralytics package including all [requirements](https://github.com/ultralytics/ultralytics/blob/main/pyproject.toml) in a [**Python>=3.8**](https://www.python.org/) environment with [**PyTorch>=1.8**](https://pytorch.org/get-started/locally/).
|
Pip install the Ultralytics package including all [requirements](https://github.com/ultralytics/ultralytics/blob/main/pyproject.toml) in a [**Python>=3.8**](https://www.python.org/) environment with [**PyTorch>=1.8**](https://pytorch.org/get-started/locally/).
|
||||||
|
|
||||||
[](https://pypi.org/project/ultralytics/) [](https://www.pepy.tech/projects/ultralytics) [](https://pypi.org/project/ultralytics/)
|
[](https://pypi.org/project/ultralytics/) [](https://www.pepy.tech/projects/ultralytics) [](https://pypi.org/project/ultralytics/)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@
|
||||||
[](https://www.pepy.tech/projects/ultralytics)
|
[](https://www.pepy.tech/projects/ultralytics)
|
||||||
[](https://pypi.org/project/ultralytics/)
|
[](https://pypi.org/project/ultralytics/)
|
||||||
|
|
||||||
To install the ultralytics package in developer mode, ensure you have Git and Python 3 installed on your system. Then, follow these steps:
|
To install the Ultralytics package in developer mode, ensure you have Git and Python 3 installed on your system. Then, follow these steps:
|
||||||
|
|
||||||
1. Clone the ultralytics repository to your local machine using Git:
|
1. Clone the ultralytics repository to your local machine using Git:
|
||||||
|
|
||||||
|
|
@ -38,7 +38,7 @@ To install the ultralytics package in developer mode, ensure you have Git and Py
|
||||||
pip install -e '.[dev]'
|
pip install -e '.[dev]'
|
||||||
```
|
```
|
||||||
|
|
||||||
- This command installs the ultralytics package along with all development dependencies, allowing you to modify the package code and have the changes immediately reflected in your Python environment.
|
- This command installs the Ultralytics package along with all development dependencies, allowing you to modify the package code and have the changes immediately reflected in your Python environment.
|
||||||
|
|
||||||
## 🚀 Building and Serving Locally
|
## 🚀 Building and Serving Locally
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,9 @@ Streamlit makes it simple to build and deploy interactive web applications. Comb
|
||||||
=== "CLI"
|
=== "CLI"
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
yolo streamlit-predict
|
yolo solutions inference
|
||||||
|
|
||||||
|
yolo solutions inference model="path/to/model/file.pt"
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
|
|
@ -51,7 +53,11 @@ Streamlit makes it simple to build and deploy interactive web applications. Comb
|
||||||
```python
|
```python
|
||||||
from ultralytics import solutions
|
from ultralytics import solutions
|
||||||
|
|
||||||
solutions.inference()
|
inf = solutions.Inference(
|
||||||
|
model="yolo11n.pt", # You can use any model that Ultralytics support, i.e. YOLO11, or custom trained model
|
||||||
|
)
|
||||||
|
|
||||||
|
inf.inference()
|
||||||
|
|
||||||
### Make sure to run the file using command `streamlit run <file-name.py>`
|
### Make sure to run the file using command `streamlit run <file-name.py>`
|
||||||
```
|
```
|
||||||
|
|
@ -67,8 +73,11 @@ You can optionally supply a specific model in Python:
|
||||||
```python
|
```python
|
||||||
from ultralytics import solutions
|
from ultralytics import solutions
|
||||||
|
|
||||||
# Pass a model as an argument
|
inf = solutions.Inference(
|
||||||
solutions.inference(model="path/to/model.pt")
|
model="yolo11n.pt", # You can use any model that Ultralytics support, i.e. YOLO11, YOLOv10
|
||||||
|
)
|
||||||
|
|
||||||
|
inf.inference()
|
||||||
|
|
||||||
### Make sure to run the file using command `streamlit run <file-name.py>`
|
### Make sure to run the file using command `streamlit run <file-name.py>`
|
||||||
```
|
```
|
||||||
|
|
@ -111,7 +120,11 @@ Then, you can create a basic Streamlit application to run live inference:
|
||||||
```python
|
```python
|
||||||
from ultralytics import solutions
|
from ultralytics import solutions
|
||||||
|
|
||||||
solutions.inference()
|
inf = solutions.Inference(
|
||||||
|
model="yolo11n.pt", # You can use any model that Ultralytics support, i.e. YOLO11, YOLOv10
|
||||||
|
)
|
||||||
|
|
||||||
|
inf.inference()
|
||||||
|
|
||||||
### Make sure to run the file using command `streamlit run <file-name.py>`
|
### Make sure to run the file using command `streamlit run <file-name.py>`
|
||||||
```
|
```
|
||||||
|
|
@ -119,7 +132,7 @@ Then, you can create a basic Streamlit application to run live inference:
|
||||||
=== "CLI"
|
=== "CLI"
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
yolo streamlit-predict
|
yolo solutions inference
|
||||||
```
|
```
|
||||||
|
|
||||||
For more details on the practical setup, refer to the [Streamlit Application Code section](#streamlit-application-code) of the documentation.
|
For more details on the practical setup, refer to the [Streamlit Application Code section](#streamlit-application-code) of the documentation.
|
||||||
|
|
|
||||||
|
|
@ -51,10 +51,6 @@ keywords: Ultralytics, YOLO, configuration, cfg2dict, get_cfg, check_cfg, save_d
|
||||||
|
|
||||||
<br><br><hr><br>
|
<br><br><hr><br>
|
||||||
|
|
||||||
## ::: ultralytics.cfg.handle_streamlit_inference
|
|
||||||
|
|
||||||
<br><br><hr><br>
|
|
||||||
|
|
||||||
## ::: ultralytics.cfg.parse_key_value_pair
|
## ::: ultralytics.cfg.parse_key_value_pair
|
||||||
|
|
||||||
<br><br><hr><br>
|
<br><br><hr><br>
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,6 @@ keywords: Ultralytics, YOLOv8, live inference, real-time object detection, Strea
|
||||||
|
|
||||||
<br>
|
<br>
|
||||||
|
|
||||||
## ::: ultralytics.solutions.streamlit_inference.inference
|
## ::: ultralytics.solutions.streamlit_inference.Inference
|
||||||
|
|
||||||
<br><br>
|
<br><br>
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ import torch
|
||||||
from ultralytics.utils import ASSETS, yaml_load
|
from ultralytics.utils import ASSETS, yaml_load
|
||||||
from ultralytics.utils.checks import check_requirements, check_yaml
|
from ultralytics.utils.checks import check_requirements, check_yaml
|
||||||
|
|
||||||
|
|
||||||
class RTDETR:
|
class RTDETR:
|
||||||
"""RTDETR object detection model class for handling inference and visualization."""
|
"""RTDETR object detection model class for handling inference and visualization."""
|
||||||
|
|
||||||
|
|
@ -71,11 +72,17 @@ class RTDETR:
|
||||||
|
|
||||||
# Draw a filled rectangle as the background for the label text
|
# Draw a filled rectangle as the background for the label text
|
||||||
cv2.rectangle(
|
cv2.rectangle(
|
||||||
self.img, (int(label_x), int(label_y - label_height)), (int(label_x + label_width), int(label_y + label_height)), color, cv2.FILLED
|
self.img,
|
||||||
|
(int(label_x), int(label_y - label_height)),
|
||||||
|
(int(label_x + label_width), int(label_y + label_height)),
|
||||||
|
color,
|
||||||
|
cv2.FILLED,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Draw the label text on the image
|
# Draw the label text on the image
|
||||||
cv2.putText(self.img, label, (int(label_x), int(label_y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
|
cv2.putText(
|
||||||
|
self.img, label, (int(label_x), int(label_y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA
|
||||||
|
)
|
||||||
|
|
||||||
def preprocess(self):
|
def preprocess(self):
|
||||||
"""
|
"""
|
||||||
|
|
@ -110,8 +117,7 @@ class RTDETR:
|
||||||
|
|
||||||
def bbox_cxcywh_to_xyxy(self, boxes):
|
def bbox_cxcywh_to_xyxy(self, boxes):
|
||||||
"""
|
"""
|
||||||
Converts bounding boxes from (center x, center y, width, height) format
|
Converts bounding boxes from (center x, center y, width, height) format to (x_min, y_min, x_max, y_max) format.
|
||||||
to (x_min, y_min, x_max, y_max) format.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
boxes (numpy.ndarray): An array of shape (N, 4) where each row represents
|
boxes (numpy.ndarray): An array of shape (N, 4) where each row represents
|
||||||
|
|
@ -189,6 +195,7 @@ class RTDETR:
|
||||||
# Process and return the model output
|
# Process and return the model output
|
||||||
return self.postprocess(model_output)
|
return self.postprocess(model_output)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Set up argument parser for command-line arguments
|
# Set up argument parser for command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
|
||||||
|
|
@ -82,4 +82,4 @@ def test_instance_segmentation():
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
def test_streamlit_predict():
|
def test_streamlit_predict():
|
||||||
"""Test streamlit predict live inference solution."""
|
"""Test streamlit predict live inference solution."""
|
||||||
solutions.inference()
|
solutions.Inference().inference()
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.3.53"
|
__version__ = "8.3.54"
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,7 @@ SOLUTION_MAP = {
|
||||||
"workout": ("AIGym", "monitor"),
|
"workout": ("AIGym", "monitor"),
|
||||||
"analytics": ("Analytics", "process_data"),
|
"analytics": ("Analytics", "process_data"),
|
||||||
"trackzone": ("TrackZone", "trackzone"),
|
"trackzone": ("TrackZone", "trackzone"),
|
||||||
|
"inference": ("Inference", "inference"),
|
||||||
"help": None,
|
"help": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -98,6 +99,9 @@ SOLUTIONS_HELP_MSG = f"""
|
||||||
|
|
||||||
6. Track objects within specific zones
|
6. Track objects within specific zones
|
||||||
yolo solutions trackzone source="path/to/video/file.mp4" region=[(150, 150), (1130, 150), (1130, 570), (150, 570)]
|
yolo solutions trackzone source="path/to/video/file.mp4" region=[(150, 150), (1130, 150), (1130, 570), (150, 570)]
|
||||||
|
|
||||||
|
7. Streamlit real-time webcam inference GUI
|
||||||
|
yolo streamlit-predict
|
||||||
"""
|
"""
|
||||||
CLI_HELP_MSG = f"""
|
CLI_HELP_MSG = f"""
|
||||||
Arguments received: {str(['yolo'] + ARGV[1:])}. Ultralytics 'yolo' commands use the following syntax:
|
Arguments received: {str(['yolo'] + ARGV[1:])}. Ultralytics 'yolo' commands use the following syntax:
|
||||||
|
|
@ -121,13 +125,10 @@ CLI_HELP_MSG = f"""
|
||||||
4. Export a YOLO11n classification model to ONNX format at image size 224 by 128 (no TASK required)
|
4. Export a YOLO11n classification model to ONNX format at image size 224 by 128 (no TASK required)
|
||||||
yolo export model=yolo11n-cls.pt format=onnx imgsz=224,128
|
yolo export model=yolo11n-cls.pt format=onnx imgsz=224,128
|
||||||
|
|
||||||
5. Streamlit real-time webcam inference GUI
|
5. Ultralytics solutions usage
|
||||||
yolo streamlit-predict
|
|
||||||
|
|
||||||
6. Ultralytics solutions usage
|
|
||||||
yolo solutions count or in {list(SOLUTION_MAP.keys())[1:-1]} source="path/to/video/file.mp4"
|
yolo solutions count or in {list(SOLUTION_MAP.keys())[1:-1]} source="path/to/video/file.mp4"
|
||||||
|
|
||||||
7. Run special commands:
|
6. Run special commands:
|
||||||
yolo help
|
yolo help
|
||||||
yolo checks
|
yolo checks
|
||||||
yolo version
|
yolo version
|
||||||
|
|
@ -636,6 +637,9 @@ def handle_yolo_solutions(args: List[str]) -> None:
|
||||||
Run analytics with custom configuration:
|
Run analytics with custom configuration:
|
||||||
>>> handle_yolo_solutions(["analytics", "conf=0.25", "source=path/to/video/file.mp4"])
|
>>> handle_yolo_solutions(["analytics", "conf=0.25", "source=path/to/video/file.mp4"])
|
||||||
|
|
||||||
|
Run inference with custom configuration, requires Streamlit version 1.29.0 or higher.
|
||||||
|
>>> handle_yolo_solutions(["inference", "model=yolo11n.pt"])
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
- Default configurations are merged from DEFAULT_SOL_DICT and DEFAULT_CFG_DICT
|
- Default configurations are merged from DEFAULT_SOL_DICT and DEFAULT_CFG_DICT
|
||||||
- Arguments can be provided in the format 'key=value' or as boolean flags
|
- Arguments can be provided in the format 'key=value' or as boolean flags
|
||||||
|
|
@ -645,7 +649,9 @@ def handle_yolo_solutions(args: List[str]) -> None:
|
||||||
- For 'analytics' solution, frame numbers are tracked for generating analytical graphs
|
- For 'analytics' solution, frame numbers are tracked for generating analytical graphs
|
||||||
- Video processing can be interrupted by pressing 'q'
|
- Video processing can be interrupted by pressing 'q'
|
||||||
- Processes video frames sequentially and saves output in .avi format
|
- Processes video frames sequentially and saves output in .avi format
|
||||||
- If no source is specified, downloads and uses a default sample video
|
- If no source is specified, downloads and uses a default sample video\
|
||||||
|
- The inference solution will be launched using the 'streamlit run' command.
|
||||||
|
- The Streamlit app file is located in the Ultralytics package directory.
|
||||||
"""
|
"""
|
||||||
full_args_dict = {**DEFAULT_SOL_DICT, **DEFAULT_CFG_DICT} # arguments dictionary
|
full_args_dict = {**DEFAULT_SOL_DICT, **DEFAULT_CFG_DICT} # arguments dictionary
|
||||||
overrides = {}
|
overrides = {}
|
||||||
|
|
@ -678,60 +684,56 @@ def handle_yolo_solutions(args: List[str]) -> None:
|
||||||
if args and args[0] == "help": # Add check for return if user call `yolo solutions help`
|
if args and args[0] == "help": # Add check for return if user call `yolo solutions help`
|
||||||
return
|
return
|
||||||
|
|
||||||
cls, method = SOLUTION_MAP[s_n] # solution class name, method name and default source
|
if s_n == "inference":
|
||||||
|
checks.check_requirements("streamlit>=1.29.0")
|
||||||
|
LOGGER.info("💡 Loading Ultralytics live inference app...")
|
||||||
|
subprocess.run(
|
||||||
|
[ # Run subprocess with Streamlit custom argument
|
||||||
|
"streamlit",
|
||||||
|
"run",
|
||||||
|
str(ROOT / "solutions/streamlit_inference.py"),
|
||||||
|
"--server.headless",
|
||||||
|
"true",
|
||||||
|
overrides["model"],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cls, method = SOLUTION_MAP[s_n] # solution class name, method name and default source
|
||||||
|
|
||||||
from ultralytics import solutions # import ultralytics solutions
|
from ultralytics import solutions # import ultralytics solutions
|
||||||
|
|
||||||
solution = getattr(solutions, cls)(IS_CLI=True, **overrides) # get solution class i.e ObjectCounter
|
solution = getattr(solutions, cls)(IS_CLI=True, **overrides) # get solution class i.e ObjectCounter
|
||||||
process = getattr(solution, method) # get specific function of class for processing i.e, count from ObjectCounter
|
process = getattr(
|
||||||
|
solution, method
|
||||||
|
) # get specific function of class for processing i.e, count from ObjectCounter
|
||||||
|
|
||||||
cap = cv2.VideoCapture(solution.CFG["source"]) # read the video file
|
cap = cv2.VideoCapture(solution.CFG["source"]) # read the video file
|
||||||
|
|
||||||
# extract width, height and fps of the video file, create save directory and initialize video writer
|
# extract width, height and fps of the video file, create save directory and initialize video writer
|
||||||
import os # for directory creation
|
import os # for directory creation
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from ultralytics.utils.files import increment_path # for output directory path update
|
from ultralytics.utils.files import increment_path # for output directory path update
|
||||||
|
|
||||||
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
|
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
|
||||||
if s_n == "analytics": # analytical graphs follow fixed shape for output i.e w=1920, h=1080
|
if s_n == "analytics": # analytical graphs follow fixed shape for output i.e w=1920, h=1080
|
||||||
w, h = 1920, 1080
|
w, h = 1920, 1080
|
||||||
save_dir = increment_path(Path("runs") / "solutions" / "exp", exist_ok=False)
|
save_dir = increment_path(Path("runs") / "solutions" / "exp", exist_ok=False)
|
||||||
save_dir.mkdir(parents=True, exist_ok=True) # create the output directory
|
save_dir.mkdir(parents=True, exist_ok=True) # create the output directory
|
||||||
vw = cv2.VideoWriter(os.path.join(save_dir, "solution.avi"), cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
|
vw = cv2.VideoWriter(os.path.join(save_dir, "solution.avi"), cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
|
||||||
|
|
||||||
try: # Process video frames
|
try: # Process video frames
|
||||||
f_n = 0 # frame number, required for analytical graphs
|
f_n = 0 # frame number, required for analytical graphs
|
||||||
while cap.isOpened():
|
while cap.isOpened():
|
||||||
success, frame = cap.read()
|
success, frame = cap.read()
|
||||||
if not success:
|
if not success:
|
||||||
break
|
break
|
||||||
frame = process(frame, f_n := f_n + 1) if s_n == "analytics" else process(frame)
|
frame = process(frame, f_n := f_n + 1) if s_n == "analytics" else process(frame)
|
||||||
vw.write(frame)
|
vw.write(frame)
|
||||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||||
break
|
break
|
||||||
finally:
|
finally:
|
||||||
cap.release()
|
cap.release()
|
||||||
|
|
||||||
|
|
||||||
def handle_streamlit_inference():
|
|
||||||
"""
|
|
||||||
Open the Ultralytics Live Inference Streamlit app for real-time object detection.
|
|
||||||
|
|
||||||
This function initializes and runs a Streamlit application designed for performing live object detection using
|
|
||||||
Ultralytics models. It checks for the required Streamlit package and launches the app.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> handle_streamlit_inference()
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
- Requires Streamlit version 1.29.0 or higher.
|
|
||||||
- The app is launched using the 'streamlit run' command.
|
|
||||||
- The Streamlit app file is located in the Ultralytics package directory.
|
|
||||||
"""
|
|
||||||
checks.check_requirements("streamlit>=1.29.0")
|
|
||||||
LOGGER.info("💡 Loading Ultralytics Live Inference app...")
|
|
||||||
subprocess.run(["streamlit", "run", ROOT / "solutions/streamlit_inference.py", "--server.headless", "true"])
|
|
||||||
|
|
||||||
|
|
||||||
def parse_key_value_pair(pair: str = "key=value"):
|
def parse_key_value_pair(pair: str = "key=value"):
|
||||||
|
|
@ -853,7 +855,6 @@ def entrypoint(debug=""):
|
||||||
"login": lambda: handle_yolo_hub(args),
|
"login": lambda: handle_yolo_hub(args),
|
||||||
"logout": lambda: handle_yolo_hub(args),
|
"logout": lambda: handle_yolo_hub(args),
|
||||||
"copy-cfg": copy_default_cfg,
|
"copy-cfg": copy_default_cfg,
|
||||||
"streamlit-predict": lambda: handle_streamlit_inference(),
|
|
||||||
"solutions": lambda: handle_yolo_solutions(args[1:]),
|
"solutions": lambda: handle_yolo_solutions(args[1:]),
|
||||||
}
|
}
|
||||||
full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special}
|
full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special}
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from .queue_management import QueueManager
|
||||||
from .region_counter import RegionCounter
|
from .region_counter import RegionCounter
|
||||||
from .security_alarm import SecurityAlarm
|
from .security_alarm import SecurityAlarm
|
||||||
from .speed_estimation import SpeedEstimator
|
from .speed_estimation import SpeedEstimator
|
||||||
from .streamlit_inference import inference
|
from .streamlit_inference import Inference
|
||||||
from .trackzone import TrackZone
|
from .trackzone import TrackZone
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
|
|
@ -23,7 +23,7 @@ __all__ = (
|
||||||
"QueueManager",
|
"QueueManager",
|
||||||
"SpeedEstimator",
|
"SpeedEstimator",
|
||||||
"Analytics",
|
"Analytics",
|
||||||
"inference",
|
"Inference",
|
||||||
"RegionCounter",
|
"RegionCounter",
|
||||||
"TrackZone",
|
"TrackZone",
|
||||||
"SecurityAlarm",
|
"SecurityAlarm",
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,9 @@ import json
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ultralytics.solutions.solutions import LOGGER, BaseSolution, check_requirements
|
from ultralytics.solutions.solutions import BaseSolution
|
||||||
|
from ultralytics.utils import LOGGER
|
||||||
|
from ultralytics.utils.checks import check_requirements
|
||||||
from ultralytics.utils.plotting import Annotator
|
from ultralytics.utils.plotting import Annotator
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
from ultralytics.solutions.solutions import LOGGER, BaseSolution
|
from ultralytics.solutions.solutions import BaseSolution
|
||||||
|
from ultralytics.utils import LOGGER
|
||||||
from ultralytics.utils.plotting import Annotator, colors
|
from ultralytics.utils.plotting import Annotator, colors
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,145 +4,188 @@ import io
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import torch
|
|
||||||
|
|
||||||
|
from ultralytics import YOLO
|
||||||
|
from ultralytics.utils import LOGGER
|
||||||
from ultralytics.utils.checks import check_requirements
|
from ultralytics.utils.checks import check_requirements
|
||||||
from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS
|
from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS
|
||||||
|
|
||||||
|
|
||||||
def inference(model=None):
|
class Inference:
|
||||||
"""Performs real-time object detection on video input using YOLO in a Streamlit web application."""
|
"""
|
||||||
check_requirements("streamlit>=1.29.0") # scope imports for faster ultralytics package load speeds
|
A class to perform object detection, image classification, image segmentation and pose estimation inference using
|
||||||
import streamlit as st
|
Streamlit and Ultralytics YOLO models. It provides the functionalities such as loading models, configuring settings,
|
||||||
|
uploading video files, and performing real-time inference.
|
||||||
|
|
||||||
from ultralytics import YOLO
|
Attributes:
|
||||||
|
st (module): Streamlit module for UI creation.
|
||||||
|
temp_dict (dict): Temporary dictionary to store the model path.
|
||||||
|
model_path (str): Path to the loaded model.
|
||||||
|
model (YOLO): The YOLO model instance.
|
||||||
|
source (str): Selected video source.
|
||||||
|
enable_trk (str): Enable tracking option.
|
||||||
|
conf (float): Confidence threshold.
|
||||||
|
iou (float): IoU threshold for non-max suppression.
|
||||||
|
vid_file_name (str): Name of the uploaded video file.
|
||||||
|
selected_ind (list): List of selected class indices.
|
||||||
|
|
||||||
# Hide main menu style
|
Methods:
|
||||||
menu_style_cfg = """<style>MainMenu {visibility: hidden;}</style>"""
|
web_ui: Sets up the Streamlit web interface with custom HTML elements.
|
||||||
|
sidebar: Configures the Streamlit sidebar for model and inference settings.
|
||||||
|
source_upload: Handles video file uploads through the Streamlit interface.
|
||||||
|
configure: Configures the model and loads selected classes for inference.
|
||||||
|
inference: Performs real-time object detection inference.
|
||||||
|
|
||||||
# Main title of streamlit application
|
Examples:
|
||||||
main_title_cfg = """<div><h1 style="color:#FF64DA; text-align:center; font-size:40px;
|
>>> inf = solutions.Inference(model="path/to/model/file.pt") # Model is not necessary argument.
|
||||||
font-family: 'Archivo', sans-serif; margin-top:-50px;margin-bottom:20px;">
|
>>> inf.inference()
|
||||||
Ultralytics YOLO Streamlit Application
|
"""
|
||||||
</h1></div>"""
|
|
||||||
|
|
||||||
# Subtitle of streamlit application
|
def __init__(self, **kwargs):
|
||||||
sub_title_cfg = """<div><h4 style="color:#042AFF; text-align:center;
|
"""
|
||||||
font-family: 'Archivo', sans-serif; margin-top:-15px; margin-bottom:50px;">
|
Initializes the Inference class, checking Streamlit requirements and setting up the model path.
|
||||||
Experience real-time object detection on your webcam with the power of Ultralytics YOLO! 🚀</h4>
|
|
||||||
</div>"""
|
|
||||||
|
|
||||||
# Set html page configuration
|
Args:
|
||||||
st.set_page_config(page_title="Ultralytics Streamlit App", layout="wide", initial_sidebar_state="auto")
|
**kwargs (Dict): Additional keyword arguments for model configuration.
|
||||||
|
"""
|
||||||
|
check_requirements("streamlit>=1.29.0") # scope imports for faster ultralytics package load speeds
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
# Append the custom HTML
|
self.st = st
|
||||||
st.markdown(menu_style_cfg, unsafe_allow_html=True)
|
|
||||||
st.markdown(main_title_cfg, unsafe_allow_html=True)
|
|
||||||
st.markdown(sub_title_cfg, unsafe_allow_html=True)
|
|
||||||
|
|
||||||
# Add ultralytics logo in sidebar
|
self.temp_dict = {"model": None} # Temporary dict to store the model path
|
||||||
with st.sidebar:
|
self.temp_dict.update(kwargs)
|
||||||
logo = "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg"
|
|
||||||
st.image(logo, width=250)
|
|
||||||
|
|
||||||
# Add elements to vertical setting menu
|
self.model_path = None # Store model file name with path
|
||||||
st.sidebar.title("User Configuration")
|
if self.temp_dict["model"] is not None:
|
||||||
|
self.model_path = self.temp_dict["model"]
|
||||||
|
|
||||||
# Add video source selection dropdown
|
LOGGER.info(f"Ultralytics Solutions: ✅ {self.temp_dict}")
|
||||||
source = st.sidebar.selectbox(
|
|
||||||
"Video",
|
|
||||||
("webcam", "video"),
|
|
||||||
)
|
|
||||||
|
|
||||||
vid_file_name = ""
|
def web_ui(self):
|
||||||
if source == "video":
|
"""Sets up the Streamlit web interface with custom HTML elements."""
|
||||||
vid_file = st.sidebar.file_uploader("Upload Video File", type=["mp4", "mov", "avi", "mkv"])
|
menu_style_cfg = """<style>MainMenu {visibility: hidden;}</style>""" # Hide main menu style
|
||||||
if vid_file is not None:
|
|
||||||
g = io.BytesIO(vid_file.read()) # BytesIO Object
|
|
||||||
vid_location = "ultralytics.mp4"
|
|
||||||
with open(vid_location, "wb") as out: # Open temporary file as bytes
|
|
||||||
out.write(g.read()) # Read bytes into file
|
|
||||||
vid_file_name = "ultralytics.mp4"
|
|
||||||
elif source == "webcam":
|
|
||||||
vid_file_name = 0
|
|
||||||
|
|
||||||
# Add dropdown menu for model selection
|
# Main title of streamlit application
|
||||||
available_models = [x.replace("yolo", "YOLO") for x in GITHUB_ASSETS_STEMS if x.startswith("yolo11")]
|
main_title_cfg = """<div><h1 style="color:#FF64DA; text-align:center; font-size:40px; margin-top:-50px;
|
||||||
if model:
|
font-family: 'Archivo', sans-serif; margin-bottom:20px;">Ultralytics YOLO Streamlit Application</h1></div>"""
|
||||||
available_models.insert(0, model.split(".pt")[0]) # insert model without suffix as *.pt is added later
|
|
||||||
|
|
||||||
selected_model = st.sidebar.selectbox("Model", available_models)
|
# Subtitle of streamlit application
|
||||||
with st.spinner("Model is downloading..."):
|
sub_title_cfg = """<div><h4 style="color:#042AFF; text-align:center; font-family: 'Archivo', sans-serif;
|
||||||
model = YOLO(f"{selected_model.lower()}.pt") # Load the YOLO model
|
margin-top:-15px; margin-bottom:50px;">Experience real-time object detection on your webcam with the power
|
||||||
class_names = list(model.names.values()) # Convert dictionary to list of class names
|
of Ultralytics YOLO! 🚀</h4></div>"""
|
||||||
st.success("Model loaded successfully!")
|
|
||||||
|
|
||||||
# Multiselect box with class names and get indices of selected classes
|
# Set html page configuration and append custom HTML
|
||||||
selected_classes = st.sidebar.multiselect("Classes", class_names, default=class_names[:3])
|
self.st.set_page_config(page_title="Ultralytics Streamlit App", layout="wide", initial_sidebar_state="auto")
|
||||||
selected_ind = [class_names.index(option) for option in selected_classes]
|
self.st.markdown(menu_style_cfg, unsafe_allow_html=True)
|
||||||
|
self.st.markdown(main_title_cfg, unsafe_allow_html=True)
|
||||||
|
self.st.markdown(sub_title_cfg, unsafe_allow_html=True)
|
||||||
|
|
||||||
if not isinstance(selected_ind, list): # Ensure selected_options is a list
|
def sidebar(self):
|
||||||
selected_ind = list(selected_ind)
|
"""Configures the Streamlit sidebar for model and inference settings."""
|
||||||
|
with self.st.sidebar: # Add Ultralytics LOGO
|
||||||
|
logo = "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg"
|
||||||
|
self.st.image(logo, width=250)
|
||||||
|
|
||||||
enable_trk = st.sidebar.radio("Enable Tracking", ("Yes", "No"))
|
self.st.sidebar.title("User Configuration") # Add elements to vertical setting menu
|
||||||
conf = float(st.sidebar.slider("Confidence Threshold", 0.0, 1.0, 0.25, 0.01))
|
self.source = self.st.sidebar.selectbox(
|
||||||
iou = float(st.sidebar.slider("IoU Threshold", 0.0, 1.0, 0.45, 0.01))
|
"Video",
|
||||||
|
("webcam", "video"),
|
||||||
|
) # Add source selection dropdown
|
||||||
|
self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) # Enable object tracking
|
||||||
|
self.conf = float(self.st.sidebar.slider("Confidence Threshold", 0.0, 1.0, 0.25, 0.01)) # Slider for confidence
|
||||||
|
self.iou = float(self.st.sidebar.slider("IoU Threshold", 0.0, 1.0, 0.45, 0.01)) # Slider for NMS threshold
|
||||||
|
|
||||||
col1, col2 = st.columns(2)
|
col1, col2 = self.st.columns(2)
|
||||||
org_frame = col1.empty()
|
self.org_frame = col1.empty()
|
||||||
ann_frame = col2.empty()
|
self.ann_frame = col2.empty()
|
||||||
|
self.fps_display = self.st.sidebar.empty() # Placeholder for FPS display
|
||||||
|
|
||||||
fps_display = st.sidebar.empty() # Placeholder for FPS display
|
def source_upload(self):
|
||||||
|
"""Handles video file uploads through the Streamlit interface."""
|
||||||
|
self.vid_file_name = ""
|
||||||
|
if self.source == "video":
|
||||||
|
vid_file = self.st.sidebar.file_uploader("Upload Video File", type=["mp4", "mov", "avi", "mkv"])
|
||||||
|
if vid_file is not None:
|
||||||
|
g = io.BytesIO(vid_file.read()) # BytesIO Object
|
||||||
|
with open("ultralytics.mp4", "wb") as out: # Open temporary file as bytes
|
||||||
|
out.write(g.read()) # Read bytes into file
|
||||||
|
self.vid_file_name = "ultralytics.mp4"
|
||||||
|
elif self.source == "webcam":
|
||||||
|
self.vid_file_name = 0
|
||||||
|
|
||||||
if st.sidebar.button("Start"):
|
def configure(self):
|
||||||
videocapture = cv2.VideoCapture(vid_file_name) # Capture the video
|
"""Configures the model and loads selected classes for inference."""
|
||||||
|
# Add dropdown menu for model selection
|
||||||
|
available_models = [x.replace("yolo", "YOLO") for x in GITHUB_ASSETS_STEMS if x.startswith("yolo11")]
|
||||||
|
if self.model_path: # If user provided the custom model, insert model without suffix as *.pt is added later
|
||||||
|
available_models.insert(0, self.model_path.split(".pt")[0])
|
||||||
|
selected_model = self.st.sidebar.selectbox("Model", available_models)
|
||||||
|
|
||||||
if not videocapture.isOpened():
|
with self.st.spinner("Model is downloading..."):
|
||||||
st.error("Could not open webcam.")
|
self.model = YOLO(f"{selected_model.lower()}.pt") # Load the YOLO model
|
||||||
|
class_names = list(self.model.names.values()) # Convert dictionary to list of class names
|
||||||
|
self.st.success("Model loaded successfully!")
|
||||||
|
|
||||||
stop_button = st.button("Stop") # Button to stop the inference
|
# Multiselect box with class names and get indices of selected classes
|
||||||
|
selected_classes = self.st.sidebar.multiselect("Classes", class_names, default=class_names[:3])
|
||||||
|
self.selected_ind = [class_names.index(option) for option in selected_classes]
|
||||||
|
|
||||||
while videocapture.isOpened():
|
if not isinstance(self.selected_ind, list): # Ensure selected_options is a list
|
||||||
success, frame = videocapture.read()
|
self.selected_ind = list(self.selected_ind)
|
||||||
if not success:
|
|
||||||
st.warning("Failed to read frame from webcam. Please make sure the webcam is connected properly.")
|
|
||||||
break
|
|
||||||
|
|
||||||
prev_time = time.time() # Store initial time for FPS calculation
|
def inference(self):
|
||||||
|
"""Performs real-time object detection inference."""
|
||||||
|
self.web_ui() # Initialize the web interface
|
||||||
|
self.sidebar() # Create the sidebar
|
||||||
|
self.source_upload() # Upload the video source
|
||||||
|
self.configure() # Configure the app
|
||||||
|
|
||||||
# Store model predictions
|
if self.st.sidebar.button("Start"):
|
||||||
if enable_trk == "Yes":
|
stop_button = self.st.button("Stop") # Button to stop the inference
|
||||||
results = model.track(frame, conf=conf, iou=iou, classes=selected_ind, persist=True)
|
cap = cv2.VideoCapture(self.vid_file_name) # Capture the video
|
||||||
else:
|
if not cap.isOpened():
|
||||||
results = model(frame, conf=conf, iou=iou, classes=selected_ind)
|
self.st.error("Could not open webcam.")
|
||||||
annotated_frame = results[0].plot() # Add annotations on frame
|
while cap.isOpened():
|
||||||
|
success, frame = cap.read()
|
||||||
|
if not success:
|
||||||
|
st.warning("Failed to read frame from webcam. Please make sure the webcam is connected properly.")
|
||||||
|
break
|
||||||
|
|
||||||
# Calculate model FPS
|
prev_time = time.time() # Store initial time for FPS calculation
|
||||||
curr_time = time.time()
|
|
||||||
fps = 1 / (curr_time - prev_time)
|
|
||||||
|
|
||||||
# display frame
|
# Store model predictions
|
||||||
org_frame.image(frame, channels="BGR")
|
if self.enable_trk == "Yes":
|
||||||
ann_frame.image(annotated_frame, channels="BGR")
|
results = self.model.track(
|
||||||
|
frame, conf=self.conf, iou=self.iou, classes=self.selected_ind, persist=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
results = self.model(frame, conf=self.conf, iou=self.iou, classes=self.selected_ind)
|
||||||
|
annotated_frame = results[0].plot() # Add annotations on frame
|
||||||
|
|
||||||
if stop_button:
|
fps = 1 / (time.time() - prev_time) # Calculate model FPS
|
||||||
videocapture.release() # Release the capture
|
|
||||||
torch.cuda.empty_cache() # Clear CUDA memory
|
|
||||||
st.stop() # Stop streamlit app
|
|
||||||
|
|
||||||
# Display FPS in sidebar
|
if stop_button:
|
||||||
fps_display.metric("FPS", f"{fps:.2f}")
|
cap.release() # Release the capture
|
||||||
|
self.st.stop() # Stop streamlit app
|
||||||
|
|
||||||
# Release the capture
|
self.fps_display.metric("FPS", f"{fps:.2f}") # Display FPS in sidebar
|
||||||
videocapture.release()
|
self.org_frame.image(frame, channels="BGR") # Display original frame
|
||||||
|
self.ann_frame.image(annotated_frame, channels="BGR") # Display processed frame
|
||||||
|
|
||||||
# Clear CUDA memory
|
cap.release() # Release the capture
|
||||||
torch.cuda.empty_cache()
|
cv2.destroyAllWindows() # Destroy window
|
||||||
|
|
||||||
# Destroy window
|
|
||||||
cv2.destroyAllWindows()
|
|
||||||
|
|
||||||
|
|
||||||
# Main function call
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
inference()
|
import sys # Import the sys module for accessing command-line arguments
|
||||||
|
|
||||||
|
model = None # Initialize the model variable as None
|
||||||
|
|
||||||
|
# Check if a model name is provided as a command-line argument
|
||||||
|
args = len(sys.argv)
|
||||||
|
if args > 1:
|
||||||
|
model = args # Assign the first argument as the model name
|
||||||
|
|
||||||
|
# Create an instance of the Inference class and run inference
|
||||||
|
Inference(model=model).inference()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue