diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index b0cc4d7cc05..a7c7e7a8943 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -136,6 +136,61 @@ inline auto pointwise_inputs() }; } +// find attention blocks that have been quantized and undo them +struct find_quant_attention +{ + auto matcher() const + { + auto gemm1 = + match::name("dequantizelinear")(match::arg(0)(match::name("quant_dot").bind("qgemm1"))) + .bind("deq1"); + auto softmax = match::softmax_input(match::skip(match::name("convert"))(gemm1)); + auto probs = match::name("quantizelinear")( + match::arg(0)(match::skip(match::name("convert"))(softmax))); + return match::name("quant_dot")(match::arg(0)(probs)).bind("qgemm2"); + } + + // removes the q/dq pairs from attention block gemms + static bool dequantize_gemm(module& m, instruction_ref qgemm, instruction_ref deq) + { + auto qa = qgemm->inputs().at(0); + auto qb = qgemm->inputs().at(1); + if(qa->name() != "quantizelinear" or qb->name() != "quantizelinear") + return false; + auto a = qa->inputs().front(); + auto b = qb->inputs().front(); + auto compute_type = b->get_shape().type(); + if(a->get_shape().type() != compute_type) + a = m.insert_instruction(deq, make_op("convert", {{"target_type", compute_type}}), a); + instruction_ref dot = m.insert_instruction(deq, make_op("dot"), a, b); + if(compute_type != deq->get_shape().type()) + dot = m.insert_instruction( + deq, make_op("convert", {{"target_type", deq->get_shape().type()}}), dot); + m.replace_instruction(deq, dot); + return true; + } + + void apply(module_pass_manager& mpm, const match::matcher_result& r) const + { + auto& m = mpm.get_module(); + auto qgemm1 = r.instructions["qgemm1"]; + auto deq1 = r.instructions["deq1"]; + auto qgemm2 = r.result; + + // gemm2's dequantizelinear is its consumer; locate it before rewriting. + auto qgemm2_outs = qgemm2->outputs(); + auto deq2 = std::find_if(qgemm2_outs.begin(), qgemm2_outs.end(), [](auto o) { + return o->name() == "dequantizelinear"; + }); + if(deq2 == qgemm2_outs.end()) + return; + + if(not dequantize_gemm(m, qgemm1, deq1)) + return; + dequantize_gemm(m, qgemm2, *deq2); + } +}; + struct find_attention { std::size_t* counter; @@ -977,6 +1032,11 @@ void fuse_attention::apply(module_pass_manager& mpm) const // Only fuse plain attention when requested if(attn_enabled) { + // remove quantization from attention blocks so they can be fused; rocMLIR currently does + // not support fp8 attention + match::find_matches(mpm, find_quant_attention{}); + mpm.run_pass(dead_code_elimination{}); + match::find_matches(mpm, find_attention{.counter = &counter}); mpm.get_module().sort(); mpm.run_pass(dead_code_elimination{}); diff --git a/test/fuse_attention.cpp b/test/fuse_attention.cpp index 64cb4ca2d06..d462a64e548 100644 --- a/test/fuse_attention.cpp +++ b/test/fuse_attention.cpp @@ -2229,6 +2229,72 @@ TEST_CASE(ceil_mul_of_function) EXPECT(migraphx::ceil_mul_of(2049, 32) == 2080); // 2049 -> 2080 (padding = 31) } +// Attention block with 8-bit quantization should have q/dq pairs removed to allow for fusion +TEST_CASE(fp8_quant_gemm_softmax_gemm) +{ + migraphx::shape s{migraphx::shape::half_type, {1, 12, 256, 256}}; + + auto quantize = [](migraphx::module* mm, migraphx::instruction_ref x) { + auto scale = mm->add_literal( + migraphx::literal{migraphx::shape{x->get_shape().type(), {1}}, {0.05f}}); + scale = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), scale); + auto zp = mm->add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::fp8e4m3fn_type, {1}}, {0}}); + zp = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), zp); + return mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp); + }; + auto dequantize = [](migraphx::module* mm, + migraphx::instruction_ref x, + migraphx::shape::type_t out_type) { + auto scale = mm->add_literal(migraphx::literal{migraphx::shape{out_type, {1}}, {0.0025f}}); + scale = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), scale); + return mm->add_instruction(migraphx::make_op("dequantizelinear"), x, scale); + }; + + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto q = mm->add_parameter("q", s); + auto k = mm->add_parameter("k", s); + auto v = mm->add_parameter("v", s); + auto kt = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), k); + + auto gemm1 = + mm->add_instruction(migraphx::make_op("quant_dot"), quantize(mm, q), quantize(mm, kt)); + auto deq1 = dequantize(mm, gemm1, migraphx::shape::float_type); + + auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {3}}}), deq1); + rmax = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), + rmax); + auto sub = mm->add_instruction(migraphx::make_op("sub"), deq1, rmax); + auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), exp); + rsum = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), + rsum); + auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); + + auto gemm2 = + mm->add_instruction(migraphx::make_op("quant_dot"), quantize(mm, div), quantize(mm, v)); + auto deq2 = dequantize(mm, gemm2, migraphx::shape::half_type); + mm->add_return({deq2}); + } + run_pass(p1, {.attn_enabled = true}); + + auto* mm = p1.get_main_module(); + // The attention gemms must be de-quantized (no quant_dot left) ... + EXPECT(std::none_of( + mm->begin(), mm->end(), [](const auto& ins) { return ins.name() == "quant_dot"; })); + // ... and fused into an attention group. + EXPECT(std::any_of(mm->begin(), mm->end(), [](const auto& ins) { + return ins.name() == "group" and + ins.get_operator().to_value()["tag"].template to() == "attention"; + })); +} + int main(int argc, const char* argv[]) { test::run(argc, argv);