Add --classes arg in YOLOv8 Region Counter + optimize (#6568)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Muhammad Rizwan Munawar 2023-11-24 17:45:35 +05:00 committed by GitHub
parent 4ac93d82fa
commit 8f1c3f3d1e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 15 deletions

View file

@ -69,6 +69,7 @@ def run(
view_img=False,
save_img=False,
exist_ok=False,
classes=None,
line_thickness=2,
track_thickness=2,
region_thickness=2,
@ -87,6 +88,7 @@ def run(
view_img (bool): Show results.
save_img (bool): Save results.
exist_ok (bool): Overwrite existing files.
classes (list): classes to detect and track
line_thickness (int): Bounding box thickness.
track_thickness (int): Tracking line thickness
region_thickness (int): Region thickness.
@ -101,6 +103,9 @@ def run(
model = YOLO(f'{weights}')
model.to('cuda') if device == '0' else model.to('cpu')
# Extract classes names
names = model.model.names
# Video setup
videocapture = cv2.VideoCapture(source)
frame_width, frame_height = int(videocapture.get(3)), int(videocapture.get(4))
@ -119,36 +124,29 @@ def run(
vid_frame_count += 1
# Extract the results
results = model.track(frame, persist=True)
results = model.track(frame, persist=True, classes=classes)
if results[0].boxes.id is not None:
boxes = results[0].boxes.xywh.cpu()
boxes = results[0].boxes.xyxy.cpu()
track_ids = results[0].boxes.id.int().cpu().tolist()
clss = results[0].boxes.cls.cpu().tolist()
names = results[0].names
annotator = Annotator(frame, line_width=line_thickness, example=str(names))
for box, track_id, cls in zip(boxes, track_ids, clss):
x, y, w, h = box
label = str(names[cls])
xyxy = (x - w / 2), (y - h / 2), (x + w / 2), (y + h / 2)
annotator.box_label(box, str(names[cls]), color=colors(cls, True))
bbox_center = (box[0] + box[2]) / 2, (box[1] + box[3]) / 2 # Bbox center
# Bounding box plot
bbox_color = colors(cls, True)
annotator.box_label(xyxy, label, color=bbox_color)
# Tracking Lines plot
track = track_history[track_id]
track.append((float(x), float(y)))
track = track_history[track_id] # Tracking Lines plot
track.append((float(bbox_center[0]), float(bbox_center[1])))
if len(track) > 30:
track.pop(0)
points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(frame, [points], isClosed=False, color=bbox_color, thickness=track_thickness)
cv2.polylines(frame, [points], isClosed=False, color=colors(cls, True), thickness=track_thickness)
# Check if detection inside region
for region in counting_regions:
if region['polygon'].contains(Point((x, y))):
if region['polygon'].contains(Point((bbox_center[0], bbox_center[1]))):
region['counts'] += 1
# Draw regions (Polygons/Rectangles)
@ -202,6 +200,7 @@ def parse_opt():
parser.add_argument('--view-img', action='store_true', help='show results')
parser.add_argument('--save-img', action='store_true', help='save results')
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3')
parser.add_argument('--line-thickness', type=int, default=2, help='bounding box thickness')
parser.add_argument('--track-thickness', type=int, default=2, help='Tracking line thickness')
parser.add_argument('--region-thickness', type=int, default=4, help='Region thickness')