From e61d4c0aba1a7b6e159f3b35c323c5fe7f000d91 Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 28 Apr 2026 13:51:15 -0700 Subject: [PATCH 1/9] make dynamic shape conversion semantics clear --- src/include/migraphx/shape.hpp | 37 ++++++- src/shape.cpp | 80 +++++++++++++- test/shape_test.cpp | 183 +++++++++++++++++++++++++++++++-- 3 files changed, 291 insertions(+), 9 deletions(-) diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index e59f8f2c684..1b8d3cb4c08 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -447,9 +447,24 @@ struct MIGRAPHX_EXPORT shape shape with_type(type_t t) const; - // convert the shape to an equivalent dynamic shape with constant symbolic strides + // convert the shape to an equivalent range-based dynamic shape: each static len becomes + // dd{len, len} (strides are not carried); a symbolic shape is demoted by evaluating + // each dim's interval/optimals (symbolic strides are dropped). Idempotent on a shape + // that is already range-based dynamic. shape to_dynamic() const; + // Align a list of shapes to a single representation. If any input contains a + // range-based dynamic shape (at any nesting level), every shape is converted via + // to_dynamic() (symbolic inputs are demoted). Otherwise every shape is converted + // via to_symbolic() (static inputs are promoted to symbolic literals). Recurses + // into tuple sub-shapes. + static std::vector to_dynamic(const std::vector& shapes); + + // convert the shape to an equivalent symbolic dynamic shape: each static len becomes + // dd{sym::lit(len)} and each static stride becomes sym::lit(stride). Idempotent on a + // shape that is already symbolic. Throws on a range-based dynamic shape. + shape to_symbolic() const; + // convert the shape to a static one setting any non-fixed dynamic_dimensions to x shape to_static(std::size_t x) const; shape to_static(const std::unordered_map& symbol_map) const; @@ -572,6 +587,26 @@ struct MIGRAPHX_EXPORT shape void debug_print() const; + /// Whether a dim-like value has a single, known static integer value. + static bool is_fixed_dim(std::size_t) { return true; } + static bool is_fixed_dim(const dynamic_dimension& d) { return d.is_fixed(); } + + /// Extract the static integer value from a fixed dim-like value. Caller is + /// responsible for ensuring `is_fixed_dim(x)` first. + static std::size_t static_dim_value(std::size_t x) { return x; } + static std::size_t static_dim_value(const dynamic_dimension& d) + { + if(not d.is_fixed()) + MIGRAPHX_THROW("shape::static_dim_value: dimension is not fixed"); + return d.get_interval().max; + } + + /// Whether all dims of this shape have a single, known static integer value. + /// True for static shapes, range-based shapes with all-fixed dims, symbolic + /// shapes whose dims are all literals (or vars with collapsed bounds), and + /// tuple shapes whose sub-shapes are all fixed. + bool is_fixed() const; + private: shape(std::shared_ptr pimpl); std::shared_ptr impl; diff --git a/src/shape.cpp b/src/shape.cpp index 4aa67e157b3..228fadb6fa3 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -843,6 +843,11 @@ shape shape::with_type(type_t t) const return {c}; } +// Convert to an equivalent range-based dynamic shape: +// - static : each len becomes dd{len, len} (strides are not carried) +// - range-based dynamic : identity +// - symbolic : each dim is demoted via get_interval()/get_optimals(); symbolic +// strides are dropped (range-based shapes don't carry them) shape shape::to_dynamic() const { if(not sub_shapes().empty()) @@ -854,6 +859,17 @@ shape shape::to_dynamic() const [](auto s) { return s.to_dynamic(); }); return shape(subs); } + if(this->symbolic()) + { + std::vector dims; + dims.reserve(ndim()); + std::transform( + dyn_dims().begin(), dyn_dims().end(), std::back_inserter(dims), [](const auto& d) { + auto iv = d.get_interval(); + return dynamic_dimension{iv.min, iv.max, d.get_optimals()}; + }); + return {type(), std::move(dims)}; + } if(this->dynamic()) { return *this; @@ -863,6 +879,52 @@ shape shape::to_dynamic() const std::transform(lens().begin(), lens().end(), std::back_inserter(dims), [](auto len) { return dynamic_dimension{len, len}; }); + return {type(), std::move(dims)}; +} + +static bool any_non_sym_dynamic(const shape& s) +{ + if(not s.sub_shapes().empty()) + return std::any_of(s.sub_shapes().begin(), s.sub_shapes().end(), &any_non_sym_dynamic); + return s.dynamic() and not s.symbolic(); +} + +std::vector shape::to_dynamic(const std::vector& shapes) +{ + const bool any_non_sym = std::any_of(shapes.begin(), shapes.end(), &any_non_sym_dynamic); + std::vector result; + result.reserve(shapes.size()); + std::transform(shapes.begin(), shapes.end(), std::back_inserter(result), [&](const auto& s) { + return any_non_sym ? s.to_dynamic() : s.to_symbolic(); + }); + return result; +} + +shape shape::to_symbolic() const +{ + if(not sub_shapes().empty()) + { + std::vector subs; + std::transform(sub_shapes().cbegin(), + sub_shapes().cend(), + std::back_inserter(subs), + [](auto s) { return s.to_symbolic(); }); + return shape(subs); + } + if(this->symbolic()) + { + return *this; + } + if(this->dynamic()) + { + // Range-based dynamic shapes have no clean symbolic representation + MIGRAPHX_THROW("SHAPE: to_symbolic() called on a range-based dynamic shape"); + } + std::vector dims; + dims.reserve(ndim()); + std::transform(lens().begin(), lens().end(), std::back_inserter(dims), [](auto len) { + return dynamic_dimension{sym::lit(len)}; + }); std::vector dstrides; dstrides.reserve(ndim()); std::transform(strides().begin(), strides().end(), std::back_inserter(dstrides), [](auto s) { @@ -1274,9 +1336,25 @@ std::ostream& operator<<(std::ostream& os, const shape& x) return os; } +bool shape::is_fixed() const +{ + if(not sub_shapes().empty()) + return std::all_of( + sub_shapes().begin(), sub_shapes().end(), [](const auto& s) { return s.is_fixed(); }); + if(this->dynamic()) + return std::all_of( + dyn_dims().begin(), dyn_dims().end(), [](const auto& d) { return d.is_fixed(); }); + return true; +} + +// Fixed shapes compare by resolved static lens; otherwise compare dyn_dims directly. bool shape::same_lens(const shape& x, const shape& y) { - return x.to_dynamic().dyn_dims() == y.to_dynamic().dyn_dims(); + if(x.is_fixed() != y.is_fixed()) + return false; + if(x.is_fixed()) + return x.to_static({}).lens() == y.to_static({}).lens(); + return x.dyn_dims() == y.dyn_dims(); } shape::type_t shape::parse_type(const std::string& s) diff --git a/test/shape_test.cpp b/test/shape_test.cpp index fd67fe8aae4..f897588d04b 100644 --- a/test/shape_test.cpp +++ b/test/shape_test.cpp @@ -446,9 +446,7 @@ TEST_CASE(test_shape_static_to_dynamic) { migraphx::shape s0{migraphx::shape::float_type, {1, 2, 4, 4}}; migraphx::shape s1 = s0.to_dynamic(); - migraphx::shape s2{migraphx::shape::float_type, - {{1, 1}, {2, 2}, {4, 4}, {4, 4}}, - {lit(32), lit(16), lit(4), lit(1)}}; + migraphx::shape s2{migraphx::shape::float_type, {{1, 1}, {2, 2}, {4, 4}, {4, 4}}}; EXPECT(s1 == s2); } @@ -468,12 +466,136 @@ TEST_CASE(test_shape_subshapes_to_dynamic) migraphx::shape s1 = s0.to_dynamic(); std::vector sub_shapes1 = {}; sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}}); - sub_shapes1.push_back(migraphx::shape{ - migraphx::shape::float_type, {{3, 3}, {4, 4}, {5, 5}}, {lit(20), lit(5), lit(1)}}); + sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, {{3, 3}, {4, 4}, {5, 5}}}); migraphx::shape s2{sub_shapes1}; EXPECT(s1 == s2); } +TEST_CASE(test_shape_static_to_symbolic) +{ + migraphx::shape s0{migraphx::shape::float_type, {1, 2, 4, 4}}; + migraphx::shape s1 = s0.to_symbolic(); + migraphx::shape s2{migraphx::shape::float_type, + {dd{lit(1)}, dd{lit(2)}, dd{lit(4)}, dd{lit(4)}}, + {lit(32), lit(16), lit(4), lit(1)}}; + EXPECT(s1 == s2); + EXPECT(s1.symbolic()); +} + +TEST_CASE(test_shape_symbolic_to_symbolic) +{ + auto n = var("n", {1, 8}); + auto c = var("c", {1, 16}); + migraphx::shape s0{migraphx::shape::float_type, {dd{n}, dd{c}, dd{lit(4)}}}; + auto s1 = s0.to_symbolic(); + EXPECT(s0 == s1); +} + +TEST_CASE(test_shape_dyn_to_symbolic_throws) +{ + migraphx::shape s0{migraphx::shape::float_type, {{1, 4}, {4, 4}}}; + EXPECT(test::throws([&] { s0.to_symbolic(); })); +} + +TEST_CASE(test_shape_subshapes_to_symbolic) +{ + std::vector sub_shapes0 = {}; + sub_shapes0.push_back(migraphx::shape{migraphx::shape::float_type, {2, 3}}); + sub_shapes0.push_back(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}}); + migraphx::shape s0{sub_shapes0}; + migraphx::shape s1 = s0.to_symbolic(); + std::vector sub_shapes1 = {}; + sub_shapes1.push_back( + migraphx::shape{migraphx::shape::float_type, {dd{lit(2)}, dd{lit(3)}}, {lit(3), lit(1)}}); + sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, + {dd{lit(3)}, dd{lit(4)}, dd{lit(5)}}, + {lit(20), lit(5), lit(1)}}); + migraphx::shape s2{sub_shapes1}; + EXPECT(s1 == s2); +} + +TEST_CASE(test_shapes_to_dynamic_empty) +{ + auto out = migraphx::shape::to_dynamic({}); + EXPECT(out.empty()); +} + +TEST_CASE(test_shapes_to_dynamic_all_static) +{ + migraphx::shape a{migraphx::shape::float_type, {2, 3}}; + migraphx::shape b{migraphx::shape::float_type, {3, 4}}; + auto out = migraphx::shape::to_dynamic({a, b}); + EXPECT(out.size() == 2); + EXPECT(out[0] == a.to_symbolic()); + EXPECT(out[1] == b.to_symbolic()); + EXPECT(out[0].symbolic()); + EXPECT(out[1].symbolic()); +} + +TEST_CASE(test_shapes_to_dynamic_all_symbolic) +{ + auto n = var("n", {1, 8}); + migraphx::shape a{migraphx::shape::float_type, {dd{n}, dd{lit(4)}}}; + migraphx::shape b{migraphx::shape::float_type, {dd{lit(2)}, dd{n}}}; + auto out = migraphx::shape::to_dynamic({a, b}); + EXPECT(out[0] == a); + EXPECT(out[1] == b); +} + +TEST_CASE(test_shapes_to_dynamic_all_range) +{ + migraphx::shape a{migraphx::shape::float_type, {{1, 4}, {4, 4}}}; + migraphx::shape b{migraphx::shape::float_type, {{2, 2}, {3, 8}}}; + auto out = migraphx::shape::to_dynamic({a, b}); + EXPECT(out[0] == a); + EXPECT(out[1] == b); +} + +TEST_CASE(test_shapes_to_dynamic_sym_and_static) +{ + auto n = var("n", {1, 8}); + migraphx::shape a{migraphx::shape::float_type, {dd{n}, dd{lit(4)}}}; + migraphx::shape b{migraphx::shape::float_type, {2, 4}}; + auto out = migraphx::shape::to_dynamic({a, b}); + EXPECT(out[0] == a); + EXPECT(out[1] == b.to_symbolic()); + EXPECT(out[1].symbolic()); +} + +TEST_CASE(test_shapes_to_dynamic_range_and_static) +{ + migraphx::shape a{migraphx::shape::float_type, {{1, 4}, {4, 4}}}; + migraphx::shape b{migraphx::shape::float_type, {2, 4}}; + auto out = migraphx::shape::to_dynamic({a, b}); + EXPECT(out[0] == a); + EXPECT(out[1] == b.to_dynamic()); + EXPECT(not out[1].symbolic()); + EXPECT(out[1].dynamic()); +} + +TEST_CASE(test_shapes_to_dynamic_sym_and_range_demotes) +{ + auto n = var("n", {1, 8}); + migraphx::shape a{migraphx::shape::float_type, {dd{n}, dd{lit(4)}}}; + migraphx::shape b{migraphx::shape::float_type, {{2, 2}, {3, 8}}}; + auto out = migraphx::shape::to_dynamic({a, b}); + EXPECT(not out[0].symbolic()); + EXPECT(out[0].dynamic()); + EXPECT(out[0] == a.to_dynamic()); + EXPECT(out[1] == b); +} + +TEST_CASE(test_shapes_to_dynamic_subshapes_recurse) +{ + migraphx::shape inner_static{migraphx::shape::float_type, {2, 3}}; + migraphx::shape inner_range{migraphx::shape::float_type, {{1, 4}, {3, 3}}}; + migraphx::shape tuple_with_range{std::vector{inner_range, inner_static}}; + migraphx::shape plain_static{migraphx::shape::float_type, {3, 4}}; + auto out = migraphx::shape::to_dynamic({tuple_with_range, plain_static}); + EXPECT(out[0] == tuple_with_range.to_dynamic()); + EXPECT(out[1] == plain_static.to_dynamic()); +} + TEST_CASE(test_shape_dyn_to_static) { migraphx::shape s0{migraphx::shape::float_type, {{1, 1}, {2, 2}, {2, 10}, {2, 10}}}; @@ -1280,6 +1402,50 @@ TEST_CASE(shape_same_lens_static_dynamic) EXPECT(not migraphx::shape::same_lens(s1, s3)); } +TEST_CASE(shape_same_lens_symbolic_fixed) +{ + auto n = var("n", {4, 4}); + migraphx::shape s_static{migraphx::shape::float_type, {1, 4, 8}}; + migraphx::shape s_sym_lit{migraphx::shape::half_type, {dd{lit(1)}, dd{lit(4)}, dd{lit(8)}}}; + migraphx::shape s_sym_fixed_var{migraphx::shape::float_type, {dd{lit(1)}, dd{n}, dd{lit(8)}}}; + migraphx::shape s_dyn_fixed{migraphx::shape::float_type, {{1, 1}, {4, 4}, {8, 8}}}; + EXPECT(migraphx::shape::same_lens(s_static, s_sym_lit)); + EXPECT(migraphx::shape::same_lens(s_static, s_sym_fixed_var)); + EXPECT(migraphx::shape::same_lens(s_sym_lit, s_dyn_fixed)); + EXPECT(migraphx::shape::same_lens(s_sym_fixed_var, s_dyn_fixed)); +} + +TEST_CASE(shape_same_lens_symbolic_nonfixed) +{ + auto n = var("n", {1, 8}); + auto m = var("m", {1, 8}); + migraphx::shape s_n{migraphx::shape::float_type, {dd{n}, dd{lit(4)}}}; + migraphx::shape s_n_again{migraphx::shape::half_type, {dd{n}, dd{lit(4)}}}; + migraphx::shape s_m{migraphx::shape::float_type, {dd{m}, dd{lit(4)}}}; + migraphx::shape s_range{migraphx::shape::float_type, {{1, 8}, {4, 4}}}; + EXPECT(migraphx::shape::same_lens(s_n, s_n_again)); + EXPECT(not migraphx::shape::same_lens(s_n, s_m)); + EXPECT(not migraphx::shape::same_lens(s_n, s_range)); +} + +TEST_CASE(shape_is_fixed) +{ + migraphx::shape s_static{migraphx::shape::float_type, {1, 2, 8}}; + migraphx::shape s_dyn_fixed{migraphx::shape::float_type, {{1, 1}, {2, 2}, {8, 8}}}; + migraphx::shape s_dyn_range{migraphx::shape::float_type, {{1, 4}, {2, 2}, {8, 8}}}; + migraphx::shape s_sym_lit{migraphx::shape::float_type, {dd{lit(1)}, dd{lit(2)}}}; + migraphx::shape s_sym_var{migraphx::shape::float_type, {dd{var("n", {1, 8})}, dd{lit(2)}}}; + EXPECT(s_static.is_fixed()); + EXPECT(s_dyn_fixed.is_fixed()); + EXPECT(not s_dyn_range.is_fixed()); + EXPECT(s_sym_lit.is_fixed()); + EXPECT(not s_sym_var.is_fixed()); + migraphx::shape s_tuple_fixed{{s_static, s_sym_lit}}; + migraphx::shape s_tuple_mixed{{s_static, s_sym_var}}; + EXPECT(s_tuple_fixed.is_fixed()); + EXPECT(not s_tuple_mixed.is_fixed()); +} + // =================================================================== // Symbolic dynamic_dimension tests // =================================================================== @@ -1578,13 +1744,16 @@ TEST_CASE(test_symbolic_transposed) EXPECT(not s.broadcasted()); } -TEST_CASE(test_symbolic_to_dynamic_identity) +TEST_CASE(test_symbolic_to_dynamic_demotes) { auto n = var("n", {1, 8}); auto c = var("c", {1, 16}); migraphx::shape s{migraphx::shape::float_type, {dd{n}, dd{c}, dd{lit(4)}}}; auto s2 = s.to_dynamic(); - EXPECT(s == s2); + EXPECT(not s2.symbolic()); + EXPECT(s2.dynamic()); + migraphx::shape expected{migraphx::shape::float_type, {{1, 8}, {1, 16}, {4, 4}}}; + EXPECT(s2 == expected); } TEST_CASE(test_symbolic_overlap) From f1c088ab161b3873cdccbad87ab2c76d325fb8e4 Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 28 Apr 2026 13:59:23 -0700 Subject: [PATCH 2/9] revert to_dynamic logic to before sym_dim_integration change --- src/shape.cpp | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/shape.cpp b/src/shape.cpp index 228fadb6fa3..e0fb44f2aeb 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -874,12 +874,7 @@ shape shape::to_dynamic() const { return *this; } - std::vector dims; - dims.reserve(ndim()); - std::transform(lens().begin(), lens().end(), std::back_inserter(dims), [](auto len) { - return dynamic_dimension{len, len}; - }); - return {type(), std::move(dims)}; + return {type(), lens(), lens(), {}}; } static bool any_non_sym_dynamic(const shape& s) From 44c9b153b0d460df24f3a79c4e9c6ad961f21765 Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 28 Apr 2026 14:07:36 -0700 Subject: [PATCH 3/9] style --- src/shape.cpp | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/src/shape.cpp b/src/shape.cpp index e0fb44f2aeb..9f7647d80f6 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -861,13 +861,11 @@ shape shape::to_dynamic() const } if(this->symbolic()) { - std::vector dims; - dims.reserve(ndim()); - std::transform( - dyn_dims().begin(), dyn_dims().end(), std::back_inserter(dims), [](const auto& d) { - auto iv = d.get_interval(); - return dynamic_dimension{iv.min, iv.max, d.get_optimals()}; - }); + std::vector dims(ndim()); + std::transform(dyn_dims().begin(), dyn_dims().end(), dims.begin(), [](const auto& d) { + auto iv = d.get_interval(); + return dynamic_dimension{iv.min, iv.max, d.get_optimals()}; + }); return {type(), std::move(dims)}; } if(this->dynamic()) @@ -887,9 +885,8 @@ static bool any_non_sym_dynamic(const shape& s) std::vector shape::to_dynamic(const std::vector& shapes) { const bool any_non_sym = std::any_of(shapes.begin(), shapes.end(), &any_non_sym_dynamic); - std::vector result; - result.reserve(shapes.size()); - std::transform(shapes.begin(), shapes.end(), std::back_inserter(result), [&](const auto& s) { + std::vector result(shapes.size()); + std::transform(shapes.begin(), shapes.end(), result.begin(), [&](const auto& s) { return any_non_sym ? s.to_dynamic() : s.to_symbolic(); }); return result; @@ -915,16 +912,13 @@ shape shape::to_symbolic() const // Range-based dynamic shapes have no clean symbolic representation MIGRAPHX_THROW("SHAPE: to_symbolic() called on a range-based dynamic shape"); } - std::vector dims; - dims.reserve(ndim()); - std::transform(lens().begin(), lens().end(), std::back_inserter(dims), [](auto len) { + std::vector dims(ndim()); + std::transform(lens().begin(), lens().end(), dims.begin(), [](auto len) { return dynamic_dimension{sym::lit(len)}; }); - std::vector dstrides; - dstrides.reserve(ndim()); - std::transform(strides().begin(), strides().end(), std::back_inserter(dstrides), [](auto s) { - return sym::lit(s); - }); + std::vector dstrides(ndim()); + std::transform( + strides().begin(), strides().end(), dstrides.begin(), [](auto s) { return sym::lit(s); }); return {type(), std::move(dims), std::move(dstrides)}; } From 1db4eccf4d74069f843ac4e9e5dfc8ba96a95202 Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 29 Apr 2026 09:42:14 -0700 Subject: [PATCH 4/9] fix bug with to_static --- src/include/migraphx/shape.hpp | 2 ++ src/shape.cpp | 9 ++++++++- src/sym.cpp | 3 +++ test/sym_test.cpp | 14 ++++++++++++-- 4 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index 1b8d3cb4c08..1a8c1f9d53e 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -468,6 +468,8 @@ struct MIGRAPHX_EXPORT shape // convert the shape to a static one setting any non-fixed dynamic_dimensions to x shape to_static(std::size_t x) const; shape to_static(const std::unordered_map& symbol_map) const; + // Collapse a fully-fixed shape to a static one; throws on non-fixed dimensions. + shape to_static() const; MIGRAPHX_EXPORT friend bool operator==(const shape& x, const shape& y); MIGRAPHX_EXPORT friend bool operator!=(const shape& x, const shape& y); diff --git a/src/shape.cpp b/src/shape.cpp index 9f7647d80f6..0cf5470e5a6 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -980,6 +980,13 @@ shape shape::to_static(const std::unordered_map& symbol_ return {type(), static_lens, static_strides}; } +shape shape::to_static() const +{ + if(not this->is_fixed()) + MIGRAPHX_THROW("SHAPE: to_static() requires fully-fixed dimensions"); + return this->to_static(std::unordered_map{}); +} + std::size_t shape::element_space() const { return impl->element_space(); } std::string shape::type_string() const { return name(this->type()); } @@ -1342,7 +1349,7 @@ bool shape::same_lens(const shape& x, const shape& y) if(x.is_fixed() != y.is_fixed()) return false; if(x.is_fixed()) - return x.to_static({}).lens() == y.to_static({}).lens(); + return x.to_static().lens() == y.to_static().lens(); return x.dyn_dims() == y.dyn_dims(); } diff --git a/src/sym.cpp b/src/sym.cpp index 4aa68370812..c54b4ff532f 100644 --- a/src/sym.cpp +++ b/src/sym.cpp @@ -615,6 +615,9 @@ static int64_t eval_direct(const expr_ptr& e, const binding_map& bindings) auto it = bindings.find(node); if(it != bindings.end()) return it->second; + // Fall back to the symbol's own bounds when fixed (min == max). + if(d.min == d.max) + return d.min; MIGRAPHX_THROW("sym::expr::eval_uint: unbound symbol '" + d.name + "'"); }); } diff --git a/test/sym_test.cpp b/test/sym_test.cpp index ff5a5e77043..3134ff82eba 100644 --- a/test/sym_test.cpp +++ b/test/sym_test.cpp @@ -499,12 +499,22 @@ TEST_CASE(eval_trunc_division) TEST_CASE(eval_unbound_throws) { - auto h = var("h"); - auto w = var("w"); + auto h = var("h", {1, 8}); + auto w = var("w", {1, 8}); EXPECT(test::throws([&] { h.eval_uint({}); })); EXPECT(test::throws([&] { (h + w).eval_uint({{h, 1}}); })); } +TEST_CASE(eval_uint_falls_back_to_fixed_bounds) +{ + // Fixed-bound vars (min == max) are resolved from their own bounds. + auto n = var("n", {4, 4}); + EXPECT(n.eval_uint({}) == 4); + EXPECT((n * 8).eval_uint({}) == 32); + auto h = var("h", {1, 8}); + EXPECT((h + n).eval_uint({{h, 2}}) == 6); +} + TEST_CASE(eval_division_by_zero_throws) { auto h = var("h"); From 9a5f2819ed240b5892aa52f7966fed88b9405117 Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 29 Apr 2026 16:14:14 -0700 Subject: [PATCH 5/9] refactor dim attribute of reshape op to support --- src/CMakeLists.txt | 1 + src/dim_like.cpp | 62 +++++ src/include/migraphx/dim_like.hpp | 91 +++++++ src/include/migraphx/op/reshape.hpp | 10 +- src/include/migraphx/op/reshape_lazy.hpp | 13 +- test/dim_like_test.cpp | 291 +++++++++++++++++++++++ test/onnx/parse/reshape_test.cpp | 2 +- 7 files changed, 462 insertions(+), 8 deletions(-) create mode 100644 src/dim_like.cpp create mode 100644 src/include/migraphx/dim_like.hpp create mode 100644 test/dim_like_test.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7a1d00a4faa..2ba0f39e012 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -61,6 +61,7 @@ add_library(migraphx convert_to_json.cpp cpp_generator.cpp dead_code_elimination.cpp + dim_like.cpp dom_info.cpp dynamic_loader.cpp eliminate_allocation.cpp diff --git a/src/dim_like.cpp b/src/dim_like.cpp new file mode 100644 index 00000000000..07aa76aa08b --- /dev/null +++ b/src/dim_like.cpp @@ -0,0 +1,62 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include + +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +std::ostream& operator<<(std::ostream& os, const dim_like& d) +{ + std::visit([&](const auto& x) { os << x; }, d.value); + return os; +} + +migraphx::value dim_like::to_value() const +{ + return std::visit([](const auto& x) { return migraphx::to_value(x); }, this->value); +} + +void dim_like::from_value(const migraphx::value& v) +{ + // Backward-compatible path: integer-valued entries (signed or unsigned) + // route through the int alternative so old .mxr files and call sites that + // pass plain integer arrays both decode without going through the + // dynamic_dimension reflect path. + if(v.is_int64() or v.is_uint64()) + { + this->value = v.to(); + return; + } + shape::dynamic_dimension d; + migraphx::from_value(v, d); + this->value = std::move(d); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/include/migraphx/dim_like.hpp b/src/include/migraphx/dim_like.hpp new file mode 100644 index 00000000000..b34c03826fc --- /dev/null +++ b/src/include/migraphx/dim_like.hpp @@ -0,0 +1,91 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_DIM_LIKE_HPP +#define MIGRAPHX_GUARD_MIGRAPHLIB_DIM_LIKE_HPP + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct value; + +// A dim attribute entry that may be either a plain int64_t or a +// dynamic_dimension. Used by ops whose dim-valued attributes need to carry +// either static integers or dynamic/symbolic dimensions. +struct MIGRAPHX_EXPORT dim_like +{ + std::variant value = int64_t{0}; + + constexpr dim_like() = default; + + template {})> + constexpr dim_like(T v) : value{static_cast(v)} // NOLINT(google-explicit-constructor) + { + } + + dim_like(shape::dynamic_dimension d) // NOLINT(google-explicit-constructor) + : value{std::move(d)} + { + } + + friend bool operator==(const dim_like& a, const dim_like& b) { return a.value == b.value; } + friend bool operator!=(const dim_like& a, const dim_like& b) { return not(a == b); } + + MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const dim_like& d); + + migraphx::value to_value() const; + void from_value(const migraphx::value& v); +}; + +template +bool holds_alternative(const dim_like& d) +{ + return std::holds_alternative(d.value); +} + +template +const T& get(const dim_like& d) +{ + return std::get(d.value); +} + +template +T& get(dim_like& d) +{ + return std::get(d.value); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/reshape.hpp b/src/include/migraphx/op/reshape.hpp index 51c68d4c924..e95a39a30ea 100644 --- a/src/include/migraphx/op/reshape.hpp +++ b/src/include/migraphx/op/reshape.hpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -56,7 +57,7 @@ namespace op { */ struct reshape { - std::vector dims; + std::vector dims; template static auto reflect(Self& self, F f) @@ -90,7 +91,7 @@ struct reshape } else { - std::size_t u_dim = d; + std::size_t u_dim = get(d); output_dyn_dims.at(i) = {u_dim, u_dim}; } } @@ -138,7 +139,10 @@ struct reshape { check_shapes{inputs, *this}.has(1); auto&& idims = inputs.front().lens(); - std::vector rdims(dims.begin(), dims.end()); + std::vector rdims(dims.size()); + std::transform(dims.begin(), dims.end(), rdims.begin(), [](const dim_like& d) { + return get(d); + }); for(std::size_t i = 0; i < dims.size(); i++) { diff --git a/src/include/migraphx/op/reshape_lazy.hpp b/src/include/migraphx/op/reshape_lazy.hpp index 845fb5cd43e..a3c5da832cc 100644 --- a/src/include/migraphx/op/reshape_lazy.hpp +++ b/src/include/migraphx/op/reshape_lazy.hpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -37,7 +38,7 @@ namespace op { struct reshape_lazy { - std::vector dims; + std::vector dims; template static auto reflect(Self& self, F f) @@ -65,7 +66,7 @@ struct reshape_lazy { if(dyn_dims[i].is_fixed()) { - num_dims_ele *= dims[i]; + num_dims_ele *= get(dims[i]); num_dd_ele *= dyn_dims[i].get_interval().min; } else @@ -89,9 +90,10 @@ struct reshape_lazy dims.cend(), dyn_dims.cbegin(), output_dyn_dims.begin(), - [](std::size_t dim, auto dyn_dim) { + [](const dim_like& d, auto dyn_dim) { if(not dyn_dim.is_fixed()) return dyn_dim; + std::size_t dim = get(d); return shape::dynamic_dimension{dim, dim}; }); return {s0.type(), output_dyn_dims}; @@ -255,7 +257,10 @@ struct reshape_lazy { check_shapes{inputs, *this}.has(1); auto&& idims = inputs.front().lens(); - std::vector rdims(dims.begin(), dims.end()); + std::vector rdims(dims.size()); + std::transform(dims.begin(), dims.end(), rdims.begin(), [](const dim_like& d) { + return get(d); + }); for(std::size_t i = 0; i < dims.size(); i++) { diff --git a/test/dim_like_test.cpp b/test/dim_like_test.cpp new file mode 100644 index 00000000000..1952be5a8e7 --- /dev/null +++ b/test/dim_like_test.cpp @@ -0,0 +1,291 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "test.hpp" + +using migraphx::dim_like; +using dd = migraphx::shape::dynamic_dimension; + +static dim_like round_trip(const dim_like& d) +{ + auto v = migraphx::to_value(d); + return migraphx::from_value(v); +} + +// =================================================================== +// Construction and alternative inspection +// =================================================================== + +TEST_CASE(construct_default) +{ + dim_like d; + EXPECT(std::holds_alternative(d.value)); + EXPECT(std::get(d.value) == 0); +} + +TEST_CASE(construct_int_marker_zero) +{ + dim_like d = 0; + EXPECT(std::holds_alternative(d.value)); + EXPECT(std::get(d.value) == 0); +} + +TEST_CASE(construct_int_marker_neg_one) +{ + dim_like d = -1; + EXPECT(std::holds_alternative(d.value)); + EXPECT(std::get(d.value) == -1); +} + +TEST_CASE(construct_int_value) +{ + dim_like d = 42; + EXPECT(std::holds_alternative(d.value)); + EXPECT(std::get(d.value) == 42); +} + +TEST_CASE(construct_from_size_t) +{ + std::size_t n = 7; + dim_like d = n; + EXPECT(std::holds_alternative(d.value)); + EXPECT(std::get(d.value) == 7); +} + +TEST_CASE(construct_from_dynamic_dimension_range) +{ + dim_like d = dd{1, 4}; + EXPECT(std::holds_alternative
(d.value)); + EXPECT(std::get
(d.value) == dd{1, 4}); +} + +TEST_CASE(construct_from_dynamic_dimension_symbolic) +{ + dim_like d = dd{migraphx::sym::var("n", {1, 8})}; + EXPECT(std::holds_alternative
(d.value)); + EXPECT(std::get
(d.value).is_symbolic()); +} + +// =================================================================== +// Equality / count semantics for legacy 0/-1 marker patterns +// =================================================================== + +TEST_CASE(equality_int_marker_zero) +{ + dim_like d = 0; + EXPECT(d == 0); + EXPECT(0 == d); + EXPECT(not(d == -1)); +} + +TEST_CASE(equality_int_marker_neg_one) +{ + dim_like d = -1; + EXPECT(d == -1); + EXPECT(-1 == d); + EXPECT(not(d == 0)); +} + +TEST_CASE(equality_int_value) +{ + dim_like d = 5; + EXPECT(d == 5); + EXPECT(5 == d); + EXPECT(d != 4); +} + +TEST_CASE(equality_dd_alternative_never_matches_marker) +{ + dim_like d = dd{0, 4}; + EXPECT(d != 0); + EXPECT(d != -1); +} + +TEST_CASE(equality_between_alternatives) +{ + dim_like a = 3; + dim_like b = dd{3, 3}; + EXPECT(a != b); +} + +TEST_CASE(adl_get_and_holds_alternative) +{ + using migraphx::get; + using migraphx::holds_alternative; + + dim_like d_int = 42; + EXPECT(holds_alternative(d_int)); + EXPECT(get(d_int) == 42); + + dim_like d_dd = dd{1, 4}; + EXPECT(holds_alternative
(d_dd)); + EXPECT(get
(d_dd) == dd{1, 4}); + + EXPECT(test::throws([&] { (void)get
(d_int); })); +} + +TEST_CASE(std_count_marker) +{ + std::vector dims = {0, 0, 6, -1}; + EXPECT(std::count(dims.begin(), dims.end(), -1) == 1); + EXPECT(std::count(dims.begin(), dims.end(), 0) == 2); +} + +// =================================================================== +// Streaming +// =================================================================== + +TEST_CASE(stream_int) +{ + std::ostringstream ss; + ss << dim_like{42}; + EXPECT(ss.str() == "42"); +} + +TEST_CASE(stream_neg_one) +{ + std::ostringstream ss; + ss << dim_like{-1}; + EXPECT(ss.str() == "-1"); +} + +TEST_CASE(stream_dd) +{ + std::ostringstream ss; + ss << dim_like{dd{1, 4}}; + std::ostringstream expected; + expected << dd{1, 4}; + EXPECT(ss.str() == expected.str()); +} + +TEST_CASE(stream_dd_symbolic) +{ + auto sd = dd{migraphx::sym::var("n", {1, 8})}; + std::ostringstream ss; + ss << dim_like{sd}; + std::ostringstream expected; + expected << sd; + EXPECT(ss.str() == expected.str()); +} + +// =================================================================== +// Serialization round-trip +// =================================================================== + +TEST_CASE(serialize_int_zero) +{ + dim_like d = 0; + auto rt = round_trip(d); + EXPECT(rt == d); + EXPECT(std::holds_alternative(rt.value)); +} + +TEST_CASE(serialize_int_neg_one) +{ + dim_like d = -1; + auto rt = round_trip(d); + EXPECT(rt == d); + EXPECT(std::holds_alternative(rt.value)); +} + +TEST_CASE(serialize_int_value) +{ + dim_like d = 42; + auto rt = round_trip(d); + EXPECT(rt == d); + EXPECT(std::holds_alternative(rt.value)); +} + +TEST_CASE(serialize_dd_range) +{ + dim_like d = dd{1, 4}; + auto rt = round_trip(d); + EXPECT(rt == d); + EXPECT(std::holds_alternative
(rt.value)); +} + +TEST_CASE(serialize_dd_symbolic) +{ + dim_like d = dd{migraphx::sym::var("n", {1, 8})}; + auto rt = round_trip(d); + EXPECT(rt == d); + EXPECT(std::holds_alternative
(rt.value)); +} + +// =================================================================== +// Backward-compat: legacy serialized models stored dims as a plain int64 +// array. Decoding such a value into vector must succeed and +// produce the int alternative for every entry. +// =================================================================== + +TEST_CASE(from_value_legacy_int_array) +{ + std::vector legacy = {0, 0, 6, -1}; + auto loaded = migraphx::from_value>(migraphx::to_value(legacy)); + EXPECT(loaded == std::vector{0, 0, 6, -1}); +} + +TEST_CASE(from_value_size_t_array) +{ + // Common path at op-construction sites: make_op("...", {{"dims", lens()}}) + // where lens() is vector. The value layer routes that through uint64. + std::vector lens = {4, 24, 1}; + auto loaded = migraphx::from_value>(migraphx::to_value(lens)); + EXPECT(loaded == std::vector{4, 24, 1}); +} + +TEST_CASE(to_value_int_array_byte_compat) +{ + // A vector holding only int alternatives must serialize to the + // same value as the equivalent vector, so models with no symbolic + // dims save byte-identical to today. + std::vector dims = {0, 0, 6, -1}; + std::vector legacy{0, 0, 6, -1}; + EXPECT(migraphx::to_value(dims) == migraphx::to_value(legacy)); +} + +TEST_CASE(round_trip_mixed_vector) +{ + std::vector dims = { + dim_like{0}, + dim_like{42}, + dim_like{dd{1, 4}}, + dim_like{-1}, + }; + auto v = migraphx::to_value(dims); + auto loaded = migraphx::from_value>(v); + EXPECT(loaded == dims); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/onnx/parse/reshape_test.cpp b/test/onnx/parse/reshape_test.cpp index 7ee9e492198..ff1ed432da4 100644 --- a/test/onnx/parse/reshape_test.cpp +++ b/test/onnx/parse/reshape_test.cpp @@ -34,7 +34,7 @@ TEST_CASE(reshape_test) mm->add_literal( migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, reshape_dims}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}}); - op.dims = reshape_dims; + op.dims.assign(reshape_dims.begin(), reshape_dims.end()); mm->add_instruction(op, l0); mm->add_instruction(op, l0); auto prog = optimize_onnx("reshape_test.onnx"); From 36439ce952a96729dec39d7c54974009e8d63625 Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 29 Apr 2026 19:50:17 -0700 Subject: [PATCH 6/9] fix license --- test/onnx/parse/reshape_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/onnx/parse/reshape_test.cpp b/test/onnx/parse/reshape_test.cpp index ff1ed432da4..be51e4b74dd 100644 --- a/test/onnx/parse/reshape_test.cpp +++ b/test/onnx/parse/reshape_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal From 4498317b640d62b5e6f8199dd5731c63d899bde8 Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 20 May 2026 15:25:06 -0700 Subject: [PATCH 7/9] refactor to use picked_variant --- .gitignore | 3 + src/dim_like.cpp | 28 +- src/include/migraphx/dim_like.hpp | 53 +- src/include/migraphx/op/reshape.hpp | 22 +- src/include/migraphx/op/reshape_lazy.hpp | 16 +- src/include/migraphx/picked_variant.hpp | 138 ++++++ test/dim_like_test.cpp | 129 ++--- test/picked_variant.cpp | 585 +++++++++++++++++++++++ 8 files changed, 842 insertions(+), 132 deletions(-) create mode 100644 src/include/migraphx/picked_variant.hpp create mode 100644 test/picked_variant.cpp diff --git a/.gitignore b/.gitignore index ba8698ec024..b0294825de2 100644 --- a/.gitignore +++ b/.gitignore @@ -97,3 +97,6 @@ docs/_toc.yml # Ignore CMake user presets CMakeUserPresets.json +# local Python virtual environment +.venv/ +.cache/ diff --git a/src/dim_like.cpp b/src/dim_like.cpp index 07aa76aa08b..7c375a51085 100644 --- a/src/dim_like.cpp +++ b/src/dim_like.cpp @@ -26,36 +26,24 @@ #include #include -#include - namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -std::ostream& operator<<(std::ostream& os, const dim_like& d) -{ - std::visit([&](const auto& x) { os << x; }, d.value); - return os; -} - -migraphx::value dim_like::to_value() const +void migraphx_to_value(value& v, const dim_like& d) { - return std::visit([](const auto& x) { return migraphx::to_value(x); }, this->value); + v = visit([](const auto& x) { return migraphx::to_value(x); }, d); } -void dim_like::from_value(const migraphx::value& v) +void migraphx_from_value(const value& v, dim_like& d) { - // Backward-compatible path: integer-valued entries (signed or unsigned) - // route through the int alternative so old .mxr files and call sites that - // pass plain integer arrays both decode without going through the - // dynamic_dimension reflect path. - if(v.is_int64() or v.is_uint64()) + if(v.is_object()) { - this->value = v.to(); + shape::dynamic_dimension dd; + migraphx::from_value(v, dd); + d = std::move(dd); return; } - shape::dynamic_dimension d; - migraphx::from_value(v, d); - this->value = std::move(d); + d = v.to(); } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/dim_like.hpp b/src/include/migraphx/dim_like.hpp index b34c03826fc..ca24b4648b7 100644 --- a/src/include/migraphx/dim_like.hpp +++ b/src/include/migraphx/dim_like.hpp @@ -27,10 +27,9 @@ #include #include #include -#include -#include #include +#include #include #include @@ -39,51 +38,33 @@ inline namespace MIGRAPHX_INLINE_NS { struct value; -// A dim attribute entry that may be either a plain int64_t or a -// dynamic_dimension. Used by ops whose dim-valued attributes need to carry -// either static integers or dynamic/symbolic dimensions. -struct MIGRAPHX_EXPORT dim_like +// Routes any integral type through int64_t so call sites don't need casts. +struct dim_like_picker { - std::variant value = int64_t{0}; - - constexpr dim_like() = default; - template {})> - constexpr dim_like(T v) : value{static_cast(v)} // NOLINT(google-explicit-constructor) - { - } - - dim_like(shape::dynamic_dimension d) // NOLINT(google-explicit-constructor) - : value{std::move(d)} + static int64_t apply(T v) { + return static_cast(v); } - friend bool operator==(const dim_like& a, const dim_like& b) { return a.value == b.value; } - friend bool operator!=(const dim_like& a, const dim_like& b) { return not(a == b); } - - MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const dim_like& d); - - migraphx::value to_value() const; - void from_value(const migraphx::value& v); + static shape::dynamic_dimension apply(shape::dynamic_dimension d) { return d; } }; -template -bool holds_alternative(const dim_like& d) -{ - return std::holds_alternative(d.value); -} +// A dim attribute entry that may be either a plain int64_t or a dynamic_dimension. +using dim_like = picked_variant; -template -const T& get(const dim_like& d) +// Templated to hide from ADL on unrelated types: a non-template overload would +// be probed during overload resolution for things like vector, which +// would instantiate Picker::apply(vector<...>) and hard-fail. +template {})> +inline std::ostream& operator<<(std::ostream& os, const T& d) { - return std::get(d.value); + visit([&](const auto& x) { os << x; }, d); + return os; } -template -T& get(dim_like& d) -{ - return std::get(d.value); -} +MIGRAPHX_EXPORT void migraphx_to_value(value& v, const dim_like& d); +MIGRAPHX_EXPORT void migraphx_from_value(const value& v, dim_like& d); } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/include/migraphx/op/reshape.hpp b/src/include/migraphx/op/reshape.hpp index 4c47c844293..7b10ae62c88 100644 --- a/src/include/migraphx/op/reshape.hpp +++ b/src/include/migraphx/op/reshape.hpp @@ -72,9 +72,9 @@ struct reshape // Makes no checks for the validity of the `dims` attribute for the given input shape. shape dyn_1arg_compute_shape(shape s0) const { - auto input_dyn_dims = s0.dyn_dims(); - const auto neg_dim_num = - std::distance(this->dims.begin(), std::find(this->dims.begin(), this->dims.end(), -1)); + auto input_dyn_dims = s0.dyn_dims(); + const auto neg_dim_num = std::distance( + this->dims.begin(), std::find(this->dims.begin(), this->dims.end(), dim_like{-1})); const bool has_negative_dim_attr = neg_dim_num < dims.size(); // construct output dynamic shape from dims attribute std::vector output_dyn_dims(dims.size()); @@ -82,17 +82,17 @@ struct reshape for(std::size_t i = 0; i < dims.size(); ++i) { auto d = dims.at(i); - if(d == 0) + if(d == dim_like{0}) { output_dyn_dims.at(i) = input_dyn_dims.at(i); } - else if(d == -1) + else if(d == dim_like{-1}) { output_dyn_dims.at(i) = {1, 1}; } else { - std::size_t u_dim = get(d); + std::size_t u_dim = std::get(d); output_dyn_dims.at(i) = {u_dim, u_dim}; } } @@ -142,16 +142,16 @@ struct reshape auto&& idims = inputs.front().lens(); std::vector rdims(dims.size()); std::transform(dims.begin(), dims.end(), rdims.begin(), [](const dim_like& d) { - return get(d); + return std::get(d); }); for(std::size_t i = 0; i < dims.size(); i++) { - if(dims[i] == 0) + if(dims[i] == dim_like{0}) rdims[i] = idims[i]; // convert -1 to 1 for rdims since rdims uses size_t (-1 is max_int for size_t) - if(dims[i] == -1) + if(dims[i] == dim_like{-1}) rdims[i] = 1; } @@ -162,7 +162,7 @@ struct reshape std::accumulate(rdims.begin(), rdims.end(), 1, std::multiplies()); for(std::size_t i = 0; i < rdims.size(); i++) { - if(dims[i] == -1) + if(dims[i] == dim_like{-1}) rdims[i] = missing_dim; } } @@ -186,7 +186,7 @@ struct reshape { check_shapes{inputs, *this, true}.has(1, 2); - auto n_neg_dims = std::count(dims.begin(), dims.end(), -1); + auto n_neg_dims = std::count(dims.begin(), dims.end(), dim_like{-1}); if(n_neg_dims > 1) MIGRAPHX_THROW("Reshape: Dimensions for reshape can only have one -1 dim"); diff --git a/src/include/migraphx/op/reshape_lazy.hpp b/src/include/migraphx/op/reshape_lazy.hpp index 9c10cbf949b..40fc8612ad1 100644 --- a/src/include/migraphx/op/reshape_lazy.hpp +++ b/src/include/migraphx/op/reshape_lazy.hpp @@ -66,12 +66,12 @@ struct reshape_lazy { if(dyn_dims[i].is_fixed()) { - num_dims_ele *= get(dims[i]); + num_dims_ele *= std::get(dims[i]); num_dd_ele *= dyn_dims[i].get_interval().min; } else { - if(dims[i] != 0 and dims[i] != -1) + if(dims[i] != dim_like{0} and dims[i] != dim_like{-1}) { MIGRAPHX_THROW( "reshape_lazy: Non-fixed dynamic_dimension doesn't match with 0 or -1 " @@ -93,7 +93,7 @@ struct reshape_lazy [](const dim_like& d, auto dyn_dim) { if(not dyn_dim.is_fixed()) return dyn_dim; - std::size_t dim = get(d); + std::size_t dim = std::get(d); return shape::dynamic_dimension{dim, dim}; }); return {s0.type(), output_dyn_dims}; @@ -105,17 +105,17 @@ struct reshape_lazy auto&& idims = inputs.front().lens(); std::vector rdims(dims.size()); std::transform(dims.begin(), dims.end(), rdims.begin(), [](const dim_like& d) { - return get(d); + return std::get(d); }); for(std::size_t i = 0; i < dims.size(); i++) { - if(dims[i] == 0) + if(dims[i] == dim_like{0}) rdims[i] = idims[i]; // since rdims using size_t type, -1 is the max value // is size_t that cause later compuation incorrect - if(dims[i] == -1) + if(dims[i] == dim_like{-1}) rdims[i] = 1; } @@ -126,7 +126,7 @@ struct reshape_lazy std::accumulate(rdims.begin(), rdims.end(), 1, std::multiplies()); for(std::size_t i = 0; i < rdims.size(); i++) { - if(dims[i] == -1) + if(dims[i] == dim_like{-1}) rdims[i] = missing_dim; } } @@ -148,7 +148,7 @@ struct reshape_lazy shape compute_shape(std::vector inputs) const { check_shapes{inputs, *this, true}.has(1); - auto n_neg_dims = std::count(dims.begin(), dims.end(), -1); + auto n_neg_dims = std::count(dims.begin(), dims.end(), dim_like{-1}); if(n_neg_dims > 1) MIGRAPHX_THROW("reshape_lazy: Dimensions for reshape_lazy can only have one -1 dim"); const auto& s0 = inputs[0]; diff --git a/src/include/migraphx/picked_variant.hpp b/src/include/migraphx/picked_variant.hpp new file mode 100644 index 00000000000..38736f1af17 --- /dev/null +++ b/src/include/migraphx/picked_variant.hpp @@ -0,0 +1,138 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_PICKED_VARIANT_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_PICKED_VARIANT_HPP + +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +template +struct picked_variant; + +template +struct is_picked_variant : std::false_type +{ +}; +template +struct is_picked_variant> : std::true_type +{ +}; + +// Namespace-scope as_variant overloads for plain std::variant. Declared +// before picked_variant so its hidden-friend visit body can find them via +// ordinary lookup. For picked_variant arguments the class's hidden-friend +// as_variant wins (exact-type match beats the derived-to-base conversion +// these overloads would otherwise need). +template +constexpr std::variant& as_variant(std::variant& v) +{ + return v; +} +template +constexpr const std::variant& as_variant(const std::variant& v) +{ + return v; +} +template +constexpr std::variant&& as_variant(std::variant&& v) +{ + return std::move(v); +} + +template +struct picked_variant : std::variant +{ + using base_t = std::variant; + using base_t::base_t; // inherit default, in_place_type, in_place_index ctors + + template >{})> + constexpr picked_variant(T&& x) : base_t(Picker::apply(std::forward(x))) + { + } + + friend constexpr base_t& as_variant(picked_variant& x) { return x; } + friend constexpr const base_t& as_variant(const picked_variant& x) { return x; } + friend constexpr base_t&& as_variant(picked_variant&& x) { return std::move(x); } + + // Hidden friends. + // + // One overload per "position" of the picked_variant argument. Each takes a + // forwarding reference constrained via SFINAE on the decayed type so a + // single overload covers `&`, `const&`, and `&&`. + // + // Return type is `decltype(auto)` rather than a trailing `decltype(...)` + // so the body is only substituted when the overload is actually selected. + // std::visit is not required by the standard to be SFINAE-friendly and on + // older libstdc++ (e.g. GCC 7 on SLES) it isn't, so a trailing decltype + // around a std::visit call would produce hard errors during overload + // resolution rather than cleanly removing the overload. + // + // Every variant argument is routed through `as_variant`. The namespace- + // scope overloads above cover plain std::variant; together they ensure + // std::visit only ever sees std::variant arguments and never the raw + // picked_variant (which doesn't have std::variant_size specialized for + // it). + // + // The position-2 overload handles calls where picked_variant is the + // second variant (e.g. visit(f, std_v, pv)). Its SFINAE excludes the case + // where the first variant is also a picked_variant -- those are handled + // by the position-1 overload. + + template , picked_variant>{})> + friend constexpr decltype(auto) visit(Visitor&& vis, V&& pv, Variants&&... vars) + { + return std::visit(std::forward(vis), + as_variant(std::forward(pv)), + as_variant(std::forward(vars))...); + } + + template >{} and + std::is_same, picked_variant>{})> + friend constexpr decltype(auto) visit(Visitor&& vis, V0&& v0, V1&& v1, Variants&&... vars) + { + return std::visit(std::forward(vis), + as_variant(std::forward(v0)), + as_variant(std::forward(v1)), + as_variant(std::forward(vars))...); + } +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif // MIGRAPHX_GUARD_MIGRAPHX_PICKED_VARIANT_HPP diff --git a/test/dim_like_test.cpp b/test/dim_like_test.cpp index 1952be5a8e7..424e7d23959 100644 --- a/test/dim_like_test.cpp +++ b/test/dim_like_test.cpp @@ -50,86 +50,92 @@ static dim_like round_trip(const dim_like& d) TEST_CASE(construct_default) { dim_like d; - EXPECT(std::holds_alternative(d.value)); - EXPECT(std::get(d.value) == 0); + EXPECT(std::holds_alternative(d)); + EXPECT(std::get(d) == 0); } TEST_CASE(construct_int_marker_zero) { dim_like d = 0; - EXPECT(std::holds_alternative(d.value)); - EXPECT(std::get(d.value) == 0); + EXPECT(std::holds_alternative(d)); + EXPECT(std::get(d) == 0); } TEST_CASE(construct_int_marker_neg_one) { dim_like d = -1; - EXPECT(std::holds_alternative(d.value)); - EXPECT(std::get(d.value) == -1); + EXPECT(std::holds_alternative(d)); + EXPECT(std::get(d) == -1); } TEST_CASE(construct_int_value) { dim_like d = 42; - EXPECT(std::holds_alternative(d.value)); - EXPECT(std::get(d.value) == 42); + EXPECT(std::holds_alternative(d)); + EXPECT(std::get(d) == 42); } TEST_CASE(construct_from_size_t) { std::size_t n = 7; dim_like d = n; - EXPECT(std::holds_alternative(d.value)); - EXPECT(std::get(d.value) == 7); + EXPECT(std::holds_alternative(d)); + EXPECT(std::get(d) == 7); } TEST_CASE(construct_from_dynamic_dimension_range) { dim_like d = dd{1, 4}; - EXPECT(std::holds_alternative
(d.value)); - EXPECT(std::get
(d.value) == dd{1, 4}); + EXPECT(std::holds_alternative
(d)); + EXPECT(std::get
(d) == dd{1, 4}); } TEST_CASE(construct_from_dynamic_dimension_symbolic) { dim_like d = dd{migraphx::sym::var("n", {1, 8})}; - EXPECT(std::holds_alternative
(d.value)); - EXPECT(std::get
(d.value).is_symbolic()); + EXPECT(std::holds_alternative
(d)); + EXPECT(std::get
(d).is_symbolic()); +} + +TEST_CASE(get_throws_on_wrong_alternative) +{ + dim_like d = 42; + EXPECT(test::throws([&] { (void)std::get
(d); })); } // =================================================================== -// Equality / count semantics for legacy 0/-1 marker patterns +// Equality and std::count on 0 / -1 markers // =================================================================== TEST_CASE(equality_int_marker_zero) { dim_like d = 0; - EXPECT(d == 0); - EXPECT(0 == d); - EXPECT(not(d == -1)); + EXPECT(d == dim_like{0}); + EXPECT(dim_like{0} == d); + EXPECT(not(d == dim_like{-1})); } TEST_CASE(equality_int_marker_neg_one) { dim_like d = -1; - EXPECT(d == -1); - EXPECT(-1 == d); - EXPECT(not(d == 0)); + EXPECT(d == dim_like{-1}); + EXPECT(dim_like{-1} == d); + EXPECT(not(d == dim_like{0})); } TEST_CASE(equality_int_value) { dim_like d = 5; - EXPECT(d == 5); - EXPECT(5 == d); - EXPECT(d != 4); + EXPECT(d == dim_like{5}); + EXPECT(dim_like{5} == d); + EXPECT(d != dim_like{4}); } TEST_CASE(equality_dd_alternative_never_matches_marker) { dim_like d = dd{0, 4}; - EXPECT(d != 0); - EXPECT(d != -1); + EXPECT(d != dim_like{0}); + EXPECT(d != dim_like{-1}); } TEST_CASE(equality_between_alternatives) @@ -139,27 +145,11 @@ TEST_CASE(equality_between_alternatives) EXPECT(a != b); } -TEST_CASE(adl_get_and_holds_alternative) -{ - using migraphx::get; - using migraphx::holds_alternative; - - dim_like d_int = 42; - EXPECT(holds_alternative(d_int)); - EXPECT(get(d_int) == 42); - - dim_like d_dd = dd{1, 4}; - EXPECT(holds_alternative
(d_dd)); - EXPECT(get
(d_dd) == dd{1, 4}); - - EXPECT(test::throws([&] { (void)get
(d_int); })); -} - TEST_CASE(std_count_marker) { std::vector dims = {0, 0, 6, -1}; - EXPECT(std::count(dims.begin(), dims.end(), -1) == 1); - EXPECT(std::count(dims.begin(), dims.end(), 0) == 2); + EXPECT(std::count(dims.begin(), dims.end(), dim_like{-1}) == 1); + EXPECT(std::count(dims.begin(), dims.end(), dim_like{0}) == 2); } // =================================================================== @@ -208,7 +198,7 @@ TEST_CASE(serialize_int_zero) dim_like d = 0; auto rt = round_trip(d); EXPECT(rt == d); - EXPECT(std::holds_alternative(rt.value)); + EXPECT(std::holds_alternative(rt)); } TEST_CASE(serialize_int_neg_one) @@ -216,7 +206,7 @@ TEST_CASE(serialize_int_neg_one) dim_like d = -1; auto rt = round_trip(d); EXPECT(rt == d); - EXPECT(std::holds_alternative(rt.value)); + EXPECT(std::holds_alternative(rt)); } TEST_CASE(serialize_int_value) @@ -224,7 +214,7 @@ TEST_CASE(serialize_int_value) dim_like d = 42; auto rt = round_trip(d); EXPECT(rt == d); - EXPECT(std::holds_alternative(rt.value)); + EXPECT(std::holds_alternative(rt)); } TEST_CASE(serialize_dd_range) @@ -232,7 +222,7 @@ TEST_CASE(serialize_dd_range) dim_like d = dd{1, 4}; auto rt = round_trip(d); EXPECT(rt == d); - EXPECT(std::holds_alternative
(rt.value)); + EXPECT(std::holds_alternative
(rt)); } TEST_CASE(serialize_dd_symbolic) @@ -240,13 +230,11 @@ TEST_CASE(serialize_dd_symbolic) dim_like d = dd{migraphx::sym::var("n", {1, 8})}; auto rt = round_trip(d); EXPECT(rt == d); - EXPECT(std::holds_alternative
(rt.value)); + EXPECT(std::holds_alternative
(rt)); } // =================================================================== -// Backward-compat: legacy serialized models stored dims as a plain int64 -// array. Decoding such a value into vector must succeed and -// produce the int alternative for every entry. +// Backward-compat: load and save against legacy int / size_t arrays // =================================================================== TEST_CASE(from_value_legacy_int_array) @@ -258,8 +246,6 @@ TEST_CASE(from_value_legacy_int_array) TEST_CASE(from_value_size_t_array) { - // Common path at op-construction sites: make_op("...", {{"dims", lens()}}) - // where lens() is vector. The value layer routes that through uint64. std::vector lens = {4, 24, 1}; auto loaded = migraphx::from_value>(migraphx::to_value(lens)); EXPECT(loaded == std::vector{4, 24, 1}); @@ -267,9 +253,6 @@ TEST_CASE(from_value_size_t_array) TEST_CASE(to_value_int_array_byte_compat) { - // A vector holding only int alternatives must serialize to the - // same value as the equivalent vector, so models with no symbolic - // dims save byte-identical to today. std::vector dims = {0, 0, 6, -1}; std::vector legacy{0, 0, 6, -1}; EXPECT(migraphx::to_value(dims) == migraphx::to_value(legacy)); @@ -288,4 +271,36 @@ TEST_CASE(round_trip_mixed_vector) EXPECT(loaded == dims); } +// =================================================================== +// ADL visit +// =================================================================== + +TEST_CASE(visit_int) +{ + dim_like d = 42; + auto seen = visit( + [](const auto& x) -> std::string { + if constexpr(std::is_same_v, int64_t>) + return "int"; + else + return "dd"; + }, + d); + EXPECT(seen == "int"); +} + +TEST_CASE(visit_dd) +{ + dim_like d = dd{1, 4}; + auto seen = visit( + [](const auto& x) -> std::string { + if constexpr(std::is_same_v, int64_t>) + return "int"; + else + return "dd"; + }, + d); + EXPECT(seen == "dd"); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/picked_variant.cpp b/test/picked_variant.cpp new file mode 100644 index 00000000000..9f72bc79105 --- /dev/null +++ b/test/picked_variant.cpp @@ -0,0 +1,585 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include + +#include +#include +#include + +struct copy_picker +{ + template + static decltype(auto) apply(T&& x) + { + return std::forward(x); + } +}; + +using pv_t = migraphx::picked_variant; + +struct always_long +{ + template + static long apply(T&&) + { + return 999L; + } +}; + +using long_pv = migraphx::picked_variant; + +TEST_CASE(default_ctor) +{ + pv_t v; + EXPECT(v.index() == 0); + EXPECT(std::holds_alternative(v)); + EXPECT(std::get(v) == 0); +} + +TEST_CASE(in_place_type_ctor_int) +{ + pv_t v(std::in_place_type, 42); + EXPECT(v.index() == 0); + EXPECT(std::holds_alternative(v)); + EXPECT(std::get(v) == 42); +} + +TEST_CASE(in_place_type_ctor_long) +{ + pv_t v(std::in_place_type, 7L); + EXPECT(v.index() == 1); + EXPECT(std::holds_alternative(v)); + EXPECT(std::get(v) == 7L); +} + +TEST_CASE(in_place_type_ctor_string) +{ + pv_t v(std::in_place_type, "hello"); + EXPECT(v.index() == 2); + EXPECT(std::holds_alternative(v)); + EXPECT(std::get(v) == "hello"); +} + +TEST_CASE(in_place_index_ctor) +{ + pv_t v(std::in_place_index<2>, "world"); + EXPECT(v.index() == 2); + EXPECT(std::holds_alternative(v)); + EXPECT(std::get(v) == "world"); +} + +TEST_CASE(holds_alternative_negative) +{ + pv_t v(std::in_place_type, 42); + EXPECT(std::holds_alternative(v)); + EXPECT(not std::holds_alternative(v)); + EXPECT(not std::holds_alternative(v)); +} + +TEST_CASE(get_by_type) +{ + pv_t v(std::in_place_type, 123); + EXPECT(std::get(v) == 123); + EXPECT(test::throws([&] { (void)std::get(v); })); + EXPECT(test::throws([&] { (void)std::get(v); })); +} + +TEST_CASE(get_by_index) +{ + pv_t v(std::in_place_index<1>, 99L); + EXPECT(std::get<1>(v) == 99L); + EXPECT(test::throws([&] { (void)std::get<0>(v); })); +} + +TEST_CASE(get_if_pointer) +{ + pv_t v(std::in_place_type, 7); + auto* ip = std::get_if(&v); + EXPECT(ip != nullptr); + EXPECT(*ip == 7); + auto* lp = std::get_if(&v); + EXPECT(lp == nullptr); +} + +TEST_CASE(visit_returns_value) +{ + pv_t v(std::in_place_type, 5L); + auto doubled = visit( + [](const auto& x) -> long { + if constexpr(std::is_same, std::string>{}) + return x.size(); + else + return x * 2; + }, + v); + EXPECT(doubled == 10L); +} + +TEST_CASE(visit_mutates_value) +{ + pv_t v(std::in_place_type, 1); + visit( + [](auto& x) { + if constexpr(std::is_arithmetic>{}) + x += 4; + }, + v); + EXPECT(std::get(v) == 5); +} + +TEST_CASE(copy_ctor) +{ + pv_t v1(std::in_place_type, "copy"); + pv_t v2 = v1; // NOLINT(performance-unnecessary-copy-initialization) + EXPECT(std::holds_alternative(v2)); + EXPECT(std::get(v2) == "copy"); +} + +TEST_CASE(move_ctor) +{ + pv_t v1(std::in_place_type, "move"); + pv_t v2 = std::move(v1); + EXPECT(std::holds_alternative(v2)); + EXPECT(std::get(v2) == "move"); +} + +TEST_CASE(emplace_changes_alternative) +{ + pv_t v(std::in_place_type, 1); + v.emplace("now-string"); + EXPECT(std::holds_alternative(v)); + EXPECT(std::get(v) == "now-string"); + v.emplace(42); + EXPECT(std::holds_alternative(v)); + EXPECT(std::get(v) == 42); +} + +TEST_CASE(equality) +{ + pv_t a(std::in_place_type, 42); + pv_t b(std::in_place_type, 42); + pv_t c(std::in_place_type, 42L); + EXPECT(a == b); + EXPECT(a != c); +} + +TEST_CASE(equality_same_alternative_different_value) +{ + pv_t a(std::in_place_type, 1); + pv_t b(std::in_place_type, 2); + EXPECT(not(a == b)); + EXPECT(a != b); +} + +TEST_CASE(less_than_same_alternative) +{ + pv_t a(std::in_place_type, 1); + pv_t b(std::in_place_type, 2); + EXPECT(a < b); + EXPECT(a <= b); + EXPECT(b > a); + EXPECT(b >= a); + EXPECT(not(a > b)); + EXPECT(not(b < a)); +} + +TEST_CASE(less_than_different_alternative) +{ + pv_t a(std::in_place_type, 100); + pv_t b(std::in_place_type, 0L); + EXPECT(a < b); + EXPECT(a <= b); + EXPECT(b > a); + EXPECT(b >= a); +} + +TEST_CASE(less_equal_when_equal) +{ + pv_t a(std::in_place_type, 7L); + pv_t b(std::in_place_type, 7L); + EXPECT(a <= b); + EXPECT(a >= b); + EXPECT(not(a < b)); + EXPECT(not(a > b)); +} + +TEST_CASE(comparison_strings) +{ + pv_t a(std::in_place_type, "abc"); + pv_t b(std::in_place_type, "abd"); + EXPECT(a < b); + EXPECT(a != b); + EXPECT(b > a); +} + +TEST_CASE(swap_alternatives) +{ + pv_t a(std::in_place_type, 1); + pv_t b(std::in_place_type, "two"); + a.swap(b); + EXPECT(std::holds_alternative(a)); + EXPECT(std::get(a) == "two"); + EXPECT(std::holds_alternative(b)); + EXPECT(std::get(b) == 1); +} + +TEST_CASE(value_ctor_int_invokes_picker) +{ + pv_t v(42); + EXPECT(std::holds_alternative(v)); + EXPECT(std::get(v) == 42); +} + +TEST_CASE(value_ctor_long_invokes_picker) +{ + pv_t v(42L); + EXPECT(std::holds_alternative(v)); + EXPECT(std::get(v) == 42L); +} + +TEST_CASE(value_ctor_string_invokes_picker) +{ + pv_t v(std::string{"hi"}); + EXPECT(std::holds_alternative(v)); + EXPECT(std::get(v) == "hi"); +} + +TEST_CASE(value_ctor_const_char_ptr_invokes_picker) +{ + pv_t v("hi"); + EXPECT(std::holds_alternative(v)); + EXPECT(std::get(v) == "hi"); +} + +TEST_CASE(picker_can_redirect_alternative) +{ + long_pv v(42); + EXPECT(std::holds_alternative(v)); + EXPECT(std::get(v) == 999L); +} + +TEST_CASE(picker_redirects_string_to_long) +{ + long_pv v(std::string{"ignored"}); + EXPECT(std::holds_alternative(v)); + EXPECT(std::get(v) == 999L); +} + +TEST_CASE(copy_does_not_invoke_picker) +{ + long_pv v1(std::in_place_type, 7); + long_pv v2 = v1; // NOLINT(performance-unnecessary-copy-initialization) + EXPECT(std::holds_alternative(v2)); + EXPECT(std::get(v2) == 7); +} + +struct route_by_type +{ + static int apply(int x) { return x + 1; } + static long apply(long x) { return x + 100; } + static std::string apply(const char* s) { return std::string{"got:"} + s; } + static std::string apply(const std::string& s) { return "str:" + s; } +}; + +using route_pv = migraphx::picked_variant; + +TEST_CASE(picker_overload_int) +{ + route_pv v(5); + EXPECT(std::holds_alternative(v)); + EXPECT(std::get(v) == 6); +} + +TEST_CASE(picker_overload_long) +{ + route_pv v(5L); + EXPECT(std::holds_alternative(v)); + EXPECT(std::get(v) == 105L); +} + +TEST_CASE(picker_overload_const_char) +{ + route_pv v("hi"); + EXPECT(std::holds_alternative(v)); + EXPECT(std::get(v) == "got:hi"); +} + +TEST_CASE(picker_overload_string) +{ + route_pv v(std::string{"hi"}); + EXPECT(std::holds_alternative(v)); + EXPECT(std::get(v) == "str:hi"); +} + +struct count_picker +{ + static int& counter() + { + static int n = 0; + return n; + } + template + static decltype(auto) apply(T&& x) + { + ++counter(); + return std::forward(x); + } +}; + +using counted_pv = migraphx::picked_variant; + +TEST_CASE(picker_invoked_exactly_once_per_value_ctor) +{ + count_picker::counter() = 0; + counted_pv a(1); + counted_pv b(2L); + counted_pv c(std::string{"x"}); + EXPECT(count_picker::counter() == 3); +} + +TEST_CASE(picker_not_invoked_for_default_ctor) +{ + count_picker::counter() = 0; + counted_pv v; + (void)v; + EXPECT(count_picker::counter() == 0); +} + +TEST_CASE(picker_not_invoked_for_in_place_ctor) +{ + count_picker::counter() = 0; + counted_pv v(std::in_place_type, 5L); + (void)v; + EXPECT(count_picker::counter() == 0); +} + +TEST_CASE(picker_not_invoked_for_copy) +{ + counted_pv source(std::in_place_type, 9); + count_picker::counter() = 0; + counted_pv copied = source; // NOLINT(performance-unnecessary-copy-initialization) + counted_pv moved = std::move(source); + EXPECT(count_picker::counter() == 0); + EXPECT(std::holds_alternative(copied)); + EXPECT(std::holds_alternative(moved)); +} + +struct lvalue_or_rvalue_picker +{ + static std::string apply(int&) { return "lvalue"; } + static std::string apply(const int&) { return "const-lvalue"; } + // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) + static std::string apply(int&&) { return "rvalue"; } +}; + +using vc_pv = migraphx::picked_variant; + +TEST_CASE(picker_receives_rvalue) +{ + vc_pv v(42); + EXPECT(std::get(v) == "rvalue"); +} + +TEST_CASE(picker_receives_lvalue) +{ + int x = 42; + vc_pv v(x); + EXPECT(std::get(v) == "lvalue"); +} + +TEST_CASE(picker_receives_const_lvalue) +{ + const int x = 42; + vc_pv v(x); + EXPECT(std::get(v) == "const-lvalue"); +} + +TEST_CASE(is_derived_from_variant) +{ + static_assert(std::is_base_of, pv_t>{}, + "picked_variant must derive from std::variant"); + pv_t v(std::in_place_type, 42); + std::variant& base_ref = v; + EXPECT(std::get(base_ref) == 42); +} + +TEST_CASE(visit_two_same_picked_variants) +{ + pv_t a(std::in_place_type, 5); + pv_t b(std::in_place_type, 7L); + auto sum = visit( + [](const auto& x, const auto& y) -> long { + if constexpr(std::is_arithmetic>{} and + std::is_arithmetic>{}) + return static_cast(x) + static_cast(y); + else + return -1L; + }, + a, + b); + EXPECT(sum == 12L); +} + +TEST_CASE(visit_two_different_picked_variants) +{ + pv_t a(std::in_place_type, 10); + long_pv b(std::in_place_type, 20L); + auto sum = visit( + [](const auto& x, const auto& y) -> long { + if constexpr(std::is_arithmetic>{} and + std::is_arithmetic>{}) + return static_cast(x) + static_cast(y); + else + return -1L; + }, + a, + b); + EXPECT(sum == 30L); +} + +template +using is_arith = std::is_arithmetic>; + +TEST_CASE(visit_picked_first_then_std_variant) +{ + pv_t a(std::in_place_type, 3); + std::variant b{4}; + auto sum = visit( + [](const auto& x, const auto& y) -> long { + if constexpr(is_arith{} and is_arith{}) + return static_cast(x) + static_cast(y); + else + return -1L; + }, + a, + b); + EXPECT(sum == 7L); +} + +TEST_CASE(visit_std_variant_first_then_picked) +{ + std::variant a{3}; + pv_t b(std::in_place_type, 4); + auto sum = visit( + [](const auto& x, const auto& y) -> long { + if constexpr(is_arith{} and is_arith{}) + return static_cast(x) + static_cast(y); + else + return -1L; + }, + a, + b); + EXPECT(sum == 7L); +} + +TEST_CASE(visit_three_picked_variants_different_types) +{ + pv_t a(std::in_place_type, 1); + long_pv b(std::in_place_type, 2L); + pv_t c(std::in_place_type, 3L); + auto sum = visit( + [](const auto& x, const auto& y, const auto& z) -> long { + if constexpr(is_arith{} and is_arith{} and + is_arith{}) + return static_cast(x) + static_cast(y) + static_cast(z); + else + return -1L; + }, + a, + b, + c); + EXPECT(sum == 6L); +} + +TEST_CASE(visit_mixed_three_variants) +{ + pv_t a(std::in_place_type, 1); + std::variant b{2L}; + long_pv c(std::in_place_type, 3L); + auto sum = visit( + [](const auto& x, const auto& y, const auto& z) -> long { + if constexpr(is_arith{} and is_arith{} and + is_arith{}) + return static_cast(x) + static_cast(y) + static_cast(z); + else + return -1L; + }, + a, + b, + c); + EXPECT(sum == 6L); +} + +TEST_CASE(visit_picks_correct_alternative_pair) +{ + pv_t a(std::in_place_type, "hi"); + pv_t b(std::in_place_type, 5); + auto result = visit( + [](auto&& x, auto&& y) -> std::string { + using x_type = std::decay_t; + using y_type = std::decay_t; + if constexpr(std::is_same{} and std::is_arithmetic{}) + return x + ":" + std::to_string(y); + else + return "no-match"; + }, + a, + b); + EXPECT(result == "hi:5"); +} + +TEST_CASE(visit_const_multi_variant) +{ + const pv_t a(std::in_place_type, 11); + const long_pv b(std::in_place_type, 22); + auto sum = visit( + [](const auto& x, const auto& y) -> long { + if constexpr(std::is_arithmetic>{} and + std::is_arithmetic>{}) + return static_cast(x) + static_cast(y); + else + return -1L; + }, + a, + b); + EXPECT(sum == 33L); +} + +TEST_CASE(visit_rvalue_multi_variant) +{ + auto sum = visit( + [](const auto& x, const auto& y) -> long { + if constexpr(std::is_arithmetic>{} and + std::is_arithmetic>{}) + return static_cast(x) + static_cast(y); + else + return -1L; + }, + pv_t(std::in_place_type, 100), + long_pv(std::in_place_type, 200L)); + EXPECT(sum == 300L); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } From 09ffc69755a463fe17df5e8d96f29e2f8c473699 Mon Sep 17 00:00:00 2001 From: shivadbhavsar <105248561+shivadbhavsar@users.noreply.github.com> Date: Wed, 20 May 2026 15:32:06 -0700 Subject: [PATCH 8/9] Apply suggestion from @github-actions[bot] Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/include/migraphx/dim_like.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/migraphx/dim_like.hpp b/src/include/migraphx/dim_like.hpp index ca24b4648b7..3e06e31a8e5 100644 --- a/src/include/migraphx/dim_like.hpp +++ b/src/include/migraphx/dim_like.hpp @@ -57,7 +57,7 @@ using dim_like = picked_variant, which // would instantiate Picker::apply(vector<...>) and hard-fail. template {})> -inline std::ostream& operator<<(std::ostream& os, const T& d) +inline std::ostream& operator<<(std::ostream & os, const T & d) { visit([&](const auto& x) { os << x; }, d); return os; From 10407630728686e67925646c503b8c44a48dbd93 Mon Sep 17 00:00:00 2001 From: Shiv Date: Thu, 28 May 2026 11:10:07 -0700 Subject: [PATCH 9/9] review comments --- src/include/migraphx/dim_like.hpp | 6 +----- src/include/migraphx/op/reshape.hpp | 5 +++++ src/include/migraphx/op/reshape_lazy.hpp | 6 ++++++ src/include/migraphx/picked_variant.hpp | 4 +++- test/dim_like_test.cpp | 4 ++-- 5 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/include/migraphx/dim_like.hpp b/src/include/migraphx/dim_like.hpp index 3e06e31a8e5..d3933fe2a40 100644 --- a/src/include/migraphx/dim_like.hpp +++ b/src/include/migraphx/dim_like.hpp @@ -53,11 +53,7 @@ struct dim_like_picker // A dim attribute entry that may be either a plain int64_t or a dynamic_dimension. using dim_like = picked_variant; -// Templated to hide from ADL on unrelated types: a non-template overload would -// be probed during overload resolution for things like vector, which -// would instantiate Picker::apply(vector<...>) and hard-fail. -template {})> -inline std::ostream& operator<<(std::ostream & os, const T & d) +inline std::ostream& operator<<(std::ostream& os, const dim_like& d) { visit([&](const auto& x) { os << x; }, d); return os; diff --git a/src/include/migraphx/op/reshape.hpp b/src/include/migraphx/op/reshape.hpp index 7b10ae62c88..0c5d32943b5 100644 --- a/src/include/migraphx/op/reshape.hpp +++ b/src/include/migraphx/op/reshape.hpp @@ -186,6 +186,11 @@ struct reshape { check_shapes{inputs, *this, true}.has(1, 2); + if(std::any_of(dims.begin(), dims.end(), [](const auto& d) { + return std::holds_alternative(d); + })) + MIGRAPHX_THROW("Reshape: dynamic_dimension dim entries are not currently supported"); + auto n_neg_dims = std::count(dims.begin(), dims.end(), dim_like{-1}); if(n_neg_dims > 1) MIGRAPHX_THROW("Reshape: Dimensions for reshape can only have one -1 dim"); diff --git a/src/include/migraphx/op/reshape_lazy.hpp b/src/include/migraphx/op/reshape_lazy.hpp index 40fc8612ad1..278500469a7 100644 --- a/src/include/migraphx/op/reshape_lazy.hpp +++ b/src/include/migraphx/op/reshape_lazy.hpp @@ -148,6 +148,12 @@ struct reshape_lazy shape compute_shape(std::vector inputs) const { check_shapes{inputs, *this, true}.has(1); + if(std::any_of(dims.begin(), dims.end(), [](const auto& d) { + return std::holds_alternative(d); + })) + MIGRAPHX_THROW( + "reshape_lazy: dynamic_dimension dim entries are not currently supported"); + auto n_neg_dims = std::count(dims.begin(), dims.end(), dim_like{-1}); if(n_neg_dims > 1) MIGRAPHX_THROW("reshape_lazy: Dimensions for reshape_lazy can only have one -1 dim"); diff --git a/src/include/migraphx/picked_variant.hpp b/src/include/migraphx/picked_variant.hpp index 38736f1af17..b4baaa987af 100644 --- a/src/include/migraphx/picked_variant.hpp +++ b/src/include/migraphx/picked_variant.hpp @@ -73,7 +73,9 @@ struct picked_variant : std::variant using base_t = std::variant; using base_t::base_t; // inherit default, in_place_type, in_place_index ctors - template >{})> + template >{}), + class = decltype(Picker::apply(std::declval()))> constexpr picked_variant(T&& x) : base_t(Picker::apply(std::forward(x))) { } diff --git a/test/dim_like_test.cpp b/test/dim_like_test.cpp index 424e7d23959..0918c079620 100644 --- a/test/dim_like_test.cpp +++ b/test/dim_like_test.cpp @@ -280,7 +280,7 @@ TEST_CASE(visit_int) dim_like d = 42; auto seen = visit( [](const auto& x) -> std::string { - if constexpr(std::is_same_v, int64_t>) + if constexpr(std::is_same, int64_t>{}) return "int"; else return "dd"; @@ -294,7 +294,7 @@ TEST_CASE(visit_dd) dim_like d = dd{1, 4}; auto seen = visit( [](const auto& x) -> std::string { - if constexpr(std::is_same_v, int64_t>) + if constexpr(std::is_same, int64_t>{}) return "int"; else return "dd";