Fix typo in YOLOv8-Libtorch-CPP-Inference (#9330)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
898cbcdc08
commit
1325889305
1 changed files with 2 additions and 2 deletions
|
|
@ -139,7 +139,7 @@ torch::Tensor nms(const torch::Tensor& bboxes, const torch::Tensor& scores, floa
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
torch::Tensor non_max_supperession(torch::Tensor& prediction, float conf_thres = 0.25, float iou_thres = 0.45, int max_det = 300) {
|
torch::Tensor non_max_suppression(torch::Tensor& prediction, float conf_thres = 0.25, float iou_thres = 0.45, int max_det = 300) {
|
||||||
auto bs = prediction.size(0);
|
auto bs = prediction.size(0);
|
||||||
auto nc = prediction.size(1) - 4;
|
auto nc = prediction.size(1) - 4;
|
||||||
auto nm = prediction.size(1) - nc - 4;
|
auto nm = prediction.size(1) - nc - 4;
|
||||||
|
|
@ -237,7 +237,7 @@ int main() {
|
||||||
torch::Tensor output = yolo_model.forward(inputs).toTensor().cpu();
|
torch::Tensor output = yolo_model.forward(inputs).toTensor().cpu();
|
||||||
|
|
||||||
// NMS
|
// NMS
|
||||||
auto keep = non_max_supperession(output)[0];
|
auto keep = non_max_suppression(output)[0];
|
||||||
auto boxes = keep.index({Slice(), Slice(None, 4)});
|
auto boxes = keep.index({Slice(), Slice(None, 4)});
|
||||||
keep.index_put_({Slice(), Slice(None, 4)}, scale_boxes({input_image.rows, input_image.cols}, boxes, {image.rows, image.cols}));
|
keep.index_put_({Slice(), Slice(None, 4)}, scale_boxes({input_image.rows, input_image.cols}, boxes, {image.rows, image.cols}));
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue