Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions src/op/builder/include/migraphx/op/builder/kit.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#ifndef MIGRAPHX_GUARD_BUILDER_KIT_HPP
#define MIGRAPHX_GUARD_BUILDER_KIT_HPP

#include <migraphx/config.hpp>
#include <migraphx/op/builder/op_builder.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/common.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
namespace builder {

struct register_kit_action
{
template <class T>
static void apply()
{
T{}.apply();
}
};

template <class T>
struct kit : auto_register<register_kit_action, T>
{
void apply() const {}

std::string derived_prefix() const { return static_cast<const T&>(*this).prefix(); }

op_builder_if from_op(const std::string& op_name) const
{
return op_builder_if{[=](module& m,
instruction_ref ins,
const std::vector<instruction_ref>& args,
const std::vector<module_ref>& module_args,
const value& options) -> std::vector<instruction_ref> {
auto opd = make_op(op_name, options);
return {m.insert_instruction(ins, opd, args, module_args)};
},
[=] { return make_op(op_name).to_value(); }};
}

op_builder_if from_builder(const std::string& op_builder) const
{
return get_op_builder_if(op_builder);
}

op_builder_if with_common(op_builder_if obi, common_options coptions = {}) const

Check warning on line 48 in src/op/builder/include/migraphx/op/builder/kit.hpp

View workflow job for this annotation

GitHub Actions / tidy

the parameter 'obi' of type 'op_builder_if' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param,-warnings-as-errors]
{
return op_builder_if{[=](module& m,
instruction_ref ins,
const std::vector<instruction_ref>& args,
const std::vector<module_ref>& module_args,
const value& options) {
auto cargs = insert_common_args(m, ins, args, coptions);
return obi.bld_func(m, ins, cargs, module_args, options);
},
[=] { return obi.to_val_func(); }};
}

void ops(const std::initializer_list<std::string>& op_names) const
{
for(const auto& name : op_names)
{
register_builder(derived_prefix() + name, from_op(name));
}
}

void common_ops(const std::initializer_list<std::string>& op_names,
common_options coptions = {}) const
{
for(const auto& name : op_names)
{
register_builder(derived_prefix() + name, with_common(from_op(name), coptions));
}
}

void builders(const std::initializer_list<std::string>& builder_names) const
{
for(const auto& name : builder_names)
{
register_builder(derived_prefix() + name, from_builder(name));
}
}
};

} // namespace builder
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_BUILDER_KIT_HPP
2 changes: 2 additions & 0 deletions src/op/builder/include/migraphx/op/builder/op_builder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ struct op_builder_if

MIGRAPHX_EXPORT void register_builder(const std::string& name, op_builder_if opb_if);

const op_builder_if& get_op_builder_if(const std::string& name);

template <class T>
auto invoke_builder(const std::string& /*name*/,
module& m,
Expand Down
7 changes: 7 additions & 0 deletions src/op/builder/op_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ void register_builder(const std::string& name, op_builder_if opb_if)
builder_map()[name] = std::move(opb_if);
}

const op_builder_if& get_op_builder_if(const std::string& name)
{
if(has_op_builder(name))
return builder_map().at(name);
MIGRAPHX_THROW("GET_OP_BUILDER_IF: OpBuilder not found: " + name);
}

value get_op_builder_value(const std::string& name)
{
if(has_op_builder(name))
Expand Down
109 changes: 109 additions & 0 deletions src/op/builder/torch_kit.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@

#include <migraphx/op/builder/kit.hpp>
#include <migraphx/op/common.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
namespace builder {

struct torch_lstm : op_builder<torch_lstm>
{
std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{};
rnn_direction direction = rnn_direction::forward;
float clip = 0.0f;
int input_forget = 0;

template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden_size"),
f(self.actv_funcs, "actv_func"),
f(self.direction, "direction"),
f(self.clip, "clip"),
f(self.input_forget, "input_forget"));
}

static std::vector<std::string> names() { return {"tm::lstm"}; }

std::vector<instruction_ref>
insert(module& m, instruction_ref ins, const std::vector<instruction_ref>& args) const
{
auto self = *this;
if(self.actv_funcs.empty())
{
self.actv_funcs = {make_op("sigmoid"), make_op("tanh"), make_op("tanh")};
if(self.direction == rnn_direction::bidirectional)
{
self.actv_funcs.insert(self.actv_funcs.end(),
{make_op("sigmoid"), make_op("tanh"), make_op("tanh")});
}
}
auto hidden_states =
m.insert_instruction(ins, make_op("lstm", migraphx::to_value(self)), args);
auto last_hs = m.insert_instruction(ins, make_op("rnn_last_hs_output"), hidden_states);
auto last_cell = m.insert_instruction(ins, make_op("rnn_last_cell_output"), hidden_states);
return {hidden_states, last_hs, last_cell};
}
};

struct torch_kit : kit<torch_kit>
{
std::string prefix() const { return "tm::"; }
void apply() const
{
this->common_ops({
"ceil", "convert", "cos", "cosh", "div", "dot", "elu", "equal",
"erf", "exp", "floor", "fmod", "greater", "isinf", "isnan", "leaky_relu",
"less", "log", "log2", "logical_and", "max", "min", "mul", "mul",
"neg", "not", "pow", "recip", "relu", "rsqrt", "sigmoid", "sign",
"sin", "sinh", "sqrt", "sub", "tan", "tanh",
});
this->common_ops({"where"}, {.common_type = false});

this->ops({
"argmax",
"argmin",
"broadcast",
"concat",
"contiguous",
"contiguous",
"convolution",
"deconvolution",
"dequantizelinear",
"gather",
"gathernd",
"get_tuple_elem",
"multibroadcast",
"multibroadcast",
"pad",
"pooling",
"prefix_scan_sum",
"quantizelinear",
"reduce_all",
"reduce_any",
"reduce_max",
"reduce_mean",
"reduce_min",
"reduce_prod",
"reduce_sum",
"reshape",
"scatter_none",
"scatter_none",
"slice",
"softmax",
"squeeze",
"step",
"topk",
"transpose",
"undefined",
"unsqueeze",
});
}
};

} // namespace builder
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Loading