diff --git a/samples/AttentionFMHA.py b/samples/AttentionFMHA.py index 48ec984..1ba77de 100644 --- a/samples/AttentionFMHA.py +++ b/samples/AttentionFMHA.py @@ -31,7 +31,8 @@ def fmha_kernel(Q, K, V, Out, qk_scale: float, input_pos: int, - TILE_D: ConstInt, # TILE_D = hidden_size + Dqk: ConstInt, # Head dimension of Q and K + Dv: ConstInt, # Head dimension of V H: ConstInt, TILE_M: ConstInt, TILE_N: ConstInt, @@ -64,12 +65,12 @@ def fmha_kernel(Q, K, V, Out, # Initialize online softmax accumulators in float32 for stability m_i = ct.full((TILE_M, 1), -np.inf, dtype=np.float32) l_i = ct.full((TILE_M, 1), 0.0, dtype=np.float32) - acc = ct.full((TILE_M, TILE_D), 0.0, dtype=np.float32) + acc = ct.full((TILE_M, Dv), 0.0, dtype=np.float32) # Load query tile for this batch, head, and M-chunk q = ct.load( - Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D) - ).reshape((TILE_M, TILE_D)) # [TILE_M, TILE_D] + Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, Dqk) + ).reshape((TILE_M, Dqk)) # [TILE_M, Dqk] # loop over k, v and update accumulator m_end = input_pos + (bid_x + 1) * TILE_M @@ -88,11 +89,11 @@ def fmha_kernel(Q, K, V, Out, for j in range(0, Tc): # --- Compute QK product --- k = ct.load( - K, index=(batch_idx, off_kv_h, 0, j), shape=(1, 1, TILE_D, TILE_N), + K, index=(batch_idx, off_kv_h, 0, j), shape=(1, 1, Dqk, TILE_N), order=(0, 1, 3, 2), latency=2, ) - k = k.reshape((TILE_D, TILE_N)) # [TILE_D, TILE_N] + k = k.reshape((Dqk, TILE_N)) # [Dqk, TILE_N] qk = ct.full((TILE_M, TILE_N), 0., dtype=np.float32) qk = ct.mma(q, k, qk) # [TILE_M, TILE_N] @@ -125,16 +126,16 @@ def fmha_kernel(Q, K, V, Out, # --- Compute PV product --- v = ct.load( - V, index=(batch_idx, off_kv_h, j, 0), shape=(1, 1, TILE_N, TILE_D), + V, index=(batch_idx, off_kv_h, j, 0), shape=(1, 1, TILE_N, Dv), latency=4, - ).reshape((TILE_N, TILE_D)) # [TILE_N, TILE_D] + ).reshape((TILE_N, Dv)) # [TILE_N, Dv] p = p.astype(Q.dtype) acc = ct.mma(p, v, acc) # [TILE_M, TILE_N] m_i = m_ij # [TILE_M, 1] # --- Final Normalization and Store --- acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX) - acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype) + acc = acc.reshape((1, 1, TILE_M, Dv)).astype(Out.dtype) ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc) @@ -202,6 +203,7 @@ def cutile_fmha(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, qk_scale, input_pos, D_k, + D_v, Heads, tile_m, tile_n, @@ -273,12 +275,18 @@ def cutile_autotune_fmha(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, def torch_fmha(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, - is_causal: bool, enable_gqa: bool) -> torch.Tensor: - backend = SDPBackend.CUDNN_ATTENTION \ - if (Q.shape[2] == K.shape[2]) \ - else SDPBackend.FLASH_ATTENTION - with sdpa_kernel(backend): - ret = scaled_dot_product_attention(Q, K, V, + is_causal: bool, enable_gqa: bool, + use_backend_selection_rule: bool = False) -> torch.Tensor: + if use_backend_selection_rule: + backend = SDPBackend.CUDNN_ATTENTION \ + if (Q.shape[2] == K.shape[2]) \ + else SDPBackend.FLASH_ATTENTION + with sdpa_kernel(backend): + ret = scaled_dot_product_attention(Q, K, V, + is_causal=is_causal, + enable_gqa=enable_gqa) + else: + ret = scaled_dot_product_attention(Q, K, V, is_causal=is_causal, enable_gqa=enable_gqa) return ret @@ -296,13 +304,14 @@ def torch_fmha(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, # --- User Configuration --- BATCH_SIZE = 2 - NUM_HEADS = 8 + NUM_HEADS = 32 SEQ_LEN_Q = 128 - SEQ_LEN_KV = 128 - D_K = 64 + SEQ_LEN_KV = 256 + D_K = 128 D_V = 64 - QUERY_GROUP_SIZE = 1 + QUERY_GROUP_SIZE = 8 + enable_gqa = QUERY_GROUP_SIZE > 1 DTYPE = torch.float16 @@ -336,7 +345,7 @@ def torch_fmha(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, dtype:{output_fmha_cutile_non_causal.dtype}""") if args.correctness_check: ref_fmha = torch_fmha(Q_input, K_input, V_input, - is_causal=False, enable_gqa=False) + is_causal=False, enable_gqa=enable_gqa) torch.testing.assert_close(output_fmha_cutile_non_causal, ref_fmha, atol=1e-3, rtol=1e-3) print("Correctness check passed") else: @@ -354,7 +363,7 @@ def torch_fmha(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, dtype: {output_fmha_cutile_causal.dtype}""") if args.correctness_check: ref_fmha = torch_fmha(Q_input, K_input, V_input, - is_causal=True, enable_gqa=False) + is_causal=True, enable_gqa=enable_gqa) torch.testing.assert_close(output_fmha_cutile_causal, ref_fmha, atol=1e-3, rtol=1e-3) print("Correctness check passed") else: @@ -394,7 +403,7 @@ def torch_fmha(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, dtype: {output_fmha_cutile_autotune_causal.dtype}""") print(f"Tuned config: {tuned_config}") if args.correctness_check: - ref_fmha = torch_fmha(Q_input, K_input, V_input, is_causal=True, enable_gqa=False) + ref_fmha = torch_fmha(Q_input, K_input, V_input, is_causal=True, enable_gqa=enable_gqa) torch.testing.assert_close( output_fmha_cutile_autotune_causal, ref_fmha, atol=1e-2, rtol=5e-2 ) diff --git a/samples/templates/AttentionFMHA.py b/samples/templates/AttentionFMHA.py index 3f67ce2..18930e6 100644 --- a/samples/templates/AttentionFMHA.py +++ b/samples/templates/AttentionFMHA.py @@ -83,6 +83,7 @@ def cutile_fmha(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, qk_scale, input_pos, D_k, + D_v, Heads, tile_m, tile_n, diff --git a/test/bench_attention.py b/test/bench_attention.py index 9350be1..cea4cf8 100644 --- a/test/bench_attention.py +++ b/test/bench_attention.py @@ -74,10 +74,11 @@ def bench_fmha(qkv_shape, dtype, backend, benchmark): rounds=rounds, warmup_rounds=warmup_rounds, iterations=iterations, ) - B, H, L, D = q.shape - # first gemm mma(q, k): 2 * B * H * L * L * D - # second gemm mma(p, v): 2 * B * H * L * L * D - flop_count = 4 * B * H * L * L * D + B, H, L, Dqk = q.shape + _, _, _, Dv = v.shape + # first gemm mma(q, k): 2 * B * H * L * L * Dqk + # second gemm mma(p, v): 2 * B * H * L * L * Dv + flop_count = 2 * B * H * L * L * (Dqk + Dv) if is_causal: flop_count /= 2 @@ -88,9 +89,10 @@ def bench_fmha(qkv_shape, dtype, backend, benchmark): def cutile_fmha(q, k, v, o, is_causal, enable_gqa): - b, qh, q_len, d = q.shape + b, qh, q_len, dqk = q.shape _, kh, k_len, _ = k.shape - qk_scale = 1 / sqrt(d) + _, _, _, dv = v.shape + qk_scale = 1 / sqrt(dqk) TILE_M, TILE_N = (256, 128) if is_causal else (64, 128) query_group_size = qh // kh grid = (ceil(q_len / TILE_M), b * qh, 1) @@ -100,7 +102,7 @@ def cutile_fmha(q, k, v, o, is_causal, enable_gqa): (q, k, v, o, qk_scale, input_pos, - d, qh, + dqk, dv, qh, TILE_M, TILE_N, query_group_size, is_causal, EVEN_K)) diff --git a/test/kernels/attention.py b/test/kernels/attention.py index fd431b8..5d8cd41 100644 --- a/test/kernels/attention.py +++ b/test/kernels/attention.py @@ -21,7 +21,8 @@ def fmha_kernel(Q, K, V, Out, qk_scale: float, input_pos: int, - TILE_D: ConstInt, # TILE_D = hidden_size + Dqk: ConstInt, # Head dimension of Q and K + Dv: ConstInt, # Head dimension of V H: ConstInt, TILE_M: ConstInt, TILE_N: ConstInt, @@ -54,12 +55,12 @@ def fmha_kernel(Q, K, V, Out, # Initialize online softmax accumulators in float32 for stability m_i = ct.full((TILE_M, 1), -np.inf, dtype=np.float32) l_i = ct.full((TILE_M, 1), 0.0, dtype=np.float32) - acc = ct.full((TILE_M, TILE_D), 0.0, dtype=np.float32) + acc = ct.full((TILE_M, Dv), 0.0, dtype=np.float32) # Load query tile for this batch, head, and M-chunk q = ct.load( - Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D) - ).reshape((TILE_M, TILE_D)) # [TILE_M, TILE_D] + Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, Dqk) + ).reshape((TILE_M, Dqk)) # [TILE_M, Dqk] # loop over k, v and update accumulator m_end = input_pos + (bid_x + 1) * TILE_M @@ -78,11 +79,11 @@ def fmha_kernel(Q, K, V, Out, for j in range(0, Tc): # --- Compute QK product --- k = ct.load( - K, index=(batch_idx, off_kv_h, 0, j), shape=(1, 1, TILE_D, TILE_N), + K, index=(batch_idx, off_kv_h, 0, j), shape=(1, 1, Dqk, TILE_N), order=(0, 1, 3, 2), latency=2, ) - k = k.reshape((TILE_D, TILE_N)) # [TILE_D, TILE_N] + k = k.reshape((Dqk, TILE_N)) # [Dqk, TILE_N] qk = ct.full((TILE_M, TILE_N), 0., dtype=np.float32) qk = ct.mma(q, k, qk) # [TILE_M, TILE_N] @@ -115,14 +116,14 @@ def fmha_kernel(Q, K, V, Out, # --- Compute PV product --- v = ct.load( - V, index=(batch_idx, off_kv_h, j, 0), shape=(1, 1, TILE_N, TILE_D), + V, index=(batch_idx, off_kv_h, j, 0), shape=(1, 1, TILE_N, Dv), latency=4, - ).reshape((TILE_N, TILE_D)) # [TILE_N, TILE_D] + ).reshape((TILE_N, Dv)) # [TILE_N, Dv] p = p.astype(Q.dtype) acc = ct.mma(p, v, acc) # [TILE_M, TILE_N] m_i = m_ij # [TILE_M, 1] # --- Final Normalization and Store --- acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX) - acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype) + acc = acc.reshape((1, 1, TILE_M, Dv)).astype(Out.dtype) ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc) diff --git a/test/test_attention.py b/test/test_attention.py index 2d2ad91..ed02cc9 100644 --- a/test/test_attention.py +++ b/test/test_attention.py @@ -15,21 +15,21 @@ @pytest.mark.parametrize("k_heads", [8]) @pytest.mark.parametrize("q_len", [1, 15, 32]) @pytest.mark.parametrize("k_len", [32, 63]) -@pytest.mark.parametrize("hidden_size", [32]) +@pytest.mark.parametrize("head_dim", [32]) @pytest.mark.parametrize("tile_size", [(8, 16)]) @pytest.mark.parametrize("is_causal", [True, False]) @pytest.mark.parametrize("use_input_pos", [True, False]) def test_flash_attention(batch_size, q_heads, k_heads, q_len, k_len, - hidden_size, tile_size, is_causal, + head_dim, tile_size, is_causal, use_input_pos, float_dtype): query_group_size = q_heads // k_heads TILE_M, TILE_N = tile_size - qk_scale = 1 / math.sqrt(hidden_size) - q = torch.randn((batch_size, q_heads, q_len, hidden_size), dtype=float_dtype, device='cuda') - k = torch.randn((batch_size, k_heads, k_len, hidden_size), dtype=float_dtype, device='cuda') - v = torch.randn((batch_size, k_heads, k_len, hidden_size), dtype=float_dtype, device='cuda') + qk_scale = 1 / math.sqrt(head_dim) + q = torch.randn((batch_size, q_heads, q_len, head_dim), dtype=float_dtype, device='cuda') + k = torch.randn((batch_size, k_heads, k_len, head_dim), dtype=float_dtype, device='cuda') + v = torch.randn((batch_size, k_heads, k_len, head_dim), dtype=float_dtype, device='cuda') o = torch.zeros_like(q) grid = (math.ceil(q_len / TILE_M), batch_size * q_heads, 1) if use_input_pos: @@ -43,7 +43,7 @@ def test_flash_attention(batch_size, q_heads, k_heads, (q, k, v, o, qk_scale, input_pos, - hidden_size, q_heads, + head_dim, head_dim, q_heads, TILE_M, TILE_N, query_group_size, is_causal, EVEN_K)) if is_causal: