Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 30 additions & 56 deletions src/include/migraphx/op/concat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,67 +81,41 @@ struct concat
// be at least 1.
check_shapes{inputs, *this, true}.same_ndims().same_type();

if(std::none_of(inputs.begin(), inputs.end(), [&](const shape& s) { return s.dynamic(); }))
bool all_static =
std::none_of(inputs.begin(), inputs.end(), [](const shape& s) { return s.dynamic(); });
auto unified = shape::to_dynamic(inputs);

const auto& dds0 = unified.front().dyn_dims();
for(std::size_t i = 0; i < dds0.size(); ++i)
{
// Static input shapes
const auto& first_shape_lens = inputs.front().lens();
const auto& type = inputs.front().type();
for(std::size_t ll = 0; ll < first_shape_lens.size(); ll++)
{
if(ll != axis)
{
if(not std::all_of(inputs.begin(), inputs.end(), [&](auto s) {
return s.lens()[ll] == first_shape_lens[ll];
}))
{
MIGRAPHX_THROW("CONCAT: all input dimensions should match along axis " +
std::to_string(ll));
}
}
}
std::size_t new_dim_axis = 0;
for(const auto& input : inputs)
{
const auto& lens = input.lens();
new_dim_axis += lens[axis];
}
std::vector<std::size_t> new_lens = first_shape_lens;
new_lens[axis] = new_dim_axis;
return shape::from_permutation(type, new_lens, find_permutation(inputs));
if(i == axis)
continue;
if(not std::all_of(unified.begin(), unified.end(), [&](const shape& s) {
return s.dyn_dims()[i] == dds0[i];
}))
MIGRAPHX_THROW("CONCAT: all input dimensions should match in axis " +
std::to_string(i));
}
else if(std::all_of(
inputs.begin(), inputs.end(), [&](const shape& s) { return s.dynamic(); }))
{
// Dynamic input shapes
for(std::size_t index = 0; index < inputs[0].ndim(); index++)
{
if(index != axis)
{
if(not std::all_of(inputs.begin(), inputs.end(), [&](const shape& s) {
return s.dyn_dims()[index] == inputs[0].dyn_dims()[index];
}))
MIGRAPHX_THROW("CONCAT: all input dimensions should match in axis " +
std::to_string(index));
}
}
std::size_t new_min = 0;
std::size_t new_max = 0;
for(const auto& input : inputs)
{
auto ddim = input.dyn_dims()[axis];
auto dim_interval = ddim.get_interval();
new_min += dim_interval.min;
new_max += dim_interval.max;
}

auto new_dims = inputs[0].dyn_dims();
new_dims[axis] = migraphx::shape::dynamic_dimension{new_min, new_max};
return {inputs[0].type(), new_dims};
}
else
auto new_dds = dds0;
new_dds[axis] = std::accumulate(
unified.begin() + 1, unified.end(), dds0[axis], [&](const auto& acc, const shape& s) {
return acc + s.dyn_dims()[axis];
});

auto type = unified.front().type();
if(all_static)
{
MIGRAPHX_THROW("CONCAT: Cannot mix static and dynamic input shapes.");
std::vector<std::size_t> new_lens(new_dds.size());
std::transform(new_dds.begin(), new_dds.end(), new_lens.begin(), [](const auto& d) {
assert(d.sym_expr.is_literal());
return d.sym_expr.eval_uint({});
});
return shape::from_permutation(type, new_lens, find_permutation(inputs));
}
if(unified.front().symbolic())
return shape::from_permutation(type, new_dds, find_permutation(unified));
return {type, new_dds};
}

argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
Expand Down
61 changes: 29 additions & 32 deletions src/include/migraphx/op/slice.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,42 +251,39 @@ struct slice
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1, 2, 3, 4);
if(inputs.size() == 1)
if(inputs.size() != 1)
return compute_two_or_more(inputs);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet handling the 2+ slice input versions, right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya, not handling true runtime computed shapes using symbolics yet.


auto input_shape = inputs[0];
auto set_attributes = get_set_attributes();
if(set_attributes != all_set)
MIGRAPHX_THROW("SLICE 1_arg: Invalid 1 input and attributes configuration");

// TODO: support slicing non-fixed symbolic dims (output dim would be
// a sym::expr derived from starts/ends and the symbolic axis bound).
if(input_shape.dynamic() and std::any_of(axes.begin(), axes.end(), [&](auto axis) {
return not input_shape.dyn_dims()[axis].is_fixed();
}))
{
auto input_shape = inputs[0];
auto set_attributes = get_set_attributes();
if(set_attributes != all_set)
{
MIGRAPHX_THROW("SLICE 1_arg: Invalid 1 input and attributes configuration");
}
// NOTE: make sure to update how normalization works here if this type of slicing is
// changed to be allowed
if(input_shape.dynamic() and std::any_of(axes.begin(), axes.end(), [&](auto axis) {
return not input_shape.dyn_dims()[axis].is_fixed();
}))
{
MIGRAPHX_THROW(
"SLICE 1_arg: slicing is not allowed on non-fixed dynamic input axis ");
}
if(input_shape.dynamic())
{
return shape{
input_shape.type(),
lens_calc(input_shape.min_lens(), this->starts, this->ends, this->axes),
lens_calc(input_shape.max_lens(), this->starts, this->ends, this->axes),
{}};
}
else
{
return shape{input_shape.type(),
lens_calc(input_shape.lens(), this->starts, this->ends, this->axes),
input_shape.strides()};
}
MIGRAPHX_THROW("SLICE 1_arg: slicing is not allowed on non-fixed dynamic input axis ");
}
else

auto new_lens = lens_calc(input_shape.max_lens(), this->starts, this->ends, this->axes);

if(not input_shape.dynamic())
return shape{input_shape.type(), new_lens, input_shape.strides()};

auto dds = input_shape.dyn_dims();
for(auto axis : this->axes)
{
return compute_two_or_more(inputs);
dds[axis] = input_shape.symbolic()
? shape::dynamic_dimension{sym::lit(new_lens[axis])}
: shape::dynamic_dimension{new_lens[axis], new_lens[axis]};
}

if(input_shape.symbolic())
return shape{input_shape.type(), dds, input_shape.dyn_strides()};
return shape{input_shape.type(), dds};
}

/**
Expand Down
31 changes: 17 additions & 14 deletions src/include/migraphx/op/step.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>

Expand Down Expand Up @@ -55,10 +56,8 @@ struct step
std::string name() const { return "step"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
check_shapes{inputs, *this, true}.has(1);
const auto& input = inputs.at(0);
auto in_lens = input.lens();
auto t = input.type();

if(axes.size() != steps.size())
{
Expand All @@ -67,27 +66,31 @@ struct step
"}.");
}

if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return axis >= in_lens.size(); }))
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return axis >= input.ndim(); }))
{
MIGRAPHX_THROW("STEP: axis value is out of range!");
}

auto lens = in_lens;
auto strides = input.strides();
auto unified = shape::to_dynamic({input}).front();
auto dds = unified.dyn_dims();
auto dstrides = unified.symbolic() ? unified.dyn_strides() : std::vector<sym::expr>{};
for(auto i : range(axes.size()))
{
auto axis = axes[i];
auto step = steps[i];
lens[axis] = (in_lens[axis] + step - 1) / step;
strides[axis] *= step;
auto s = static_cast<std::size_t>(steps[i]);
dds[axes[i]] = (dds[axes[i]] + (s - 1)) / s;
if(unified.symbolic())
dstrides[axes[i]] = dstrides[axes[i]] * sym::lit(s);
}

return {t, lens, strides};
if(not input.dynamic())
return shape{input.type(), dds, dstrides}.to_static();
if(unified.symbolic())
return shape{input.type(), dds, dstrides};
return shape{input.type(), dds};
}

argument compute(shape output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
return args[0].reshape(output_shape);
return args[0].reshape(dyn_out.computed_shape);
}

std::vector<std::size_t> output_alias(const std::vector<shape>&) const { return {0}; }
Expand Down
Loading
Loading