Add --classes arg in YOLOv8 Region Counter + optimize (#6568)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
4ac93d82fa
commit
8f1c3f3d1e
2 changed files with 20 additions and 15 deletions
|
|
@ -69,6 +69,7 @@ def run(
|
|||
view_img=False,
|
||||
save_img=False,
|
||||
exist_ok=False,
|
||||
classes=None,
|
||||
line_thickness=2,
|
||||
track_thickness=2,
|
||||
region_thickness=2,
|
||||
|
|
@ -87,6 +88,7 @@ def run(
|
|||
view_img (bool): Show results.
|
||||
save_img (bool): Save results.
|
||||
exist_ok (bool): Overwrite existing files.
|
||||
classes (list): classes to detect and track
|
||||
line_thickness (int): Bounding box thickness.
|
||||
track_thickness (int): Tracking line thickness
|
||||
region_thickness (int): Region thickness.
|
||||
|
|
@ -101,6 +103,9 @@ def run(
|
|||
model = YOLO(f'{weights}')
|
||||
model.to('cuda') if device == '0' else model.to('cpu')
|
||||
|
||||
# Extract classes names
|
||||
names = model.model.names
|
||||
|
||||
# Video setup
|
||||
videocapture = cv2.VideoCapture(source)
|
||||
frame_width, frame_height = int(videocapture.get(3)), int(videocapture.get(4))
|
||||
|
|
@ -119,36 +124,29 @@ def run(
|
|||
vid_frame_count += 1
|
||||
|
||||
# Extract the results
|
||||
results = model.track(frame, persist=True)
|
||||
results = model.track(frame, persist=True, classes=classes)
|
||||
|
||||
if results[0].boxes.id is not None:
|
||||
boxes = results[0].boxes.xywh.cpu()
|
||||
boxes = results[0].boxes.xyxy.cpu()
|
||||
track_ids = results[0].boxes.id.int().cpu().tolist()
|
||||
clss = results[0].boxes.cls.cpu().tolist()
|
||||
names = results[0].names
|
||||
|
||||
annotator = Annotator(frame, line_width=line_thickness, example=str(names))
|
||||
|
||||
for box, track_id, cls in zip(boxes, track_ids, clss):
|
||||
x, y, w, h = box
|
||||
label = str(names[cls])
|
||||
xyxy = (x - w / 2), (y - h / 2), (x + w / 2), (y + h / 2)
|
||||
annotator.box_label(box, str(names[cls]), color=colors(cls, True))
|
||||
bbox_center = (box[0] + box[2]) / 2, (box[1] + box[3]) / 2 # Bbox center
|
||||
|
||||
# Bounding box plot
|
||||
bbox_color = colors(cls, True)
|
||||
annotator.box_label(xyxy, label, color=bbox_color)
|
||||
|
||||
# Tracking Lines plot
|
||||
track = track_history[track_id]
|
||||
track.append((float(x), float(y)))
|
||||
track = track_history[track_id] # Tracking Lines plot
|
||||
track.append((float(bbox_center[0]), float(bbox_center[1])))
|
||||
if len(track) > 30:
|
||||
track.pop(0)
|
||||
points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
|
||||
cv2.polylines(frame, [points], isClosed=False, color=bbox_color, thickness=track_thickness)
|
||||
cv2.polylines(frame, [points], isClosed=False, color=colors(cls, True), thickness=track_thickness)
|
||||
|
||||
# Check if detection inside region
|
||||
for region in counting_regions:
|
||||
if region['polygon'].contains(Point((x, y))):
|
||||
if region['polygon'].contains(Point((bbox_center[0], bbox_center[1]))):
|
||||
region['counts'] += 1
|
||||
|
||||
# Draw regions (Polygons/Rectangles)
|
||||
|
|
@ -202,6 +200,7 @@ def parse_opt():
|
|||
parser.add_argument('--view-img', action='store_true', help='show results')
|
||||
parser.add_argument('--save-img', action='store_true', help='save results')
|
||||
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
|
||||
parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3')
|
||||
parser.add_argument('--line-thickness', type=int, default=2, help='bounding box thickness')
|
||||
parser.add_argument('--track-thickness', type=int, default=2, help='Tracking line thickness')
|
||||
parser.add_argument('--region-thickness', type=int, default=4, help='Region thickness')
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue