From 33bb509ec3112830c3eba34ed19ff0971029615f Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil Date: Wed, 1 Apr 2026 21:28:01 +0000 Subject: [PATCH 1/3] first experiment --- finegrained-fp8/bench_fused_moe.py | 377 +++++++++ .../torch-ext/finegrained_fp8/__init__.py | 12 + .../torch-ext/finegrained_fp8/atomic.py | 634 ++++++++++++++ .../torch-ext/finegrained_fp8/fused.py | 775 ++++++++++++++++++ .../torch-ext/finegrained_fp8/moe.py | 189 +++++ 5 files changed, 1987 insertions(+) create mode 100644 finegrained-fp8/bench_fused_moe.py create mode 100644 finegrained-fp8/torch-ext/finegrained_fp8/atomic.py create mode 100644 finegrained-fp8/torch-ext/finegrained_fp8/fused.py create mode 100644 finegrained-fp8/torch-ext/finegrained_fp8/moe.py diff --git a/finegrained-fp8/bench_fused_moe.py b/finegrained-fp8/bench_fused_moe.py new file mode 100644 index 00000000..fb307498 --- /dev/null +++ b/finegrained-fp8/bench_fused_moe.py @@ -0,0 +1,377 @@ +"""Benchmark all MoE dispatch methods: correctness matrix + performance sweep + plot. + +Usage: + python bench_fused_moe.py +""" + +import sys + +sys.path.append("torch-ext") + +import matplotlib +import matplotlib.pyplot as plt +import torch +import triton +from rich.console import Console +from rich.table import Table +from finegrained_fp8 import ( + moe_grouped, + moe_batched, + moe_grouped_fused, + moe_batched_fused, + moe_grouped_atomic, + moe_batched_atomic, +) + +matplotlib.use("Agg") + +FP8_DTYPE = torch.float8_e4m3fn +FP8_MAX = torch.finfo(FP8_DTYPE).max + +console = Console() + +METHODS = { + "grouped": moe_grouped, + # "batched": moe_batched, + "grouped_fused": moe_grouped_fused, + # "batched_fused": moe_batched_fused, + "grouped_atomic": moe_grouped_atomic, + # "batched_atomic": moe_batched_atomic, +} + + +def quantize_weights_block(W, block_n=128, block_k=128): + E, N, K = W.shape + Wq = torch.empty_like(W, dtype=FP8_DTYPE) + Bs = torch.empty( + E, N // block_n, K // block_k, dtype=torch.float32, device=W.device + ) + for e in range(E): + for ni in range(N // block_n): + for ki in range(K // block_k): + block = W[ + e, + ni * block_n : (ni + 1) * block_n, + ki * block_k : (ki + 1) * block_k, + ] + amax = block.abs().amax().clamp(min=1e-12) + scale = FP8_MAX / amax + Wq[ + e, + ni * block_n : (ni + 1) * block_n, + ki * block_k : (ki + 1) * block_k, + ] = (block * scale).clamp(-FP8_MAX, FP8_MAX).to(FP8_DTYPE) + Bs[e, ni, ki] = 1.0 / scale + return Wq, Bs + + +def diff_emoji(d): + if d == 0: + return "✅", "exact", "green" + elif d < 1: + return "🟢", f"{d:.4f}", "green" + elif d < 100: + return "🟡", f"{d:.1f}", "yellow" + elif d < 1000: + return "🟠", f"{d:.1f}", "dark_orange" + else: + return "🔴", f"{d:.1f}", "red" + + +def main(): + device = "cuda" + block_size = [128, 128] + + # Qwen3-30B-A3B-Instruct dimensions + model_name = "Qwen3-30B-A3B" + E, N_inter, K, top_k = 128, 768, 2048, 8 + + torch.manual_seed(42) + W_gu = torch.randn(E, 2 * N_inter, K, dtype=torch.float32, device=device) + gate_up_proj, gate_up_proj_scale_inv = quantize_weights_block(W_gu) + W_down = torch.randn(E, K, N_inter, dtype=torch.float32, device=device) + down_proj, down_proj_scale_inv = quantize_weights_block(W_down) + + names = list(METHODS.keys()) + + # ═══════════════════════════════════════════════════════════════════════════ + # Correctness + # ═══════════════════════════════════════════════════════════════════════════ + console.rule("[bold]Correctness[/bold]") + + num_tokens = 32 + hidden = torch.randn(num_tokens, K, device=device, dtype=torch.bfloat16) + top_k_idx = torch.randint(0, E, (num_tokens, top_k), device=device) + top_k_wts = torch.randn(num_tokens, top_k, device=device, dtype=torch.bfloat16) + + args = ( + hidden, + top_k_idx, + top_k_wts, + gate_up_proj, + down_proj, + gate_up_proj_scale_inv, + down_proj_scale_inv, + block_size, + ) + + # For correctness, fused methods use simulate_unfused=True to match unfused precision + def call_method(name, fn): + if name not in ("grouped", "batched"): + return fn(*args, simulate_unfused=True) + return fn(*args) + + outputs = {} + with torch.no_grad(): + for name, fn in METHODS.items(): + outputs[name] = call_method(name, fn) + + outputs2 = {} + with torch.no_grad(): + for name, fn in METHODS.items(): + outputs2[name] = call_method(name, fn) + + # Self-consistency table + det_table = Table(title="Self-consistency (determinism)") + det_table.add_column("Method", style="cyan") + det_table.add_column("Max diff", justify="right") + det_table.add_column("Status", justify="center") + + for name in names: + d = (outputs[name].float() - outputs2[name].float()).abs().max().item() + if d == 0: + det_table.add_row(name, "0.00", "[green]✅ deterministic[/green]") + else: + det_table.add_row( + name, f"{d:.2f}", f"[yellow]⚠️ non-deterministic ({d:.1f})[/yellow]" + ) + + console.print(det_table) + + # Parity matrix + parity_table = Table(title="Parity matrix (max abs diff)") + parity_table.add_column("", style="cyan") + for name in names: + parity_table.add_column(name[:14], justify="center") + + for n1 in names: + cells = [] + for n2 in names: + d = (outputs[n1].float() - outputs[n2].float()).abs().max().item() + emoji, text, color = diff_emoji(d) + cells.append(f"[{color}]{emoji} {text}[/{color}]") + parity_table.add_row(n1, *cells) + + console.print(parity_table) + + # ═══════════════════════════════════════════════════════════════════════════ + # Benchmark + # ═══════════════════════════════════════════════════════════════════════════ + console.rule( + f"[bold]Benchmark — {model_name} (E={E}, N_inter={N_inter}, K={K}, top_k={top_k})[/bold]" + ) + + def run_sweep(bench_fn_factory, title, mode="eager"): + """Run adaptive token sweep, dropping methods that plateau (<10% TFLOPS gain).""" + table = Table(title=title) + table.add_column("Tokens", justify="right", style="bold") + for name in names: + table.add_column(name[:14], justify="right") + table.add_column("Winner", justify="center", style="bold green") + + all_res = [] + prev_tflops = {} # per-method previous TFLOPS + active = set(names) # methods still progressing + num_tokens = 1 + + while num_tokens <= 131072 and active: # stop when all methods plateau + hidden = torch.randn(num_tokens, K, device=device, dtype=torch.bfloat16) + top_k_idx = torch.randint(0, E, (num_tokens, top_k), device=device) + top_k_wts = torch.randn( + num_tokens, top_k, device=device, dtype=torch.bfloat16 + ) + + args = ( + hidden, + top_k_idx, + top_k_wts, + gate_up_proj, + down_proj, + gate_up_proj_scale_inv, + down_proj_scale_inv, + block_size, + ) + + results = {} + for name, fn in METHODS.items(): + + def _bench(fn=fn): + return fn(*args) + + try: + results[name] = bench_fn_factory(_bench) + except Exception: + results[name] = float("inf") + + all_res.append((num_tokens, results)) + + # Compute per-method TFLOPS and check progress + S = num_tokens * top_k + flops = 2 * S * K * 2 * N_inter + 2 * S * N_inter * K + + best_name = min(results, key=results.get) + cells = [] + for name in names: + ms = results[name] + if ms == float("inf"): + cells.append("[red]n/a[/red]") + continue + tflops = flops / (ms * 1e-3) / 1e12 + + # Check if this method is still progressing + if name in prev_tflops and num_tokens >= 1024: + increase = (tflops - prev_tflops[name]) / max( + prev_tflops[name], 1e-12 + ) + if increase < 0.10: + active.discard(name) + + prev_tflops[name] = tflops + + if name not in active: + cells.append(f"[dim]{ms:.3f}[/dim]") + elif name == best_name: + cells.append(f"[bold green]{ms:.3f}[/bold green]") + else: + cells.append(f"{ms:.3f}") + + table.add_row( + str(num_tokens), + *cells, + best_name if results[best_name] != float("inf") else "n/a", + ) + + num_tokens *= 2 + + console.print(table) + return all_res + + all_results = run_sweep( + lambda fn: triton.testing.do_bench(fn, warmup=10, rep=50), + "Latency — Eager (ms)", + ) + + all_results_cg = run_sweep( + lambda fn: triton.testing.do_bench_cudagraph(fn), + "Latency — CUDA Graphs (ms)", + ) + + # ═══════════════════════════════════════════════════════════════════════════ + # Save CSV + # ═══════════════════════════════════════════════════════════════════════════ + import csv + + csv_path = "moe_results.csv" + with open(csv_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["tokens", "method", "mode", "time_ms", "tflops"]) + for mode, results_list in [ + ("eager", all_results), + ("cudagraph", all_results_cg), + ]: + for num_tokens, results in results_list: + S = num_tokens * top_k + flops = 2 * S * K * 2 * N_inter + 2 * S * N_inter * K + for name, ms in results.items(): + if ms == float("inf"): + continue + tflops = flops / (ms * 1e-3) / 1e12 + writer.writerow( + [num_tokens, name, mode, f"{ms:.3f}", f"{tflops:.2f}"] + ) + console.print(f"Results saved to [bold]{csv_path}[/bold]") + + # ═══════════════════════════════════════════════════════════════════════════ + # Plot + # ═══════════════════════════════════════════════════════════════════════════ + colors = { + "grouped": "#e74c3c", + "batched": "#c0392b", + "grouped_fused": "#2ecc71", + "batched_fused": "#27ae60", + "grouped_atomic": "#3498db", + "batched_atomic": "#2980b9", + } + linestyles = { + "grouped": "--", + "batched": "--", + "grouped_fused": "-", + "batched_fused": "-", + "grouped_atomic": ":", + "batched_atomic": ":", + } + + def compute_tflops(results_list): + data = {} + for name in names: + xs, ys = [], [] + for num_tokens, results in results_list: + S = num_tokens * top_k + flops = 2 * S * K * 2 * N_inter + 2 * S * N_inter * K + ms = results.get(name, float("inf")) + if ms == float("inf"): + continue + tflops = flops / (ms * 1e-3) / 1e12 + xs.append(num_tokens) + ys.append(tflops) + data[name] = (xs, ys) + return data + + tflops_eager = compute_tflops(all_results) + tflops_cg = compute_tflops(all_results_cg) + + fig, axes = plt.subplots(2, 2, figsize=(18, 14)) + + for col, (yscale, title_suffix) in enumerate( + [("linear", "Linear"), ("log", "Log")] + ): + for row, (tflops_data, mode) in enumerate( + [(tflops_eager, "Eager"), (tflops_cg, "CUDA Graphs")] + ): + ax = axes[row, col] + for name in names: + xs, ys = tflops_data[name] + if xs: + ax.plot( + xs, + ys, + linestyles.get(name, "-"), + marker="o", + markersize=4, + color=colors.get(name, "gray"), + label=name, + linewidth=2, + ) + ax.set_xlabel("Tokens", fontsize=12) + ax.set_ylabel("TFLOPS", fontsize=12) + ax.set_title( + f"{mode} — TFLOPS ({title_suffix})", fontsize=13, fontweight="bold" + ) + ax.set_xscale("log", base=2) + ax.set_yscale(yscale) + ax.legend(fontsize=8, loc="best") + ax.grid(True, alpha=0.3) + + fig.suptitle( + f"FP8 MoE Expert Dispatch — {model_name} (E={E}, N_inter={N_inter}, K={K}, top_k={top_k})", + fontsize=14, + fontweight="bold", + ) + fig.tight_layout() + fig.savefig("moe_tflops.png", dpi=150) + console.print("\nPlot saved to [bold]moe_tflops.png[/bold]") + plt.close(fig) + + +if __name__ == "__main__": + main() diff --git a/finegrained-fp8/torch-ext/finegrained_fp8/__init__.py b/finegrained-fp8/torch-ext/finegrained_fp8/__init__.py index a4171be8..24ffef0a 100644 --- a/finegrained-fp8/torch-ext/finegrained_fp8/__init__.py +++ b/finegrained-fp8/torch-ext/finegrained_fp8/__init__.py @@ -14,6 +14,9 @@ w8a8_block_fp8_matmul, w8a8_tensor_fp8_matmul, ) +from .moe import moe_grouped, moe_batched +from .fused import moe_grouped_fused, moe_batched_fused +from .atomic import moe_grouped_atomic, moe_batched_atomic __all__ = [ "fp8_act_quant", @@ -29,4 +32,13 @@ "w8a8_fp8_matmul_grouped", "w8a8_block_fp8_matmul_grouped", "w8a8_tensor_fp8_matmul_grouped", + # End-to-end MoE (unfused) + "moe_grouped", + "moe_batched", + # End-to-end MoE (fused, deterministic) + "moe_grouped_fused", + "moe_batched_fused", + # End-to-end MoE (fused, atomic) + "moe_grouped_atomic", + "moe_batched_atomic", ] diff --git a/finegrained-fp8/torch-ext/finegrained_fp8/atomic.py b/finegrained-fp8/torch-ext/finegrained_fp8/atomic.py new file mode 100644 index 00000000..36fcbf82 --- /dev/null +++ b/finegrained-fp8/torch-ext/finegrained_fp8/atomic.py @@ -0,0 +1,634 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Atomic fused MoE: gate_up + SiLU + down in one pass with atomic split-K. + +The intermediate tensor NEVER hits HBM — gate_up, SiLU, FP8 quantization, +and down projection all happen in registers. Output is accumulated via +atomic_add (split-K across intermediate N-tiles). + +Non-deterministic due to atomic_add ordering. Faster than the two-kernel +deterministic path at small token counts (decode), but degrades at high +token counts due to atomic contention. +""" + +import torch +import triton +import triton.language as tl +from torch.library import triton_op, wrap_triton + +from .utils import device_context + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=w, num_stages=s) + for w in [4, 8, 16] + for s in [2, 3, 4] + ], + key=["N_inter", "K", "BLOCK_SIZE_M"], + reset_to_zero=["Out"], +) +@triton.jit +def moe_atomic_kernel( + # Unsorted inputs + A, # (num_tokens, K) raw BF16/FP16 — NOT sorted + Perm, # (S,) int64 — sorted_pos → original flat index + SampleWeights, # (S,) routing weights in sorted order + # Expert weights + W_gu, # (E, 2*N_inter, K) FP8 gate_up weights + W_down, # (E, hidden, N_inter) FP8 down weights + Ws_gu, # gate_up scales + Ws_down, # down scales + # Output + Out, # (S, hidden) — accumulated via atomic_add + # Expert scheduling + Offsets, + TileOffsets, + # Shapes + N_inter, + K, + hidden, + num_top_k, + # Strides — A, W_gu, Ws_gu, W_down, Ws_down, Out + stride_am, + stride_ak, + stride_be_gu, + stride_bk_gu, + stride_bn_gu, + stride_bs_e_gu, + stride_bs_k_gu, + stride_bs_n_gu, + stride_be_down, + stride_bk_down, + stride_bn_down, + stride_bs_e_down, + stride_bs_k_down, + stride_bs_n_down, + stride_om, + stride_oh, + # Constexprs + NUM_EXPERTS: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + NUM_H_TILES: tl.constexpr, + NUM_EXPERTS_BIT_LENGTH: tl.constexpr, + SIMULATE_UNFUSED: tl.constexpr, +): + """Single fused MoE kernel: gather + gate_up + SiLU + down in one pass. + + Grid: (M-tiles, N-tiles). Each program: + 1. Gathers A from unsorted hidden_states via Perm + 2. Gate+up GEMM → SiLU → FP8 quant (intermediate in registers) + 3. Loops over H-tiles: down projection → atomic_add to output + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + + total_tiles = tl.load(TileOffsets + NUM_EXPERTS - 1) + if pid_m >= total_tiles: + return + + # Binary search for expert + lo = 0 + hi = NUM_EXPERTS + for _ in tl.static_range(NUM_EXPERTS_BIT_LENGTH): + mid = (lo + hi) >> 1 + mid_val = tl.load(TileOffsets + mid) + is_left = mid_val <= pid_m + lo = tl.where(is_left, mid + 1, lo) + hi = tl.where(is_left, hi, mid) + expert_id = lo.to(tl.int64) + + prev_eid = tl.maximum(expert_id - 1, 0) + expert_start = tl.where(expert_id == 0, 0, tl.load(Offsets + prev_eid)) + expert_end = tl.load(Offsets + expert_id) + M_expert = expert_end - expert_start + + expert_tile_start = tl.where(expert_id == 0, 0, tl.load(TileOffsets + prev_eid)) + local_tile = pid_m - expert_tile_start + m_off = local_tile * BLOCK_SIZE_M + + offs_am = m_off + tl.arange(0, BLOCK_SIZE_M) + row_mask = offs_am < M_expert + sorted_indices = expert_start + offs_am + + # Gather from unsorted A via Perm + perm_vals = tl.load(Perm + sorted_indices, mask=row_mask, other=0) + original_tokens = perm_vals // num_top_k + + # ── Gate + Up projection ── + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = A + original_tokens[:, None] * stride_am + offs_k[None, :] * stride_ak + b_base = W_gu + expert_id * stride_be_gu + offs_k[:, None] * stride_bk_gu + b_gate_ptrs = b_base + offs_bn[None, :] * stride_bn_gu + b_up_ptrs = b_base + (N_inter + offs_bn)[None, :] * stride_bn_gu + + n_scale_blocks = N_inter // BLOCK_SIZE_N + bs_base = Ws_gu + expert_id * stride_bs_e_gu + bs_gate_ptrs = bs_base + pid_n * stride_bs_n_gu + bs_up_ptrs = bs_base + (n_scale_blocks + pid_n) * stride_bs_n_gu + + acc_gate = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + acc_up = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + b_gate = tl.load(b_gate_ptrs) + b_up = tl.load(b_up_ptrs) + bs_gate = tl.load(bs_gate_ptrs + k * stride_bs_k_gu) + bs_up = tl.load(bs_up_ptrs + k * stride_bs_k_gu) + + a_raw = tl.load(a_ptrs, mask=row_mask[:, None], other=0.0).to(tl.float32) + a_s = tl.max(tl.abs(a_raw), axis=1) / 448.0 + a = (a_raw / tl.maximum(a_s[:, None], 1e-12)).to(tl.float8e4nv) + + acc_gate += tl.dot(a, b_gate) * a_s[:, None] * bs_gate[None, :] + acc_up += tl.dot(a, b_up) * a_s[:, None] * bs_up[None, :] + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_gate_ptrs += BLOCK_SIZE_K * stride_bk_gu + b_up_ptrs += BLOCK_SIZE_K * stride_bk_gu + + # ── SiLU(gate) * up ── + if SIMULATE_UNFUSED: + acc_gate = acc_gate.to(tl.bfloat16).to(tl.float32) + acc_up = acc_up.to(tl.bfloat16).to(tl.float32) + intermediate = (acc_gate * tl.sigmoid(acc_gate)).to(tl.bfloat16).to( + tl.float32 + ) * acc_up + intermediate = intermediate.to(tl.bfloat16).to(tl.float32) + else: + intermediate = acc_gate * tl.sigmoid(acc_gate) * acc_up + + # ── Quantize intermediate to FP8 ── + inter_s = tl.max(tl.abs(intermediate), axis=1) / 448.0 + inter_fp8 = (intermediate / tl.maximum(inter_s[:, None], 1e-12)).to(tl.float8e4nv) + + # ── Down projection + atomic accumulate ── + offs_h = tl.arange(0, BLOCK_SIZE_H) + for h in range(0, NUM_H_TILES): + h_offs = h * BLOCK_SIZE_H + offs_h + w_down_ptrs = ( + W_down + + expert_id * stride_be_down + + offs_bn[:, None] * stride_bk_down + + h_offs[None, :] * stride_bn_down + ) + w_down = tl.load(w_down_ptrs) + ws_down = tl.load( + Ws_down + + expert_id * stride_bs_e_down + + h * stride_bs_n_down + + pid_n * stride_bs_k_down + ) + + partial = tl.dot(inter_fp8, w_down) * inter_s[:, None] * ws_down + + out_ptrs = ( + Out + sorted_indices[:, None] * stride_om + h_offs[None, :] * stride_oh + ) + tl.atomic_add( + out_ptrs, + partial.to(Out.dtype.element_ty), + mask=row_mask[:, None], + sem="relaxed", + ) + + +# ── Wrapper ────────────────────────────────────────────────────────────────── + + +@triton_op("finegrained_fp8::moe_grouped_atomic", mutates_args=()) +def _moe_grouped_atomic( + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + gate_up_proj: torch.Tensor, + down_proj: torch.Tensor, + gate_up_proj_scale_inv: torch.Tensor, + down_proj_scale_inv: torch.Tensor, + block_size: list[int], + simulate_unfused: bool = False, +) -> torch.Tensor: + """Single-kernel fused MoE expert layer: non-deterministic (atomic split-K). + + Input: unsorted hidden_states + router outputs (top_k_index, top_k_weights) + Output: (num_tokens, hidden) — accumulated across top_k experts + + Pipeline: sort → ONE kernel (gather + gate_up + SiLU + down, atomic split-K) → routing + unsort + reduce + Non-deterministic: atomic_add across intermediate N-tiles. + """ + device = hidden_states.device + num_top_k = top_k_index.size(-1) + num_experts = gate_up_proj.size(0) + num_tokens = hidden_states.size(0) + hidden_dim = down_proj.size(1) + intermediate_dim = down_proj.size(2) + block_n, block_k = block_size + + # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) + sample_weights = top_k_weights.reshape(-1) # (S,) + expert_ids = top_k_index.reshape(-1) # (S,) + S = expert_ids.size(0) + + # Sort by expert for grouped processing + _, perm = expert_ids.sort(stable=True) + expert_ids_g = expert_ids[perm] + sample_weights_g = sample_weights[perm] + + # Compute offsets for grouped processing + histc_input = expert_ids_g.float() if device.type == "cpu" else expert_ids_g.int() + tokens_per_expert = torch.histc( + histc_input, bins=num_experts, min=0, max=num_experts - 1 + ) + offsets = torch.cumsum(tokens_per_expert, dim=0, dtype=torch.int32) + + # Tile setup + BLOCK_SIZE_M = min( + max(triton.next_power_of_2((S + num_experts - 1) // num_experts), 16), 128 + ) + tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M + tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32) + max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + num_experts + num_N_tiles = triton.cdiv(intermediate_dim, block_n) + + # fp32 output for atomic accumulation + proj_out = torch.zeros(S, hidden_dim, device=device, dtype=torch.float32) + + grid = (max_M_tiles, num_N_tiles) + with device_context(device): + wrap_triton(moe_atomic_kernel)[grid]( + hidden_states, + perm, + sample_weights_g, + gate_up_proj, + down_proj, + gate_up_proj_scale_inv, + down_proj_scale_inv, + proj_out, + offsets, + tile_offsets, + intermediate_dim, + hidden_states.shape[1], + hidden_dim, + num_top_k, + hidden_states.stride(0), + hidden_states.stride(1), + gate_up_proj.stride(0), + gate_up_proj.stride(2), + gate_up_proj.stride(1), + gate_up_proj_scale_inv.stride(0), + gate_up_proj_scale_inv.stride(2), + gate_up_proj_scale_inv.stride(1), + down_proj.stride(0), + down_proj.stride(2), + down_proj.stride(1), + down_proj_scale_inv.stride(0), + down_proj_scale_inv.stride(2), + down_proj_scale_inv.stride(1), + proj_out.stride(0), + proj_out.stride(1), + NUM_EXPERTS=num_experts, + BLOCK_SIZE_N=block_n, + BLOCK_SIZE_K=block_k, + BLOCK_SIZE_H=block_n, + BLOCK_SIZE_M=BLOCK_SIZE_M, + NUM_H_TILES=triton.cdiv(hidden_dim, block_n), + NUM_EXPERTS_BIT_LENGTH=num_experts.bit_length(), + SIMULATE_UNFUSED=simulate_unfused, + ) + + # Apply routing weights + unsort + reduce + proj_out = proj_out.to(hidden_states.dtype) + weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze(-1) + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(S, device=device) + final_hidden_states = ( + weighted_out[inv_perm].view(num_tokens, num_top_k, hidden_dim).sum(dim=1) + ) + + return final_hidden_states.to(hidden_states.dtype) + + +def moe_grouped_atomic( + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + gate_up_proj: torch.Tensor, + down_proj: torch.Tensor, + gate_up_proj_scale_inv: torch.Tensor, + down_proj_scale_inv: torch.Tensor, + block_size: list[int], + simulate_unfused: bool = False, +) -> torch.Tensor: + """Single-kernel fused MoE expert layer: non-deterministic (atomic split-K). + + Input: unsorted hidden_states + router outputs (top_k_index, top_k_weights) + Output: (num_tokens, hidden) — accumulated across top_k experts + + Pipeline: sort → ONE kernel (gather + gate_up + SiLU + down, atomic split-K) → routing + unsort + reduce + Non-deterministic: atomic_add across intermediate N-tiles. Intermediate stays in registers. + """ + return torch.ops.finegrained_fp8.moe_grouped_atomic( + hidden_states, + top_k_index, + top_k_weights, + gate_up_proj, + down_proj, + gate_up_proj_scale_inv, + down_proj_scale_inv, + block_size, + simulate_unfused, + ) + + +# ── Batched atomic: gate_up + SiLU + down per token (atomic split-K) ──────── + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=w, num_stages=s) + for w in [4, 8, 16] + for s in [2, 3, 4] + ], + key=["N_inter", "K", "BLOCK_SIZE_M"], + reset_to_zero=["Out"], +) +@triton.jit +def moe_batched_atomic_kernel( + A, # (S, K) raw BF16/FP16 activations + W_gu, # (E, 2*N_inter, K) FP8 gate_up weights + W_down, # (E, hidden, N_inter) FP8 down weights + Out, # (S, hidden) output — accumulated via atomic_add + Ws_gu, # gate_up scales + Ws_down, # down scales + ExpertIds, # (S,) expert index per token + SampleWeights, # (S,) routing weights + # Shapes + N_inter, + K, + hidden, + # Strides — A + stride_am, + stride_ak, + # Strides — W_gu, Ws_gu + stride_be_gu, + stride_bk_gu, + stride_bn_gu, + stride_bs_e_gu, + stride_bs_k_gu, + stride_bs_n_gu, + # Strides — W_down, Ws_down + stride_be_down, + stride_bk_down, + stride_bn_down, + stride_bs_e_down, + stride_bs_k_down, + stride_bs_n_down, + # Strides — Out + stride_om, + stride_oh, + # Constexprs + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + NUM_H_TILES: tl.constexpr, + NUM_EXPERTS_BIT_LENGTH: tl.constexpr, + SIMULATE_UNFUSED: tl.constexpr, +): + """Batched atomic fused MoE kernel: gate_up + SiLU + down per token. + + Grid: (S, N-tiles). Each program handles one (token, N-tile): + 1. Gate+up GEMM → SiLU → FP8 quant (intermediate in registers) + 2. Loops over H-tiles: down projection → atomic_add to output + """ + batch_id = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + + expert_id = tl.load(ExpertIds + batch_id).to(tl.int64) + offs_m = tl.arange(0, BLOCK_SIZE_M) + + # ── Gate + Up projection ── + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = A + batch_id * stride_am + offs_k[None, :] * stride_ak + b_base = W_gu + expert_id * stride_be_gu + offs_k[:, None] * stride_bk_gu + b_gate_ptrs = b_base + offs_bn[None, :] * stride_bn_gu + b_up_ptrs = b_base + (N_inter + offs_bn)[None, :] * stride_bn_gu + + n_scale_blocks = N_inter // BLOCK_SIZE_N + bs_base = Ws_gu + expert_id * stride_bs_e_gu + bs_gate_ptrs = bs_base + pid_n * stride_bs_n_gu + bs_up_ptrs = bs_base + (n_scale_blocks + pid_n) * stride_bs_n_gu + + acc_gate = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + acc_up = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a_raw = tl.load(a_ptrs + offs_m[:, None] * 0).to(tl.float32) + a_s = tl.max(tl.abs(a_raw)) / 448.0 + a = (a_raw / tl.maximum(a_s, 1e-12)).to(tl.float8e4nv) + + b_gate = tl.load(b_gate_ptrs) + b_up = tl.load(b_up_ptrs) + bs_gate = tl.load(bs_gate_ptrs + k * stride_bs_k_gu) + bs_up = tl.load(bs_up_ptrs + k * stride_bs_k_gu) + + acc_gate += tl.dot(a, b_gate) * a_s * bs_gate[None, :] + acc_up += tl.dot(a, b_up) * a_s * bs_up[None, :] + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_gate_ptrs += BLOCK_SIZE_K * stride_bk_gu + b_up_ptrs += BLOCK_SIZE_K * stride_bk_gu + + # ── SiLU(gate) * up ── + if SIMULATE_UNFUSED: + acc_gate = acc_gate.to(tl.bfloat16).to(tl.float32) + acc_up = acc_up.to(tl.bfloat16).to(tl.float32) + intermediate = (acc_gate * tl.sigmoid(acc_gate)).to(tl.bfloat16).to( + tl.float32 + ) * acc_up + intermediate = intermediate.to(tl.bfloat16).to(tl.float32) + else: + intermediate = acc_gate * tl.sigmoid(acc_gate) * acc_up + + # ── Quantize intermediate to FP8 ── + inter_s = tl.max(tl.abs(intermediate)) / 448.0 + inter_fp8 = (intermediate / tl.maximum(inter_s, 1e-12)).to(tl.float8e4nv) + + # ── Down projection + atomic accumulate ── + offs_h = tl.arange(0, BLOCK_SIZE_H) + for h in range(0, NUM_H_TILES): + h_offs = h * BLOCK_SIZE_H + offs_h + w_down_ptrs = ( + W_down + + expert_id * stride_be_down + + offs_bn[:, None] * stride_bk_down + + h_offs[None, :] * stride_bn_down + ) + w_down = tl.load(w_down_ptrs) + ws_down = tl.load( + Ws_down + + expert_id * stride_bs_e_down + + h * stride_bs_n_down + + pid_n * stride_bs_k_down + ) + + partial = tl.dot(inter_fp8, w_down) * inter_s * ws_down + + out_ptrs = ( + Out + + batch_id * stride_om + + h_offs[None, :] * stride_oh + + offs_m[:, None] * 0 + ) + # Only write row 0 — all rows are identical (batched: 1 token per program) + row_mask = offs_m[:, None] == 0 + tl.atomic_add( + out_ptrs, partial.to(Out.dtype.element_ty), mask=row_mask, sem="relaxed" + ) + + +@triton_op("finegrained_fp8::moe_batched_atomic", mutates_args=()) +def _moe_batched_atomic( + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + gate_up_proj: torch.Tensor, + down_proj: torch.Tensor, + gate_up_proj_scale_inv: torch.Tensor, + down_proj_scale_inv: torch.Tensor, + block_size: list[int], + simulate_unfused: bool = False, +) -> torch.Tensor: + """Batched atomic fused MoE expert layer: non-deterministic (atomic split-K). + + Input: unsorted hidden_states + router outputs (top_k_index, top_k_weights) + Output: (num_tokens, hidden) — accumulated across top_k experts + + Pipeline: expand → ONE kernel (gate_up + SiLU + down per token, atomic split-K) → routing + reduce + Non-deterministic: atomic_add across intermediate N-tiles. Intermediate stays in registers. + """ + device = hidden_states.device + num_top_k = top_k_index.size(-1) + num_tokens = hidden_states.size(0) + hidden_dim = down_proj.shape[1] + intermediate_dim = down_proj.shape[2] + block_n, block_k = block_size + num_experts = gate_up_proj.shape[0] + + token_idx = ( + torch.arange(num_tokens, device=device) + .unsqueeze(1) + .expand(-1, num_top_k) + .reshape(-1) + ) + sample_weights = top_k_weights.reshape(-1) + expert_ids = top_k_index.reshape(-1) + + selected_hidden_states = hidden_states[token_idx] + S = expert_ids.size(0) + + BLOCK_SIZE_M = min( + max(triton.next_power_of_2((S + num_experts - 1) // num_experts), 16), 128 + ) + num_N_tiles = triton.cdiv(intermediate_dim, block_n) + + Out = torch.zeros(S, hidden_dim, device=device, dtype=torch.float32) + grid = (S, num_N_tiles) + + with device_context(device): + wrap_triton(moe_batched_atomic_kernel)[grid]( + selected_hidden_states, + gate_up_proj, + down_proj, + Out, + gate_up_proj_scale_inv, + down_proj_scale_inv, + expert_ids, + sample_weights, + intermediate_dim, + hidden_states.shape[1], + hidden_dim, + selected_hidden_states.stride(0), + selected_hidden_states.stride(1), + gate_up_proj.stride(0), + gate_up_proj.stride(2), + gate_up_proj.stride(1), + gate_up_proj_scale_inv.stride(0), + gate_up_proj_scale_inv.stride(2), + gate_up_proj_scale_inv.stride(1), + down_proj.stride(0), + down_proj.stride(2), + down_proj.stride(1), + down_proj_scale_inv.stride(0), + down_proj_scale_inv.stride(2), + down_proj_scale_inv.stride(1), + Out.stride(0), + Out.stride(1), + BLOCK_SIZE_N=block_n, + BLOCK_SIZE_K=block_k, + BLOCK_SIZE_H=block_n, + BLOCK_SIZE_M=BLOCK_SIZE_M, + NUM_H_TILES=triton.cdiv(hidden_dim, block_n), + NUM_EXPERTS_BIT_LENGTH=num_experts.bit_length(), + SIMULATE_UNFUSED=simulate_unfused, + ) + + # Apply routing weights + reduce + Out = Out.to(hidden_states.dtype) + weighted_out = Out * sample_weights.to(Out.dtype).unsqueeze(-1) + final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum( + dim=1 + ) + + return final_hidden_states.to(hidden_states.dtype) + + +def moe_batched_atomic( + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + gate_up_proj: torch.Tensor, + down_proj: torch.Tensor, + gate_up_proj_scale_inv: torch.Tensor, + down_proj_scale_inv: torch.Tensor, + block_size: list[int], + simulate_unfused: bool = False, +) -> torch.Tensor: + """Batched atomic fused MoE expert layer: non-deterministic (atomic split-K). + + Input: unsorted hidden_states + router outputs (top_k_index, top_k_weights) + Output: (num_tokens, hidden) — accumulated across top_k experts + + Pipeline: expand → ONE kernel (gate_up + SiLU + down per token, atomic split-K) → routing + reduce + Non-deterministic: atomic_add across intermediate N-tiles. Intermediate stays in registers. + """ + return torch.ops.finegrained_fp8.moe_batched_atomic( + hidden_states, + top_k_index, + top_k_weights, + gate_up_proj, + down_proj, + gate_up_proj_scale_inv, + down_proj_scale_inv, + block_size, + simulate_unfused, + ) diff --git a/finegrained-fp8/torch-ext/finegrained_fp8/fused.py b/finegrained-fp8/torch-ext/finegrained_fp8/fused.py new file mode 100644 index 00000000..bf3c0c83 --- /dev/null +++ b/finegrained-fp8/torch-ext/finegrained_fp8/fused.py @@ -0,0 +1,775 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fused MoE: two-kernel deterministic approach. + +Kernel 1 (M×N grid): gather + gate_up + SiLU + FP8 quant → fp8 intermediate buffer +Kernel 2 (M×H grid): fp8 intermediate → down projection → output + +Both kernels are deterministic (no atomic_add). The fp8 intermediate buffer +is ~half the size of the bf16 intermediate in the unfused path. + +The full pipeline: + 1. histc + cumsum (expert offsets) + 2. sort (expert grouping) + 3. Kernel 1: gather → gate_up → SiLU → FP8 quant → (S, N_inter) fp8 + scales + 4. Kernel 2: fp8 intermediate → down_proj → (S, hidden) + 5. routing weights + unsort + top_k reduce +""" + +import torch +import triton +import triton.language as tl +from torch.library import triton_op, wrap_triton + +from .utils import device_context + + +# ── Kernel 1: gather + gate_up + SiLU + FP8 quant ────────────────────────── + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=w, num_stages=s) + for w in [4, 8, 16] + for s in [2, 3, 4] + ], + key=["N_inter", "K", "BLOCK_SIZE_M"], +) +@triton.jit +def fused_gate_up_silu_kernel( + # Unsorted inputs + A, # (num_tokens, K) raw BF16/FP16 — NOT sorted + Perm, # (S,) int64 — sorted_pos → original flat index + # Expert weights + W_gu, # (E, 2*N_inter, K) FP8 gate_up weights + Ws_gu, # gate_up scales + # Outputs + Inter, # (S, N_inter) FP8 intermediate + Inter_s, # (S,) fp32 per-row scales + # Expert scheduling + Offsets, + TileOffsets, + # Shapes + N_inter, + K, + num_top_k, + # Strides — A, W_gu, Ws_gu, Inter + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_bs_e, + stride_bs_k, + stride_bs_n, + stride_im, + stride_in, + # Constexprs + NUM_EXPERTS: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + NUM_EXPERTS_BIT_LENGTH: tl.constexpr, + SIMULATE_UNFUSED: tl.constexpr, +): + """Kernel 1: gather A from unsorted → gate_up GEMM → SiLU → FP8 quant. + + Grid: (M-tiles, N-tiles). Each program writes one (BLOCK_M, BLOCK_N) + tile of the fp8 intermediate + per-row scales. No atomics. + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + + total_tiles = tl.load(TileOffsets + NUM_EXPERTS - 1) + if pid_m >= total_tiles: + return + + # Binary search for expert + lo = 0 + hi = NUM_EXPERTS + for _ in tl.static_range(NUM_EXPERTS_BIT_LENGTH): + mid = (lo + hi) >> 1 + mid_val = tl.load(TileOffsets + mid) + is_left = mid_val <= pid_m + lo = tl.where(is_left, mid + 1, lo) + hi = tl.where(is_left, hi, mid) + expert_id = lo.to(tl.int64) + + prev_eid = tl.maximum(expert_id - 1, 0) + expert_start = tl.where(expert_id == 0, 0, tl.load(Offsets + prev_eid)) + expert_end = tl.load(Offsets + expert_id) + M_expert = expert_end - expert_start + + expert_tile_start = tl.where(expert_id == 0, 0, tl.load(TileOffsets + prev_eid)) + local_tile = pid_m - expert_tile_start + m_off = local_tile * BLOCK_SIZE_M + + offs_am = m_off + tl.arange(0, BLOCK_SIZE_M) + row_mask = offs_am < M_expert + sorted_indices = expert_start + offs_am + + # Gather from unsorted A via Perm + perm_vals = tl.load(Perm + sorted_indices, mask=row_mask, other=0) + original_tokens = perm_vals // num_top_k + + # Gate + Up projection + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = A + original_tokens[:, None] * stride_am + offs_k[None, :] * stride_ak + b_base = W_gu + expert_id * stride_be + offs_k[:, None] * stride_bk + b_gate_ptrs = b_base + offs_bn[None, :] * stride_bn + b_up_ptrs = b_base + (N_inter + offs_bn)[None, :] * stride_bn + + n_scale_blocks = N_inter // BLOCK_SIZE_N + bs_base = Ws_gu + expert_id * stride_bs_e + bs_gate_ptrs = bs_base + pid_n * stride_bs_n + bs_up_ptrs = bs_base + (n_scale_blocks + pid_n) * stride_bs_n + + acc_gate = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + acc_up = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + b_gate = tl.load(b_gate_ptrs) + b_up = tl.load(b_up_ptrs) + bs_gate = tl.load(bs_gate_ptrs + k * stride_bs_k) + bs_up = tl.load(bs_up_ptrs + k * stride_bs_k) + + a_raw = tl.load(a_ptrs, mask=row_mask[:, None], other=0.0).to(tl.float32) + a_s = tl.max(tl.abs(a_raw), axis=1) / 448.0 + a = (a_raw / tl.maximum(a_s[:, None], 1e-12)).to(tl.float8e4nv) + + acc_gate += tl.dot(a, b_gate) * a_s[:, None] * bs_gate[None, :] + acc_up += tl.dot(a, b_up) * a_s[:, None] * bs_up[None, :] + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_gate_ptrs += BLOCK_SIZE_K * stride_bk + b_up_ptrs += BLOCK_SIZE_K * stride_bk + + # SiLU(gate) * up + if SIMULATE_UNFUSED: + acc_gate = acc_gate.to(tl.bfloat16).to(tl.float32) + acc_up = acc_up.to(tl.bfloat16).to(tl.float32) + intermediate = (acc_gate * tl.sigmoid(acc_gate)).to(tl.bfloat16).to( + tl.float32 + ) * acc_up + intermediate = intermediate.to(tl.bfloat16).to(tl.float32) + else: + intermediate = acc_gate * tl.sigmoid(acc_gate) * acc_up + + # FP8 quantize — per-row scale across this N-tile + inter_s = tl.max(tl.abs(intermediate), axis=1) / 448.0 + inter_fp8 = (intermediate / tl.maximum(inter_s[:, None], 1e-12)).to(tl.float8e4nv) + + # Store fp8 intermediate tile + inter_ptrs = ( + Inter + sorted_indices[:, None] * stride_im + offs_bn[None, :] * stride_in + ) + tl.store(inter_ptrs, inter_fp8, mask=row_mask[:, None]) + + # Store per-row scale (one per row per N-tile) + # Layout: Inter_s[sorted_idx, pid_n] + scale_ptrs = Inter_s + sorted_indices * tl.cdiv(N_inter, BLOCK_SIZE_N) + pid_n + tl.store(scale_ptrs, inter_s, mask=row_mask) + + +# ── Kernel 2: down projection from fp8 intermediate ───────────────────────── + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=w, num_stages=s) + for w in [4, 8, 16] + for s in [2, 3, 4] + ], + key=["N_inter", "hidden", "BLOCK_SIZE_M"], +) +@triton.jit +def fused_down_proj_kernel( + # Inputs + Inter, # (S, N_inter) FP8 intermediate + Inter_s, # (S, num_N_tiles) fp32 per-row-per-N-tile scales + W_down, # (E, hidden, N_inter) FP8 down weights + Ws_down, # down scales + SampleWeights, # (S,) routing weights in sorted order + Perm, # (S,) int64 — sorted_pos → original flat index + # Output + Out, # (num_tokens * top_k, hidden) output in original flat order + # Expert scheduling + Offsets, + TileOffsets, + # Shapes + N_inter, + hidden, + # Strides — Inter, W_down, Ws_down, Out + stride_im, + stride_in, + stride_be, + stride_bk, + stride_bn, + stride_bs_e, + stride_bs_k, + stride_bs_n, + stride_om, + stride_oh, + # Constexprs + NUM_EXPERTS: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + NUM_N_TILES: tl.constexpr, + NUM_EXPERTS_BIT_LENGTH: tl.constexpr, + SIMULATE_UNFUSED: tl.constexpr, +): + """Kernel 2: fp8 intermediate → down_proj → output. + + Grid: (M-tiles, H-tiles). Each program reads (BLOCK_M, N_inter) fp8 + intermediate, does the down projection tiled over N, stores (BLOCK_M, BLOCK_H). + No atomics — deterministic. + """ + pid_m = tl.program_id(axis=0) + pid_h = tl.program_id(axis=1) + + total_tiles = tl.load(TileOffsets + NUM_EXPERTS - 1) + if pid_m >= total_tiles: + return + + # Binary search for expert + lo = 0 + hi = NUM_EXPERTS + for _ in tl.static_range(NUM_EXPERTS_BIT_LENGTH): + mid = (lo + hi) >> 1 + mid_val = tl.load(TileOffsets + mid) + is_left = mid_val <= pid_m + lo = tl.where(is_left, mid + 1, lo) + hi = tl.where(is_left, hi, mid) + expert_id = lo.to(tl.int64) + + prev_eid = tl.maximum(expert_id - 1, 0) + expert_start = tl.where(expert_id == 0, 0, tl.load(Offsets + prev_eid)) + expert_end = tl.load(Offsets + expert_id) + M_expert = expert_end - expert_start + + expert_tile_start = tl.where(expert_id == 0, 0, tl.load(TileOffsets + prev_eid)) + local_tile = pid_m - expert_tile_start + m_off = local_tile * BLOCK_SIZE_M + + offs_am = m_off + tl.arange(0, BLOCK_SIZE_M) + row_mask = offs_am < M_expert + sorted_indices = expert_start + offs_am + + offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + offs_n = tl.arange(0, BLOCK_SIZE_N) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_H), dtype=tl.float32) + + for n_tile in range(0, NUM_N_TILES): + n_offs = n_tile * BLOCK_SIZE_N + offs_n + + # Load fp8 intermediate tile + inter_ptrs = ( + Inter + sorted_indices[:, None] * stride_im + n_offs[None, :] * stride_in + ) + inter_fp8 = tl.load(inter_ptrs, mask=row_mask[:, None], other=0.0) + + # Load per-row scale for this N-tile + scale_ptrs = Inter_s + sorted_indices * NUM_N_TILES + n_tile + inter_s = tl.load(scale_ptrs, mask=row_mask, other=0.0) + + # Load down weights + w_down_ptrs = ( + W_down + + expert_id * stride_be + + n_offs[:, None] * stride_bk + + offs_h[None, :] * stride_bn + ) + w_down = tl.load(w_down_ptrs) + ws_down = tl.load( + Ws_down + + expert_id * stride_bs_e + + pid_h * stride_bs_n + + n_tile * stride_bs_k + ) + + acc += tl.dot(inter_fp8, w_down) * inter_s[:, None] * ws_down + + # Apply routing weights and scatter to original flat order via Perm + if SIMULATE_UNFUSED: + acc = acc.to(tl.bfloat16).to(tl.float32) + routing_w = tl.load(SampleWeights + sorted_indices, mask=row_mask, other=0.0) + acc = acc * routing_w[:, None] + original_flat_idx = tl.load(Perm + sorted_indices, mask=row_mask, other=0) + out_ptrs = ( + Out + original_flat_idx[:, None] * stride_om + offs_h[None, :] * stride_oh + ) + tl.store(out_ptrs, acc.to(Out.dtype.element_ty), mask=row_mask[:, None]) + + +# ── Wrapper ────────────────────────────────────────────────────────────────── + + +@triton_op("finegrained_fp8::moe_grouped_fused", mutates_args=()) +def _moe_grouped_fused( + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + gate_up_proj: torch.Tensor, + down_proj: torch.Tensor, + gate_up_proj_scale_inv: torch.Tensor, + down_proj_scale_inv: torch.Tensor, + block_size: list[int], + simulate_unfused: bool = False, +) -> torch.Tensor: + """Two-kernel fused MoE expert layer: deterministic, no atomics. + + Input: unsorted hidden_states + router outputs (top_k_index, top_k_weights) + Output: (num_tokens, hidden) — accumulated across top_k experts + + Pipeline: sort → fused_gate_up_silu_kernel → fused_down_proj_kernel → routing + unsort + reduce + Deterministic: no atomic_add, intermediate goes through HBM as fp8. + """ + device = hidden_states.device + num_top_k = top_k_index.size(-1) + num_experts = gate_up_proj.size(0) + num_tokens = hidden_states.size(0) + hidden_dim = down_proj.size(1) + intermediate_dim = down_proj.size(2) + block_n, block_k = block_size + + # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) + sample_weights = top_k_weights.reshape(-1) # (S,) + expert_ids = top_k_index.reshape(-1) # (S,) + S = expert_ids.size(0) + + # Sort by expert for grouped processing + _, perm = expert_ids.sort(stable=True) + expert_ids_g = expert_ids[perm] + sample_weights_g = sample_weights[perm] + + # Compute offsets for grouped processing + histc_input = expert_ids_g.float() if device.type == "cpu" else expert_ids_g.int() + tokens_per_expert = torch.histc( + histc_input, bins=num_experts, min=0, max=num_experts - 1 + ) + offsets = torch.cumsum(tokens_per_expert, dim=0, dtype=torch.int32) + + # Tile setup + BLOCK_SIZE_M = min( + max(triton.next_power_of_2((S + num_experts - 1) // num_experts), 16), 128 + ) + tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M + tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32) + max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + num_experts + num_N_tiles = triton.cdiv(intermediate_dim, block_n) + num_H_tiles = triton.cdiv(hidden_dim, block_n) + + # Temp buffer: fp8 intermediate + per-row-per-N-tile scales + inter_fp8 = torch.empty( + S, intermediate_dim, device=device, dtype=torch.float8_e4m3fn + ) + inter_scales = torch.empty(S, num_N_tiles, device=device, dtype=torch.float32) + + # --- Kernel 1: gate_up + SiLU + FP8 quant --- + grid1 = (max_M_tiles, num_N_tiles) + with device_context(device): + wrap_triton(fused_gate_up_silu_kernel)[grid1]( + hidden_states, + perm, + gate_up_proj, + gate_up_proj_scale_inv, + inter_fp8, + inter_scales, + offsets, + tile_offsets, + intermediate_dim, + hidden_states.shape[1], + num_top_k, + hidden_states.stride(0), + hidden_states.stride(1), + gate_up_proj.stride(0), + gate_up_proj.stride(2), + gate_up_proj.stride(1), + gate_up_proj_scale_inv.stride(0), + gate_up_proj_scale_inv.stride(2), + gate_up_proj_scale_inv.stride(1), + inter_fp8.stride(0), + inter_fp8.stride(1), + NUM_EXPERTS=num_experts, + BLOCK_SIZE_N=block_n, + BLOCK_SIZE_K=block_k, + BLOCK_SIZE_M=BLOCK_SIZE_M, + NUM_EXPERTS_BIT_LENGTH=num_experts.bit_length(), + SIMULATE_UNFUSED=simulate_unfused, + ) + + # --- Kernel 2: down projection + routing weights + scatter to original order --- + proj_out = torch.empty(S, hidden_dim, device=device, dtype=hidden_states.dtype) + grid2 = (max_M_tiles, num_H_tiles) + with device_context(device): + wrap_triton(fused_down_proj_kernel)[grid2]( + inter_fp8, + inter_scales, + down_proj, + down_proj_scale_inv, + sample_weights_g, + perm, + proj_out, + offsets, + tile_offsets, + intermediate_dim, + hidden_dim, + inter_fp8.stride(0), + inter_fp8.stride(1), + down_proj.stride(0), + down_proj.stride(2), + down_proj.stride(1), + down_proj_scale_inv.stride(0), + down_proj_scale_inv.stride(2), + down_proj_scale_inv.stride(1), + proj_out.stride(0), + proj_out.stride(1), + NUM_EXPERTS=num_experts, + BLOCK_SIZE_N=block_n, + BLOCK_SIZE_H=block_n, + BLOCK_SIZE_M=BLOCK_SIZE_M, + NUM_N_TILES=num_N_tiles, + NUM_EXPERTS_BIT_LENGTH=num_experts.bit_length(), + SIMULATE_UNFUSED=simulate_unfused, + ) + + # Output already in original flat order — just reduce across top_k + final_hidden_states = proj_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1) + + return final_hidden_states.to(hidden_states.dtype) + + +def moe_grouped_fused( + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + gate_up_proj: torch.Tensor, + down_proj: torch.Tensor, + gate_up_proj_scale_inv: torch.Tensor, + down_proj_scale_inv: torch.Tensor, + block_size: list[int], + simulate_unfused: bool = False, +) -> torch.Tensor: + """Two-kernel fused MoE expert layer: deterministic, no atomics. + + Input: unsorted hidden_states + router outputs (top_k_index, top_k_weights) + Output: (num_tokens, hidden) — accumulated across top_k experts + + Pipeline: sort → fused_gate_up_silu_kernel → fused_down_proj_kernel → routing + unsort + reduce + Deterministic: no atomic_add, intermediate goes through HBM as fp8. + """ + return torch.ops.finegrained_fp8.moe_grouped_fused( + hidden_states, + top_k_index, + top_k_weights, + gate_up_proj, + down_proj, + gate_up_proj_scale_inv, + down_proj_scale_inv, + block_size, + simulate_unfused, + ) + + +# ── Batched fused: gate_up + SiLU + down (no sorting, no atomics) ─────────── +# +# Each program handles one (token, H-tile) and loops over N-tiles sequentially. +# The intermediate stays entirely in registers. No sorting needed — expert +# lookup is per-token via ExpertIds. + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=w, num_stages=s) + for w in [4, 8, 16] + for s in [2, 3, 4] + ], + key=["N_inter", "K", "BLOCK_SIZE_M"], +) +@triton.jit +def moe_batched_fused_kernel( + A, # (S, K) raw BF16/FP16 activations + W_gu, # (E, 2*N_inter, K) FP8 gate_up weights + W_down, # (E, hidden, N_inter) FP8 down weights + Out, # (S, hidden) output + Ws_gu, # gate_up scales + Ws_down, # down scales + ExpertIds, # (S,) expert index per token + SampleWeights, # (S,) routing weights + # Shapes + N_inter, + K, + hidden, + # Strides — A + stride_am, + stride_ak, + # Strides — W_gu, Ws_gu + stride_be_gu, + stride_bk_gu, + stride_bn_gu, + stride_bs_e_gu, + stride_bs_k_gu, + stride_bs_n_gu, + # Strides — W_down, Ws_down + stride_be_down, + stride_bk_down, + stride_bn_down, + stride_bs_e_down, + stride_bs_k_down, + stride_bs_n_down, + # Strides — Out + stride_om, + stride_oh, + # Constexprs + NUM_N_TILES: tl.constexpr, + NUM_K_TILES: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + SIMULATE_UNFUSED: tl.constexpr, +): + """Batched fused MoE kernel: gate_up + SiLU + down in one kernel, no atomics. + + Grid: (S, H-tiles). Each program handles one (token, H-tile) and loops + over N-tiles. The intermediate stays entirely in registers. + """ + batch_id = tl.program_id(axis=0) + pid_h = tl.program_id(axis=1) + + expert_id = tl.load(ExpertIds + batch_id).to(tl.int64) + + offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + offs_m = tl.arange(0, BLOCK_SIZE_M) + + acc_down = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_H), dtype=tl.float32) + + for n_inter in range(0, NUM_N_TILES): + offs_n = n_inter * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # ── Gate + Up projection for this N-tile ── + a_ptrs = A + batch_id * stride_am + offs_k[None, :] * stride_ak + b_base = W_gu + expert_id * stride_be_gu + offs_k[:, None] * stride_bk_gu + b_gate_ptrs = b_base + offs_n[None, :] * stride_bn_gu + b_up_ptrs = b_base + (N_inter + offs_n)[None, :] * stride_bn_gu + + n_scale_blocks = N_inter // BLOCK_SIZE_N + bs_base = Ws_gu + expert_id * stride_bs_e_gu + bs_gate_ptr = bs_base + n_inter * stride_bs_n_gu + bs_up_ptr = bs_base + (n_scale_blocks + n_inter) * stride_bs_n_gu + + acc_gate = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + acc_up = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, NUM_K_TILES): + a_raw = tl.load(a_ptrs + offs_m[:, None] * 0).to(tl.float32) + a_s = tl.max(tl.abs(a_raw)) / 448.0 + a = (a_raw / tl.maximum(a_s, 1e-12)).to(tl.float8e4nv) + + b_gate = tl.load(b_gate_ptrs) + b_up = tl.load(b_up_ptrs) + bs_gate = tl.load(bs_gate_ptr + k * stride_bs_k_gu) + bs_up = tl.load(bs_up_ptr + k * stride_bs_k_gu) + + acc_gate += tl.dot(a, b_gate) * a_s * bs_gate[None, :] + acc_up += tl.dot(a, b_up) * a_s * bs_up[None, :] + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_gate_ptrs += BLOCK_SIZE_K * stride_bk_gu + b_up_ptrs += BLOCK_SIZE_K * stride_bk_gu + + # ── SiLU(gate) * up ── + if SIMULATE_UNFUSED: + acc_gate = acc_gate.to(tl.bfloat16).to(tl.float32) + acc_up = acc_up.to(tl.bfloat16).to(tl.float32) + intermediate = (acc_gate * tl.sigmoid(acc_gate)).to(tl.bfloat16).to( + tl.float32 + ) * acc_up + intermediate = intermediate.to(tl.bfloat16).to(tl.float32) + else: + intermediate = acc_gate * tl.sigmoid(acc_gate) * acc_up + + # ── Quantize intermediate to FP8 ── + inter_s = tl.max(tl.abs(intermediate)) / 448.0 + inter_fp8 = (intermediate / tl.maximum(inter_s, 1e-12)).to(tl.float8e4nv) + + # ── Partial down projection ── + w_down_ptrs = ( + W_down + + expert_id * stride_be_down + + offs_n[:, None] * stride_bk_down + + offs_h[None, :] * stride_bn_down + ) + w_down = tl.load(w_down_ptrs) + ws_down = tl.load( + Ws_down + + expert_id * stride_bs_e_down + + pid_h * stride_bs_n_down + + n_inter * stride_bs_k_down + ) + + acc_down += tl.dot(inter_fp8, w_down) * inter_s * ws_down + + # ── Apply routing weight and store ── + if SIMULATE_UNFUSED: + acc_down = acc_down.to(tl.bfloat16).to(tl.float32) + routing_w = tl.load(SampleWeights + batch_id) + acc_down = acc_down * routing_w + + if Out.dtype.element_ty == tl.bfloat16: + c = acc_down.to(tl.bfloat16) + elif Out.dtype.element_ty == tl.float16: + c = acc_down.to(tl.float16) + else: + c = acc_down.to(tl.float32) + + c_ptrs = ( + Out + batch_id * stride_om + offs_h[None, :] * stride_oh + offs_m[:, None] * 0 + ) + tl.store(c_ptrs, c) + + +@triton_op("finegrained_fp8::moe_batched_fused", mutates_args=()) +def _moe_batched_fused( + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + gate_up_proj: torch.Tensor, + down_proj: torch.Tensor, + gate_up_proj_scale_inv: torch.Tensor, + down_proj_scale_inv: torch.Tensor, + block_size: list[int], + simulate_unfused: bool = False, +) -> torch.Tensor: + """Batched fused MoE expert layer: deterministic, no sorting, no atomics. + + Input: unsorted hidden_states + router outputs (top_k_index, top_k_weights) + Output: (num_tokens, hidden) — accumulated across top_k experts + + Pipeline: expand → ONE kernel (gate_up + SiLU + down per token) → routing + reduce + Deterministic: each token processed independently. Intermediate stays in registers. + """ + device = hidden_states.device + num_top_k = top_k_index.size(-1) + num_tokens = hidden_states.size(0) + hidden_dim = down_proj.shape[1] + intermediate_dim = down_proj.shape[2] + block_n, block_k = block_size + + # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) + token_idx = ( + torch.arange(num_tokens, device=device) + .unsqueeze(1) + .expand(-1, num_top_k) + .reshape(-1) + ) + sample_weights = top_k_weights.reshape(-1) + expert_ids = top_k_index.reshape(-1) + + selected_hidden_states = hidden_states[token_idx] + S = expert_ids.size(0) + + BLOCK_SIZE_M = min( + max( + triton.next_power_of_2( + (S + gate_up_proj.shape[0] - 1) // gate_up_proj.shape[0] + ), + 16, + ), + 128, + ) + Out = selected_hidden_states.new_empty(S, hidden_dim) + num_H_tiles = triton.cdiv(hidden_dim, block_n) + grid = (S, num_H_tiles) + + with device_context(device): + wrap_triton(moe_batched_fused_kernel)[grid]( + selected_hidden_states, + gate_up_proj, + down_proj, + Out, + gate_up_proj_scale_inv, + down_proj_scale_inv, + expert_ids, + sample_weights, + intermediate_dim, + hidden_states.shape[1], + hidden_dim, + selected_hidden_states.stride(0), + selected_hidden_states.stride(1), + gate_up_proj.stride(0), + gate_up_proj.stride(2), + gate_up_proj.stride(1), + gate_up_proj_scale_inv.stride(0), + gate_up_proj_scale_inv.stride(2), + gate_up_proj_scale_inv.stride(1), + down_proj.stride(0), + down_proj.stride(2), + down_proj.stride(1), + down_proj_scale_inv.stride(0), + down_proj_scale_inv.stride(2), + down_proj_scale_inv.stride(1), + Out.stride(0), + Out.stride(1), + NUM_N_TILES=triton.cdiv(intermediate_dim, block_n), + NUM_K_TILES=triton.cdiv(hidden_states.shape[1], block_k), + BLOCK_SIZE_N=block_n, + BLOCK_SIZE_K=block_k, + BLOCK_SIZE_H=block_n, + BLOCK_SIZE_M=BLOCK_SIZE_M, + SIMULATE_UNFUSED=simulate_unfused, + ) + + # Routing weights already applied in kernel — just reduce + final_hidden_states = Out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1) + + return final_hidden_states.to(hidden_states.dtype) + + +def moe_batched_fused( + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + gate_up_proj: torch.Tensor, + down_proj: torch.Tensor, + gate_up_proj_scale_inv: torch.Tensor, + down_proj_scale_inv: torch.Tensor, + block_size: list[int], + simulate_unfused: bool = False, +) -> torch.Tensor: + """Batched fused MoE expert layer: deterministic, no sorting, no atomics. + + Input: unsorted hidden_states + router outputs (top_k_index, top_k_weights) + Output: (num_tokens, hidden) — accumulated across top_k experts + + Pipeline: expand → ONE kernel (gate_up + SiLU + down per token) → routing + reduce + Deterministic: each token processed independently. Intermediate stays in registers. + """ + return torch.ops.finegrained_fp8.moe_batched_fused( + hidden_states, + top_k_index, + top_k_weights, + gate_up_proj, + down_proj, + gate_up_proj_scale_inv, + down_proj_scale_inv, + block_size, + simulate_unfused, + ) diff --git a/finegrained-fp8/torch-ext/finegrained_fp8/moe.py b/finegrained-fp8/torch-ext/finegrained_fp8/moe.py new file mode 100644 index 00000000..aa53f000 --- /dev/null +++ b/finegrained-fp8/torch-ext/finegrained_fp8/moe.py @@ -0,0 +1,189 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end MoE expert layer using unfused grouped GEMM primitives. + +Pipeline: + 1. Sort tokens by expert (histc + cumsum + argsort) + 2. Gate+up grouped GEMM → bf16 intermediate + 3. SiLU(gate) * up → bf16 intermediate + 4. Down grouped GEMM → bf16 output + 5. Apply routing weights + unsort + top_k reduce + +All operations are deterministic (no atomic_add). This serves as the +baseline implementation; see fused.py for the fused variant that +merges steps 2-3 into a single kernel. +""" + +import torch +import torch.nn.functional as F + +from .batched import w8a8_block_fp8_matmul_batched +from .grouped import w8a8_block_fp8_matmul_grouped + + +def moe_grouped( + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + gate_up_proj: torch.Tensor, + down_proj: torch.Tensor, + gate_up_proj_scale_inv: torch.Tensor, + down_proj_scale_inv: torch.Tensor, + block_size: list[int], +) -> torch.Tensor: + """End-to-end MoE expert layer using unfused grouped GEMM primitives. + + Input: unsorted hidden_states + router outputs (top_k_index, top_k_weights) + Output: (num_tokens, hidden) — accumulated across top_k experts + + Pipeline: sort → gate_up GEMM → SiLU → down GEMM → routing + unsort + reduce + Deterministic: no atomic_add, all operations have fixed execution order. + """ + + device = hidden_states.device + num_top_k = top_k_index.size(-1) + num_experts = gate_up_proj.size(0) + num_tokens = hidden_states.size(0) + hidden_dim = hidden_states.size(-1) + + # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) + token_idx = ( + torch.arange(num_tokens, device=device) + .unsqueeze(1) + .expand(-1, num_top_k) + .reshape(-1) + ) # (S,) + sample_weights = top_k_weights.reshape(-1) # (S,) + expert_ids = top_k_index.reshape(-1) # (S,) + + # Sort by expert for grouped processing + perm = torch.argsort(expert_ids) + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(perm.size(0), device=device) + + expert_ids_g = expert_ids[perm] + sample_weights_g = sample_weights[perm] + selected_hidden_states_g = hidden_states[token_idx[perm]] + + # Compute offsets for grouped processing. + # histc instead of bincount avoids cuda-graph issues; + # CPU requires float input, CUDA requires int input (deterministic mode). + histc_input = expert_ids_g.float() if device.type == "cpu" else expert_ids_g.int() + tokens_per_expert = torch.histc( + histc_input, bins=num_experts, min=0, max=num_experts - 1 + ) + offsets = torch.cumsum(tokens_per_expert, dim=0, dtype=torch.int32) + + # --- Gate+up projection per expert (FP8 grouped) --- + proj_out = w8a8_block_fp8_matmul_grouped( + selected_hidden_states_g, + gate_up_proj, + gate_up_proj_scale_inv, + offsets, + tokens_per_expert, + block_size, + ) # (S, 2 * intermediate_dim) + + # Apply SiLU gating + gate, up = proj_out.chunk(2, dim=-1) + proj_out = F.silu(gate) * up # (S, intermediate_dim) + + # --- Down projection per expert (FP8 grouped) --- + proj_out = w8a8_block_fp8_matmul_grouped( + proj_out, down_proj, down_proj_scale_inv, offsets, tokens_per_expert, block_size + ) # (S, hidden_dim) + + # Apply routing weights + weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze( + -1 + ) # (S, hidden_dim) + + # Restore original order + weighted_out = weighted_out[inv_perm] + + # Accumulate results using deterministic reshape+sum instead of index_add_ + # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd) + final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum( + dim=1 + ) + + return final_hidden_states.to(hidden_states.dtype) + + +def moe_batched( + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + gate_up_proj: torch.Tensor, + down_proj: torch.Tensor, + gate_up_proj_scale_inv: torch.Tensor, + down_proj_scale_inv: torch.Tensor, + block_size: list[int], +) -> torch.Tensor: + """End-to-end MoE expert layer using unfused batched GEMM primitives. + + Input: unsorted hidden_states + router outputs (top_k_index, top_k_weights) + Output: (num_tokens, hidden) — accumulated across top_k experts + + Pipeline: expand → batched_mm(gate_up) → SiLU → batched_mm(down) → routing + reduce + Deterministic: no sorting, no atomic_add. Each token processed independently. + """ + device = hidden_states.device + num_top_k = top_k_index.size(-1) + num_tokens = hidden_states.size(0) + hidden_dim = hidden_states.size(-1) + + # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) + token_idx = ( + torch.arange(num_tokens, device=device) + .unsqueeze(1) + .expand(-1, num_top_k) + .reshape(-1) + ) # (S,) + sample_weights = top_k_weights.reshape(-1) # (S,) + expert_ids = top_k_index.reshape(-1) # (S,) + + # Get current hidden states for selected samples (no sorting needed) + selected_hidden_states = hidden_states[token_idx] + + # --- Gate+up projection per expert (FP8 batched) --- + proj_out = w8a8_block_fp8_matmul_batched( + selected_hidden_states, + gate_up_proj, + gate_up_proj_scale_inv, + expert_ids, + block_size, + ) # (S, 2 * intermediate_dim) + + # Apply SiLU gating + gate, up = proj_out.chunk(2, dim=-1) + proj_out = F.silu(gate) * up # (S, intermediate_dim) + + # --- Down projection per expert (FP8 batched) --- + proj_out = w8a8_block_fp8_matmul_batched( + proj_out, down_proj, down_proj_scale_inv, expert_ids, block_size + ) # (S, hidden_dim) + + # Apply routing weights + weighted_out = proj_out * sample_weights.to(proj_out.dtype).unsqueeze( + -1 + ) # (S, hidden_dim) + + # Accumulate results using deterministic reshape+sum instead of index_add_ + final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum( + dim=1 + ) + + return final_hidden_states.to(hidden_states.dtype) From 8718947c336a0e5ad234f44b6c66d9ebb8e33b2e Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil Date: Wed, 1 Apr 2026 21:35:07 +0000 Subject: [PATCH 2/3] more configs --- finegrained-fp8/torch-ext/finegrained_fp8/atomic.py | 8 ++++---- finegrained-fp8/torch-ext/finegrained_fp8/fused.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/finegrained-fp8/torch-ext/finegrained_fp8/atomic.py b/finegrained-fp8/torch-ext/finegrained_fp8/atomic.py index 36fcbf82..ddc73334 100644 --- a/finegrained-fp8/torch-ext/finegrained_fp8/atomic.py +++ b/finegrained-fp8/torch-ext/finegrained_fp8/atomic.py @@ -34,8 +34,8 @@ @triton.autotune( configs=[ triton.Config({}, num_warps=w, num_stages=s) - for w in [4, 8, 16] - for s in [2, 3, 4] + for w in [2, 4, 8, 16] + for s in [2, 3, 4, 5] ], key=["N_inter", "K", "BLOCK_SIZE_M"], reset_to_zero=["Out"], @@ -363,8 +363,8 @@ def moe_grouped_atomic( @triton.autotune( configs=[ triton.Config({}, num_warps=w, num_stages=s) - for w in [4, 8, 16] - for s in [2, 3, 4] + for w in [2, 4, 8, 16] + for s in [2, 3, 4, 5] ], key=["N_inter", "K", "BLOCK_SIZE_M"], reset_to_zero=["Out"], diff --git a/finegrained-fp8/torch-ext/finegrained_fp8/fused.py b/finegrained-fp8/torch-ext/finegrained_fp8/fused.py index bf3c0c83..28197fc9 100644 --- a/finegrained-fp8/torch-ext/finegrained_fp8/fused.py +++ b/finegrained-fp8/torch-ext/finegrained_fp8/fused.py @@ -42,8 +42,8 @@ @triton.autotune( configs=[ triton.Config({}, num_warps=w, num_stages=s) - for w in [4, 8, 16] - for s in [2, 3, 4] + for w in [2, 4, 8, 16] + for s in [2, 3, 4, 5] ], key=["N_inter", "K", "BLOCK_SIZE_M"], ) @@ -191,8 +191,8 @@ def fused_gate_up_silu_kernel( @triton.autotune( configs=[ triton.Config({}, num_warps=w, num_stages=s) - for w in [4, 8, 16] - for s in [2, 3, 4] + for w in [2, 4, 8, 16] + for s in [2, 3, 4, 5] ], key=["N_inter", "hidden", "BLOCK_SIZE_M"], ) From 2e639af1e618f36d48c8da75981c857e2f00d30c Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil Date: Tue, 7 Apr 2026 09:26:45 +0000 Subject: [PATCH 3/3] final fused kernels --- finegrained-fp8/bench_fused_moe.py | 125 +++-- .../torch-ext/finegrained_fp8/atomic.py | 97 ++-- .../torch-ext/finegrained_fp8/fused.py | 465 +++++++++++------- 3 files changed, 410 insertions(+), 277 deletions(-) diff --git a/finegrained-fp8/bench_fused_moe.py b/finegrained-fp8/bench_fused_moe.py index fb307498..83c15a51 100644 --- a/finegrained-fp8/bench_fused_moe.py +++ b/finegrained-fp8/bench_fused_moe.py @@ -1,19 +1,23 @@ -"""Benchmark all MoE dispatch methods: correctness matrix + performance sweep + plot. +"""Benchmark MoE dispatch methods: correctness matrix + performance sweep + plot. Usage: - python bench_fused_moe.py + python bench_fused_moe.py --grouped # grouped variants + python bench_fused_moe.py --batched # batched variants + python bench_fused_moe.py --all # all variants """ +import argparse import sys sys.path.append("torch-ext") -import matplotlib -import matplotlib.pyplot as plt import torch import triton -from rich.console import Console +import matplotlib from rich.table import Table +from rich.console import Console +import matplotlib.pyplot as plt + from finegrained_fp8 import ( moe_grouped, moe_batched, @@ -30,15 +34,20 @@ console = Console() -METHODS = { +GROUPED_METHODS = { "grouped": moe_grouped, - # "batched": moe_batched, "grouped_fused": moe_grouped_fused, - # "batched_fused": moe_batched_fused, "grouped_atomic": moe_grouped_atomic, - # "batched_atomic": moe_batched_atomic, } +BATCHED_METHODS = { + "batched": moe_batched, + "batched_fused": moe_batched_fused, + "batched_atomic": moe_batched_atomic, +} + +ALL_METHODS = {**GROUPED_METHODS, **BATCHED_METHODS} + def quantize_weights_block(W, block_n=128, block_k=128): E, N, K = W.shape @@ -79,6 +88,23 @@ def diff_emoji(d): def main(): + parser = argparse.ArgumentParser(description="Benchmark MoE dispatch methods") + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument("--grouped", action="store_true", help="Benchmark grouped variants") + group.add_argument("--batched", action="store_true", help="Benchmark batched variants") + group.add_argument("--all", action="store_true", help="Benchmark all variants") + cli_args = parser.parse_args() + + if cli_args.all: + METHODS = ALL_METHODS + variant = "all" + elif cli_args.batched: + METHODS = BATCHED_METHODS + variant = "batched" + else: + METHODS = GROUPED_METHODS + variant = "grouped" + device = "cuda" block_size = [128, 128] @@ -172,7 +198,7 @@ def call_method(name, fn): ) def run_sweep(bench_fn_factory, title, mode="eager"): - """Run adaptive token sweep, dropping methods that plateau (<10% TFLOPS gain).""" + """Run adaptive token sweep, stopping when TFLOPS plateau (<5% gain).""" table = Table(title=title) table.add_column("Tokens", justify="right", style="bold") for name in names: @@ -201,16 +227,19 @@ def run_sweep(bench_fn_factory, title, mode="eager"): down_proj_scale_inv, block_size, ) - results = {} for name, fn in METHODS.items(): + msg = f" [dim]benching {name} @ {num_tokens} tokens...[/dim]" + console.print(msg + " " * 40, end="\r") + @torch.no_grad() def _bench(fn=fn): return fn(*args) try: results[name] = bench_fn_factory(_bench) - except Exception: + except Exception as e: + console.print(f" [red]{name} @ {num_tokens} tokens: {e}[/red]") results[name] = float("inf") all_res.append((num_tokens, results)) @@ -218,8 +247,9 @@ def _bench(fn=fn): # Compute per-method TFLOPS and check progress S = num_tokens * top_k flops = 2 * S * K * 2 * N_inter + 2 * S * N_inter * K + valid = {n: ms for n, ms in results.items() if ms != float("inf")} + best_name = min(valid, key=valid.get) if valid else "n/a" - best_name = min(results, key=results.get) cells = [] for name in names: ms = results[name] @@ -233,7 +263,7 @@ def _bench(fn=fn): increase = (tflops - prev_tflops[name]) / max( prev_tflops[name], 1e-12 ) - if increase < 0.10: + if increase < 0.05: active.discard(name) prev_tflops[name] = tflops @@ -271,7 +301,7 @@ def _bench(fn=fn): # ═══════════════════════════════════════════════════════════════════════════ import csv - csv_path = "moe_results.csv" + csv_path = f"moe_results_{variant}.csv" with open(csv_path, "w", newline="") as f: writer = csv.writer(f) writer.writerow(["tokens", "method", "mode", "time_ms", "tflops"]) @@ -296,18 +326,18 @@ def _bench(fn=fn): # ═══════════════════════════════════════════════════════════════════════════ colors = { "grouped": "#e74c3c", - "batched": "#c0392b", "grouped_fused": "#2ecc71", - "batched_fused": "#27ae60", "grouped_atomic": "#3498db", + "batched": "#c0392b", + "batched_fused": "#27ae60", "batched_atomic": "#2980b9", } linestyles = { "grouped": "--", - "batched": "--", "grouped_fused": "-", - "batched_fused": "-", "grouped_atomic": ":", + "batched": "--", + "batched_fused": "-", "batched_atomic": ":", } @@ -330,37 +360,31 @@ def compute_tflops(results_list): tflops_eager = compute_tflops(all_results) tflops_cg = compute_tflops(all_results_cg) - fig, axes = plt.subplots(2, 2, figsize=(18, 14)) + fig, axes = plt.subplots(1, 2, figsize=(18, 7)) - for col, (yscale, title_suffix) in enumerate( - [("linear", "Linear"), ("log", "Log")] + for col, (tflops_data, mode) in enumerate( + [(tflops_eager, "Eager"), (tflops_cg, "CUDA Graphs")] ): - for row, (tflops_data, mode) in enumerate( - [(tflops_eager, "Eager"), (tflops_cg, "CUDA Graphs")] - ): - ax = axes[row, col] - for name in names: - xs, ys = tflops_data[name] - if xs: - ax.plot( - xs, - ys, - linestyles.get(name, "-"), - marker="o", - markersize=4, - color=colors.get(name, "gray"), - label=name, - linewidth=2, - ) - ax.set_xlabel("Tokens", fontsize=12) - ax.set_ylabel("TFLOPS", fontsize=12) - ax.set_title( - f"{mode} — TFLOPS ({title_suffix})", fontsize=13, fontweight="bold" - ) - ax.set_xscale("log", base=2) - ax.set_yscale(yscale) - ax.legend(fontsize=8, loc="best") - ax.grid(True, alpha=0.3) + ax = axes[col] + for name in names: + xs, ys = tflops_data[name] + if xs: + ax.plot( + xs, + ys, + linestyles.get(name, "-"), + marker="o", + markersize=4, + color=colors.get(name, "gray"), + label=name, + linewidth=2, + ) + ax.set_xlabel("Tokens", fontsize=12) + ax.set_ylabel("TFLOPS", fontsize=12) + ax.set_title(f"{mode} — TFLOPS", fontsize=13, fontweight="bold") + ax.set_xscale("log", base=2) + ax.legend(fontsize=8, loc="best") + ax.grid(True, alpha=0.3) fig.suptitle( f"FP8 MoE Expert Dispatch — {model_name} (E={E}, N_inter={N_inter}, K={K}, top_k={top_k})", @@ -368,8 +392,9 @@ def compute_tflops(results_list): fontweight="bold", ) fig.tight_layout() - fig.savefig("moe_tflops.png", dpi=150) - console.print("\nPlot saved to [bold]moe_tflops.png[/bold]") + plot_path = f"moe_tflops_{variant}.png" + fig.savefig(plot_path, dpi=150) + console.print(f"\nPlot saved to [bold]{plot_path}[/bold]") plt.close(fig) diff --git a/finegrained-fp8/torch-ext/finegrained_fp8/atomic.py b/finegrained-fp8/torch-ext/finegrained_fp8/atomic.py index ddc73334..4284d66d 100644 --- a/finegrained-fp8/torch-ext/finegrained_fp8/atomic.py +++ b/finegrained-fp8/torch-ext/finegrained_fp8/atomic.py @@ -33,9 +33,10 @@ @triton.autotune( configs=[ - triton.Config({}, num_warps=w, num_stages=s) + triton.Config({"GROUP_SIZE_M": g}, num_warps=w, num_stages=s) for w in [2, 4, 8, 16] for s in [2, 3, 4, 5] + for g in [1, 8] ], key=["N_inter", "K", "BLOCK_SIZE_M"], reset_to_zero=["Out"], @@ -56,11 +57,13 @@ def moe_atomic_kernel( # Expert scheduling Offsets, TileOffsets, + TileToExpert, # (max_M_tiles,) int32 — tile → expert lookup # Shapes N_inter, K, hidden, num_top_k, + num_M_tiles, # Strides — A, W_gu, Ws_gu, W_down, Ws_down, Out stride_am, stride_ak, @@ -85,8 +88,8 @@ def moe_atomic_kernel( BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, NUM_H_TILES: tl.constexpr, - NUM_EXPERTS_BIT_LENGTH: tl.constexpr, SIMULATE_UNFUSED: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, ): """Single fused MoE kernel: gather + gate_up + SiLU + down in one pass. @@ -97,21 +100,15 @@ def moe_atomic_kernel( """ pid_m = tl.program_id(axis=0) pid_n = tl.program_id(axis=1) + num_N_tiles = tl.cdiv(N_inter, BLOCK_SIZE_N) + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_M_tiles, num_N_tiles, GROUP_SIZE_M) total_tiles = tl.load(TileOffsets + NUM_EXPERTS - 1) if pid_m >= total_tiles: return - # Binary search for expert - lo = 0 - hi = NUM_EXPERTS - for _ in tl.static_range(NUM_EXPERTS_BIT_LENGTH): - mid = (lo + hi) >> 1 - mid_val = tl.load(TileOffsets + mid) - is_left = mid_val <= pid_m - lo = tl.where(is_left, mid + 1, lo) - hi = tl.where(is_left, hi, mid) - expert_id = lo.to(tl.int64) + # O(1) tile → expert lookup + expert_id = tl.load(TileToExpert + pid_m).to(tl.int64) prev_eid = tl.maximum(expert_id - 1, 0) expert_start = tl.where(expert_id == 0, 0, tl.load(Offsets + prev_eid)) @@ -130,14 +127,26 @@ def moe_atomic_kernel( perm_vals = tl.load(Perm + sorted_indices, mask=row_mask, other=0) original_tokens = perm_vals // num_top_k - # ── Gate + Up projection ── - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + # ── Gate + Up projection — block pointers for weight loads ── offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A + original_tokens[:, None] * stride_am + offs_k[None, :] * stride_ak - b_base = W_gu + expert_id * stride_be_gu + offs_k[:, None] * stride_bk_gu - b_gate_ptrs = b_base + offs_bn[None, :] * stride_bn_gu - b_up_ptrs = b_base + (N_inter + offs_bn)[None, :] * stride_bn_gu + + b_gate_ptr = tl.make_block_ptr( + base=W_gu + expert_id * stride_be_gu, + shape=(K, N_inter * 2), + strides=(stride_bk_gu, stride_bn_gu), + offsets=(0, pid_n * BLOCK_SIZE_N), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), + order=(0, 1), + ) + b_up_ptr = tl.make_block_ptr( + base=W_gu + expert_id * stride_be_gu, + shape=(K, N_inter * 2), + strides=(stride_bk_gu, stride_bn_gu), + offsets=(0, N_inter + pid_n * BLOCK_SIZE_N), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), + order=(0, 1), + ) n_scale_blocks = N_inter // BLOCK_SIZE_N bs_base = Ws_gu + expert_id * stride_bs_e_gu @@ -148,8 +157,8 @@ def moe_atomic_kernel( acc_up = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - b_gate = tl.load(b_gate_ptrs) - b_up = tl.load(b_up_ptrs) + b_gate = tl.load(b_gate_ptr) + b_up = tl.load(b_up_ptr) bs_gate = tl.load(bs_gate_ptrs + k * stride_bs_k_gu) bs_up = tl.load(bs_up_ptrs + k * stride_bs_k_gu) @@ -161,8 +170,8 @@ def moe_atomic_kernel( acc_up += tl.dot(a, b_up) * a_s[:, None] * bs_up[None, :] a_ptrs += BLOCK_SIZE_K * stride_ak - b_gate_ptrs += BLOCK_SIZE_K * stride_bk_gu - b_up_ptrs += BLOCK_SIZE_K * stride_bk_gu + b_gate_ptr = tl.advance(b_gate_ptr, (BLOCK_SIZE_K, 0)) + b_up_ptr = tl.advance(b_up_ptr, (BLOCK_SIZE_K, 0)) # ── SiLU(gate) * up ── if SIMULATE_UNFUSED: @@ -179,17 +188,19 @@ def moe_atomic_kernel( inter_s = tl.max(tl.abs(intermediate), axis=1) / 448.0 inter_fp8 = (intermediate / tl.maximum(inter_s[:, None], 1e-12)).to(tl.float8e4nv) - # ── Down projection + atomic accumulate ── - offs_h = tl.arange(0, BLOCK_SIZE_H) + # ── Down projection + atomic accumulate — block pointer for weights ── + w_down_ptr = tl.make_block_ptr( + base=W_down + expert_id * stride_be_down, + shape=(N_inter, hidden), + strides=(stride_bk_down, stride_bn_down), + offsets=(pid_n * BLOCK_SIZE_N, 0), + block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_H), + order=(0, 1), + ) + for h in range(0, NUM_H_TILES): - h_offs = h * BLOCK_SIZE_H + offs_h - w_down_ptrs = ( - W_down - + expert_id * stride_be_down - + offs_bn[:, None] * stride_bk_down - + h_offs[None, :] * stride_bn_down - ) - w_down = tl.load(w_down_ptrs) + h_offs = h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + w_down = tl.load(w_down_ptr) ws_down = tl.load( Ws_down + expert_id * stride_bs_e_down @@ -208,6 +219,7 @@ def moe_atomic_kernel( mask=row_mask[:, None], sem="relaxed", ) + w_down_ptr = tl.advance(w_down_ptr, (0, BLOCK_SIZE_H)) # ── Wrapper ────────────────────────────────────────────────────────────────── @@ -258,13 +270,21 @@ def _moe_grouped_atomic( ) offsets = torch.cumsum(tokens_per_expert, dim=0, dtype=torch.int32) - # Tile setup + # Tile setup — BLOCK_SIZE_M capped at 64 for better SM utilization with many experts BLOCK_SIZE_M = min( - max(triton.next_power_of_2((S + num_experts - 1) // num_experts), 16), 128 + max(triton.next_power_of_2((S + num_experts - 1) // num_experts), 16), 64 ) tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32) max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + num_experts + + # Tile-to-expert lookup via bucketize (CUDA-graph safe, replaces binary search) + tile_to_expert = torch.bucketize( + torch.arange(max_M_tiles, device=device, dtype=torch.int32), + tile_offsets, + right=True, + ) + num_N_tiles = triton.cdiv(intermediate_dim, block_n) # fp32 output for atomic accumulation @@ -283,10 +303,12 @@ def _moe_grouped_atomic( proj_out, offsets, tile_offsets, + tile_to_expert, intermediate_dim, hidden_states.shape[1], hidden_dim, num_top_k, + max_M_tiles, hidden_states.stride(0), hidden_states.stride(1), gate_up_proj.stride(0), @@ -309,7 +331,6 @@ def _moe_grouped_atomic( BLOCK_SIZE_H=block_n, BLOCK_SIZE_M=BLOCK_SIZE_M, NUM_H_TILES=triton.cdiv(hidden_dim, block_n), - NUM_EXPERTS_BIT_LENGTH=num_experts.bit_length(), SIMULATE_UNFUSED=simulate_unfused, ) @@ -363,8 +384,8 @@ def moe_grouped_atomic( @triton.autotune( configs=[ triton.Config({}, num_warps=w, num_stages=s) - for w in [2, 4, 8, 16] - for s in [2, 3, 4, 5] + for w in [4, 8, 16] + for s in [2, 3, 4] ], key=["N_inter", "K", "BLOCK_SIZE_M"], reset_to_zero=["Out"], @@ -547,7 +568,7 @@ def _moe_batched_atomic( S = expert_ids.size(0) BLOCK_SIZE_M = min( - max(triton.next_power_of_2((S + num_experts - 1) // num_experts), 16), 128 + max(triton.next_power_of_2((S + num_experts - 1) // num_experts), 16), 64 ) num_N_tiles = triton.cdiv(intermediate_dim, block_n) diff --git a/finegrained-fp8/torch-ext/finegrained_fp8/fused.py b/finegrained-fp8/torch-ext/finegrained_fp8/fused.py index 28197fc9..0ba1da9f 100644 --- a/finegrained-fp8/torch-ext/finegrained_fp8/fused.py +++ b/finegrained-fp8/torch-ext/finegrained_fp8/fused.py @@ -37,13 +37,17 @@ # ── Kernel 1: gather + gate_up + SiLU + FP8 quant ────────────────────────── +# +# Optimizations: block pointers for weight loads, L2 swizzle (GROUP_SIZE_M), +# tile-to-expert via bucketize (replaces binary search), BLOCK_SIZE_M capped at 64. @triton.autotune( configs=[ - triton.Config({}, num_warps=w, num_stages=s) + triton.Config({"GROUP_SIZE_M": g}, num_warps=w, num_stages=s) for w in [2, 4, 8, 16] for s in [2, 3, 4, 5] + for g in [1, 8] ], key=["N_inter", "K", "BLOCK_SIZE_M"], ) @@ -61,10 +65,12 @@ def fused_gate_up_silu_kernel( # Expert scheduling Offsets, TileOffsets, + TileToExpert, # (max_M_tiles,) int32 — tile → expert lookup # Shapes N_inter, K, num_top_k, + num_M_tiles, # Strides — A, W_gu, Ws_gu, Inter stride_am, stride_ak, @@ -81,8 +87,8 @@ def fused_gate_up_silu_kernel( BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, - NUM_EXPERTS_BIT_LENGTH: tl.constexpr, SIMULATE_UNFUSED: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, ): """Kernel 1: gather A from unsorted → gate_up GEMM → SiLU → FP8 quant. @@ -91,21 +97,15 @@ def fused_gate_up_silu_kernel( """ pid_m = tl.program_id(axis=0) pid_n = tl.program_id(axis=1) + num_N_tiles = tl.cdiv(N_inter, BLOCK_SIZE_N) + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_M_tiles, num_N_tiles, GROUP_SIZE_M) total_tiles = tl.load(TileOffsets + NUM_EXPERTS - 1) if pid_m >= total_tiles: return - # Binary search for expert - lo = 0 - hi = NUM_EXPERTS - for _ in tl.static_range(NUM_EXPERTS_BIT_LENGTH): - mid = (lo + hi) >> 1 - mid_val = tl.load(TileOffsets + mid) - is_left = mid_val <= pid_m - lo = tl.where(is_left, mid + 1, lo) - hi = tl.where(is_left, hi, mid) - expert_id = lo.to(tl.int64) + # O(1) tile → expert lookup + expert_id = tl.load(TileToExpert + pid_m).to(tl.int64) prev_eid = tl.maximum(expert_id - 1, 0) expert_start = tl.where(expert_id == 0, 0, tl.load(Offsets + prev_eid)) @@ -129,9 +129,23 @@ def fused_gate_up_silu_kernel( offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = A + original_tokens[:, None] * stride_am + offs_k[None, :] * stride_ak - b_base = W_gu + expert_id * stride_be + offs_k[:, None] * stride_bk - b_gate_ptrs = b_base + offs_bn[None, :] * stride_bn - b_up_ptrs = b_base + (N_inter + offs_bn)[None, :] * stride_bn + + b_gate_ptr = tl.make_block_ptr( + base=W_gu + expert_id * stride_be, + shape=(K, N_inter * 2), + strides=(stride_bk, stride_bn), + offsets=(0, pid_n * BLOCK_SIZE_N), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), + order=(0, 1), + ) + b_up_ptr = tl.make_block_ptr( + base=W_gu + expert_id * stride_be, + shape=(K, N_inter * 2), + strides=(stride_bk, stride_bn), + offsets=(0, N_inter + pid_n * BLOCK_SIZE_N), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), + order=(0, 1), + ) n_scale_blocks = N_inter // BLOCK_SIZE_N bs_base = Ws_gu + expert_id * stride_bs_e @@ -142,8 +156,8 @@ def fused_gate_up_silu_kernel( acc_up = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - b_gate = tl.load(b_gate_ptrs) - b_up = tl.load(b_up_ptrs) + b_gate = tl.load(b_gate_ptr) + b_up = tl.load(b_up_ptr) bs_gate = tl.load(bs_gate_ptrs + k * stride_bs_k) bs_up = tl.load(bs_up_ptrs + k * stride_bs_k) @@ -155,8 +169,8 @@ def fused_gate_up_silu_kernel( acc_up += tl.dot(a, b_up) * a_s[:, None] * bs_up[None, :] a_ptrs += BLOCK_SIZE_K * stride_ak - b_gate_ptrs += BLOCK_SIZE_K * stride_bk - b_up_ptrs += BLOCK_SIZE_K * stride_bk + b_gate_ptr = tl.advance(b_gate_ptr, (BLOCK_SIZE_K, 0)) + b_up_ptr = tl.advance(b_up_ptr, (BLOCK_SIZE_K, 0)) # SiLU(gate) * up if SIMULATE_UNFUSED: @@ -180,7 +194,6 @@ def fused_gate_up_silu_kernel( tl.store(inter_ptrs, inter_fp8, mask=row_mask[:, None]) # Store per-row scale (one per row per N-tile) - # Layout: Inter_s[sorted_idx, pid_n] scale_ptrs = Inter_s + sorted_indices * tl.cdiv(N_inter, BLOCK_SIZE_N) + pid_n tl.store(scale_ptrs, inter_s, mask=row_mask) @@ -190,9 +203,10 @@ def fused_gate_up_silu_kernel( @triton.autotune( configs=[ - triton.Config({}, num_warps=w, num_stages=s) + triton.Config({"GROUP_SIZE_M": g}, num_warps=w, num_stages=s) for w in [2, 4, 8, 16] for s in [2, 3, 4, 5] + for g in [1, 8] ], key=["N_inter", "hidden", "BLOCK_SIZE_M"], ) @@ -210,9 +224,11 @@ def fused_down_proj_kernel( # Expert scheduling Offsets, TileOffsets, + TileToExpert, # (max_M_tiles,) int32 — tile → expert lookup # Shapes N_inter, hidden, + num_M_tiles, # Strides — Inter, W_down, Ws_down, Out stride_im, stride_in, @@ -230,8 +246,8 @@ def fused_down_proj_kernel( BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, NUM_N_TILES: tl.constexpr, - NUM_EXPERTS_BIT_LENGTH: tl.constexpr, SIMULATE_UNFUSED: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, ): """Kernel 2: fp8 intermediate → down_proj → output. @@ -241,21 +257,15 @@ def fused_down_proj_kernel( """ pid_m = tl.program_id(axis=0) pid_h = tl.program_id(axis=1) + num_H_tiles = tl.cdiv(hidden, BLOCK_SIZE_H) + pid_m, pid_h = tl.swizzle2d(pid_m, pid_h, num_M_tiles, num_H_tiles, GROUP_SIZE_M) total_tiles = tl.load(TileOffsets + NUM_EXPERTS - 1) if pid_m >= total_tiles: return - # Binary search for expert - lo = 0 - hi = NUM_EXPERTS - for _ in tl.static_range(NUM_EXPERTS_BIT_LENGTH): - mid = (lo + hi) >> 1 - mid_val = tl.load(TileOffsets + mid) - is_left = mid_val <= pid_m - lo = tl.where(is_left, mid + 1, lo) - hi = tl.where(is_left, hi, mid) - expert_id = lo.to(tl.int64) + # O(1) tile → expert lookup + expert_id = tl.load(TileToExpert + pid_m).to(tl.int64) prev_eid = tl.maximum(expert_id - 1, 0) expert_start = tl.where(expert_id == 0, 0, tl.load(Offsets + prev_eid)) @@ -271,12 +281,21 @@ def fused_down_proj_kernel( sorted_indices = expert_start + offs_am offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) - offs_n = tl.arange(0, BLOCK_SIZE_N) + + # Block pointer for down weights + w_down_ptr = tl.make_block_ptr( + base=W_down + expert_id * stride_be, + shape=(N_inter, hidden), + strides=(stride_bk, stride_bn), + offsets=(0, pid_h * BLOCK_SIZE_H), + block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_H), + order=(0, 1), + ) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_H), dtype=tl.float32) for n_tile in range(0, NUM_N_TILES): - n_offs = n_tile * BLOCK_SIZE_N + offs_n + n_offs = n_tile * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) # Load fp8 intermediate tile inter_ptrs = ( @@ -288,14 +307,8 @@ def fused_down_proj_kernel( scale_ptrs = Inter_s + sorted_indices * NUM_N_TILES + n_tile inter_s = tl.load(scale_ptrs, mask=row_mask, other=0.0) - # Load down weights - w_down_ptrs = ( - W_down - + expert_id * stride_be - + n_offs[:, None] * stride_bk - + offs_h[None, :] * stride_bn - ) - w_down = tl.load(w_down_ptrs) + # Load down weights via block pointer + w_down = tl.load(w_down_ptr) ws_down = tl.load( Ws_down + expert_id * stride_bs_e @@ -304,6 +317,7 @@ def fused_down_proj_kernel( ) acc += tl.dot(inter_fp8, w_down) * inter_s[:, None] * ws_down + w_down_ptr = tl.advance(w_down_ptr, (BLOCK_SIZE_N, 0)) # Apply routing weights and scatter to original flat order via Perm if SIMULATE_UNFUSED: @@ -365,13 +379,19 @@ def _moe_grouped_fused( ) offsets = torch.cumsum(tokens_per_expert, dim=0, dtype=torch.int32) - # Tile setup + # Tile setup — BLOCK_SIZE_M capped at 64 for better SM utilization with many experts BLOCK_SIZE_M = min( - max(triton.next_power_of_2((S + num_experts - 1) // num_experts), 16), 128 + max(triton.next_power_of_2((S + num_experts - 1) // num_experts), 16), 64 ) tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32) max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + num_experts + tile_to_expert = torch.bucketize( + torch.arange(max_M_tiles, device=device, dtype=torch.int32), + tile_offsets, + right=True, + ) + num_N_tiles = triton.cdiv(intermediate_dim, block_n) num_H_tiles = triton.cdiv(hidden_dim, block_n) @@ -393,9 +413,11 @@ def _moe_grouped_fused( inter_scales, offsets, tile_offsets, + tile_to_expert, intermediate_dim, hidden_states.shape[1], num_top_k, + max_M_tiles, hidden_states.stride(0), hidden_states.stride(1), gate_up_proj.stride(0), @@ -410,7 +432,6 @@ def _moe_grouped_fused( BLOCK_SIZE_N=block_n, BLOCK_SIZE_K=block_k, BLOCK_SIZE_M=BLOCK_SIZE_M, - NUM_EXPERTS_BIT_LENGTH=num_experts.bit_length(), SIMULATE_UNFUSED=simulate_unfused, ) @@ -428,8 +449,10 @@ def _moe_grouped_fused( proj_out, offsets, tile_offsets, + tile_to_expert, intermediate_dim, hidden_dim, + max_M_tiles, inter_fp8.stride(0), inter_fp8.stride(1), down_proj.stride(0), @@ -445,7 +468,6 @@ def _moe_grouped_fused( BLOCK_SIZE_H=block_n, BLOCK_SIZE_M=BLOCK_SIZE_M, NUM_N_TILES=num_N_tiles, - NUM_EXPERTS_BIT_LENGTH=num_experts.bit_length(), SIMULATE_UNFUSED=simulate_unfused, ) @@ -487,11 +509,12 @@ def moe_grouped_fused( ) -# ── Batched fused: gate_up + SiLU + down (no sorting, no atomics) ─────────── +# ── Batched fused: two-kernel approach (no sorting, no atomics) ────────────── # -# Each program handles one (token, H-tile) and loops over N-tiles sequentially. -# The intermediate stays entirely in registers. No sorting needed — expert -# lookup is per-token via ExpertIds. +# Same two-kernel architecture as grouped fused but with per-token dispatch: +# Kernel 1: (S, N-tiles) — gate_up + SiLU + FP8 quant → intermediate buffer +# Kernel 2: (S, H-tiles) — fp8 intermediate → down proj → output +# No sorting needed — expert lookup is per-token via ExpertIds. @triton.autotune( @@ -503,142 +526,196 @@ def moe_grouped_fused( key=["N_inter", "K", "BLOCK_SIZE_M"], ) @triton.jit -def moe_batched_fused_kernel( - A, # (S, K) raw BF16/FP16 activations - W_gu, # (E, 2*N_inter, K) FP8 gate_up weights - W_down, # (E, hidden, N_inter) FP8 down weights - Out, # (S, hidden) output - Ws_gu, # gate_up scales - Ws_down, # down scales - ExpertIds, # (S,) expert index per token - SampleWeights, # (S,) routing weights - # Shapes +def batched_gate_up_silu_kernel( + A, + W_gu, + Ws_gu, + Inter, + Inter_s, + ExpertIds, N_inter, K, - hidden, - # Strides — A stride_am, stride_ak, - # Strides — W_gu, Ws_gu - stride_be_gu, - stride_bk_gu, - stride_bn_gu, - stride_bs_e_gu, - stride_bs_k_gu, - stride_bs_n_gu, - # Strides — W_down, Ws_down - stride_be_down, - stride_bk_down, - stride_bn_down, - stride_bs_e_down, - stride_bs_k_down, - stride_bs_n_down, - # Strides — Out + stride_be, + stride_bk, + stride_bn, + stride_bs_e, + stride_bs_k, + stride_bs_n, + stride_im, + stride_in, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + SIMULATE_UNFUSED: tl.constexpr, +): + """Batched kernel 1: per-token gate_up + SiLU + FP8 quant. Grid: (S, N-tiles).""" + batch_id = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + + expert_id = tl.load(ExpertIds + batch_id).to(tl.int64) + offs_m = tl.arange(0, BLOCK_SIZE_M) + offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + a_ptrs = A + batch_id * stride_am + offs_k[None, :] * stride_ak + + b_gate_ptr = tl.make_block_ptr( + base=W_gu + expert_id * stride_be, + shape=(K, N_inter * 2), + strides=(stride_bk, stride_bn), + offsets=(0, pid_n * BLOCK_SIZE_N), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), + order=(0, 1), + ) + b_up_ptr = tl.make_block_ptr( + base=W_gu + expert_id * stride_be, + shape=(K, N_inter * 2), + strides=(stride_bk, stride_bn), + offsets=(0, N_inter + pid_n * BLOCK_SIZE_N), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), + order=(0, 1), + ) + + n_scale_blocks = N_inter // BLOCK_SIZE_N + bs_base = Ws_gu + expert_id * stride_bs_e + bs_gate_ptrs = bs_base + pid_n * stride_bs_n + bs_up_ptrs = bs_base + (n_scale_blocks + pid_n) * stride_bs_n + + acc_gate = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + acc_up = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a_raw = tl.load(a_ptrs + offs_m[:, None] * 0).to(tl.float32) + a_s = tl.max(tl.abs(a_raw)) / 448.0 + a = (a_raw / tl.maximum(a_s, 1e-12)).to(tl.float8e4nv) + + b_gate = tl.load(b_gate_ptr) + b_up = tl.load(b_up_ptr) + bs_gate = tl.load(bs_gate_ptrs + k * stride_bs_k) + bs_up = tl.load(bs_up_ptrs + k * stride_bs_k) + + acc_gate += tl.dot(a, b_gate) * a_s * bs_gate[None, :] + acc_up += tl.dot(a, b_up) * a_s * bs_up[None, :] + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_gate_ptr = tl.advance(b_gate_ptr, (BLOCK_SIZE_K, 0)) + b_up_ptr = tl.advance(b_up_ptr, (BLOCK_SIZE_K, 0)) + + if SIMULATE_UNFUSED: + acc_gate = acc_gate.to(tl.bfloat16).to(tl.float32) + acc_up = acc_up.to(tl.bfloat16).to(tl.float32) + intermediate = (acc_gate * tl.sigmoid(acc_gate)).to(tl.bfloat16).to( + tl.float32 + ) * acc_up + intermediate = intermediate.to(tl.bfloat16).to(tl.float32) + else: + intermediate = acc_gate * tl.sigmoid(acc_gate) * acc_up + + inter_s = tl.max(tl.abs(intermediate)) / 448.0 + inter_fp8 = (intermediate / tl.maximum(inter_s, 1e-12)).to(tl.float8e4nv) + + inter_ptrs = ( + Inter + + batch_id * stride_im + + offs_bn[None, :] * stride_in + + offs_m[:, None] * 0 + ) + tl.store(inter_ptrs, inter_fp8) + + num_N_tiles = tl.cdiv(N_inter, BLOCK_SIZE_N) + tl.store(Inter_s + batch_id * num_N_tiles + pid_n, inter_s) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=w, num_stages=s) + for w in [4, 8, 16] + for s in [2, 3, 4] + ], + key=["N_inter", "hidden", "BLOCK_SIZE_M"], +) +@triton.jit +def batched_down_proj_kernel( + Inter, + Inter_s, + W_down, + Ws_down, + ExpertIds, + SampleWeights, + Out, + N_inter, + hidden, + stride_im, + stride_in, + stride_be, + stride_bk, + stride_bn, + stride_bs_e, + stride_bs_k, + stride_bs_n, stride_om, stride_oh, - # Constexprs - NUM_N_TILES: tl.constexpr, - NUM_K_TILES: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, + NUM_N_TILES: tl.constexpr, SIMULATE_UNFUSED: tl.constexpr, ): - """Batched fused MoE kernel: gate_up + SiLU + down in one kernel, no atomics. - - Grid: (S, H-tiles). Each program handles one (token, H-tile) and loops - over N-tiles. The intermediate stays entirely in registers. - """ + """Batched kernel 2: fp8 intermediate → down proj → output. Grid: (S, H-tiles).""" batch_id = tl.program_id(axis=0) pid_h = tl.program_id(axis=1) expert_id = tl.load(ExpertIds + batch_id).to(tl.int64) - - offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) offs_m = tl.arange(0, BLOCK_SIZE_M) + offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + + w_down_ptr = tl.make_block_ptr( + base=W_down + expert_id * stride_be, + shape=(N_inter, hidden), + strides=(stride_bk, stride_bn), + offsets=(0, pid_h * BLOCK_SIZE_H), + block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_H), + order=(0, 1), + ) - acc_down = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_H), dtype=tl.float32) - - for n_inter in range(0, NUM_N_TILES): - offs_n = n_inter * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - - # ── Gate + Up projection for this N-tile ── - a_ptrs = A + batch_id * stride_am + offs_k[None, :] * stride_ak - b_base = W_gu + expert_id * stride_be_gu + offs_k[:, None] * stride_bk_gu - b_gate_ptrs = b_base + offs_n[None, :] * stride_bn_gu - b_up_ptrs = b_base + (N_inter + offs_n)[None, :] * stride_bn_gu - - n_scale_blocks = N_inter // BLOCK_SIZE_N - bs_base = Ws_gu + expert_id * stride_bs_e_gu - bs_gate_ptr = bs_base + n_inter * stride_bs_n_gu - bs_up_ptr = bs_base + (n_scale_blocks + n_inter) * stride_bs_n_gu - - acc_gate = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - acc_up = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - for k in range(0, NUM_K_TILES): - a_raw = tl.load(a_ptrs + offs_m[:, None] * 0).to(tl.float32) - a_s = tl.max(tl.abs(a_raw)) / 448.0 - a = (a_raw / tl.maximum(a_s, 1e-12)).to(tl.float8e4nv) - - b_gate = tl.load(b_gate_ptrs) - b_up = tl.load(b_up_ptrs) - bs_gate = tl.load(bs_gate_ptr + k * stride_bs_k_gu) - bs_up = tl.load(bs_up_ptr + k * stride_bs_k_gu) - - acc_gate += tl.dot(a, b_gate) * a_s * bs_gate[None, :] - acc_up += tl.dot(a, b_up) * a_s * bs_up[None, :] - - a_ptrs += BLOCK_SIZE_K * stride_ak - b_gate_ptrs += BLOCK_SIZE_K * stride_bk_gu - b_up_ptrs += BLOCK_SIZE_K * stride_bk_gu - - # ── SiLU(gate) * up ── - if SIMULATE_UNFUSED: - acc_gate = acc_gate.to(tl.bfloat16).to(tl.float32) - acc_up = acc_up.to(tl.bfloat16).to(tl.float32) - intermediate = (acc_gate * tl.sigmoid(acc_gate)).to(tl.bfloat16).to( - tl.float32 - ) * acc_up - intermediate = intermediate.to(tl.bfloat16).to(tl.float32) - else: - intermediate = acc_gate * tl.sigmoid(acc_gate) * acc_up - - # ── Quantize intermediate to FP8 ── - inter_s = tl.max(tl.abs(intermediate)) / 448.0 - inter_fp8 = (intermediate / tl.maximum(inter_s, 1e-12)).to(tl.float8e4nv) - - # ── Partial down projection ── - w_down_ptrs = ( - W_down - + expert_id * stride_be_down - + offs_n[:, None] * stride_bk_down - + offs_h[None, :] * stride_bn_down + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_H), dtype=tl.float32) + + for n_tile in range(0, NUM_N_TILES): + n_offs = n_tile * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + inter_ptrs = ( + Inter + + batch_id * stride_im + + n_offs[None, :] * stride_in + + offs_m[:, None] * 0 ) - w_down = tl.load(w_down_ptrs) + inter_fp8 = tl.load(inter_ptrs) + + inter_s = tl.load(Inter_s + batch_id * NUM_N_TILES + n_tile) + + w_down = tl.load(w_down_ptr) ws_down = tl.load( Ws_down - + expert_id * stride_bs_e_down - + pid_h * stride_bs_n_down - + n_inter * stride_bs_k_down + + expert_id * stride_bs_e + + pid_h * stride_bs_n + + n_tile * stride_bs_k ) - acc_down += tl.dot(inter_fp8, w_down) * inter_s * ws_down + acc += tl.dot(inter_fp8, w_down) * inter_s * ws_down + w_down_ptr = tl.advance(w_down_ptr, (BLOCK_SIZE_N, 0)) - # ── Apply routing weight and store ── if SIMULATE_UNFUSED: - acc_down = acc_down.to(tl.bfloat16).to(tl.float32) + acc = acc.to(tl.bfloat16).to(tl.float32) routing_w = tl.load(SampleWeights + batch_id) - acc_down = acc_down * routing_w + acc = acc * routing_w if Out.dtype.element_ty == tl.bfloat16: - c = acc_down.to(tl.bfloat16) + c = acc.to(tl.bfloat16) elif Out.dtype.element_ty == tl.float16: - c = acc_down.to(tl.float16) + c = acc.to(tl.float16) else: - c = acc_down.to(tl.float32) + c = acc.to(tl.float32) c_ptrs = ( Out + batch_id * stride_om + offs_h[None, :] * stride_oh + offs_m[:, None] * 0 @@ -658,13 +735,9 @@ def _moe_batched_fused( block_size: list[int], simulate_unfused: bool = False, ) -> torch.Tensor: - """Batched fused MoE expert layer: deterministic, no sorting, no atomics. + """Two-kernel batched fused MoE: deterministic, no sorting, no atomics. - Input: unsorted hidden_states + router outputs (top_k_index, top_k_weights) - Output: (num_tokens, hidden) — accumulated across top_k experts - - Pipeline: expand → ONE kernel (gate_up + SiLU + down per token) → routing + reduce - Deterministic: each token processed independently. Intermediate stays in registers. + Pipeline: expand → kernel 1 (gate_up + SiLU + FP8 quant) → kernel 2 (down proj + routing) → reduce """ device = hidden_states.device num_top_k = top_k_index.size(-1) @@ -673,7 +746,6 @@ def _moe_batched_fused( intermediate_dim = down_proj.shape[2] block_n, block_k = block_size - # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) token_idx = ( torch.arange(num_tokens, device=device) .unsqueeze(1) @@ -682,10 +754,18 @@ def _moe_batched_fused( ) sample_weights = top_k_weights.reshape(-1) expert_ids = top_k_index.reshape(-1) - selected_hidden_states = hidden_states[token_idx] S = expert_ids.size(0) + num_N_tiles = triton.cdiv(intermediate_dim, block_n) + num_H_tiles = triton.cdiv(hidden_dim, block_n) + + inter_fp8 = torch.empty( + S, intermediate_dim, device=device, dtype=torch.float8_e4m3fn + ) + inter_scales = torch.empty(S, num_N_tiles, device=device, dtype=torch.float32) + + # Kernel 1: gate_up + SiLU + FP8 quant — grid (S, N-tiles) BLOCK_SIZE_M = min( max( triton.next_power_of_2( @@ -693,25 +773,19 @@ def _moe_batched_fused( ), 16, ), - 128, + 64, ) - Out = selected_hidden_states.new_empty(S, hidden_dim) - num_H_tiles = triton.cdiv(hidden_dim, block_n) - grid = (S, num_H_tiles) - + grid1 = (S, num_N_tiles) with device_context(device): - wrap_triton(moe_batched_fused_kernel)[grid]( + wrap_triton(batched_gate_up_silu_kernel)[grid1]( selected_hidden_states, gate_up_proj, - down_proj, - Out, gate_up_proj_scale_inv, - down_proj_scale_inv, + inter_fp8, + inter_scales, expert_ids, - sample_weights, intermediate_dim, hidden_states.shape[1], - hidden_dim, selected_hidden_states.stride(0), selected_hidden_states.stride(1), gate_up_proj.stride(0), @@ -720,6 +794,30 @@ def _moe_batched_fused( gate_up_proj_scale_inv.stride(0), gate_up_proj_scale_inv.stride(2), gate_up_proj_scale_inv.stride(1), + inter_fp8.stride(0), + inter_fp8.stride(1), + BLOCK_SIZE_N=block_n, + BLOCK_SIZE_K=block_k, + BLOCK_SIZE_M=BLOCK_SIZE_M, + SIMULATE_UNFUSED=simulate_unfused, + ) + + # Kernel 2: down proj + routing — grid (S, H-tiles) + Out = selected_hidden_states.new_empty(S, hidden_dim) + grid2 = (S, num_H_tiles) + with device_context(device): + wrap_triton(batched_down_proj_kernel)[grid2]( + inter_fp8, + inter_scales, + down_proj, + down_proj_scale_inv, + expert_ids, + sample_weights, + Out, + intermediate_dim, + hidden_dim, + inter_fp8.stride(0), + inter_fp8.stride(1), down_proj.stride(0), down_proj.stride(2), down_proj.stride(1), @@ -728,18 +826,14 @@ def _moe_batched_fused( down_proj_scale_inv.stride(1), Out.stride(0), Out.stride(1), - NUM_N_TILES=triton.cdiv(intermediate_dim, block_n), - NUM_K_TILES=triton.cdiv(hidden_states.shape[1], block_k), BLOCK_SIZE_N=block_n, - BLOCK_SIZE_K=block_k, BLOCK_SIZE_H=block_n, BLOCK_SIZE_M=BLOCK_SIZE_M, + NUM_N_TILES=num_N_tiles, SIMULATE_UNFUSED=simulate_unfused, ) - # Routing weights already applied in kernel — just reduce final_hidden_states = Out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1) - return final_hidden_states.to(hidden_states.dtype) @@ -754,14 +848,7 @@ def moe_batched_fused( block_size: list[int], simulate_unfused: bool = False, ) -> torch.Tensor: - """Batched fused MoE expert layer: deterministic, no sorting, no atomics. - - Input: unsorted hidden_states + router outputs (top_k_index, top_k_weights) - Output: (num_tokens, hidden) — accumulated across top_k experts - - Pipeline: expand → ONE kernel (gate_up + SiLU + down per token) → routing + reduce - Deterministic: each token processed independently. Intermediate stays in registers. - """ + """Two-kernel batched fused MoE: deterministic, no sorting, no atomics.""" return torch.ops.finegrained_fp8.moe_batched_fused( hidden_states, top_k_index,