ultralytics 8.0.195 NVIDIA Triton Inference Server support (#5257)

Co-authored-by: TheConstant3 <46416203+TheConstant3@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-10-07 19:26:35 +02:00 committed by GitHub
parent 40e3923cfc
commit c7aa83da31
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 349 additions and 98 deletions

View file

@ -15,13 +15,14 @@ class FastSAMPredictor(DetectionPredictor):
self.args.task = 'segment'
def postprocess(self, preds, img, orig_imgs):
p = ops.non_max_suppression(preds[0],
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
nc=len(self.model.names),
classes=self.args.classes)
p = ops.non_max_suppression(
preds[0],
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
nc=1, # set to 1 class since SAM has no class predictions
classes=self.args.classes)
full_box = torch.zeros(p[0].shape[1], device=p[0].device)
full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
full_box = full_box.view(1, -1)