Integration of v8 segmentation (#107)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
384f0ef1c6
commit
8406b49b49
16 changed files with 422 additions and 224 deletions
|
|
@ -12,17 +12,14 @@ class SegmentationPredictor(DetectionPredictor):
|
|||
|
||||
def postprocess(self, preds, img, orig_img):
|
||||
masks = []
|
||||
if len(preds) == 2: # eval
|
||||
p, proto, = preds
|
||||
else: # len(3) train
|
||||
p, proto, _ = preds
|
||||
# TODO: filter by classes
|
||||
p = ops.non_max_suppression(p,
|
||||
p = ops.non_max_suppression(preds[0],
|
||||
self.args.conf_thres,
|
||||
self.args.iou_thres,
|
||||
agnostic=self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
nm=32)
|
||||
proto = preds[1][-1]
|
||||
for i, pred in enumerate(p):
|
||||
shape = orig_img[i].shape if self.webcam else orig_img.shape
|
||||
if not len(pred):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue