ultralytics 8.2.95 faster checkpoint saving (#16311)
Signed-off-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
7b19e0daa0
commit
ba438aea5a
5 changed files with 53 additions and 58 deletions
|
|
@ -292,42 +292,27 @@ Finally, after all threads have completed their task, the windows displaying the
|
||||||
|
|
||||||
# Define model names and video sources
|
# Define model names and video sources
|
||||||
MODEL_NAMES = ["yolov8n.pt", "yolov8n-seg.pt"]
|
MODEL_NAMES = ["yolov8n.pt", "yolov8n-seg.pt"]
|
||||||
SOURCES = ["path/to/video1.mp4", 0] # local video, 0 for webcam
|
SOURCES = ["path/to/video.mp4", "0"] # local video, 0 for webcam
|
||||||
|
|
||||||
|
|
||||||
def run_tracker_in_thread(model_name, filename, index):
|
def run_tracker_in_thread(model_name, filename):
|
||||||
"""
|
"""
|
||||||
Runs a video file or webcam stream concurrently with the YOLOv8 model using threading. This function captures video
|
Run YOLO tracker in its own thread for concurrent processing.
|
||||||
frames from a given file or camera source and utilizes the YOLOv8 model for object tracking. The function runs in
|
|
||||||
its own thread for concurrent processing.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
model_name (str): The YOLOv8 model object.
|
||||||
filename (str): The path to the video file or the identifier for the webcam/external camera source.
|
filename (str): The path to the video file or the identifier for the webcam/external camera source.
|
||||||
model (obj): The YOLOv8 model object.
|
|
||||||
index (int): An index to uniquely identify the file being processed, used for display purposes.
|
|
||||||
"""
|
"""
|
||||||
model = YOLO(model_name)
|
model = YOLO(model_name)
|
||||||
video = cv2.VideoCapture(filename)
|
results = model.track(filename, save=True, stream=True)
|
||||||
|
for r in results:
|
||||||
while True:
|
pass
|
||||||
ret, frame = video.read()
|
|
||||||
if not ret:
|
|
||||||
break
|
|
||||||
|
|
||||||
results = model.track(frame, persist=True)
|
|
||||||
res_plotted = results[0].plot()
|
|
||||||
cv2.imshow(f"Tracking_Stream_{index}", res_plotted)
|
|
||||||
|
|
||||||
if cv2.waitKey(1) == ord("q"):
|
|
||||||
break
|
|
||||||
|
|
||||||
video.release()
|
|
||||||
|
|
||||||
|
|
||||||
# Create and start tracker threads using a for loop
|
# Create and start tracker threads using a for loop
|
||||||
tracker_threads = []
|
tracker_threads = []
|
||||||
for i, (video_file, model_name) in enumerate(zip(SOURCES, MODEL_NAMES), start=1):
|
for video_file, model_name in zip(SOURCES, MODEL_NAMES):
|
||||||
thread = threading.Thread(target=run_tracker_in_thread, args=(model_name, video_file, i), daemon=True)
|
thread = threading.Thread(target=run_tracker_in_thread, args=(model_name, video_file), daemon=True)
|
||||||
tracker_threads.append(thread)
|
tracker_threads.append(thread)
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
|
|
@ -395,35 +380,37 @@ To run object tracking on multiple video streams simultaneously, you can use Pyt
|
||||||
|
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
|
|
||||||
|
# Define model names and video sources
|
||||||
def run_tracker_in_thread(filename, model, file_index):
|
MODEL_NAMES = ["yolov8n.pt", "yolov8n-seg.pt"]
|
||||||
video = cv2.VideoCapture(filename)
|
SOURCES = ["path/to/video.mp4", "0"] # local video, 0 for webcam
|
||||||
while True:
|
|
||||||
ret, frame = video.read()
|
|
||||||
if not ret:
|
|
||||||
break
|
|
||||||
results = model.track(frame, persist=True)
|
|
||||||
res_plotted = results[0].plot()
|
|
||||||
cv2.imshow(f"Tracking_Stream_{file_index}", res_plotted)
|
|
||||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
|
||||||
break
|
|
||||||
video.release()
|
|
||||||
|
|
||||||
|
|
||||||
model1 = YOLO("yolov8n.pt")
|
def run_tracker_in_thread(model_name, filename):
|
||||||
model2 = YOLO("yolov8n-seg.pt")
|
"""
|
||||||
video_file1 = "path/to/video1.mp4"
|
Run YOLO tracker in its own thread for concurrent processing.
|
||||||
video_file2 = 0 # Path to a second video file, or 0 for a webcam
|
|
||||||
|
|
||||||
tracker_thread1 = threading.Thread(target=run_tracker_in_thread, args=(video_file1, model1, 1), daemon=True)
|
Args:
|
||||||
tracker_thread2 = threading.Thread(target=run_tracker_in_thread, args=(video_file2, model2, 2), daemon=True)
|
model_name (str): The YOLOv8 model object.
|
||||||
|
filename (str): The path to the video file or the identifier for the webcam/external camera source.
|
||||||
|
"""
|
||||||
|
model = YOLO(model_name)
|
||||||
|
results = model.track(filename, save=True, stream=True)
|
||||||
|
for r in results:
|
||||||
|
pass
|
||||||
|
|
||||||
tracker_thread1.start()
|
|
||||||
tracker_thread2.start()
|
|
||||||
|
|
||||||
tracker_thread1.join()
|
# Create and start tracker threads using a for loop
|
||||||
tracker_thread2.join()
|
tracker_threads = []
|
||||||
|
for video_file, model_name in zip(SOURCES, MODEL_NAMES):
|
||||||
|
thread = threading.Thread(target=run_tracker_in_thread, args=(model_name, video_file), daemon=True)
|
||||||
|
tracker_threads.append(thread)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
# Wait for all tracker threads to finish
|
||||||
|
for thread in tracker_threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
# Clean up and close windows
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.2.94"
|
__version__ = "8.2.95"
|
||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
|
||||||
|
|
@ -668,13 +668,14 @@ class BaseTrainer:
|
||||||
|
|
||||||
def final_eval(self):
|
def final_eval(self):
|
||||||
"""Performs final evaluation and validation for object detection YOLO model."""
|
"""Performs final evaluation and validation for object detection YOLO model."""
|
||||||
|
ckpt = {}
|
||||||
for f in self.last, self.best:
|
for f in self.last, self.best:
|
||||||
if f.exists():
|
if f.exists():
|
||||||
strip_optimizer(f) # strip optimizers
|
if f is self.last:
|
||||||
if f is self.best:
|
ckpt = strip_optimizer(f)
|
||||||
if self.last.is_file(): # update best.pt train_metrics from last.pt
|
elif f is self.best:
|
||||||
k = "train_results"
|
k = "train_results" # update best.pt train_metrics from last.pt
|
||||||
torch.save({**torch.load(self.best), **{k: torch.load(self.last)[k]}}, self.best)
|
strip_optimizer(f, updates={k: ckpt[k]} if k in ckpt else None)
|
||||||
LOGGER.info(f"\nValidating {f}...")
|
LOGGER.info(f"\nValidating {f}...")
|
||||||
self.validator.args.plots = self.args.plots
|
self.validator.args.plots = self.args.plots
|
||||||
self.metrics = self.validator(model=f)
|
self.metrics = self.validator(model=f)
|
||||||
|
|
|
||||||
|
|
@ -759,6 +759,10 @@ class SafeClass:
|
||||||
"""Initialize SafeClass instance, ignoring all arguments."""
|
"""Initialize SafeClass instance, ignoring all arguments."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
"""Run SafeClass instance, ignoring all arguments."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SafeUnpickler(pickle.Unpickler):
|
class SafeUnpickler(pickle.Unpickler):
|
||||||
"""Custom Unpickler that replaces unknown classes with SafeClass."""
|
"""Custom Unpickler that replaces unknown classes with SafeClass."""
|
||||||
|
|
|
||||||
|
|
@ -533,16 +533,17 @@ class ModelEMA:
|
||||||
copy_attr(self.ema, model, include, exclude)
|
copy_attr(self.ema, model, include, exclude)
|
||||||
|
|
||||||
|
|
||||||
def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
|
def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict = None) -> dict:
|
||||||
"""
|
"""
|
||||||
Strip optimizer from 'f' to finalize training, optionally save as 's'.
|
Strip optimizer from 'f' to finalize training, optionally save as 's'.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
f (str): file path to model to strip the optimizer from. Default is 'best.pt'.
|
f (str): file path to model to strip the optimizer from. Default is 'best.pt'.
|
||||||
s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
|
s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
|
||||||
|
updates (dict): a dictionary of updates to overlay onto the checkpoint before saving.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None
|
(dict): The combined checkpoint dictionary.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
|
|
@ -562,9 +563,9 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
|
||||||
assert "model" in x, "'model' missing from checkpoint"
|
assert "model" in x, "'model' missing from checkpoint"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.warning(f"WARNING ⚠️ Skipping {f}, not a valid Ultralytics model: {e}")
|
LOGGER.warning(f"WARNING ⚠️ Skipping {f}, not a valid Ultralytics model: {e}")
|
||||||
return
|
return {}
|
||||||
|
|
||||||
updates = {
|
metadata = {
|
||||||
"date": datetime.now().isoformat(),
|
"date": datetime.now().isoformat(),
|
||||||
"version": __version__,
|
"version": __version__,
|
||||||
"license": "AGPL-3.0 License (https://ultralytics.com/license)",
|
"license": "AGPL-3.0 License (https://ultralytics.com/license)",
|
||||||
|
|
@ -591,9 +592,11 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
|
||||||
# x['model'].args = x['train_args']
|
# x['model'].args = x['train_args']
|
||||||
|
|
||||||
# Save
|
# Save
|
||||||
torch.save({**updates, **x}, s or f, use_dill=False) # combine dicts (prefer to the right)
|
combined = {**metadata, **x, **(updates or {})}
|
||||||
|
torch.save(combined, s or f, use_dill=False) # combine dicts (prefer to the right)
|
||||||
mb = os.path.getsize(s or f) / 1e6 # file size
|
mb = os.path.getsize(s or f) / 1e6 # file size
|
||||||
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
||||||
|
return combined
|
||||||
|
|
||||||
|
|
||||||
def convert_optimizer_state_dict_to_fp16(state_dict):
|
def convert_optimizer_state_dict_to_fp16(state_dict):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue