Parameterize kpts plotting conf=0.25 (#10044)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Datstat Consulting 2024-04-17 05:55:00 +08:00 committed by GitHub
parent a1bf4d07ef
commit c842825595
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -13,9 +13,8 @@ from PIL import Image, ImageDraw, ImageFont
from PIL import __version__ as pil_version
from ultralytics.utils import LOGGER, TryExcept, ops, plt_settings, threaded
from .checks import check_font, check_version, is_ascii
from .files import increment_path
from ultralytics.utils.checks import check_font, check_version, is_ascii
from ultralytics.utils.files import increment_path
class Colors:
@ -241,7 +240,7 @@ class Annotator:
# Convert im back to PIL and update draw
self.fromarray(self.im)
def kpts(self, kpts, shape=(640, 640), radius=5, kpt_line=True):
def kpts(self, kpts, shape=(640, 640), radius=5, kpt_line=True, conf_thres=0.25):
"""
Plot keypoints on the image.
@ -267,7 +266,7 @@ class Annotator:
if x_coord % shape[1] != 0 and y_coord % shape[0] != 0:
if len(k) == 3:
conf = k[2]
if conf < 0.5:
if conf < conf_thres:
continue
cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, color_k, -1, lineType=cv2.LINE_AA)
@ -279,7 +278,7 @@ class Annotator:
if ndim == 3:
conf1 = kpts[(sk[0] - 1), 2]
conf2 = kpts[(sk[1] - 1), 2]
if conf1 < 0.5 or conf2 < 0.5:
if conf1 < conf_thres or conf2 < conf_thres:
continue
if pos1[0] % shape[1] == 0 or pos1[1] % shape[0] == 0 or pos1[0] < 0 or pos1[1] < 0:
continue
@ -491,7 +490,7 @@ class Annotator:
angle = 360 - angle
return angle
def draw_specific_points(self, keypoints, indices=[2, 5, 7], shape=(640, 640), radius=2):
def draw_specific_points(self, keypoints, indices=[2, 5, 7], shape=(640, 640), radius=2, conf_thres=0.25):
"""
Draw specific keypoints for gym steps counting.
@ -507,7 +506,7 @@ class Annotator:
if x_coord % shape[1] != 0 and y_coord % shape[0] != 0:
if len(k) == 3:
conf = k[2]
if conf < 0.5:
if conf < conf_thres:
continue
cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, (0, 255, 0), -1, lineType=cv2.LINE_AA)
return self.im
@ -876,7 +875,7 @@ def plot_images(
kpts_[..., 1] += y
for j in range(len(kpts_)):
if labels or conf[j] > conf_thres:
annotator.kpts(kpts_[j])
annotator.kpts(kpts_[j], conf_thres=conf_thres)
# Plot masks
if len(masks):