diff --git a/src/op/builder/include/migraphx/op/builder/kit.hpp b/src/op/builder/include/migraphx/op/builder/kit.hpp new file mode 100644 index 00000000000..ad6c65a20d6 --- /dev/null +++ b/src/op/builder/include/migraphx/op/builder/kit.hpp @@ -0,0 +1,91 @@ +#ifndef MIGRAPHX_GUARD_BUILDER_KIT_HPP +#define MIGRAPHX_GUARD_BUILDER_KIT_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { +namespace builder { + +struct register_kit_action +{ + template + static void apply() + { + T{}.apply(); + } +}; + +template +struct kit : auto_register +{ + void apply() const {} + + std::string derived_prefix() const { return static_cast(*this).prefix(); } + + op_builder_if from_op(const std::string& op_name) const + { + return op_builder_if{[=](module& m, + instruction_ref ins, + const std::vector& args, + const std::vector& module_args, + const value& options) -> std::vector { + auto opd = make_op(op_name, options); + return {m.insert_instruction(ins, opd, args, module_args)}; + }, + [=] { return make_op(op_name).to_value(); }}; + } + + op_builder_if from_builder(const std::string& op_builder) const + { + return get_op_builder_if(op_builder); + } + + op_builder_if with_common(op_builder_if obi, common_options coptions = {}) const + { + return op_builder_if{[=](module& m, + instruction_ref ins, + const std::vector& args, + const std::vector& module_args, + const value& options) { + auto cargs = insert_common_args(m, ins, args, coptions); + return obi.bld_func(m, ins, cargs, module_args, options); + }, + [=] { return obi.to_val_func(); }}; + } + + void ops(const std::initializer_list& op_names) const + { + for(const auto& name : op_names) + { + register_builder(derived_prefix() + name, from_op(name)); + } + } + + void common_ops(const std::initializer_list& op_names, + common_options coptions = {}) const + { + for(const auto& name : op_names) + { + register_builder(derived_prefix() + name, with_common(from_op(name), coptions)); + } + } + + void builders(const std::initializer_list& builder_names) const + { + for(const auto& name : builder_names) + { + register_builder(derived_prefix() + name, from_builder(name)); + } + } +}; + +} // namespace builder +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_BUILDER_KIT_HPP diff --git a/src/op/builder/include/migraphx/op/builder/op_builder.hpp b/src/op/builder/include/migraphx/op/builder/op_builder.hpp index 4fdedeefdf5..f6202241501 100644 --- a/src/op/builder/include/migraphx/op/builder/op_builder.hpp +++ b/src/op/builder/include/migraphx/op/builder/op_builder.hpp @@ -53,6 +53,8 @@ struct op_builder_if MIGRAPHX_EXPORT void register_builder(const std::string& name, op_builder_if opb_if); +const op_builder_if& get_op_builder_if(const std::string& name); + template auto invoke_builder(const std::string& /*name*/, module& m, diff --git a/src/op/builder/op_builder.cpp b/src/op/builder/op_builder.cpp index d199be07f9a..ed15b1ccc87 100644 --- a/src/op/builder/op_builder.cpp +++ b/src/op/builder/op_builder.cpp @@ -48,6 +48,13 @@ void register_builder(const std::string& name, op_builder_if opb_if) builder_map()[name] = std::move(opb_if); } +const op_builder_if& get_op_builder_if(const std::string& name) +{ + if(has_op_builder(name)) + return builder_map().at(name); + MIGRAPHX_THROW("GET_OP_BUILDER_IF: OpBuilder not found: " + name); +} + value get_op_builder_value(const std::string& name) { if(has_op_builder(name)) diff --git a/src/op/builder/torch_kit.cpp b/src/op/builder/torch_kit.cpp new file mode 100644 index 00000000000..149d035fcc3 --- /dev/null +++ b/src/op/builder/torch_kit.cpp @@ -0,0 +1,109 @@ + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { +namespace builder { + +struct torch_lstm : op_builder +{ + std::size_t hidden_size = 1; + std::vector actv_funcs{}; + rnn_direction direction = rnn_direction::forward; + float clip = 0.0f; + int input_forget = 0; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.hidden_size, "hidden_size"), + f(self.actv_funcs, "actv_func"), + f(self.direction, "direction"), + f(self.clip, "clip"), + f(self.input_forget, "input_forget")); + } + + static std::vector names() { return {"tm::lstm"}; } + + std::vector + insert(module& m, instruction_ref ins, const std::vector& args) const + { + auto self = *this; + if(self.actv_funcs.empty()) + { + self.actv_funcs = {make_op("sigmoid"), make_op("tanh"), make_op("tanh")}; + if(self.direction == rnn_direction::bidirectional) + { + self.actv_funcs.insert(self.actv_funcs.end(), + {make_op("sigmoid"), make_op("tanh"), make_op("tanh")}); + } + } + auto hidden_states = + m.insert_instruction(ins, make_op("lstm", migraphx::to_value(self)), args); + auto last_hs = m.insert_instruction(ins, make_op("rnn_last_hs_output"), hidden_states); + auto last_cell = m.insert_instruction(ins, make_op("rnn_last_cell_output"), hidden_states); + return {hidden_states, last_hs, last_cell}; + } +}; + +struct torch_kit : kit +{ + std::string prefix() const { return "tm::"; } + void apply() const + { + this->common_ops({ + "ceil", "convert", "cos", "cosh", "div", "dot", "elu", "equal", + "erf", "exp", "floor", "fmod", "greater", "isinf", "isnan", "leaky_relu", + "less", "log", "log2", "logical_and", "max", "min", "mul", "mul", + "neg", "not", "pow", "recip", "relu", "rsqrt", "sigmoid", "sign", + "sin", "sinh", "sqrt", "sub", "tan", "tanh", + }); + this->common_ops({"where"}, {.common_type = false}); + + this->ops({ + "argmax", + "argmin", + "broadcast", + "concat", + "contiguous", + "contiguous", + "convolution", + "deconvolution", + "dequantizelinear", + "gather", + "gathernd", + "get_tuple_elem", + "multibroadcast", + "multibroadcast", + "pad", + "pooling", + "prefix_scan_sum", + "quantizelinear", + "reduce_all", + "reduce_any", + "reduce_max", + "reduce_mean", + "reduce_min", + "reduce_prod", + "reduce_sum", + "reshape", + "scatter_none", + "scatter_none", + "slice", + "softmax", + "squeeze", + "step", + "topk", + "transpose", + "undefined", + "unsqueeze", + }); + } +}; + +} // namespace builder +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx