Integration of v8 segmentation (#107)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2022-12-28 23:01:38 +08:00 committed by GitHub
parent 384f0ef1c6
commit 8406b49b49
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 422 additions and 224 deletions

View file

@ -12,17 +12,14 @@ class SegmentationPredictor(DetectionPredictor):
def postprocess(self, preds, img, orig_img):
masks = []
if len(preds) == 2: # eval
p, proto, = preds
else: # len(3) train
p, proto, _ = preds
# TODO: filter by classes
p = ops.non_max_suppression(p,
p = ops.non_max_suppression(preds[0],
self.args.conf_thres,
self.args.iou_thres,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
nm=32)
proto = preds[1][-1]
for i, pred in enumerate(p):
shape = orig_img[i].shape if self.webcam else orig_img.shape
if not len(pred):