PyCharm Code Inspect fixes for Solutions and Examples (#18393)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
3e65fc2421
commit
b1af683d7a
13 changed files with 92 additions and 62 deletions
|
|
@ -79,18 +79,14 @@ Queue management using [Ultralytics YOLO11](https://github.com/ultralytics/ultra
|
|||
# Process video
|
||||
while cap.isOpened():
|
||||
success, im0 = cap.read()
|
||||
|
||||
if success:
|
||||
out = queue.process_queue(im0)
|
||||
video_writer.write(im0)
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
break
|
||||
continue
|
||||
|
||||
if not success:
|
||||
print("Video frame is empty or video processing has been successfully completed.")
|
||||
break
|
||||
out = queue.process_queue(im0)
|
||||
video_writer.write(im0)
|
||||
|
||||
cap.release()
|
||||
video_writer.release()
|
||||
cv2.destroyAllWindows()
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -81,18 +81,14 @@ keywords: Ultralytics YOLO11, speed estimation, object tracking, computer vision
|
|||
# Process video
|
||||
while cap.isOpened():
|
||||
success, im0 = cap.read()
|
||||
|
||||
if success:
|
||||
out = speed.estimate_speed(im0)
|
||||
video_writer.write(im0)
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
break
|
||||
continue
|
||||
|
||||
if not success:
|
||||
print("Video frame is empty or video processing has been successfully completed.")
|
||||
break
|
||||
out = speed.estimate_speed(im0)
|
||||
video_writer.write(im0)
|
||||
|
||||
cap.release()
|
||||
video_writer.release()
|
||||
cv2.destroyAllWindows()
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -132,17 +132,19 @@ def run(
|
|||
model.to("cuda") if device == "0" else model.to("cpu")
|
||||
|
||||
# Extract classes names
|
||||
names = model.model.names
|
||||
names = model.names
|
||||
|
||||
# Video setup
|
||||
videocapture = cv2.VideoCapture(source)
|
||||
frame_width, frame_height = int(videocapture.get(3)), int(videocapture.get(4))
|
||||
fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*"mp4v")
|
||||
frame_width = int(videocapture.get(3))
|
||||
frame_height = int(videocapture.get(4))
|
||||
fps = int(videocapture.get(5))
|
||||
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
||||
|
||||
# Output setup
|
||||
save_dir = increment_path(Path("ultralytics_rc_output") / "exp", exist_ok)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_writer = cv2.VideoWriter(str(save_dir / f"{Path(source).stem}.mp4"), fourcc, fps, (frame_width, frame_height))
|
||||
video_writer = cv2.VideoWriter(str(save_dir / f"{Path(source).stem}.avi"), fourcc, fps, (frame_width, frame_height))
|
||||
|
||||
# Iterate over video frames
|
||||
while videocapture.isOpened():
|
||||
|
|
@ -241,9 +243,9 @@ def parse_opt():
|
|||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(opt):
|
||||
def main(options):
|
||||
"""Main function."""
|
||||
run(**vars(opt))
|
||||
run(**vars(options))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -24,11 +24,16 @@ class SAHIInference:
|
|||
yolo11_model_path = f"models/{weights}"
|
||||
download_yolo11n_model(yolo11_model_path)
|
||||
self.detection_model = AutoDetectionModel.from_pretrained(
|
||||
model_type="ultralytics", model_path=yolo11_model_path, confidence_threshold=0.3, device="cpu"
|
||||
model_type="ultralytics", model_path=yolo11_model_path, device="cpu"
|
||||
)
|
||||
|
||||
def inference(
|
||||
self, weights="yolo11n.pt", source="test.mp4", view_img=False, save_img=False, exist_ok=False, track=False
|
||||
self,
|
||||
weights="yolo11n.pt",
|
||||
source="test.mp4",
|
||||
view_img=False,
|
||||
save_img=False,
|
||||
exist_ok=False,
|
||||
):
|
||||
"""
|
||||
Run object detection on a video using YOLO11 and SAHI.
|
||||
|
|
@ -39,7 +44,6 @@ class SAHIInference:
|
|||
view_img (bool): Show results.
|
||||
save_img (bool): Save results.
|
||||
exist_ok (bool): Overwrite existing files.
|
||||
track (bool): Enable object tracking with SAHI
|
||||
"""
|
||||
# Video setup
|
||||
cap = cv2.VideoCapture(source)
|
||||
|
|
@ -50,8 +54,8 @@ class SAHIInference:
|
|||
save_dir = increment_path(Path("ultralytics_results_with_sahi") / "exp", exist_ok)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_writer = cv2.VideoWriter(
|
||||
str(save_dir / f"{Path(source).stem}.mp4"),
|
||||
cv2.VideoWriter_fourcc(*"mp4v"),
|
||||
str(save_dir / f"{Path(source).stem}.avi"),
|
||||
cv2.VideoWriter_fourcc(*"MJPG"),
|
||||
int(cap.get(5)),
|
||||
(frame_width, frame_height),
|
||||
)
|
||||
|
|
@ -68,8 +72,6 @@ class SAHIInference:
|
|||
self.detection_model,
|
||||
slice_height=512,
|
||||
slice_width=512,
|
||||
overlap_height_ratio=0.2,
|
||||
overlap_width_ratio=0.2,
|
||||
)
|
||||
detection_data = [
|
||||
(det.category.name, det.category.id, (det.bbox.minx, det.bbox.miny, det.bbox.maxx, det.bbox.maxy))
|
||||
|
|
|
|||
|
|
@ -694,7 +694,7 @@ def handle_yolo_solutions(args: List[str]) -> None:
|
|||
str(ROOT / "solutions/streamlit_inference.py"),
|
||||
"--server.headless",
|
||||
"true",
|
||||
overrides["model"],
|
||||
overrides.pop("model", "yolo11n.pt"),
|
||||
]
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -170,7 +170,7 @@ class Analytics(BaseSolution):
|
|||
for key in count_dict.keys():
|
||||
y_data_dict[key] = np.append(y_data_dict[key], float(count_dict[key]))
|
||||
if len(y_data_dict[key]) < max_length:
|
||||
y_data_dict[key] = np.pad(y_data_dict[key], (0, max_length - len(y_data_dict[key])), "constant")
|
||||
y_data_dict[key] = np.pad(y_data_dict[key], (0, max_length - len(y_data_dict[key])))
|
||||
if len(x_data) > self.max_points:
|
||||
x_data = x_data[1:]
|
||||
for key in count_dict.keys():
|
||||
|
|
|
|||
|
|
@ -45,6 +45,8 @@ class DistanceCalculation(BaseSolution):
|
|||
self.left_mouse_count = 0
|
||||
self.selected_boxes = {}
|
||||
|
||||
self.centroids = [] # Initialize empty list to store centroids
|
||||
|
||||
def mouse_event_for_distance(self, event, x, y, flags, param):
|
||||
"""
|
||||
Handles mouse events to select regions in a real-time video stream for distance calculation.
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ class Heatmap(ObjectCounter):
|
|||
|
||||
# store colormap
|
||||
self.colormap = cv2.COLORMAP_PARULA if self.CFG["colormap"] is None else self.CFG["colormap"]
|
||||
self.heatmap = None
|
||||
|
||||
def heatmap_effect(self, box):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -34,7 +34,6 @@ class ParkingPtsSelection:
|
|||
canvas_max_height (int): Maximum height of the canvas.
|
||||
|
||||
Methods:
|
||||
setup_ui: Sets up the Tkinter UI components.
|
||||
initialize_properties: Initializes the necessary properties.
|
||||
upload_image: Uploads an image, resizes it to fit the canvas, and displays it.
|
||||
on_canvas_click: Handles mouse clicks to add points for bounding boxes.
|
||||
|
|
@ -55,20 +54,22 @@ class ParkingPtsSelection:
|
|||
from tkinter import filedialog, messagebox
|
||||
|
||||
self.tk, self.filedialog, self.messagebox = tk, filedialog, messagebox
|
||||
self.setup_ui()
|
||||
self.initialize_properties()
|
||||
self.master.mainloop()
|
||||
|
||||
def setup_ui(self):
|
||||
"""Sets up the Tkinter UI components for the parking zone points selection interface."""
|
||||
self.master = self.tk.Tk()
|
||||
self.master = self.tk.Tk() # Reference to the main application window or parent widget
|
||||
self.master.title("Ultralytics Parking Zones Points Selector")
|
||||
self.master.resizable(False, False)
|
||||
|
||||
# Canvas for image display
|
||||
self.canvas = self.tk.Canvas(self.master, bg="white")
|
||||
self.canvas = self.tk.Canvas(self.master, bg="white") # Canvas widget for displaying images or graphics
|
||||
self.canvas.pack(side=self.tk.BOTTOM)
|
||||
|
||||
self.image = None # Variable to store the loaded image
|
||||
self.canvas_image = None # Reference to the image displayed on the canvas
|
||||
self.canvas_max_width = None # Maximum allowed width for the canvas
|
||||
self.canvas_max_height = None # Maximum allowed height for the canvas
|
||||
self.rg_data = None # Data related to region or annotation management
|
||||
self.current_box = None # Stores the currently selected or active bounding box
|
||||
self.imgh = None # Height of the current image
|
||||
self.imgw = None # Width of the current image
|
||||
|
||||
# Button frame with buttons
|
||||
button_frame = self.tk.Frame(self.master)
|
||||
button_frame.pack(side=self.tk.TOP)
|
||||
|
|
@ -80,6 +81,9 @@ class ParkingPtsSelection:
|
|||
]:
|
||||
self.tk.Button(button_frame, text=text, command=cmd).pack(side=self.tk.LEFT)
|
||||
|
||||
self.initialize_properties()
|
||||
self.master.mainloop()
|
||||
|
||||
def initialize_properties(self):
|
||||
"""Initialize properties for image, canvas, bounding boxes, and dimensions."""
|
||||
self.image = self.canvas_image = None
|
||||
|
|
@ -105,7 +109,7 @@ class ParkingPtsSelection:
|
|||
)
|
||||
|
||||
self.canvas.config(width=canvas_width, height=canvas_height)
|
||||
self.canvas_image = ImageTk.PhotoImage(self.image.resize((canvas_width, canvas_height), Image.LANCZOS))
|
||||
self.canvas_image = ImageTk.PhotoImage(self.image.resize((canvas_width, canvas_height)))
|
||||
self.canvas.create_image(0, 0, anchor=self.tk.NW, image=self.canvas_image)
|
||||
self.canvas.bind("<Button-1>", self.on_canvas_click)
|
||||
|
||||
|
|
@ -144,8 +148,13 @@ class ParkingPtsSelection:
|
|||
"""Saves the selected parking zone points to a JSON file with scaled coordinates."""
|
||||
scale_w, scale_h = self.imgw / self.canvas.winfo_width(), self.imgh / self.canvas.winfo_height()
|
||||
data = [{"points": [(int(x * scale_w), int(y * scale_h)) for x, y in box]} for box in self.rg_data]
|
||||
with open("bounding_boxes.json", "w") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
from io import StringIO # Function level import, as it's only required to store coordinates, not every frame
|
||||
|
||||
write_buffer = StringIO()
|
||||
json.dump(data, write_buffer, indent=4)
|
||||
with open("bounding_boxes.json", "w", encoding="utf-8") as f:
|
||||
f.write(write_buffer.getvalue())
|
||||
self.messagebox.showinfo("Success", "Bounding boxes saved to bounding_boxes.json")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from ultralytics.solutions.solutions import BaseSolution
|
||||
from ultralytics.utils import LOGGER
|
||||
from ultralytics.utils.plotting import Annotator, colors
|
||||
|
||||
|
||||
|
|
@ -81,6 +82,9 @@ class RegionCounter(BaseSolution):
|
|||
|
||||
# Draw regions and process counts for each defined area
|
||||
for idx, (region_name, reg_pts) in enumerate(regions.items(), start=1):
|
||||
if not isinstance(reg_pts, list) or not all(isinstance(pt, tuple) for pt in reg_pts):
|
||||
LOGGER.warning(f"Invalid region points for {region_name}: {reg_pts}")
|
||||
continue # Skip invalid entries
|
||||
color = colors(idx, True)
|
||||
self.annotator.draw_region(reg_pts=reg_pts, color=color, thickness=self.line_width * 2)
|
||||
self.add_region(region_name, reg_pts, color, self.annotator.get_txt_color())
|
||||
|
|
|
|||
|
|
@ -34,6 +34,9 @@ class SecurityAlarm(BaseSolution):
|
|||
super().__init__(**kwargs)
|
||||
self.email_sent = False
|
||||
self.records = self.CFG["records"]
|
||||
self.server = None
|
||||
self.to_email = ""
|
||||
self.from_email = ""
|
||||
|
||||
def authenticate(self, from_email, password, to_email):
|
||||
"""
|
||||
|
|
@ -91,7 +94,7 @@ class SecurityAlarm(BaseSolution):
|
|||
|
||||
# Add the text message body
|
||||
message_body = f"Ultralytics ALERT!!! " f"{records} objects have been detected!!"
|
||||
message.attach(MIMEText(message_body, "plain"))
|
||||
message.attach(MIMEText(message_body))
|
||||
|
||||
# Attach the image
|
||||
image_attachment = MIMEImage(img_bytes, name="ultralytics.jpg")
|
||||
|
|
|
|||
|
|
@ -56,6 +56,14 @@ class BaseSolution:
|
|||
self.Polygon = Polygon
|
||||
self.Point = Point
|
||||
self.prep = prep
|
||||
self.annotator = None # Initialize annotator
|
||||
self.tracks = None
|
||||
self.track_data = None
|
||||
self.boxes = []
|
||||
self.clss = []
|
||||
self.track_ids = []
|
||||
self.track_line = None
|
||||
self.r_s = None
|
||||
|
||||
# Load config and update with args
|
||||
DEFAULT_SOL_DICT.update(kwargs)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import io
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
|
|
@ -52,11 +51,19 @@ class Inference:
|
|||
check_requirements("streamlit>=1.29.0") # scope imports for faster ultralytics package load speeds
|
||||
import streamlit as st
|
||||
|
||||
self.st = st
|
||||
self.st = st # Reference to the Streamlit class instance
|
||||
self.source = None # Placeholder for video or webcam source details
|
||||
self.enable_trk = False # Flag to toggle object tracking
|
||||
self.conf = 0.25 # Confidence threshold for detection
|
||||
self.iou = 0.45 # Intersection-over-Union (IoU) threshold for non-maximum suppression
|
||||
self.org_frame = None # Container for the original frame to be displayed
|
||||
self.ann_frame = None # Container for the annotated frame to be displayed
|
||||
self.vid_file_name = None # Holds the name of the video file
|
||||
self.selected_ind = [] # List of selected classes for detection or tracking
|
||||
self.model = None # Container for the loaded model instance
|
||||
|
||||
self.temp_dict = {"model": None} # Temporary dict to store the model path
|
||||
self.temp_dict.update(kwargs)
|
||||
|
||||
self.model_path = None # Store model file name with path
|
||||
if self.temp_dict["model"] is not None:
|
||||
self.model_path = self.temp_dict["model"]
|
||||
|
|
@ -77,7 +84,7 @@ class Inference:
|
|||
of Ultralytics YOLO! 🚀</h4></div>"""
|
||||
|
||||
# Set html page configuration and append custom HTML
|
||||
self.st.set_page_config(page_title="Ultralytics Streamlit App", layout="wide", initial_sidebar_state="auto")
|
||||
self.st.set_page_config(page_title="Ultralytics Streamlit App", layout="wide")
|
||||
self.st.markdown(menu_style_cfg, unsafe_allow_html=True)
|
||||
self.st.markdown(main_title_cfg, unsafe_allow_html=True)
|
||||
self.st.markdown(sub_title_cfg, unsafe_allow_html=True)
|
||||
|
|
@ -94,13 +101,14 @@ class Inference:
|
|||
("webcam", "video"),
|
||||
) # Add source selection dropdown
|
||||
self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) # Enable object tracking
|
||||
self.conf = float(self.st.sidebar.slider("Confidence Threshold", 0.0, 1.0, 0.25, 0.01)) # Slider for confidence
|
||||
self.iou = float(self.st.sidebar.slider("IoU Threshold", 0.0, 1.0, 0.45, 0.01)) # Slider for NMS threshold
|
||||
self.conf = float(
|
||||
self.st.sidebar.slider("Confidence Threshold", 0.0, 1.0, self.conf, 0.01)
|
||||
) # Slider for confidence
|
||||
self.iou = float(self.st.sidebar.slider("IoU Threshold", 0.0, 1.0, self.iou, 0.01)) # Slider for NMS threshold
|
||||
|
||||
col1, col2 = self.st.columns(2)
|
||||
self.org_frame = col1.empty()
|
||||
self.ann_frame = col2.empty()
|
||||
self.fps_display = self.st.sidebar.empty() # Placeholder for FPS display
|
||||
|
||||
def source_upload(self):
|
||||
"""Handles video file uploads through the Streamlit interface."""
|
||||
|
|
@ -153,8 +161,6 @@ class Inference:
|
|||
self.st.warning("Failed to read frame from webcam. Please verify the webcam is connected properly.")
|
||||
break
|
||||
|
||||
prev_time = time.time() # Store initial time for FPS calculation
|
||||
|
||||
# Store model predictions
|
||||
if self.enable_trk == "Yes":
|
||||
results = self.model.track(
|
||||
|
|
@ -164,13 +170,10 @@ class Inference:
|
|||
results = self.model(frame, conf=self.conf, iou=self.iou, classes=self.selected_ind)
|
||||
annotated_frame = results[0].plot() # Add annotations on frame
|
||||
|
||||
fps = 1 / (time.time() - prev_time) # Calculate model FPS
|
||||
|
||||
if stop_button:
|
||||
cap.release() # Release the capture
|
||||
self.st.stop() # Stop streamlit app
|
||||
|
||||
self.fps_display.metric("FPS", f"{fps:.2f}") # Display FPS in sidebar
|
||||
self.org_frame.image(frame, channels="BGR") # Display original frame
|
||||
self.ann_frame.image(annotated_frame, channels="BGR") # Display processed frame
|
||||
|
||||
|
|
@ -181,8 +184,12 @@ class Inference:
|
|||
if __name__ == "__main__":
|
||||
import sys # Import the sys module for accessing command-line arguments
|
||||
|
||||
model = None # Initialize the model variable as None
|
||||
|
||||
# Check if a model name is provided as a command-line argument
|
||||
args = len(sys.argv)
|
||||
model = args if args > 1 else None
|
||||
if args > 1:
|
||||
model = sys.argv[1] # Assign the first argument as the model name
|
||||
|
||||
# Create an instance of the Inference class and run inference
|
||||
Inference(model=model).inference()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue