Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions src/fuse_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,61 @@ inline auto pointwise_inputs()
};
}

// find attention blocks that have been quantized and undo them
struct find_quant_attention
{
auto matcher() const
{
auto gemm1 =
match::name("dequantizelinear")(match::arg(0)(match::name("quant_dot").bind("qgemm1")))
.bind("deq1");
auto softmax = match::softmax_input(match::skip(match::name("convert"))(gemm1));
auto probs = match::name("quantizelinear")(
match::arg(0)(match::skip(match::name("convert"))(softmax)));
return match::name("quant_dot")(match::arg(0)(probs)).bind("qgemm2");
}

// removes the q/dq pairs from attention block gemms
static bool dequantize_gemm(module& m, instruction_ref qgemm, instruction_ref deq)
{
auto qa = qgemm->inputs().at(0);
auto qb = qgemm->inputs().at(1);
if(qa->name() != "quantizelinear" or qb->name() != "quantizelinear")
return false;
auto a = qa->inputs().front();
auto b = qb->inputs().front();
auto compute_type = b->get_shape().type();
if(a->get_shape().type() != compute_type)
a = m.insert_instruction(deq, make_op("convert", {{"target_type", compute_type}}), a);
instruction_ref dot = m.insert_instruction(deq, make_op("dot"), a, b);
if(compute_type != deq->get_shape().type())
dot = m.insert_instruction(
deq, make_op("convert", {{"target_type", deq->get_shape().type()}}), dot);
m.replace_instruction(deq, dot);
return true;
}

void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto& m = mpm.get_module();
auto qgemm1 = r.instructions["qgemm1"];
auto deq1 = r.instructions["deq1"];
auto qgemm2 = r.result;

// gemm2's dequantizelinear is its consumer; locate it before rewriting.
auto qgemm2_outs = qgemm2->outputs();
auto deq2 = std::find_if(qgemm2_outs.begin(), qgemm2_outs.end(), [](auto o) {
return o->name() == "dequantizelinear";
});
if(deq2 == qgemm2_outs.end())
return;

if(not dequantize_gemm(m, qgemm1, deq1))
return;
dequantize_gemm(m, qgemm2, *deq2);
}
};

struct find_attention
{
std::size_t* counter;
Expand Down Expand Up @@ -977,6 +1032,11 @@ void fuse_attention::apply(module_pass_manager& mpm) const
// Only fuse plain attention when requested
if(attn_enabled)
{
// remove quantization from attention blocks so they can be fused; rocMLIR currently does
// not support fp8 attention
match::find_matches(mpm, find_quant_attention{});
mpm.run_pass(dead_code_elimination{});

match::find_matches(mpm, find_attention{.counter = &counter});
mpm.get_module().sort();
mpm.run_pass(dead_code_elimination{});
Expand Down
66 changes: 66 additions & 0 deletions test/fuse_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2229,6 +2229,72 @@ TEST_CASE(ceil_mul_of_function)
EXPECT(migraphx::ceil_mul_of(2049, 32) == 2080); // 2049 -> 2080 (padding = 31)
}

// Attention block with 8-bit quantization should have q/dq pairs removed to allow for fusion
TEST_CASE(fp8_quant_gemm_softmax_gemm)
{
migraphx::shape s{migraphx::shape::half_type, {1, 12, 256, 256}};

auto quantize = [](migraphx::module* mm, migraphx::instruction_ref x) {
auto scale = mm->add_literal(
migraphx::literal{migraphx::shape{x->get_shape().type(), {1}}, {0.05f}});
scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), scale);
auto zp = mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::fp8e4m3fn_type, {1}}, {0}});
zp = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), zp);
return mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp);
};
auto dequantize = [](migraphx::module* mm,
migraphx::instruction_ref x,
migraphx::shape::type_t out_type) {
auto scale = mm->add_literal(migraphx::literal{migraphx::shape{out_type, {1}}, {0.0025f}});
scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), scale);
return mm->add_instruction(migraphx::make_op("dequantizelinear"), x, scale);
};

migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto q = mm->add_parameter("q", s);
auto k = mm->add_parameter("k", s);
auto v = mm->add_parameter("v", s);
auto kt =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), k);

auto gemm1 =
mm->add_instruction(migraphx::make_op("quant_dot"), quantize(mm, q), quantize(mm, kt));
auto deq1 = dequantize(mm, gemm1, migraphx::shape::float_type);

auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {3}}}), deq1);
rmax = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}),
rmax);
auto sub = mm->add_instruction(migraphx::make_op("sub"), deq1, rmax);
auto exp = mm->add_instruction(migraphx::make_op("exp"), sub);
auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), exp);
rsum = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}),
rsum);
auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum);

auto gemm2 =
mm->add_instruction(migraphx::make_op("quant_dot"), quantize(mm, div), quantize(mm, v));
auto deq2 = dequantize(mm, gemm2, migraphx::shape::half_type);
mm->add_return({deq2});
}
run_pass(p1, {.attn_enabled = true});

auto* mm = p1.get_main_module();
// The attention gemms must be de-quantized (no quant_dot left) ...
EXPECT(std::none_of(
mm->begin(), mm->end(), [](const auto& ins) { return ins.name() == "quant_dot"; }));
// ... and fused into an attention group.
EXPECT(std::any_of(mm->begin(), mm->end(), [](const auto& ins) {
return ins.name() == "group" and
ins.get_operator().to_value()["tag"].template to<std::string>() == "attention";
}));
}

int main(int argc, const char* argv[])
{
test::run(argc, argv);
Expand Down
Loading