ultralytics 8.2.50 new Streamlit live inference Solution (#14210)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Muhammad Rizwan Munawar <muhammadrizwanmunawar123@gmail.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: RizwanMunawar <chr043416@gmail.com>
Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2024-07-05 22:02:38 +02:00 committed by GitHub
parent 5f0fd710a4
commit 26a664f636
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 350 additions and 22 deletions

View file

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.49"
__version__ = "8.2.50"
import os

View file

@ -78,10 +78,13 @@ CLI_HELP_MSG = f"""
4. Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
6. Explore your datasets using semantic search and SQL with a simple GUI powered by Ultralytics Explorer API
5. Explore your datasets using semantic search and SQL with a simple GUI powered by Ultralytics Explorer API
yolo explorer
5. Run special commands:
6. Streamlit real-time object detection on your webcam with Ultralytics YOLOv8
yolo streamlit-predict
7. Run special commands:
yolo help
yolo checks
yolo version
@ -514,6 +517,13 @@ def handle_explorer():
subprocess.run(["streamlit", "run", ROOT / "data/explorer/gui/dash.py", "--server.maxMessageSize", "2048"])
def handle_streamlit_inference():
"""Open the Ultralytics Live Inference streamlit app for real time object detection."""
checks.check_requirements(["streamlit", "opencv-python", "torch"])
LOGGER.info("💡 Loading Ultralytics Live Inference app...")
subprocess.run(["streamlit", "run", ROOT / "solutions/streamlit_inference.py", "--server.headless", "true"])
def parse_key_value_pair(pair):
"""Parse one 'key=value' pair and return key and value."""
k, v = pair.split("=", 1) # split on first '=' sign
@ -582,6 +592,7 @@ def entrypoint(debug=""):
"login": lambda: handle_yolo_hub(args),
"copy-cfg": copy_default_cfg,
"explorer": lambda: handle_explorer(),
"streamlit-predict": lambda: handle_streamlit_inference(),
}
full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special}

View file

@ -686,7 +686,7 @@ class RandomFlip:
flip_idx (array-like, optional): Index mapping for flipping keypoints, if any.
"""
assert direction in {"horizontal", "vertical"}, f"Support direction `horizontal` or `vertical`, got {direction}"
assert 0 <= p <= 1.0
assert 0 <= p <= 1.0, f"The probability should be in range [0, 1], but got {p}."
self.p = p
self.direction = direction
@ -1210,7 +1210,7 @@ def classify_transforms(
import torchvision.transforms as T # scope for faster 'import ultralytics'
if isinstance(size, (tuple, list)):
assert len(size) == 2
assert len(size) == 2, f"'size' tuples must be length 2, not length {len(size)}"
scale_size = tuple(math.floor(x / crop_fraction) for x in size)
else:
scale_size = math.floor(size / crop_fraction)
@ -1288,7 +1288,7 @@ def classify_augmentations(
secondary_tfl = []
disable_color_jitter = False
if auto_augment:
assert isinstance(auto_augment, str)
assert isinstance(auto_augment, str), f"Provided argument should be string, but got type {type(auto_augment)}"
# color jitter is typically disabled if AA/RA on,
# this allows override without breaking old hparm cfgs
disable_color_jitter = not force_color_jitter

View file

@ -42,7 +42,7 @@ class BaseTensor(SimpleClass):
base_tensor = BaseTensor(data, orig_shape)
```
"""
assert isinstance(data, (torch.Tensor, np.ndarray))
assert isinstance(data, (torch.Tensor, np.ndarray)), "data must be torch.Tensor or np.ndarray"
self.data = data
self.orig_shape = orig_shape

View file

@ -286,7 +286,7 @@ class FastSAMPrompt:
def box_prompt(self, bbox):
"""Modifies the bounding box properties and calculates IoU between masks and bounding box."""
if self.results[0].masks is not None:
assert bbox[2] != 0 and bbox[3] != 0
assert bbox[2] != 0 and bbox[3] != 0, "Bounding box width and height should not be zero"
masks = self.results[0].masks.data
target_height, target_width = self.results[0].orig_shape
h = masks.shape[1]

View file

@ -133,7 +133,7 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
"""Remove small disconnected regions or holes in a mask, returning the mask and a modification indicator."""
import cv2 # type: ignore
assert mode in {"holes", "islands"}
assert mode in {"holes", "islands"}, f"Provided mode {mode} is invalid"
correct_holes = mode == "holes"
working_mask = (correct_holes ^ mask).astype(np.uint8)
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)

View file

@ -261,7 +261,7 @@ class Attention(torch.nn.Module):
"""
super().__init__()
assert isinstance(resolution, tuple) and len(resolution) == 2
assert isinstance(resolution, tuple) and len(resolution) == 2, "'resolution' argument not tuple of length 2"
self.num_heads = num_heads
self.scale = key_dim**-0.5
self.key_dim = key_dim

View file

@ -72,8 +72,8 @@ class WorldTrainerFromScratch(WorldTrainer):
"""
final_data = {}
data_yaml = self.args.data
assert data_yaml.get("train", False) # object365.yaml
assert data_yaml.get("val", False) # lvis.yaml
assert data_yaml.get("train", False), "train dataset not found" # object365.yaml
assert data_yaml.get("val", False), "validation dataset not found" # lvis.yaml
data = {k: [check_det_dataset(d) for d in v.get("yolo_data", [])] for k, v in data_yaml.items()}
assert len(data["val"]) == 1, f"Only support validating on 1 dataset for now, but got {len(data['val'])}."
val_split = "minival" if "lvis" in data["val"][0]["val"] else "val"

View file

@ -8,6 +8,7 @@ from .object_counter import ObjectCounter
from .parking_management import ParkingManagement, ParkingPtsSelection
from .queue_management import QueueManager
from .speed_estimation import SpeedEstimator
from .streamlit_inference import inference
__all__ = (
"AIGym",

View file

@ -0,0 +1,154 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import io
import time
import cv2
import torch
def inference():
"""Runs real-time object detection on video input using Ultralytics YOLOv8 in a Streamlit application."""
# Scope imports for faster ultralytics package load speeds
import streamlit as st
from ultralytics import YOLO
# Hide main menu style
menu_style_cfg = """<style>MainMenu {visibility: hidden;}</style>"""
# Main title of streamlit application
main_title_cfg = """<div><h1 style="color:#FF64DA; text-align:center; font-size:40px;
font-family: 'Archivo', sans-serif; margin-top:-50px;margin-bottom:20px;">
Ultralytics YOLOv8 Streamlit Application
</h1></div>"""
# Subtitle of streamlit application
sub_title_cfg = """<div><h4 style="color:#042AFF; text-align:center;
font-family: 'Archivo', sans-serif; margin-top:-15px; margin-bottom:50px;">
Experience real-time object detection on your webcam with the power of Ultralytics YOLOv8! 🚀</h4>
</div>"""
# Set html page configuration
st.set_page_config(page_title="Ultralytics Streamlit App", layout="wide", initial_sidebar_state="auto")
# Append the custom HTML
st.markdown(menu_style_cfg, unsafe_allow_html=True)
st.markdown(main_title_cfg, unsafe_allow_html=True)
st.markdown(sub_title_cfg, unsafe_allow_html=True)
# Add ultralytics logo in sidebar
with st.sidebar:
logo = "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg"
st.image(logo, width=250)
# Add elements to vertical setting menu
st.sidebar.title("User Configuration")
# Add video source selection dropdown
source = st.sidebar.selectbox(
"Video",
("webcam", "video"),
)
vid_file_name = ""
if source == "video":
vid_file = st.sidebar.file_uploader("Upload Video File", type=["mp4", "mov", "avi", "mkv"])
if vid_file is not None:
g = io.BytesIO(vid_file.read()) # BytesIO Object
vid_location = "ultralytics.mp4"
with open(vid_location, "wb") as out: # Open temporary file as bytes
out.write(g.read()) # Read bytes into file
vid_file_name = "ultralytics.mp4"
elif source == "webcam":
vid_file_name = 0
# Add dropdown menu for model selection
yolov8_model = st.sidebar.selectbox(
"Model",
(
"YOLOv8n",
"YOLOv8s",
"YOLOv8m",
"YOLOv8l",
"YOLOv8x",
"YOLOv8n-Seg",
"YOLOv8s-Seg",
"YOLOv8m-Seg",
"YOLOv8l-Seg",
"YOLOv8x-Seg",
"YOLOv8n-Pose",
"YOLOv8s-Pose",
"YOLOv8m-Pose",
"YOLOv8l-Pose",
"YOLOv8x-Pose",
),
)
model = YOLO(f"{yolov8_model.lower()}.pt") # Load the yolov8 model
class_names = list(model.names.values()) # Convert dictionary to list of class names
# Multiselect box with class names and get indices of selected classes
selected_classes = st.sidebar.multiselect("Classes", class_names, default=class_names[:3])
selected_ind = [class_names.index(option) for option in selected_classes]
if not isinstance(selected_ind, list): # Ensure selected_options is a list
selected_ind = list(selected_ind)
conf_thres = st.sidebar.slider("Confidence Threshold", 0.0, 1.0, 0.25, 0.01)
nms_thres = st.sidebar.slider("NMS Threshold", 0.0, 1.0, 0.45, 0.01)
col1, col2 = st.columns(2)
org_frame = col1.empty()
ann_frame = col2.empty()
fps_display = st.sidebar.empty() # Placeholder for FPS display
if st.sidebar.button("Start"):
videocapture = cv2.VideoCapture(vid_file_name) # Capture the video
if not videocapture.isOpened():
st.error("Could not open webcam.")
stop_button = st.button("Stop") # Button to stop the inference
prev_time = 0
while videocapture.isOpened():
success, frame = videocapture.read()
if not success:
st.warning("Failed to read frame from webcam. Please make sure the webcam is connected properly.")
break
curr_time = time.time()
fps = 1 / (curr_time - prev_time)
prev_time = curr_time
# Store model predictions
results = model(frame, conf=float(conf_thres), iou=float(nms_thres), classes=selected_ind)
annotated_frame = results[0].plot() # Add annotations on frame
# display frame
org_frame.image(frame, channels="BGR")
ann_frame.image(annotated_frame, channels="BGR")
if stop_button:
videocapture.release() # Release the capture
torch.cuda.empty_cache() # Clear CUDA memory
st.stop() # Stop streamlit app
# Display FPS in sidebar
fps_display.metric("FPS", f"{fps:.2f}")
# Release the capture
videocapture.release()
# Clear CUDA memory
torch.cuda.empty_cache()
# Destroy window
cv2.destroyAllWindows()
# Main function call
if __name__ == "__main__":
inference()

View file

@ -740,18 +740,18 @@ class Annotator:
cv2.polylines(self.im, [np.int32([mask])], isClosed=True, color=mask_color, thickness=2)
label = f"Track ID: {track_label}" if track_label else det_label
text_size, _ = cv2.getTextSize(label, 0, 0.7, 1)
text_size, _ = cv2.getTextSize(label, 0, self.sf, self.tf)
cv2.rectangle(
self.im,
(int(mask[0][0]) - text_size[0] // 2 - 10, int(mask[0][1]) - text_size[1] - 10),
(int(mask[0][0]) + text_size[0] // 2 + 5, int(mask[0][1] + 5)),
(int(mask[0][0]) + text_size[0] // 2 + 10, int(mask[0][1] + 10)),
mask_color,
-1,
)
cv2.putText(
self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1]) - 5), 0, 0.7, (255, 255, 255), 2
self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1])), 0, self.sf, (255, 255, 255), self.tf
)
def plot_distance_and_line(self, distance_m, distance_mm, centroids, line_color, centroid_color):