Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-06-30 22:09:02 +02:00 committed by GitHub
parent ff63a56a42
commit 691b5daccb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 124 additions and 113 deletions

View file

@ -69,7 +69,7 @@ class TorchVisionVideoClassifier:
"""
return list(TorchVisionVideoClassifier.model_name_to_model_and_weights.keys())
def preprocess_crops_for_video_cls(self, crops: List[np.ndarray], input_size: list = [224, 224]) -> torch.Tensor:
def preprocess_crops_for_video_cls(self, crops: List[np.ndarray], input_size: list = None) -> torch.Tensor:
"""
Preprocess a list of crops for video classification.
@ -80,6 +80,8 @@ class TorchVisionVideoClassifier:
Returns:
torch.Tensor: Preprocessed crops as a tensor with dimensions (1, T, C, H, W).
"""
if input_size is None:
input_size = [224, 224]
from torchvision.transforms import v2
transform = v2.Compose(
@ -156,7 +158,7 @@ class HuggingFaceVideoClassifier:
model = model.half()
self.model = model.eval()
def preprocess_crops_for_video_cls(self, crops: List[np.ndarray], input_size: list = [224, 224]) -> torch.Tensor:
def preprocess_crops_for_video_cls(self, crops: List[np.ndarray], input_size: list = None) -> torch.Tensor:
"""
Preprocess a list of crops for video classification.
@ -167,6 +169,8 @@ class HuggingFaceVideoClassifier:
Returns:
torch.Tensor: Preprocessed crops as a tensor (1, T, C, H, W).
"""
if input_size is None:
input_size = [224, 224]
from torchvision.transforms import v2
transform = v2.Compose(
@ -266,15 +270,7 @@ def run(
video_cls_overlap_ratio: float = 0.25,
fp16: bool = False,
video_classifier_model: str = "microsoft/xclip-base-patch32",
labels: List[str] = [
"walking",
"running",
"brushing teeth",
"looking into phone",
"weight lifting",
"cooking",
"sitting",
],
labels: List[str] = None,
) -> None:
"""
Run action recognition on a video source using YOLO for object detection and a video classifier.
@ -295,6 +291,16 @@ def run(
Returns:
None</edit>
"""
if labels is None:
labels = [
"walking",
"running",
"brushing teeth",
"looking into phone",
"weight lifting",
"cooking",
"sitting",
]
# Initialize models and device
device = select_device(device)
yolo_model = YOLO(weights).to(device)
@ -312,9 +318,7 @@ def run(
# Initialize video capture
if source.startswith("http") and urlparse(source).hostname in {"www.youtube.com", "youtube.com", "youtu.be"}:
source = get_best_youtube_url(source)
elif source.endswith(".mp4"):
pass
else:
elif not source.endswith(".mp4"):
raise ValueError("Invalid source. Supported sources are YouTube URLs and MP4 files.")
cap = cv2.VideoCapture(source)

View file

@ -18,6 +18,7 @@ class LetterBox:
def __init__(
self, new_shape=(img_width, img_height), auto=False, scaleFill=False, scaleup=True, center=True, stride=32
):
"""Initializes LetterBox with parameters for reshaping and transforming image while maintaining aspect ratio."""
self.new_shape = new_shape
self.auto = auto
self.scaleFill = scaleFill