From 960befe3aec4244e727702bdcd885e415dd652c0 Mon Sep 17 00:00:00 2001 From: Shiv Date: Fri, 15 May 2026 09:41:16 -0700 Subject: [PATCH 1/3] concat sym compute shape --- src/include/migraphx/op/concat.hpp | 86 +++++++------------- test/op_shape_test.cpp | 123 ++++++++++++++++++++++++++++- 2 files changed, 152 insertions(+), 57 deletions(-) diff --git a/src/include/migraphx/op/concat.hpp b/src/include/migraphx/op/concat.hpp index 31288a547a2..cd763b882c8 100644 --- a/src/include/migraphx/op/concat.hpp +++ b/src/include/migraphx/op/concat.hpp @@ -81,67 +81,41 @@ struct concat // be at least 1. check_shapes{inputs, *this, true}.same_ndims().same_type(); - if(std::none_of(inputs.begin(), inputs.end(), [&](const shape& s) { return s.dynamic(); })) + bool all_static = + std::none_of(inputs.begin(), inputs.end(), [](const shape& s) { return s.dynamic(); }); + auto unified = shape::to_dynamic(inputs); + + const auto& dds0 = unified.front().dyn_dims(); + for(std::size_t i = 0; i < dds0.size(); ++i) { - // Static input shapes - const auto& first_shape_lens = inputs.front().lens(); - const auto& type = inputs.front().type(); - for(std::size_t ll = 0; ll < first_shape_lens.size(); ll++) - { - if(ll != axis) - { - if(not std::all_of(inputs.begin(), inputs.end(), [&](auto s) { - return s.lens()[ll] == first_shape_lens[ll]; - })) - { - MIGRAPHX_THROW("CONCAT: all input dimensions should match along axis " + - std::to_string(ll)); - } - } - } - std::size_t new_dim_axis = 0; - for(const auto& input : inputs) - { - const auto& lens = input.lens(); - new_dim_axis += lens[axis]; - } - std::vector new_lens = first_shape_lens; - new_lens[axis] = new_dim_axis; - return shape::from_permutation(type, new_lens, find_permutation(inputs)); + if(i == axis) + continue; + if(not std::all_of(unified.begin(), unified.end(), [&](const shape& s) { + return s.dyn_dims()[i] == dds0[i]; + })) + MIGRAPHX_THROW("CONCAT: all input dimensions should match in axis " + + std::to_string(i)); } - else if(std::all_of( - inputs.begin(), inputs.end(), [&](const shape& s) { return s.dynamic(); })) - { - // Dynamic input shapes - for(std::size_t index = 0; index < inputs[0].ndim(); index++) - { - if(index != axis) - { - if(not std::all_of(inputs.begin(), inputs.end(), [&](const shape& s) { - return s.dyn_dims()[index] == inputs[0].dyn_dims()[index]; - })) - MIGRAPHX_THROW("CONCAT: all input dimensions should match in axis " + - std::to_string(index)); - } - } - std::size_t new_min = 0; - std::size_t new_max = 0; - for(const auto& input : inputs) - { - auto ddim = input.dyn_dims()[axis]; - auto dim_interval = ddim.get_interval(); - new_min += dim_interval.min; - new_max += dim_interval.max; - } - auto new_dims = inputs[0].dyn_dims(); - new_dims[axis] = migraphx::shape::dynamic_dimension{new_min, new_max}; - return {inputs[0].type(), new_dims}; - } - else + auto new_dds = dds0; + new_dds[axis] = std::accumulate( + unified.begin() + 1, unified.end(), dds0[axis], [&](const auto& acc, const shape& s) { + return acc + s.dyn_dims()[axis]; + }); + + auto type = unified.front().type(); + if(all_static) { - MIGRAPHX_THROW("CONCAT: Cannot mix static and dynamic input shapes."); + std::vector new_lens(new_dds.size()); + std::transform(new_dds.begin(), new_dds.end(), new_lens.begin(), [](const auto& d) { + assert(d.sym_expr.is_literal()); + return d.sym_expr.eval_uint({}); + }); + return shape::from_permutation(type, new_lens, find_permutation(inputs)); } + if(unified.front().symbolic()) + return shape::from_permutation(type, new_dds, find_permutation(unified)); + return {type, new_dds}; } argument compute(const dyn_output& dyn_out, std::vector args) const diff --git a/test/op_shape_test.cpp b/test/op_shape_test.cpp index af0c7d5c34b..fd8fea76a82 100644 --- a/test/op_shape_test.cpp +++ b/test/op_shape_test.cpp @@ -6056,6 +6056,12 @@ TEST_CASE(test_dyn_concat) expect_shape(sout, migraphx::make_op("concat", {{"axis", 2}}), sx, sy); + // static + range-dynamic with compatible non-axis dims (static lifts to fixed range) + migraphx::shape sr{migraphx::shape::float_type, {{2, 2}, {1, 5}, {4, 4}}}; + migraphx::shape ss{migraphx::shape::float_type, {2, 3, 4}}; + migraphx::shape sr_out{migraphx::shape::float_type, {{2, 2}, {4, 8}, {4, 4}}}; + expect_shape(sr_out, migraphx::make_op("concat", {{"axis", 1}}), sr, ss); + // axis out of range throws_shape(migraphx::make_op("concat", {{"axis", 4}}), sx, sy); @@ -6066,11 +6072,126 @@ TEST_CASE(test_dyn_concat) // non-matching dimension 2 throws_shape(migraphx::make_op("concat", {{"axis", 1}}), sx, sy); - // static and dynamic shapes together + // static input with non-axis dim that doesn't match the range-dynamic input migraphx::shape sstat{migraphx::shape::float_type, {3, 4, 1, 6}}; throws_shape(migraphx::make_op("concat", {{"axis", 2}}), sx, sstat); } +TEST_CASE(concat_sym) +{ + auto n = var("n", {1, 8}); + auto m = var("m", {1, 16}); + auto s = var("s", {1, 128}); + auto k = var("k", {1, 64}); + std::unordered_map sym_map = {{n, 3}, {m, 5}, {s, 7}, {k, 9}}; + + auto expect_matches_static = [&](const migraphx::operation& op, + const std::vector& inputs, + const migraphx::shape& sym_out) { + std::vector static_inputs(inputs.size()); + std::transform(inputs.begin(), inputs.end(), static_inputs.begin(), [&](const auto& sh) { + return sh.to_static(sym_map); + }); + EXPECT(sym_out.to_static(sym_map) == op.compute_shape(static_inputs)); + }; + + { + // axis 0 (first): distinct symbols on the concat axis. + auto op = migraphx::make_op("concat", {{"axis", 0}}); + migraphx::shape sx{migraphx::shape::float_type, {dd{n}, dd{lit(4)}}}; + migraphx::shape sy{migraphx::shape::float_type, {dd{m}, dd{lit(4)}}}; + migraphx::shape sout{migraphx::shape::float_type, {dd{n + m}, dd{lit(4)}}}; + expect_shape(sout, op, sx, sy); + expect_matches_static(op, {sx, sy}, sout); + } + { + // axis 1 (middle): non-axis sym shared, distinct sym on the concat axis. + auto op = migraphx::make_op("concat", {{"axis", 1}}); + migraphx::shape sx{migraphx::shape::float_type, {dd{n}, dd{s}, dd{lit(8)}}}; + migraphx::shape sy{migraphx::shape::float_type, {dd{n}, dd{k}, dd{lit(8)}}}; + migraphx::shape sout{migraphx::shape::float_type, {dd{n}, dd{s + k}, dd{lit(8)}}}; + expect_shape(sout, op, sx, sy); + expect_matches_static(op, {sx, sy}, sout); + } + { + // axis 3 (last) on a 4D shape; sym at axis 0 and at the concat axis. + auto op = migraphx::make_op("concat", {{"axis", 3}}); + migraphx::shape sx{migraphx::shape::float_type, {dd{n}, dd{lit(3)}, dd{lit(5)}, dd{s}}}; + migraphx::shape sy{migraphx::shape::float_type, {dd{n}, dd{lit(3)}, dd{lit(5)}, dd{k}}}; + migraphx::shape sout{migraphx::shape::float_type, + {dd{n}, dd{lit(3)}, dd{lit(5)}, dd{s + k}}}; + expect_shape(sout, op, sx, sy); + expect_matches_static(op, {sx, sy}, sout); + } +} + +TEST_CASE(concat_sym_same_var) +{ + // Same symbol on the concat axis across both inputs -> 2*s. + auto s = var("s", {1, 64}); + migraphx::shape sx{migraphx::shape::float_type, {dd{lit(4)}, dd{s}, dd{lit(8)}}}; + migraphx::shape sout{migraphx::shape::float_type, {dd{lit(4)}, dd{s + s}, dd{lit(8)}}}; + auto op = migraphx::make_op("concat", {{"axis", 1}}); + expect_shape(sout, op, sx, sx); + + std::unordered_map sym_map = {{s, 7}}; + auto static_in = sx.to_static(sym_map); + EXPECT(sout.to_static(sym_map) == op.compute_shape({static_in, static_in})); +} + +TEST_CASE(concat_sym_static_mix) +{ + // Static + symbolic: static input lifts to sym-lits; output stays symbolic. + auto n = var("n", {1, 16}); + migraphx::shape sx{migraphx::shape::float_type, {dd{n}, dd{lit(4)}, dd{lit(5)}}}; + migraphx::shape sy{migraphx::shape::float_type, {3, 4, 5}}; + migraphx::shape sout{migraphx::shape::float_type, {dd{n + lit(3)}, dd{lit(4)}, dd{lit(5)}}}; + auto op = migraphx::make_op("concat", {{"axis", 0}}); + expect_shape(sout, op, sx, sy); + + std::unordered_map sym_map = {{n, 6}}; + EXPECT(sout.to_static(sym_map) == op.compute_shape({sx.to_static(sym_map), sy})); +} + +TEST_CASE(concat_sym_three_inputs) +{ + // Three inputs with three distinct symbols on the concat axis. + auto a = var("a", {1, 8}); + auto b = var("b", {1, 16}); + auto c = var("c", {1, 32}); + migraphx::shape s_a{migraphx::shape::float_type, {dd{lit(2)}, dd{a}, dd{lit(4)}}}; + migraphx::shape s_b{migraphx::shape::float_type, {dd{lit(2)}, dd{b}, dd{lit(4)}}}; + migraphx::shape s_c{migraphx::shape::float_type, {dd{lit(2)}, dd{c}, dd{lit(4)}}}; + migraphx::shape sout{migraphx::shape::float_type, {dd{lit(2)}, dd{a + b + c}, dd{lit(4)}}}; + auto op = migraphx::make_op("concat", {{"axis", 1}}); + expect_shape(sout, op, s_a, s_b, s_c); + + std::unordered_map sym_map = {{a, 3}, {b, 5}, {c, 7}}; + EXPECT( + sout.to_static(sym_map) == + op.compute_shape({s_a.to_static(sym_map), s_b.to_static(sym_map), s_c.to_static(sym_map)})); +} + +TEST_CASE(concat_sym_non_axis_mismatch_throws) +{ + // Non-axis dim has different symbols -> throws. + auto n = var("n", {1, 8}); + auto m = var("m", {1, 16}); + migraphx::shape sx{migraphx::shape::float_type, {dd{n}, dd{lit(4)}, dd{lit(5)}}}; + migraphx::shape sy{migraphx::shape::float_type, {dd{m}, dd{lit(4)}, dd{lit(5)}}}; + throws_shape(migraphx::make_op("concat", {{"axis", 1}}), sx, sy); +} + +TEST_CASE(concat_sym_with_range) +{ + // Symbolic + range-dynamic: sym side materialized to range; output is range. + auto n = var("n", {1, 8}); + migraphx::shape sx{migraphx::shape::float_type, {dd{lit(2)}, dd{n}, dd{lit(4)}}}; + migraphx::shape sy{migraphx::shape::float_type, {{2, 2}, {1, 5}, {4, 4}}}; + migraphx::shape sout{migraphx::shape::float_type, {{2, 2}, {2, 13}, {4, 4}}}; + expect_shape(sout, migraphx::make_op("concat", {{"axis", 1}}), sx, sy); +} + TEST_CASE(test_binary_nonpacked) { auto sx = migraphx::shape(migraphx::shape::float_type, {4, 3}, {1, 8}); From 6dbace1a5df91828cb720de2ea11ee2aa98fd02a Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 19 May 2026 12:51:05 -0700 Subject: [PATCH 2/3] slice support --- src/include/migraphx/op/slice.hpp | 61 +++++++++--------- test/op_shape_test.cpp | 102 ++++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+), 32 deletions(-) diff --git a/src/include/migraphx/op/slice.hpp b/src/include/migraphx/op/slice.hpp index 47294a70358..5d7781b90c5 100644 --- a/src/include/migraphx/op/slice.hpp +++ b/src/include/migraphx/op/slice.hpp @@ -251,42 +251,39 @@ struct slice shape normalize_compute_shape(std::vector inputs) const { check_shapes{inputs, *this, true}.has(1, 2, 3, 4); - if(inputs.size() == 1) + if(inputs.size() != 1) + return compute_two_or_more(inputs); + + auto input_shape = inputs[0]; + auto set_attributes = get_set_attributes(); + if(set_attributes != all_set) + MIGRAPHX_THROW("SLICE 1_arg: Invalid 1 input and attributes configuration"); + + // TODO: support slicing non-fixed symbolic dims (output dim would be + // a sym::expr derived from starts/ends and the symbolic axis bound). + if(input_shape.dynamic() and std::any_of(axes.begin(), axes.end(), [&](auto axis) { + return not input_shape.dyn_dims()[axis].is_fixed(); + })) { - auto input_shape = inputs[0]; - auto set_attributes = get_set_attributes(); - if(set_attributes != all_set) - { - MIGRAPHX_THROW("SLICE 1_arg: Invalid 1 input and attributes configuration"); - } - // NOTE: make sure to update how normalization works here if this type of slicing is - // changed to be allowed - if(input_shape.dynamic() and std::any_of(axes.begin(), axes.end(), [&](auto axis) { - return not input_shape.dyn_dims()[axis].is_fixed(); - })) - { - MIGRAPHX_THROW( - "SLICE 1_arg: slicing is not allowed on non-fixed dynamic input axis "); - } - if(input_shape.dynamic()) - { - return shape{ - input_shape.type(), - lens_calc(input_shape.min_lens(), this->starts, this->ends, this->axes), - lens_calc(input_shape.max_lens(), this->starts, this->ends, this->axes), - {}}; - } - else - { - return shape{input_shape.type(), - lens_calc(input_shape.lens(), this->starts, this->ends, this->axes), - input_shape.strides()}; - } + MIGRAPHX_THROW("SLICE 1_arg: slicing is not allowed on non-fixed dynamic input axis "); } - else + + auto new_lens = lens_calc(input_shape.max_lens(), this->starts, this->ends, this->axes); + + if(not input_shape.dynamic()) + return shape{input_shape.type(), new_lens, input_shape.strides()}; + + auto dds = input_shape.dyn_dims(); + for(auto axis : this->axes) { - return compute_two_or_more(inputs); + dds[axis] = input_shape.symbolic() + ? shape::dynamic_dimension{sym::lit(new_lens[axis])} + : shape::dynamic_dimension{new_lens[axis], new_lens[axis]}; } + + if(input_shape.symbolic()) + return shape{input_shape.type(), dds, input_shape.dyn_strides()}; + return shape{input_shape.type(), dds}; } /** diff --git a/test/op_shape_test.cpp b/test/op_shape_test.cpp index fd8fea76a82..5b5b4120cb2 100644 --- a/test/op_shape_test.cpp +++ b/test/op_shape_test.cpp @@ -5062,6 +5062,108 @@ TEST_CASE(slice_dyn_shape5) input); } +TEST_CASE(slice_dyn_preserves_optimals) +{ + migraphx::shape input{migraphx::shape::int32_type, {dd{2, 4, {3}}, dd{7, 7}, dd{2, 5, {3, 4}}}}; + migraphx::shape expected{migraphx::shape::int32_type, + {dd{2, 4, {3}}, dd{3, 3}, dd{2, 5, {3, 4}}}}; + expect_shape(expected, + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {4}}}), + input); +} + +TEST_CASE(slice_sym) +{ + auto n = var("n", {1, 8}); + auto m = var("m", {1, 16}); + auto k = var("k", {1, 64}); + std::unordered_map sym_map = {{n, 3}, {m, 5}, {k, 7}}; + + auto expect_matches_static = [&](const migraphx::operation& op, + const migraphx::shape& sin, + const migraphx::shape& sym_out) { + EXPECT(sym_out.to_static(sym_map) == op.compute_shape({sin.to_static(sym_map)})); + }; + + { + // Slice axis 0 (first); sym at axis 1. + auto op = migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {3}}}); + migraphx::shape sin{migraphx::shape::float_type, {dd{lit(5)}, dd{n}, dd{lit(4)}}}; + migraphx::shape sout{ + migraphx::shape::float_type, {dd{lit(2)}, dd{n}, dd{lit(4)}}, sin.dyn_strides()}; + expect_shape(sout, op, sin); + expect_matches_static(op, sin, sout); + } + { + // Slice axis 1 (middle); syms at axes 0 and 2. + auto op = migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {6}}}); + migraphx::shape sin{migraphx::shape::float_type, {dd{n}, dd{lit(8)}, dd{m}}}; + migraphx::shape sout{ + migraphx::shape::float_type, {dd{n}, dd{lit(4)}, dd{m}}, sin.dyn_strides()}; + expect_shape(sout, op, sin); + expect_matches_static(op, sin, sout); + } + { + // Slice axis 3 (last) on a 4D shape; syms at axes 0, 1, 2. + auto op = migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {3}}}); + migraphx::shape sin{migraphx::shape::float_type, {dd{n}, dd{m}, dd{k}, dd{lit(10)}}}; + migraphx::shape sout{ + migraphx::shape::float_type, {dd{n}, dd{m}, dd{k}, dd{lit(3)}}, sin.dyn_strides()}; + expect_shape(sout, op, sin); + expect_matches_static(op, sin, sout); + } +} + +TEST_CASE(slice_sym_multiple_axes) +{ + // Slice axes 0 and 2 at once; sym at axis 1 is untouched. + auto n = var("n", {1, 8}); + std::unordered_map sym_map = {{n, 4}}; + + auto op = migraphx::make_op("slice", {{"axes", {0, 2}}, {"starts", {1, 2}}, {"ends", {4, 5}}}); + migraphx::shape sin{migraphx::shape::float_type, {dd{lit(6)}, dd{n}, dd{lit(8)}}}; + migraphx::shape sout{ + migraphx::shape::float_type, {dd{lit(3)}, dd{n}, dd{lit(3)}}, sin.dyn_strides()}; + expect_shape(sout, op, sin); + EXPECT(sout.to_static(sym_map) == op.compute_shape({sin.to_static(sym_map)})); +} + +TEST_CASE(slice_sym_fixed_bound_var) +{ + // var("k", {3, 3}) is fixed (collapsed bound), so slicing the axis is allowed. + auto k = var("k", {3, 3}); + auto n = var("n", {1, 8}); + std::unordered_map sym_map = {{n, 5}}; + + auto op = migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {2}}}); + migraphx::shape sin{migraphx::shape::float_type, {dd{n}, dd{k}, dd{lit(4)}}}; + migraphx::shape sout{ + migraphx::shape::float_type, {dd{n}, dd{lit(2)}, dd{lit(4)}}, sin.dyn_strides()}; + expect_shape(sout, op, sin); + EXPECT(sout.to_static(sym_map) == op.compute_shape({sin.to_static(sym_map)})); +} + +TEST_CASE(slice_sym_non_fixed_throws) +{ + // Slicing on a non-fixed symbolic axis is rejected (same contract as range). + auto n = var("n", {1, 8}); + migraphx::shape sin{migraphx::shape::float_type, {dd{lit(4)}, dd{n}, dd{lit(8)}}}; + throws_shape(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {2}}}), sin); +} + +TEST_CASE(slice_sym_nonstandard_layout) +{ + // Non-standard symbolic input: the slice must preserve the permutation + auto n = var("n", {1, 8}); + std::unordered_map sym_map = {{n, 6}}; + + auto sin = migraphx::shape::from_permutation( + migraphx::shape::float_type, {dd{n}, dd{lit(3)}, dd{lit(5)}, dd{lit(7)}}, {0, 2, 3, 1}); + auto op = migraphx::make_op("slice", {{"axes", {3}}, {"starts", {1}}, {"ends", {6}}}); + auto sout = op.compute_shape({sin}); + EXPECT(sout.to_static(sym_map) == op.compute_shape({sin.to_static(sym_map)})); +} + TEST_CASE(test_scan_slice1) { migraphx::shape input{migraphx::shape::float_type, {2, 3, 4}}; From cfd22ad57f017779e15860c02517562f542cc7ea Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 19 May 2026 13:58:36 -0700 Subject: [PATCH 3/3] update step --- src/include/migraphx/op/step.hpp | 31 ++++---- test/op_shape_test.cpp | 129 +++++++++++++++++++++++++++++++ 2 files changed, 146 insertions(+), 14 deletions(-) diff --git a/src/include/migraphx/op/step.hpp b/src/include/migraphx/op/step.hpp index 660c0b2744f..4cbd9d47712 100644 --- a/src/include/migraphx/op/step.hpp +++ b/src/include/migraphx/op/step.hpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -55,10 +56,8 @@ struct step std::string name() const { return "step"; } shape normalize_compute_shape(std::vector inputs) const { - check_shapes{inputs, *this}.has(1); + check_shapes{inputs, *this, true}.has(1); const auto& input = inputs.at(0); - auto in_lens = input.lens(); - auto t = input.type(); if(axes.size() != steps.size()) { @@ -67,27 +66,31 @@ struct step "}."); } - if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return axis >= in_lens.size(); })) + if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return axis >= input.ndim(); })) { MIGRAPHX_THROW("STEP: axis value is out of range!"); } - auto lens = in_lens; - auto strides = input.strides(); + auto unified = shape::to_dynamic({input}).front(); + auto dds = unified.dyn_dims(); + auto dstrides = unified.symbolic() ? unified.dyn_strides() : std::vector{}; for(auto i : range(axes.size())) { - auto axis = axes[i]; - auto step = steps[i]; - lens[axis] = (in_lens[axis] + step - 1) / step; - strides[axis] *= step; + auto s = static_cast(steps[i]); + dds[axes[i]] = (dds[axes[i]] + (s - 1)) / s; + if(unified.symbolic()) + dstrides[axes[i]] = dstrides[axes[i]] * sym::lit(s); } - - return {t, lens, strides}; + if(not input.dynamic()) + return shape{input.type(), dds, dstrides}.to_static(); + if(unified.symbolic()) + return shape{input.type(), dds, dstrides}; + return shape{input.type(), dds}; } - argument compute(shape output_shape, std::vector args) const + argument compute(const dyn_output& dyn_out, std::vector args) const { - return args[0].reshape(output_shape); + return args[0].reshape(dyn_out.computed_shape); } std::vector output_alias(const std::vector&) const { return {0}; } diff --git a/test/op_shape_test.cpp b/test/op_shape_test.cpp index 5b5b4120cb2..138513aaf70 100644 --- a/test/op_shape_test.cpp +++ b/test/op_shape_test.cpp @@ -6043,6 +6043,135 @@ TEST_CASE(step_test) } } +TEST_CASE(step_sym) +{ + auto n = var("n", {1, 8}); + auto m = var("m", {1, 16}); + auto k = var("k", {1, 64}); + std::unordered_map sym_map = {{n, 3}, {m, 5}, {k, 7}}; + + auto expect_matches_static = [&](const migraphx::operation& op, + const migraphx::shape& sin, + const migraphx::shape& sym_out) { + EXPECT(sym_out.to_static(sym_map) == op.compute_shape({sin.to_static(sym_map)})); + }; + + { + // Step axis 0 (first); sym at axis 1. + auto op = migraphx::make_op("step", {{"axes", {0}}, {"steps", {2}}}); + migraphx::shape sin{migraphx::shape::float_type, {dd{lit(5)}, dd{n}, dd{lit(4)}}}; + migraphx::shape sout{ + migraphx::shape::float_type, + {dd{lit(3)}, dd{n}, dd{lit(4)}}, + {sin.dyn_strides()[0] * lit(2), sin.dyn_strides()[1], sin.dyn_strides()[2]}}; + expect_shape(sout, op, sin); + expect_matches_static(op, sin, sout); + } + { + // Step axis 1 (middle); syms at axes 0 and 2. + auto op = migraphx::make_op("step", {{"axes", {1}}, {"steps", {3}}}); + migraphx::shape sin{migraphx::shape::float_type, {dd{n}, dd{lit(8)}, dd{m}}}; + migraphx::shape sout{ + migraphx::shape::float_type, + {dd{n}, dd{lit(3)}, dd{m}}, + {sin.dyn_strides()[0], sin.dyn_strides()[1] * lit(3), sin.dyn_strides()[2]}}; + expect_shape(sout, op, sin); + expect_matches_static(op, sin, sout); + } + { + // Step axis 3 (last) on a 4D shape; syms at axes 0, 1, 2. + auto op = migraphx::make_op("step", {{"axes", {3}}, {"steps", {2}}}); + migraphx::shape sin{migraphx::shape::float_type, {dd{n}, dd{m}, dd{k}, dd{lit(10)}}}; + migraphx::shape sout{migraphx::shape::float_type, + {dd{n}, dd{m}, dd{k}, dd{lit(5)}}, + {sin.dyn_strides()[0], + sin.dyn_strides()[1], + sin.dyn_strides()[2], + sin.dyn_strides()[3] * lit(2)}}; + expect_shape(sout, op, sin); + expect_matches_static(op, sin, sout); + } +} + +TEST_CASE(step_sym_on_sym_axis) +{ + // Stepping a symbolic axis: output dim is (n+1)/2 symbolically. + auto n = var("n", {2, 32}); + std::unordered_map sym_map = {{n, 9}}; + + auto op = migraphx::make_op("step", {{"axes", {1}}, {"steps", {2}}}); + migraphx::shape sin{migraphx::shape::float_type, {dd{lit(2)}, dd{n}, dd{lit(4)}}}; + migraphx::shape sout{ + migraphx::shape::float_type, + {dd{lit(2)}, dd{(n + lit(1)) / lit(2)}, dd{lit(4)}}, + {sin.dyn_strides()[0], sin.dyn_strides()[1] * lit(2), sin.dyn_strides()[2]}}; + expect_shape(sout, op, sin); + EXPECT(sout.to_static(sym_map) == op.compute_shape({sin.to_static(sym_map)})); +} + +TEST_CASE(step_sym_multiple_axes) +{ + // Step two axes at once; one literal, one symbolic; verify both dims and the + // stride scaling at each stepped axis. + auto m = var("m", {1, 16}); + std::unordered_map sym_map = {{m, 7}}; + + auto op = migraphx::make_op("step", {{"axes", {0, 2}}, {"steps", {2, 3}}}); + migraphx::shape sin{migraphx::shape::float_type, {dd{lit(6)}, dd{m}, dd{lit(9)}}}; + migraphx::shape sout{ + migraphx::shape::float_type, + {dd{lit(3)}, dd{m}, dd{lit(3)}}, + {sin.dyn_strides()[0] * lit(2), sin.dyn_strides()[1], sin.dyn_strides()[2] * lit(3)}}; + expect_shape(sout, op, sin); + EXPECT(sout.to_static(sym_map) == op.compute_shape({sin.to_static(sym_map)})); +} + +TEST_CASE(step_sym_fully_symbolic) +{ + // Every axis symbolic; step the middle axis. Output dim is (m+1)/2 symbolically. + auto n = var("n", {1, 8}); + auto m = var("m", {2, 16}); + auto k = var("k", {1, 64}); + std::unordered_map sym_map = {{n, 4}, {m, 11}, {k, 5}}; + + auto op = migraphx::make_op("step", {{"axes", {1}}, {"steps", {2}}}); + migraphx::shape sin{migraphx::shape::float_type, {dd{n}, dd{m}, dd{k}}}; + migraphx::shape sout{ + migraphx::shape::float_type, + {dd{n}, dd{(m + lit(1)) / lit(2)}, dd{k}}, + {sin.dyn_strides()[0], sin.dyn_strides()[1] * lit(2), sin.dyn_strides()[2]}}; + expect_shape(sout, op, sin); + EXPECT(sout.to_static(sym_map) == op.compute_shape({sin.to_static(sym_map)})); +} + +TEST_CASE(step_sym_nonstandard_layout) +{ + // Non-standard symbolic input via from_permutation: step must preserve the + // permutation and scale only the stepped axis's stride. + auto n = var("n", {1, 8}); + std::unordered_map sym_map = {{n, 6}}; + + auto sin = migraphx::shape::from_permutation( + migraphx::shape::float_type, {dd{n}, dd{lit(3)}, dd{lit(5)}, dd{lit(8)}}, {0, 2, 3, 1}); + auto op = migraphx::make_op("step", {{"axes", {3}}, {"steps", {2}}}); + migraphx::shape sout{migraphx::shape::float_type, + {dd{n}, dd{lit(3)}, dd{lit(5)}, dd{lit(4)}}, + {sin.dyn_strides()[0], + sin.dyn_strides()[1], + sin.dyn_strides()[2], + sin.dyn_strides()[3] * lit(2)}}; + expect_shape(sout, op, sin); + EXPECT(sout.to_static(sym_map) == op.compute_shape({sin.to_static(sym_map)})); +} + +TEST_CASE(step_dyn) +{ + // Range-dynamic input: ceil-divide bounds; fixed-step shifts the optimal. + migraphx::shape input{migraphx::shape::float_type, {dd{2, 2}, dd{2, 8, {4}}, dd{4, 4}}}; + migraphx::shape expected{migraphx::shape::float_type, {dd{2, 2}, dd{1, 4, {2}}, dd{4, 4}}}; + expect_shape(expected, migraphx::make_op("step", {{"axes", {1}}, {"steps", {2}}}), input); +} + TEST_CASE(unary_scalar_input) { migraphx::shape ss{migraphx::shape::half_type};