diff --git a/ultralytics/utils/plotting.py b/ultralytics/utils/plotting.py index 2ba19ac5..1ffc27c2 100644 --- a/ultralytics/utils/plotting.py +++ b/ultralytics/utils/plotting.py @@ -13,9 +13,8 @@ from PIL import Image, ImageDraw, ImageFont from PIL import __version__ as pil_version from ultralytics.utils import LOGGER, TryExcept, ops, plt_settings, threaded - -from .checks import check_font, check_version, is_ascii -from .files import increment_path +from ultralytics.utils.checks import check_font, check_version, is_ascii +from ultralytics.utils.files import increment_path class Colors: @@ -241,7 +240,7 @@ class Annotator: # Convert im back to PIL and update draw self.fromarray(self.im) - def kpts(self, kpts, shape=(640, 640), radius=5, kpt_line=True): + def kpts(self, kpts, shape=(640, 640), radius=5, kpt_line=True, conf_thres=0.25): """ Plot keypoints on the image. @@ -267,7 +266,7 @@ class Annotator: if x_coord % shape[1] != 0 and y_coord % shape[0] != 0: if len(k) == 3: conf = k[2] - if conf < 0.5: + if conf < conf_thres: continue cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, color_k, -1, lineType=cv2.LINE_AA) @@ -279,7 +278,7 @@ class Annotator: if ndim == 3: conf1 = kpts[(sk[0] - 1), 2] conf2 = kpts[(sk[1] - 1), 2] - if conf1 < 0.5 or conf2 < 0.5: + if conf1 < conf_thres or conf2 < conf_thres: continue if pos1[0] % shape[1] == 0 or pos1[1] % shape[0] == 0 or pos1[0] < 0 or pos1[1] < 0: continue @@ -491,7 +490,7 @@ class Annotator: angle = 360 - angle return angle - def draw_specific_points(self, keypoints, indices=[2, 5, 7], shape=(640, 640), radius=2): + def draw_specific_points(self, keypoints, indices=[2, 5, 7], shape=(640, 640), radius=2, conf_thres=0.25): """ Draw specific keypoints for gym steps counting. @@ -507,7 +506,7 @@ class Annotator: if x_coord % shape[1] != 0 and y_coord % shape[0] != 0: if len(k) == 3: conf = k[2] - if conf < 0.5: + if conf < conf_thres: continue cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, (0, 255, 0), -1, lineType=cv2.LINE_AA) return self.im @@ -876,7 +875,7 @@ def plot_images( kpts_[..., 1] += y for j in range(len(kpts_)): if labels or conf[j] > conf_thres: - annotator.kpts(kpts_[j]) + annotator.kpts(kpts_[j], conf_thres=conf_thres) # Plot masks if len(masks):