-
Notifications
You must be signed in to change notification settings - Fork 131
GPU NMS kernel and refactor of NMS operator #4893
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
8e3f22e
ae350e3
4ec2fe1
18ae57e
ced7e69
84c7d3b
43c10be
f2734dc
6379377
2ac67b0
5ca611f
a48c909
e1e936b
fc728f3
c2ddb73
600d9fb
d5934c0
b5c1e77
1011256
b5a9568
32c779d
c5fb107
fc7a5cc
289d5ad
49e3a2a
94c3744
22d8beb
8fc4844
229cf90
8bb7865
4c27d5f
b3765f6
0bd8d04
59b95b7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,27 +36,32 @@ | |
| #include <migraphx/tensor_view.hpp> | ||
| #include <migraphx/shape_for_each.hpp> | ||
| #include <migraphx/check_shapes.hpp> | ||
| #include <migraphx/shape.hpp> | ||
| #include <migraphx/output_iterator.hpp> | ||
| #include <migraphx/argument.hpp> | ||
| #include <migraphx/par.hpp> | ||
|
|
||
| /* | ||
| https://github.com/onnx/onnx/blob/main/docs/Operators.md#NonMaxSuppression | ||
| */ | ||
| /** | ||
| * nonmaxsuppression(boxes, | ||
| * scores, | ||
| * optional(max_output_boxes_per_class), | ||
| * optional(iou_threshold), | ||
| * optional(score_threshold)); | ||
| * Outputs tuple of {tensor with dims[max_num_boxes, 3]: selected_box_indices, scalar int64_t: | ||
| * num_selected_indices} | ||
| */ | ||
| namespace migraphx { | ||
| inline namespace MIGRAPHX_INLINE_NS { | ||
| namespace op { | ||
|
|
||
| struct nonmaxsuppression | ||
| { | ||
| bool center_point_box = false; | ||
| bool use_dyn_output = false; | ||
|
|
||
| template <class Self, class F> | ||
| static auto reflect(Self& self, F f) | ||
| { | ||
| return pack(f(self.center_point_box, "center_point_box"), | ||
| f(self.use_dyn_output, "use_dyn_output")); | ||
| return pack(f(self.center_point_box, "center_point_box")); | ||
| } | ||
|
|
||
| std::string name() const { return "nonmaxsuppression"; } | ||
|
|
@@ -69,8 +74,9 @@ struct nonmaxsuppression | |
| auto max_classes = inputs.at(1).max_lens().at(1); | ||
| auto max_spatial_dimension = inputs.at(0).max_lens().at(1); | ||
| // Per ONNX spec, output is [num_selected_indices, 3] where each row is | ||
| // [batch_index, class_index, box_index]. The maximum possible | ||
| // [batch_index, class_index, box_index]. The maximum possible | ||
| // num_selected_indices = num_batches * num_classes * spatial_dimension. | ||
| // TODO: can also be limited by max_output_boxes_per_class | ||
| const auto max_num_boxes = max_batches * max_classes * max_spatial_dimension; | ||
|
|
||
| auto fixed_shape_error_check = [&]() { | ||
|
|
@@ -87,21 +93,14 @@ struct nonmaxsuppression | |
| } | ||
| }; | ||
|
|
||
| bool needs_dyn_output = use_dyn_output or inputs.at(0).dynamic() or inputs.at(1).dynamic(); | ||
|
|
||
| if(needs_dyn_output) | ||
| { | ||
| std::vector<shape::dynamic_dimension> out_lens = {}; | ||
| out_lens.push_back({0, max_num_boxes}); | ||
| out_lens.push_back({3, 3}); | ||
| return {shape::int64_type, out_lens}; | ||
| } | ||
| else | ||
| if(not(inputs.at(0).dynamic() or inputs.at(1).dynamic())) | ||
| { | ||
| fixed_shape_error_check(); | ||
| std::vector<std::size_t> out_lens = {max_num_boxes, 3}; | ||
| return {shape::int64_type, out_lens}; | ||
| } | ||
| std::vector<std::size_t> out_lens = {max_num_boxes, 3}; | ||
| shape s_ind{shape::int64_type, out_lens}; | ||
| shape s_num_selected{shape::int64_type, {1}}; | ||
| return shape({s_ind, s_num_selected}); | ||
| } | ||
|
|
||
| struct box | ||
|
|
@@ -190,7 +189,8 @@ struct nonmaxsuppression | |
| return intersection_over_union > iou_threshold; | ||
| } | ||
|
|
||
| // filter boxes below score_threshold | ||
| // Filter boxes below score_threshold. | ||
| // Don't filter for score if score_threshold == 0.f | ||
| template <class T> | ||
| std::vector<std::pair<double, int64_t>> | ||
| filter_boxes_by_score(T scores_start, std::size_t num_boxes, double score_threshold) const | ||
|
|
@@ -232,10 +232,11 @@ struct nonmaxsuppression | |
| std::size_t compute_nms(Output output, | ||
| const Boxes& boxes, | ||
| const Scores& scores, | ||
| std::size_t max_output_boxes_per_class, | ||
| int64_t max_output_boxes_per_class, | ||
| double iou_threshold, | ||
| double score_threshold) const | ||
| { | ||
| // NOTE: should not need to fill with 0 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why don't we just remove this then?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's to preserve the previous behavior for now. Technically the operator after NMS should never be reading the values after |
||
| std::fill(output.begin(), output.end(), 0); | ||
| const auto& lens = scores.get_shape().lens(); | ||
| const auto num_batches = lens[0]; | ||
|
|
@@ -302,14 +303,16 @@ struct nonmaxsuppression | |
| argument compute(const shape& output_shape, std::vector<argument> args) const | ||
| { | ||
| // make buffer of maximum size | ||
| shape max_output_shape = {output_shape.type(), output_shape.max_lens()}; | ||
| auto output_shapes = flatten_shapes({output_shape}); | ||
| shape max_output_shape = {output_shapes.at(0).type(), output_shapes.at(0).max_lens()}; | ||
| argument result{max_output_shape}; | ||
| argument num_selected_result{output_shapes.at(1)}; | ||
|
|
||
| std::size_t max_output_boxes_per_class = | ||
| (args.size() > 2) ? (args.at(2).at<std::size_t>()) : 0; | ||
| int64_t max_output_boxes_per_class = (args.size() > 2) ? (args.at(2).at<std::size_t>()) : 0; | ||
| if(max_output_boxes_per_class == 0) | ||
| { | ||
| return result; | ||
| num_selected_result.visit([&](auto output) { output[0] = 0; }); | ||
| return {{result, num_selected_result}}; | ||
| } | ||
| double iou_threshold = (args.size() > 3) ? (args.at(3).at<double>()) : 0.0f; | ||
| double score_threshold = (args.size() > 4) ? (args.at(4).at<double>()) : 0.0f; | ||
|
|
@@ -325,14 +328,8 @@ struct nonmaxsuppression | |
| score_threshold); | ||
| }); | ||
| }); | ||
| if(output_shape.dynamic()) | ||
| { | ||
| return result.reshape({output_shape.type(), {num_selected, 3}}); | ||
| } | ||
| else | ||
| { | ||
| return result; | ||
| } | ||
| num_selected_result.visit([&](auto output) { output[0] = num_selected; }); | ||
| return {{result, num_selected_result}}; | ||
| } | ||
| }; | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,7 @@ | ||
| /* | ||
| * The MIT License (MIT) | ||
| * | ||
| * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. | ||
| * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. | ||
| * | ||
| * Permission is hereby granted, free of charge, to any person obtaining a copy | ||
| * of this software and associated documentation files (the "Software"), to deal | ||
|
|
@@ -33,6 +33,14 @@ inline namespace MIGRAPHX_INLINE_NS { | |
| namespace gpu { | ||
| namespace device { | ||
|
|
||
| // Inclusive prefix sum within a kernel block. | ||
| // Hillis-Steele scan with double-buffered (ping-pong) shared array. | ||
| // `N`: upper bound on blockDim.x, sizes the shared buffer. | ||
| // `op`: associative binary reduce function ex. sum or max. | ||
| // `init`: initializer | ||
| // `fs`: striding function for thread work distribution. | ||
| // `input`: input with input(index_int). | ||
| // `output`: output with output(index_int, inclusive_scan_value_at_index_int). | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Appreciate the added comments here. |
||
| template <index_int N, | ||
| class Op, | ||
| class T, | ||
|
|
@@ -72,6 +80,7 @@ __device__ void block_scan(index idx, Op op, T init, ForStride fs, Input input, | |
| }); | ||
| } | ||
|
|
||
| // Overload of block_scan with default local_stride up to `n`. | ||
| template <index_int N, class Op, class T, class Input, class Output> | ||
| __device__ void block_scan(index idx, Op op, T init, index_int n, Input input, Output output) | ||
| { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.