diff --git a/finegrained-fp8/bench_fused_moe.py b/finegrained-fp8/bench_fused_moe.py new file mode 100644 index 00000000..83c15a51 --- /dev/null +++ b/finegrained-fp8/bench_fused_moe.py @@ -0,0 +1,402 @@ +"""Benchmark MoE dispatch methods: correctness matrix + performance sweep + plot. + +Usage: + python bench_fused_moe.py --grouped # grouped variants + python bench_fused_moe.py --batched # batched variants + python bench_fused_moe.py --all # all variants +""" + +import argparse +import sys + +sys.path.append("torch-ext") + +import torch +import triton +import matplotlib +from rich.table import Table +from rich.console import Console +import matplotlib.pyplot as plt + +from finegrained_fp8 import ( + moe_grouped, + moe_batched, + moe_grouped_fused, + moe_batched_fused, + moe_grouped_atomic, + moe_batched_atomic, +) + +matplotlib.use("Agg") + +FP8_DTYPE = torch.float8_e4m3fn +FP8_MAX = torch.finfo(FP8_DTYPE).max + +console = Console() + +GROUPED_METHODS = { + "grouped": moe_grouped, + "grouped_fused": moe_grouped_fused, + "grouped_atomic": moe_grouped_atomic, +} + +BATCHED_METHODS = { + "batched": moe_batched, + "batched_fused": moe_batched_fused, + "batched_atomic": moe_batched_atomic, +} + +ALL_METHODS = {**GROUPED_METHODS, **BATCHED_METHODS} + + +def quantize_weights_block(W, block_n=128, block_k=128): + E, N, K = W.shape + Wq = torch.empty_like(W, dtype=FP8_DTYPE) + Bs = torch.empty( + E, N // block_n, K // block_k, dtype=torch.float32, device=W.device + ) + for e in range(E): + for ni in range(N // block_n): + for ki in range(K // block_k): + block = W[ + e, + ni * block_n : (ni + 1) * block_n, + ki * block_k : (ki + 1) * block_k, + ] + amax = block.abs().amax().clamp(min=1e-12) + scale = FP8_MAX / amax + Wq[ + e, + ni * block_n : (ni + 1) * block_n, + ki * block_k : (ki + 1) * block_k, + ] = (block * scale).clamp(-FP8_MAX, FP8_MAX).to(FP8_DTYPE) + Bs[e, ni, ki] = 1.0 / scale + return Wq, Bs + + +def diff_emoji(d): + if d == 0: + return "✅", "exact", "green" + elif d < 1: + return "🟢", f"{d:.4f}", "green" + elif d < 100: + return "🟡", f"{d:.1f}", "yellow" + elif d < 1000: + return "🟠", f"{d:.1f}", "dark_orange" + else: + return "🔴", f"{d:.1f}", "red" + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark MoE dispatch methods") + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument("--grouped", action="store_true", help="Benchmark grouped variants") + group.add_argument("--batched", action="store_true", help="Benchmark batched variants") + group.add_argument("--all", action="store_true", help="Benchmark all variants") + cli_args = parser.parse_args() + + if cli_args.all: + METHODS = ALL_METHODS + variant = "all" + elif cli_args.batched: + METHODS = BATCHED_METHODS + variant = "batched" + else: + METHODS = GROUPED_METHODS + variant = "grouped" + + device = "cuda" + block_size = [128, 128] + + # Qwen3-30B-A3B-Instruct dimensions + model_name = "Qwen3-30B-A3B" + E, N_inter, K, top_k = 128, 768, 2048, 8 + + torch.manual_seed(42) + W_gu = torch.randn(E, 2 * N_inter, K, dtype=torch.float32, device=device) + gate_up_proj, gate_up_proj_scale_inv = quantize_weights_block(W_gu) + W_down = torch.randn(E, K, N_inter, dtype=torch.float32, device=device) + down_proj, down_proj_scale_inv = quantize_weights_block(W_down) + + names = list(METHODS.keys()) + + # ═══════════════════════════════════════════════════════════════════════════ + # Correctness + # ═══════════════════════════════════════════════════════════════════════════ + console.rule("[bold]Correctness[/bold]") + + num_tokens = 32 + hidden = torch.randn(num_tokens, K, device=device, dtype=torch.bfloat16) + top_k_idx = torch.randint(0, E, (num_tokens, top_k), device=device) + top_k_wts = torch.randn(num_tokens, top_k, device=device, dtype=torch.bfloat16) + + args = ( + hidden, + top_k_idx, + top_k_wts, + gate_up_proj, + down_proj, + gate_up_proj_scale_inv, + down_proj_scale_inv, + block_size, + ) + + # For correctness, fused methods use simulate_unfused=True to match unfused precision + def call_method(name, fn): + if name not in ("grouped", "batched"): + return fn(*args, simulate_unfused=True) + return fn(*args) + + outputs = {} + with torch.no_grad(): + for name, fn in METHODS.items(): + outputs[name] = call_method(name, fn) + + outputs2 = {} + with torch.no_grad(): + for name, fn in METHODS.items(): + outputs2[name] = call_method(name, fn) + + # Self-consistency table + det_table = Table(title="Self-consistency (determinism)") + det_table.add_column("Method", style="cyan") + det_table.add_column("Max diff", justify="right") + det_table.add_column("Status", justify="center") + + for name in names: + d = (outputs[name].float() - outputs2[name].float()).abs().max().item() + if d == 0: + det_table.add_row(name, "0.00", "[green]✅ deterministic[/green]") + else: + det_table.add_row( + name, f"{d:.2f}", f"[yellow]⚠️ non-deterministic ({d:.1f})[/yellow]" + ) + + console.print(det_table) + + # Parity matrix + parity_table = Table(title="Parity matrix (max abs diff)") + parity_table.add_column("", style="cyan") + for name in names: + parity_table.add_column(name[:14], justify="center") + + for n1 in names: + cells = [] + for n2 in names: + d = (outputs[n1].float() - outputs[n2].float()).abs().max().item() + emoji, text, color = diff_emoji(d) + cells.append(f"[{color}]{emoji} {text}[/{color}]") + parity_table.add_row(n1, *cells) + + console.print(parity_table) + + # ═══════════════════════════════════════════════════════════════════════════ + # Benchmark + # ═══════════════════════════════════════════════════════════════════════════ + console.rule( + f"[bold]Benchmark — {model_name} (E={E}, N_inter={N_inter}, K={K}, top_k={top_k})[/bold]" + ) + + def run_sweep(bench_fn_factory, title, mode="eager"): + """Run adaptive token sweep, stopping when TFLOPS plateau (<5% gain).""" + table = Table(title=title) + table.add_column("Tokens", justify="right", style="bold") + for name in names: + table.add_column(name[:14], justify="right") + table.add_column("Winner", justify="center", style="bold green") + + all_res = [] + prev_tflops = {} # per-method previous TFLOPS + active = set(names) # methods still progressing + num_tokens = 1 + + while num_tokens <= 131072 and active: # stop when all methods plateau + hidden = torch.randn(num_tokens, K, device=device, dtype=torch.bfloat16) + top_k_idx = torch.randint(0, E, (num_tokens, top_k), device=device) + top_k_wts = torch.randn( + num_tokens, top_k, device=device, dtype=torch.bfloat16 + ) + + args = ( + hidden, + top_k_idx, + top_k_wts, + gate_up_proj, + down_proj, + gate_up_proj_scale_inv, + down_proj_scale_inv, + block_size, + ) + results = {} + for name, fn in METHODS.items(): + msg = f" [dim]benching {name} @ {num_tokens} tokens...[/dim]" + console.print(msg + " " * 40, end="\r") + + @torch.no_grad() + def _bench(fn=fn): + return fn(*args) + + try: + results[name] = bench_fn_factory(_bench) + except Exception as e: + console.print(f" [red]{name} @ {num_tokens} tokens: {e}[/red]") + results[name] = float("inf") + + all_res.append((num_tokens, results)) + + # Compute per-method TFLOPS and check progress + S = num_tokens * top_k + flops = 2 * S * K * 2 * N_inter + 2 * S * N_inter * K + valid = {n: ms for n, ms in results.items() if ms != float("inf")} + best_name = min(valid, key=valid.get) if valid else "n/a" + + cells = [] + for name in names: + ms = results[name] + if ms == float("inf"): + cells.append("[red]n/a[/red]") + continue + tflops = flops / (ms * 1e-3) / 1e12 + + # Check if this method is still progressing + if name in prev_tflops and num_tokens >= 1024: + increase = (tflops - prev_tflops[name]) / max( + prev_tflops[name], 1e-12 + ) + if increase < 0.05: + active.discard(name) + + prev_tflops[name] = tflops + + if name not in active: + cells.append(f"[dim]{ms:.3f}[/dim]") + elif name == best_name: + cells.append(f"[bold green]{ms:.3f}[/bold green]") + else: + cells.append(f"{ms:.3f}") + + table.add_row( + str(num_tokens), + *cells, + best_name if results[best_name] != float("inf") else "n/a", + ) + + num_tokens *= 2 + + console.print(table) + return all_res + + all_results = run_sweep( + lambda fn: triton.testing.do_bench(fn, warmup=10, rep=50), + "Latency — Eager (ms)", + ) + + all_results_cg = run_sweep( + lambda fn: triton.testing.do_bench_cudagraph(fn), + "Latency — CUDA Graphs (ms)", + ) + + # ═══════════════════════════════════════════════════════════════════════════ + # Save CSV + # ═══════════════════════════════════════════════════════════════════════════ + import csv + + csv_path = f"moe_results_{variant}.csv" + with open(csv_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["tokens", "method", "mode", "time_ms", "tflops"]) + for mode, results_list in [ + ("eager", all_results), + ("cudagraph", all_results_cg), + ]: + for num_tokens, results in results_list: + S = num_tokens * top_k + flops = 2 * S * K * 2 * N_inter + 2 * S * N_inter * K + for name, ms in results.items(): + if ms == float("inf"): + continue + tflops = flops / (ms * 1e-3) / 1e12 + writer.writerow( + [num_tokens, name, mode, f"{ms:.3f}", f"{tflops:.2f}"] + ) + console.print(f"Results saved to [bold]{csv_path}[/bold]") + + # ═══════════════════════════════════════════════════════════════════════════ + # Plot + # ═══════════════════════════════════════════════════════════════════════════ + colors = { + "grouped": "#e74c3c", + "grouped_fused": "#2ecc71", + "grouped_atomic": "#3498db", + "batched": "#c0392b", + "batched_fused": "#27ae60", + "batched_atomic": "#2980b9", + } + linestyles = { + "grouped": "--", + "grouped_fused": "-", + "grouped_atomic": ":", + "batched": "--", + "batched_fused": "-", + "batched_atomic": ":", + } + + def compute_tflops(results_list): + data = {} + for name in names: + xs, ys = [], [] + for num_tokens, results in results_list: + S = num_tokens * top_k + flops = 2 * S * K * 2 * N_inter + 2 * S * N_inter * K + ms = results.get(name, float("inf")) + if ms == float("inf"): + continue + tflops = flops / (ms * 1e-3) / 1e12 + xs.append(num_tokens) + ys.append(tflops) + data[name] = (xs, ys) + return data + + tflops_eager = compute_tflops(all_results) + tflops_cg = compute_tflops(all_results_cg) + + fig, axes = plt.subplots(1, 2, figsize=(18, 7)) + + for col, (tflops_data, mode) in enumerate( + [(tflops_eager, "Eager"), (tflops_cg, "CUDA Graphs")] + ): + ax = axes[col] + for name in names: + xs, ys = tflops_data[name] + if xs: + ax.plot( + xs, + ys, + linestyles.get(name, "-"), + marker="o", + markersize=4, + color=colors.get(name, "gray"), + label=name, + linewidth=2, + ) + ax.set_xlabel("Tokens", fontsize=12) + ax.set_ylabel("TFLOPS", fontsize=12) + ax.set_title(f"{mode} — TFLOPS", fontsize=13, fontweight="bold") + ax.set_xscale("log", base=2) + ax.legend(fontsize=8, loc="best") + ax.grid(True, alpha=0.3) + + fig.suptitle( + f"FP8 MoE Expert Dispatch — {model_name} (E={E}, N_inter={N_inter}, K={K}, top_k={top_k})", + fontsize=14, + fontweight="bold", + ) + fig.tight_layout() + plot_path = f"moe_tflops_{variant}.png" + fig.savefig(plot_path, dpi=150) + console.print(f"\nPlot saved to [bold]{plot_path}[/bold]") + plt.close(fig) + + +if __name__ == "__main__": + main() diff --git a/finegrained-fp8/torch-ext/finegrained_fp8/__init__.py b/finegrained-fp8/torch-ext/finegrained_fp8/__init__.py index a4171be8..24ffef0a 100644 --- a/finegrained-fp8/torch-ext/finegrained_fp8/__init__.py +++ b/finegrained-fp8/torch-ext/finegrained_fp8/__init__.py @@ -14,6 +14,9 @@ w8a8_block_fp8_matmul, w8a8_tensor_fp8_matmul, ) +from .moe import moe_grouped, moe_batched +from .fused import moe_grouped_fused, moe_batched_fused +from .atomic import moe_grouped_atomic, moe_batched_atomic __all__ = [ "fp8_act_quant", @@ -29,4 +32,13 @@ "w8a8_fp8_matmul_grouped", "w8a8_block_fp8_matmul_grouped", "w8a8_tensor_fp8_matmul_grouped", + # End-to-end MoE (unfused) + "moe_grouped", + "moe_batched", + # End-to-end MoE (fused, deterministic) + "moe_grouped_fused", + "moe_batched_fused", + # End-to-end MoE (fused, atomic) + "moe_grouped_atomic", + "moe_batched_atomic", ] diff --git a/finegrained-fp8/torch-ext/finegrained_fp8/atomic.py b/finegrained-fp8/torch-ext/finegrained_fp8/atomic.py new file mode 100644 index 00000000..4284d66d --- /dev/null +++ b/finegrained-fp8/torch-ext/finegrained_fp8/atomic.py @@ -0,0 +1,655 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Atomic fused MoE: gate_up + SiLU + down in one pass with atomic split-K. + +The intermediate tensor NEVER hits HBM — gate_up, SiLU, FP8 quantization, +and down projection all happen in registers. Output is accumulated via +atomic_add (split-K across intermediate N-tiles). + +Non-deterministic due to atomic_add ordering. Faster than the two-kernel +deterministic path at small token counts (decode), but degrades at high +token counts due to atomic contention. +""" + +import torch +import triton +import triton.language as tl +from torch.library import triton_op, wrap_triton + +from .utils import device_context + + +@triton.autotune( + configs=[ + triton.Config({"GROUP_SIZE_M": g}, num_warps=w, num_stages=s) + for w in [2, 4, 8, 16] + for s in [2, 3, 4, 5] + for g in [1, 8] + ], + key=["N_inter", "K", "BLOCK_SIZE_M"], + reset_to_zero=["Out"], +) +@triton.jit +def moe_atomic_kernel( + # Unsorted inputs + A, # (num_tokens, K) raw BF16/FP16 — NOT sorted + Perm, # (S,) int64 — sorted_pos → original flat index + SampleWeights, # (S,) routing weights in sorted order + # Expert weights + W_gu, # (E, 2*N_inter, K) FP8 gate_up weights + W_down, # (E, hidden, N_inter) FP8 down weights + Ws_gu, # gate_up scales + Ws_down, # down scales + # Output + Out, # (S, hidden) — accumulated via atomic_add + # Expert scheduling + Offsets, + TileOffsets, + TileToExpert, # (max_M_tiles,) int32 — tile → expert lookup + # Shapes + N_inter, + K, + hidden, + num_top_k, + num_M_tiles, + # Strides — A, W_gu, Ws_gu, W_down, Ws_down, Out + stride_am, + stride_ak, + stride_be_gu, + stride_bk_gu, + stride_bn_gu, + stride_bs_e_gu, + stride_bs_k_gu, + stride_bs_n_gu, + stride_be_down, + stride_bk_down, + stride_bn_down, + stride_bs_e_down, + stride_bs_k_down, + stride_bs_n_down, + stride_om, + stride_oh, + # Constexprs + NUM_EXPERTS: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + NUM_H_TILES: tl.constexpr, + SIMULATE_UNFUSED: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Single fused MoE kernel: gather + gate_up + SiLU + down in one pass. + + Grid: (M-tiles, N-tiles). Each program: + 1. Gathers A from unsorted hidden_states via Perm + 2. Gate+up GEMM → SiLU → FP8 quant (intermediate in registers) + 3. Loops over H-tiles: down projection → atomic_add to output + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + num_N_tiles = tl.cdiv(N_inter, BLOCK_SIZE_N) + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_M_tiles, num_N_tiles, GROUP_SIZE_M) + + total_tiles = tl.load(TileOffsets + NUM_EXPERTS - 1) + if pid_m >= total_tiles: + return + + # O(1) tile → expert lookup + expert_id = tl.load(TileToExpert + pid_m).to(tl.int64) + + prev_eid = tl.maximum(expert_id - 1, 0) + expert_start = tl.where(expert_id == 0, 0, tl.load(Offsets + prev_eid)) + expert_end = tl.load(Offsets + expert_id) + M_expert = expert_end - expert_start + + expert_tile_start = tl.where(expert_id == 0, 0, tl.load(TileOffsets + prev_eid)) + local_tile = pid_m - expert_tile_start + m_off = local_tile * BLOCK_SIZE_M + + offs_am = m_off + tl.arange(0, BLOCK_SIZE_M) + row_mask = offs_am < M_expert + sorted_indices = expert_start + offs_am + + # Gather from unsorted A via Perm + perm_vals = tl.load(Perm + sorted_indices, mask=row_mask, other=0) + original_tokens = perm_vals // num_top_k + + # ── Gate + Up projection — block pointers for weight loads ── + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A + original_tokens[:, None] * stride_am + offs_k[None, :] * stride_ak + + b_gate_ptr = tl.make_block_ptr( + base=W_gu + expert_id * stride_be_gu, + shape=(K, N_inter * 2), + strides=(stride_bk_gu, stride_bn_gu), + offsets=(0, pid_n * BLOCK_SIZE_N), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), + order=(0, 1), + ) + b_up_ptr = tl.make_block_ptr( + base=W_gu + expert_id * stride_be_gu, + shape=(K, N_inter * 2), + strides=(stride_bk_gu, stride_bn_gu), + offsets=(0, N_inter + pid_n * BLOCK_SIZE_N), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), + order=(0, 1), + ) + + n_scale_blocks = N_inter // BLOCK_SIZE_N + bs_base = Ws_gu + expert_id * stride_bs_e_gu + bs_gate_ptrs = bs_base + pid_n * stride_bs_n_gu + bs_up_ptrs = bs_base + (n_scale_blocks + pid_n) * stride_bs_n_gu + + acc_gate = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + acc_up = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + b_gate = tl.load(b_gate_ptr) + b_up = tl.load(b_up_ptr) + bs_gate = tl.load(bs_gate_ptrs + k * stride_bs_k_gu) + bs_up = tl.load(bs_up_ptrs + k * stride_bs_k_gu) + + a_raw = tl.load(a_ptrs, mask=row_mask[:, None], other=0.0).to(tl.float32) + a_s = tl.max(tl.abs(a_raw), axis=1) / 448.0 + a = (a_raw / tl.maximum(a_s[:, None], 1e-12)).to(tl.float8e4nv) + + acc_gate += tl.dot(a, b_gate) * a_s[:, None] * bs_gate[None, :] + acc_up += tl.dot(a, b_up) * a_s[:, None] * bs_up[None, :] + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_gate_ptr = tl.advance(b_gate_ptr, (BLOCK_SIZE_K, 0)) + b_up_ptr = tl.advance(b_up_ptr, (BLOCK_SIZE_K, 0)) + + # ── SiLU(gate) * up ── + if SIMULATE_UNFUSED: + acc_gate = acc_gate.to(tl.bfloat16).to(tl.float32) + acc_up = acc_up.to(tl.bfloat16).to(tl.float32) + intermediate = (acc_gate * tl.sigmoid(acc_gate)).to(tl.bfloat16).to( + tl.float32 + ) * acc_up + intermediate = intermediate.to(tl.bfloat16).to(tl.float32) + else: + intermediate = acc_gate * tl.sigmoid(acc_gate) * acc_up + + # ── Quantize intermediate to FP8 ── + inter_s = tl.max(tl.abs(intermediate), axis=1) / 448.0 + inter_fp8 = (intermediate / tl.maximum(inter_s[:, None], 1e-12)).to(tl.float8e4nv) + + # ── Down projection + atomic accumulate — block pointer for weights ── + w_down_ptr = tl.make_block_ptr( + base=W_down + expert_id * stride_be_down, + shape=(N_inter, hidden), + strides=(stride_bk_down, stride_bn_down), + offsets=(pid_n * BLOCK_SIZE_N, 0), + block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_H), + order=(0, 1), + ) + + for h in range(0, NUM_H_TILES): + h_offs = h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + w_down = tl.load(w_down_ptr) + ws_down = tl.load( + Ws_down + + expert_id * stride_bs_e_down + + h * stride_bs_n_down + + pid_n * stride_bs_k_down + ) + + partial = tl.dot(inter_fp8, w_down) * inter_s[:, None] * ws_down + + out_ptrs = ( + Out + sorted_indices[:, None] * stride_om + h_offs[None, :] * stride_oh + ) + tl.atomic_add( + out_ptrs, + partial.to(Out.dtype.element_ty), + mask=row_mask[:, None], + sem="relaxed", + ) + w_down_ptr = tl.advance(w_down_ptr, (0, BLOCK_SIZE_H)) + + +# ── Wrapper ────────────────────────────────────────────────────────────────── + + +@triton_op("finegrained_fp8::moe_grouped_atomic", mutates_args=()) +def _moe_grouped_atomic( + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + gate_up_proj: torch.Tensor, + down_proj: torch.Tensor, + gate_up_proj_scale_inv: torch.Tensor, + down_proj_scale_inv: torch.Tensor, + block_size: list[int], + simulate_unfused: bool = False, +) -> torch.Tensor: + """Single-kernel fused MoE expert layer: non-deterministic (atomic split-K). + + Input: unsorted hidden_states + router outputs (top_k_index, top_k_weights) + Output: (num_tokens, hidden) — accumulated across top_k experts + + Pipeline: sort → ONE kernel (gather + gate_up + SiLU + down, atomic split-K) → routing + unsort + reduce + Non-deterministic: atomic_add across intermediate N-tiles. + """ + device = hidden_states.device + num_top_k = top_k_index.size(-1) + num_experts = gate_up_proj.size(0) + num_tokens = hidden_states.size(0) + hidden_dim = down_proj.size(1) + intermediate_dim = down_proj.size(2) + block_n, block_k = block_size + + # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) + sample_weights = top_k_weights.reshape(-1) # (S,) + expert_ids = top_k_index.reshape(-1) # (S,) + S = expert_ids.size(0) + + # Sort by expert for grouped processing + _, perm = expert_ids.sort(stable=True) + expert_ids_g = expert_ids[perm] + sample_weights_g = sample_weights[perm] + + # Compute offsets for grouped processing + histc_input = expert_ids_g.float() if device.type == "cpu" else expert_ids_g.int() + tokens_per_expert = torch.histc( + histc_input, bins=num_experts, min=0, max=num_experts - 1 + ) + offsets = torch.cumsum(tokens_per_expert, dim=0, dtype=torch.int32) + + # Tile setup — BLOCK_SIZE_M capped at 64 for better SM utilization with many experts + BLOCK_SIZE_M = min( + max(triton.next_power_of_2((S + num_experts - 1) // num_experts), 16), 64 + ) + tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M + tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32) + max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + num_experts + + # Tile-to-expert lookup via bucketize (CUDA-graph safe, replaces binary search) + tile_to_expert = torch.bucketize( + torch.arange(max_M_tiles, device=device, dtype=torch.int32), + tile_offsets, + right=True, + ) + + num_N_tiles = triton.cdiv(intermediate_dim, block_n) + + # fp32 output for atomic accumulation + proj_out = torch.zeros(S, hidden_dim, device=device, dtype=torch.float32) + + grid = (max_M_tiles, num_N_tiles) + with device_context(device): + wrap_triton(moe_atomic_kernel)[grid]( + hidden_states, + perm, + sample_weights_g, + gate_up_proj, + down_proj, + gate_up_proj_scale_inv, + down_proj_scale_inv, + proj_out, + offsets, + tile_offsets, + tile_to_expert, + intermediate_dim, + hidden_states.shape[1], + hidden_dim, + num_top_k, + max_M_tiles, + hidden_states.stride(0), + hidden_states.stride(1), + gate_up_proj.stride(0), + gate_up_proj.stride(2), + gate_up_proj.stride(1), + gate_up_proj_scale_inv.stride(0), + gate_up_proj_scale_inv.stride(2), + gate_up_proj_scale_inv.stride(1), + down_proj.stride(0), + down_proj.stride(2), + down_proj.stride(1), + down_proj_scale_inv.stride(0), + down_proj_scale_inv.stride(2), + down_proj_scale_inv.stride(1), + proj_out.stride(0), + proj_out.stride(1), + NUM_EXPERTS=num_experts, + BLOCK_SIZE_N=block_n, + BLOCK_SIZE_K=block_k, + BLOCK_SIZE_H=block_n, + BLOCK_SIZE_M=BLOCK_SIZE_M, + NUM_H_TILES=triton.cdiv(hidden_dim, block_n), + SIMULATE_UNFUSED=simulate_unfused, + ) + + # Apply routing weights + unsort + reduce + proj_out = proj_out.to(hidden_states.dtype) + weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze(-1) + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(S, device=device) + final_hidden_states = ( + weighted_out[inv_perm].view(num_tokens, num_top_k, hidden_dim).sum(dim=1) + ) + + return final_hidden_states.to(hidden_states.dtype) + + +def moe_grouped_atomic( + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + gate_up_proj: torch.Tensor, + down_proj: torch.Tensor, + gate_up_proj_scale_inv: torch.Tensor, + down_proj_scale_inv: torch.Tensor, + block_size: list[int], + simulate_unfused: bool = False, +) -> torch.Tensor: + """Single-kernel fused MoE expert layer: non-deterministic (atomic split-K). + + Input: unsorted hidden_states + router outputs (top_k_index, top_k_weights) + Output: (num_tokens, hidden) — accumulated across top_k experts + + Pipeline: sort → ONE kernel (gather + gate_up + SiLU + down, atomic split-K) → routing + unsort + reduce + Non-deterministic: atomic_add across intermediate N-tiles. Intermediate stays in registers. + """ + return torch.ops.finegrained_fp8.moe_grouped_atomic( + hidden_states, + top_k_index, + top_k_weights, + gate_up_proj, + down_proj, + gate_up_proj_scale_inv, + down_proj_scale_inv, + block_size, + simulate_unfused, + ) + + +# ── Batched atomic: gate_up + SiLU + down per token (atomic split-K) ──────── + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=w, num_stages=s) + for w in [4, 8, 16] + for s in [2, 3, 4] + ], + key=["N_inter", "K", "BLOCK_SIZE_M"], + reset_to_zero=["Out"], +) +@triton.jit +def moe_batched_atomic_kernel( + A, # (S, K) raw BF16/FP16 activations + W_gu, # (E, 2*N_inter, K) FP8 gate_up weights + W_down, # (E, hidden, N_inter) FP8 down weights + Out, # (S, hidden) output — accumulated via atomic_add + Ws_gu, # gate_up scales + Ws_down, # down scales + ExpertIds, # (S,) expert index per token + SampleWeights, # (S,) routing weights + # Shapes + N_inter, + K, + hidden, + # Strides — A + stride_am, + stride_ak, + # Strides — W_gu, Ws_gu + stride_be_gu, + stride_bk_gu, + stride_bn_gu, + stride_bs_e_gu, + stride_bs_k_gu, + stride_bs_n_gu, + # Strides — W_down, Ws_down + stride_be_down, + stride_bk_down, + stride_bn_down, + stride_bs_e_down, + stride_bs_k_down, + stride_bs_n_down, + # Strides — Out + stride_om, + stride_oh, + # Constexprs + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + NUM_H_TILES: tl.constexpr, + NUM_EXPERTS_BIT_LENGTH: tl.constexpr, + SIMULATE_UNFUSED: tl.constexpr, +): + """Batched atomic fused MoE kernel: gate_up + SiLU + down per token. + + Grid: (S, N-tiles). Each program handles one (token, N-tile): + 1. Gate+up GEMM → SiLU → FP8 quant (intermediate in registers) + 2. Loops over H-tiles: down projection → atomic_add to output + """ + batch_id = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + + expert_id = tl.load(ExpertIds + batch_id).to(tl.int64) + offs_m = tl.arange(0, BLOCK_SIZE_M) + + # ── Gate + Up projection ── + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = A + batch_id * stride_am + offs_k[None, :] * stride_ak + b_base = W_gu + expert_id * stride_be_gu + offs_k[:, None] * stride_bk_gu + b_gate_ptrs = b_base + offs_bn[None, :] * stride_bn_gu + b_up_ptrs = b_base + (N_inter + offs_bn)[None, :] * stride_bn_gu + + n_scale_blocks = N_inter // BLOCK_SIZE_N + bs_base = Ws_gu + expert_id * stride_bs_e_gu + bs_gate_ptrs = bs_base + pid_n * stride_bs_n_gu + bs_up_ptrs = bs_base + (n_scale_blocks + pid_n) * stride_bs_n_gu + + acc_gate = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + acc_up = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a_raw = tl.load(a_ptrs + offs_m[:, None] * 0).to(tl.float32) + a_s = tl.max(tl.abs(a_raw)) / 448.0 + a = (a_raw / tl.maximum(a_s, 1e-12)).to(tl.float8e4nv) + + b_gate = tl.load(b_gate_ptrs) + b_up = tl.load(b_up_ptrs) + bs_gate = tl.load(bs_gate_ptrs + k * stride_bs_k_gu) + bs_up = tl.load(bs_up_ptrs + k * stride_bs_k_gu) + + acc_gate += tl.dot(a, b_gate) * a_s * bs_gate[None, :] + acc_up += tl.dot(a, b_up) * a_s * bs_up[None, :] + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_gate_ptrs += BLOCK_SIZE_K * stride_bk_gu + b_up_ptrs += BLOCK_SIZE_K * stride_bk_gu + + # ── SiLU(gate) * up ── + if SIMULATE_UNFUSED: + acc_gate = acc_gate.to(tl.bfloat16).to(tl.float32) + acc_up = acc_up.to(tl.bfloat16).to(tl.float32) + intermediate = (acc_gate * tl.sigmoid(acc_gate)).to(tl.bfloat16).to( + tl.float32 + ) * acc_up + intermediate = intermediate.to(tl.bfloat16).to(tl.float32) + else: + intermediate = acc_gate * tl.sigmoid(acc_gate) * acc_up + + # ── Quantize intermediate to FP8 ── + inter_s = tl.max(tl.abs(intermediate)) / 448.0 + inter_fp8 = (intermediate / tl.maximum(inter_s, 1e-12)).to(tl.float8e4nv) + + # ── Down projection + atomic accumulate ── + offs_h = tl.arange(0, BLOCK_SIZE_H) + for h in range(0, NUM_H_TILES): + h_offs = h * BLOCK_SIZE_H + offs_h + w_down_ptrs = ( + W_down + + expert_id * stride_be_down + + offs_bn[:, None] * stride_bk_down + + h_offs[None, :] * stride_bn_down + ) + w_down = tl.load(w_down_ptrs) + ws_down = tl.load( + Ws_down + + expert_id * stride_bs_e_down + + h * stride_bs_n_down + + pid_n * stride_bs_k_down + ) + + partial = tl.dot(inter_fp8, w_down) * inter_s * ws_down + + out_ptrs = ( + Out + + batch_id * stride_om + + h_offs[None, :] * stride_oh + + offs_m[:, None] * 0 + ) + # Only write row 0 — all rows are identical (batched: 1 token per program) + row_mask = offs_m[:, None] == 0 + tl.atomic_add( + out_ptrs, partial.to(Out.dtype.element_ty), mask=row_mask, sem="relaxed" + ) + + +@triton_op("finegrained_fp8::moe_batched_atomic", mutates_args=()) +def _moe_batched_atomic( + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + gate_up_proj: torch.Tensor, + down_proj: torch.Tensor, + gate_up_proj_scale_inv: torch.Tensor, + down_proj_scale_inv: torch.Tensor, + block_size: list[int], + simulate_unfused: bool = False, +) -> torch.Tensor: + """Batched atomic fused MoE expert layer: non-deterministic (atomic split-K). + + Input: unsorted hidden_states + router outputs (top_k_index, top_k_weights) + Output: (num_tokens, hidden) — accumulated across top_k experts + + Pipeline: expand → ONE kernel (gate_up + SiLU + down per token, atomic split-K) → routing + reduce + Non-deterministic: atomic_add across intermediate N-tiles. Intermediate stays in registers. + """ + device = hidden_states.device + num_top_k = top_k_index.size(-1) + num_tokens = hidden_states.size(0) + hidden_dim = down_proj.shape[1] + intermediate_dim = down_proj.shape[2] + block_n, block_k = block_size + num_experts = gate_up_proj.shape[0] + + token_idx = ( + torch.arange(num_tokens, device=device) + .unsqueeze(1) + .expand(-1, num_top_k) + .reshape(-1) + ) + sample_weights = top_k_weights.reshape(-1) + expert_ids = top_k_index.reshape(-1) + + selected_hidden_states = hidden_states[token_idx] + S = expert_ids.size(0) + + BLOCK_SIZE_M = min( + max(triton.next_power_of_2((S + num_experts - 1) // num_experts), 16), 64 + ) + num_N_tiles = triton.cdiv(intermediate_dim, block_n) + + Out = torch.zeros(S, hidden_dim, device=device, dtype=torch.float32) + grid = (S, num_N_tiles) + + with device_context(device): + wrap_triton(moe_batched_atomic_kernel)[grid]( + selected_hidden_states, + gate_up_proj, + down_proj, + Out, + gate_up_proj_scale_inv, + down_proj_scale_inv, + expert_ids, + sample_weights, + intermediate_dim, + hidden_states.shape[1], + hidden_dim, + selected_hidden_states.stride(0), + selected_hidden_states.stride(1), + gate_up_proj.stride(0), + gate_up_proj.stride(2), + gate_up_proj.stride(1), + gate_up_proj_scale_inv.stride(0), + gate_up_proj_scale_inv.stride(2), + gate_up_proj_scale_inv.stride(1), + down_proj.stride(0), + down_proj.stride(2), + down_proj.stride(1), + down_proj_scale_inv.stride(0), + down_proj_scale_inv.stride(2), + down_proj_scale_inv.stride(1), + Out.stride(0), + Out.stride(1), + BLOCK_SIZE_N=block_n, + BLOCK_SIZE_K=block_k, + BLOCK_SIZE_H=block_n, + BLOCK_SIZE_M=BLOCK_SIZE_M, + NUM_H_TILES=triton.cdiv(hidden_dim, block_n), + NUM_EXPERTS_BIT_LENGTH=num_experts.bit_length(), + SIMULATE_UNFUSED=simulate_unfused, + ) + + # Apply routing weights + reduce + Out = Out.to(hidden_states.dtype) + weighted_out = Out * sample_weights.to(Out.dtype).unsqueeze(-1) + final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum( + dim=1 + ) + + return final_hidden_states.to(hidden_states.dtype) + + +def moe_batched_atomic( + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + gate_up_proj: torch.Tensor, + down_proj: torch.Tensor, + gate_up_proj_scale_inv: torch.Tensor, + down_proj_scale_inv: torch.Tensor, + block_size: list[int], + simulate_unfused: bool = False, +) -> torch.Tensor: + """Batched atomic fused MoE expert layer: non-deterministic (atomic split-K). + + Input: unsorted hidden_states + router outputs (top_k_index, top_k_weights) + Output: (num_tokens, hidden) — accumulated across top_k experts + + Pipeline: expand → ONE kernel (gate_up + SiLU + down per token, atomic split-K) → routing + reduce + Non-deterministic: atomic_add across intermediate N-tiles. Intermediate stays in registers. + """ + return torch.ops.finegrained_fp8.moe_batched_atomic( + hidden_states, + top_k_index, + top_k_weights, + gate_up_proj, + down_proj, + gate_up_proj_scale_inv, + down_proj_scale_inv, + block_size, + simulate_unfused, + ) diff --git a/finegrained-fp8/torch-ext/finegrained_fp8/fused.py b/finegrained-fp8/torch-ext/finegrained_fp8/fused.py new file mode 100644 index 00000000..0ba1da9f --- /dev/null +++ b/finegrained-fp8/torch-ext/finegrained_fp8/fused.py @@ -0,0 +1,862 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fused MoE: two-kernel deterministic approach. + +Kernel 1 (M×N grid): gather + gate_up + SiLU + FP8 quant → fp8 intermediate buffer +Kernel 2 (M×H grid): fp8 intermediate → down projection → output + +Both kernels are deterministic (no atomic_add). The fp8 intermediate buffer +is ~half the size of the bf16 intermediate in the unfused path. + +The full pipeline: + 1. histc + cumsum (expert offsets) + 2. sort (expert grouping) + 3. Kernel 1: gather → gate_up → SiLU → FP8 quant → (S, N_inter) fp8 + scales + 4. Kernel 2: fp8 intermediate → down_proj → (S, hidden) + 5. routing weights + unsort + top_k reduce +""" + +import torch +import triton +import triton.language as tl +from torch.library import triton_op, wrap_triton + +from .utils import device_context + + +# ── Kernel 1: gather + gate_up + SiLU + FP8 quant ────────────────────────── +# +# Optimizations: block pointers for weight loads, L2 swizzle (GROUP_SIZE_M), +# tile-to-expert via bucketize (replaces binary search), BLOCK_SIZE_M capped at 64. + + +@triton.autotune( + configs=[ + triton.Config({"GROUP_SIZE_M": g}, num_warps=w, num_stages=s) + for w in [2, 4, 8, 16] + for s in [2, 3, 4, 5] + for g in [1, 8] + ], + key=["N_inter", "K", "BLOCK_SIZE_M"], +) +@triton.jit +def fused_gate_up_silu_kernel( + # Unsorted inputs + A, # (num_tokens, K) raw BF16/FP16 — NOT sorted + Perm, # (S,) int64 — sorted_pos → original flat index + # Expert weights + W_gu, # (E, 2*N_inter, K) FP8 gate_up weights + Ws_gu, # gate_up scales + # Outputs + Inter, # (S, N_inter) FP8 intermediate + Inter_s, # (S,) fp32 per-row scales + # Expert scheduling + Offsets, + TileOffsets, + TileToExpert, # (max_M_tiles,) int32 — tile → expert lookup + # Shapes + N_inter, + K, + num_top_k, + num_M_tiles, + # Strides — A, W_gu, Ws_gu, Inter + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_bs_e, + stride_bs_k, + stride_bs_n, + stride_im, + stride_in, + # Constexprs + NUM_EXPERTS: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + SIMULATE_UNFUSED: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Kernel 1: gather A from unsorted → gate_up GEMM → SiLU → FP8 quant. + + Grid: (M-tiles, N-tiles). Each program writes one (BLOCK_M, BLOCK_N) + tile of the fp8 intermediate + per-row scales. No atomics. + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + num_N_tiles = tl.cdiv(N_inter, BLOCK_SIZE_N) + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_M_tiles, num_N_tiles, GROUP_SIZE_M) + + total_tiles = tl.load(TileOffsets + NUM_EXPERTS - 1) + if pid_m >= total_tiles: + return + + # O(1) tile → expert lookup + expert_id = tl.load(TileToExpert + pid_m).to(tl.int64) + + prev_eid = tl.maximum(expert_id - 1, 0) + expert_start = tl.where(expert_id == 0, 0, tl.load(Offsets + prev_eid)) + expert_end = tl.load(Offsets + expert_id) + M_expert = expert_end - expert_start + + expert_tile_start = tl.where(expert_id == 0, 0, tl.load(TileOffsets + prev_eid)) + local_tile = pid_m - expert_tile_start + m_off = local_tile * BLOCK_SIZE_M + + offs_am = m_off + tl.arange(0, BLOCK_SIZE_M) + row_mask = offs_am < M_expert + sorted_indices = expert_start + offs_am + + # Gather from unsorted A via Perm + perm_vals = tl.load(Perm + sorted_indices, mask=row_mask, other=0) + original_tokens = perm_vals // num_top_k + + # Gate + Up projection + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = A + original_tokens[:, None] * stride_am + offs_k[None, :] * stride_ak + + b_gate_ptr = tl.make_block_ptr( + base=W_gu + expert_id * stride_be, + shape=(K, N_inter * 2), + strides=(stride_bk, stride_bn), + offsets=(0, pid_n * BLOCK_SIZE_N), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), + order=(0, 1), + ) + b_up_ptr = tl.make_block_ptr( + base=W_gu + expert_id * stride_be, + shape=(K, N_inter * 2), + strides=(stride_bk, stride_bn), + offsets=(0, N_inter + pid_n * BLOCK_SIZE_N), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), + order=(0, 1), + ) + + n_scale_blocks = N_inter // BLOCK_SIZE_N + bs_base = Ws_gu + expert_id * stride_bs_e + bs_gate_ptrs = bs_base + pid_n * stride_bs_n + bs_up_ptrs = bs_base + (n_scale_blocks + pid_n) * stride_bs_n + + acc_gate = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + acc_up = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + b_gate = tl.load(b_gate_ptr) + b_up = tl.load(b_up_ptr) + bs_gate = tl.load(bs_gate_ptrs + k * stride_bs_k) + bs_up = tl.load(bs_up_ptrs + k * stride_bs_k) + + a_raw = tl.load(a_ptrs, mask=row_mask[:, None], other=0.0).to(tl.float32) + a_s = tl.max(tl.abs(a_raw), axis=1) / 448.0 + a = (a_raw / tl.maximum(a_s[:, None], 1e-12)).to(tl.float8e4nv) + + acc_gate += tl.dot(a, b_gate) * a_s[:, None] * bs_gate[None, :] + acc_up += tl.dot(a, b_up) * a_s[:, None] * bs_up[None, :] + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_gate_ptr = tl.advance(b_gate_ptr, (BLOCK_SIZE_K, 0)) + b_up_ptr = tl.advance(b_up_ptr, (BLOCK_SIZE_K, 0)) + + # SiLU(gate) * up + if SIMULATE_UNFUSED: + acc_gate = acc_gate.to(tl.bfloat16).to(tl.float32) + acc_up = acc_up.to(tl.bfloat16).to(tl.float32) + intermediate = (acc_gate * tl.sigmoid(acc_gate)).to(tl.bfloat16).to( + tl.float32 + ) * acc_up + intermediate = intermediate.to(tl.bfloat16).to(tl.float32) + else: + intermediate = acc_gate * tl.sigmoid(acc_gate) * acc_up + + # FP8 quantize — per-row scale across this N-tile + inter_s = tl.max(tl.abs(intermediate), axis=1) / 448.0 + inter_fp8 = (intermediate / tl.maximum(inter_s[:, None], 1e-12)).to(tl.float8e4nv) + + # Store fp8 intermediate tile + inter_ptrs = ( + Inter + sorted_indices[:, None] * stride_im + offs_bn[None, :] * stride_in + ) + tl.store(inter_ptrs, inter_fp8, mask=row_mask[:, None]) + + # Store per-row scale (one per row per N-tile) + scale_ptrs = Inter_s + sorted_indices * tl.cdiv(N_inter, BLOCK_SIZE_N) + pid_n + tl.store(scale_ptrs, inter_s, mask=row_mask) + + +# ── Kernel 2: down projection from fp8 intermediate ───────────────────────── + + +@triton.autotune( + configs=[ + triton.Config({"GROUP_SIZE_M": g}, num_warps=w, num_stages=s) + for w in [2, 4, 8, 16] + for s in [2, 3, 4, 5] + for g in [1, 8] + ], + key=["N_inter", "hidden", "BLOCK_SIZE_M"], +) +@triton.jit +def fused_down_proj_kernel( + # Inputs + Inter, # (S, N_inter) FP8 intermediate + Inter_s, # (S, num_N_tiles) fp32 per-row-per-N-tile scales + W_down, # (E, hidden, N_inter) FP8 down weights + Ws_down, # down scales + SampleWeights, # (S,) routing weights in sorted order + Perm, # (S,) int64 — sorted_pos → original flat index + # Output + Out, # (num_tokens * top_k, hidden) output in original flat order + # Expert scheduling + Offsets, + TileOffsets, + TileToExpert, # (max_M_tiles,) int32 — tile → expert lookup + # Shapes + N_inter, + hidden, + num_M_tiles, + # Strides — Inter, W_down, Ws_down, Out + stride_im, + stride_in, + stride_be, + stride_bk, + stride_bn, + stride_bs_e, + stride_bs_k, + stride_bs_n, + stride_om, + stride_oh, + # Constexprs + NUM_EXPERTS: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + NUM_N_TILES: tl.constexpr, + SIMULATE_UNFUSED: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Kernel 2: fp8 intermediate → down_proj → output. + + Grid: (M-tiles, H-tiles). Each program reads (BLOCK_M, N_inter) fp8 + intermediate, does the down projection tiled over N, stores (BLOCK_M, BLOCK_H). + No atomics — deterministic. + """ + pid_m = tl.program_id(axis=0) + pid_h = tl.program_id(axis=1) + num_H_tiles = tl.cdiv(hidden, BLOCK_SIZE_H) + pid_m, pid_h = tl.swizzle2d(pid_m, pid_h, num_M_tiles, num_H_tiles, GROUP_SIZE_M) + + total_tiles = tl.load(TileOffsets + NUM_EXPERTS - 1) + if pid_m >= total_tiles: + return + + # O(1) tile → expert lookup + expert_id = tl.load(TileToExpert + pid_m).to(tl.int64) + + prev_eid = tl.maximum(expert_id - 1, 0) + expert_start = tl.where(expert_id == 0, 0, tl.load(Offsets + prev_eid)) + expert_end = tl.load(Offsets + expert_id) + M_expert = expert_end - expert_start + + expert_tile_start = tl.where(expert_id == 0, 0, tl.load(TileOffsets + prev_eid)) + local_tile = pid_m - expert_tile_start + m_off = local_tile * BLOCK_SIZE_M + + offs_am = m_off + tl.arange(0, BLOCK_SIZE_M) + row_mask = offs_am < M_expert + sorted_indices = expert_start + offs_am + + offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + + # Block pointer for down weights + w_down_ptr = tl.make_block_ptr( + base=W_down + expert_id * stride_be, + shape=(N_inter, hidden), + strides=(stride_bk, stride_bn), + offsets=(0, pid_h * BLOCK_SIZE_H), + block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_H), + order=(0, 1), + ) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_H), dtype=tl.float32) + + for n_tile in range(0, NUM_N_TILES): + n_offs = n_tile * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + # Load fp8 intermediate tile + inter_ptrs = ( + Inter + sorted_indices[:, None] * stride_im + n_offs[None, :] * stride_in + ) + inter_fp8 = tl.load(inter_ptrs, mask=row_mask[:, None], other=0.0) + + # Load per-row scale for this N-tile + scale_ptrs = Inter_s + sorted_indices * NUM_N_TILES + n_tile + inter_s = tl.load(scale_ptrs, mask=row_mask, other=0.0) + + # Load down weights via block pointer + w_down = tl.load(w_down_ptr) + ws_down = tl.load( + Ws_down + + expert_id * stride_bs_e + + pid_h * stride_bs_n + + n_tile * stride_bs_k + ) + + acc += tl.dot(inter_fp8, w_down) * inter_s[:, None] * ws_down + w_down_ptr = tl.advance(w_down_ptr, (BLOCK_SIZE_N, 0)) + + # Apply routing weights and scatter to original flat order via Perm + if SIMULATE_UNFUSED: + acc = acc.to(tl.bfloat16).to(tl.float32) + routing_w = tl.load(SampleWeights + sorted_indices, mask=row_mask, other=0.0) + acc = acc * routing_w[:, None] + original_flat_idx = tl.load(Perm + sorted_indices, mask=row_mask, other=0) + out_ptrs = ( + Out + original_flat_idx[:, None] * stride_om + offs_h[None, :] * stride_oh + ) + tl.store(out_ptrs, acc.to(Out.dtype.element_ty), mask=row_mask[:, None]) + + +# ── Wrapper ────────────────────────────────────────────────────────────────── + + +@triton_op("finegrained_fp8::moe_grouped_fused", mutates_args=()) +def _moe_grouped_fused( + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + gate_up_proj: torch.Tensor, + down_proj: torch.Tensor, + gate_up_proj_scale_inv: torch.Tensor, + down_proj_scale_inv: torch.Tensor, + block_size: list[int], + simulate_unfused: bool = False, +) -> torch.Tensor: + """Two-kernel fused MoE expert layer: deterministic, no atomics. + + Input: unsorted hidden_states + router outputs (top_k_index, top_k_weights) + Output: (num_tokens, hidden) — accumulated across top_k experts + + Pipeline: sort → fused_gate_up_silu_kernel → fused_down_proj_kernel → routing + unsort + reduce + Deterministic: no atomic_add, intermediate goes through HBM as fp8. + """ + device = hidden_states.device + num_top_k = top_k_index.size(-1) + num_experts = gate_up_proj.size(0) + num_tokens = hidden_states.size(0) + hidden_dim = down_proj.size(1) + intermediate_dim = down_proj.size(2) + block_n, block_k = block_size + + # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) + sample_weights = top_k_weights.reshape(-1) # (S,) + expert_ids = top_k_index.reshape(-1) # (S,) + S = expert_ids.size(0) + + # Sort by expert for grouped processing + _, perm = expert_ids.sort(stable=True) + expert_ids_g = expert_ids[perm] + sample_weights_g = sample_weights[perm] + + # Compute offsets for grouped processing + histc_input = expert_ids_g.float() if device.type == "cpu" else expert_ids_g.int() + tokens_per_expert = torch.histc( + histc_input, bins=num_experts, min=0, max=num_experts - 1 + ) + offsets = torch.cumsum(tokens_per_expert, dim=0, dtype=torch.int32) + + # Tile setup — BLOCK_SIZE_M capped at 64 for better SM utilization with many experts + BLOCK_SIZE_M = min( + max(triton.next_power_of_2((S + num_experts - 1) // num_experts), 16), 64 + ) + tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M + tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32) + max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + num_experts + tile_to_expert = torch.bucketize( + torch.arange(max_M_tiles, device=device, dtype=torch.int32), + tile_offsets, + right=True, + ) + + num_N_tiles = triton.cdiv(intermediate_dim, block_n) + num_H_tiles = triton.cdiv(hidden_dim, block_n) + + # Temp buffer: fp8 intermediate + per-row-per-N-tile scales + inter_fp8 = torch.empty( + S, intermediate_dim, device=device, dtype=torch.float8_e4m3fn + ) + inter_scales = torch.empty(S, num_N_tiles, device=device, dtype=torch.float32) + + # --- Kernel 1: gate_up + SiLU + FP8 quant --- + grid1 = (max_M_tiles, num_N_tiles) + with device_context(device): + wrap_triton(fused_gate_up_silu_kernel)[grid1]( + hidden_states, + perm, + gate_up_proj, + gate_up_proj_scale_inv, + inter_fp8, + inter_scales, + offsets, + tile_offsets, + tile_to_expert, + intermediate_dim, + hidden_states.shape[1], + num_top_k, + max_M_tiles, + hidden_states.stride(0), + hidden_states.stride(1), + gate_up_proj.stride(0), + gate_up_proj.stride(2), + gate_up_proj.stride(1), + gate_up_proj_scale_inv.stride(0), + gate_up_proj_scale_inv.stride(2), + gate_up_proj_scale_inv.stride(1), + inter_fp8.stride(0), + inter_fp8.stride(1), + NUM_EXPERTS=num_experts, + BLOCK_SIZE_N=block_n, + BLOCK_SIZE_K=block_k, + BLOCK_SIZE_M=BLOCK_SIZE_M, + SIMULATE_UNFUSED=simulate_unfused, + ) + + # --- Kernel 2: down projection + routing weights + scatter to original order --- + proj_out = torch.empty(S, hidden_dim, device=device, dtype=hidden_states.dtype) + grid2 = (max_M_tiles, num_H_tiles) + with device_context(device): + wrap_triton(fused_down_proj_kernel)[grid2]( + inter_fp8, + inter_scales, + down_proj, + down_proj_scale_inv, + sample_weights_g, + perm, + proj_out, + offsets, + tile_offsets, + tile_to_expert, + intermediate_dim, + hidden_dim, + max_M_tiles, + inter_fp8.stride(0), + inter_fp8.stride(1), + down_proj.stride(0), + down_proj.stride(2), + down_proj.stride(1), + down_proj_scale_inv.stride(0), + down_proj_scale_inv.stride(2), + down_proj_scale_inv.stride(1), + proj_out.stride(0), + proj_out.stride(1), + NUM_EXPERTS=num_experts, + BLOCK_SIZE_N=block_n, + BLOCK_SIZE_H=block_n, + BLOCK_SIZE_M=BLOCK_SIZE_M, + NUM_N_TILES=num_N_tiles, + SIMULATE_UNFUSED=simulate_unfused, + ) + + # Output already in original flat order — just reduce across top_k + final_hidden_states = proj_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1) + + return final_hidden_states.to(hidden_states.dtype) + + +def moe_grouped_fused( + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + gate_up_proj: torch.Tensor, + down_proj: torch.Tensor, + gate_up_proj_scale_inv: torch.Tensor, + down_proj_scale_inv: torch.Tensor, + block_size: list[int], + simulate_unfused: bool = False, +) -> torch.Tensor: + """Two-kernel fused MoE expert layer: deterministic, no atomics. + + Input: unsorted hidden_states + router outputs (top_k_index, top_k_weights) + Output: (num_tokens, hidden) — accumulated across top_k experts + + Pipeline: sort → fused_gate_up_silu_kernel → fused_down_proj_kernel → routing + unsort + reduce + Deterministic: no atomic_add, intermediate goes through HBM as fp8. + """ + return torch.ops.finegrained_fp8.moe_grouped_fused( + hidden_states, + top_k_index, + top_k_weights, + gate_up_proj, + down_proj, + gate_up_proj_scale_inv, + down_proj_scale_inv, + block_size, + simulate_unfused, + ) + + +# ── Batched fused: two-kernel approach (no sorting, no atomics) ────────────── +# +# Same two-kernel architecture as grouped fused but with per-token dispatch: +# Kernel 1: (S, N-tiles) — gate_up + SiLU + FP8 quant → intermediate buffer +# Kernel 2: (S, H-tiles) — fp8 intermediate → down proj → output +# No sorting needed — expert lookup is per-token via ExpertIds. + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=w, num_stages=s) + for w in [4, 8, 16] + for s in [2, 3, 4] + ], + key=["N_inter", "K", "BLOCK_SIZE_M"], +) +@triton.jit +def batched_gate_up_silu_kernel( + A, + W_gu, + Ws_gu, + Inter, + Inter_s, + ExpertIds, + N_inter, + K, + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_bs_e, + stride_bs_k, + stride_bs_n, + stride_im, + stride_in, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + SIMULATE_UNFUSED: tl.constexpr, +): + """Batched kernel 1: per-token gate_up + SiLU + FP8 quant. Grid: (S, N-tiles).""" + batch_id = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + + expert_id = tl.load(ExpertIds + batch_id).to(tl.int64) + offs_m = tl.arange(0, BLOCK_SIZE_M) + offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + a_ptrs = A + batch_id * stride_am + offs_k[None, :] * stride_ak + + b_gate_ptr = tl.make_block_ptr( + base=W_gu + expert_id * stride_be, + shape=(K, N_inter * 2), + strides=(stride_bk, stride_bn), + offsets=(0, pid_n * BLOCK_SIZE_N), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), + order=(0, 1), + ) + b_up_ptr = tl.make_block_ptr( + base=W_gu + expert_id * stride_be, + shape=(K, N_inter * 2), + strides=(stride_bk, stride_bn), + offsets=(0, N_inter + pid_n * BLOCK_SIZE_N), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), + order=(0, 1), + ) + + n_scale_blocks = N_inter // BLOCK_SIZE_N + bs_base = Ws_gu + expert_id * stride_bs_e + bs_gate_ptrs = bs_base + pid_n * stride_bs_n + bs_up_ptrs = bs_base + (n_scale_blocks + pid_n) * stride_bs_n + + acc_gate = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + acc_up = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a_raw = tl.load(a_ptrs + offs_m[:, None] * 0).to(tl.float32) + a_s = tl.max(tl.abs(a_raw)) / 448.0 + a = (a_raw / tl.maximum(a_s, 1e-12)).to(tl.float8e4nv) + + b_gate = tl.load(b_gate_ptr) + b_up = tl.load(b_up_ptr) + bs_gate = tl.load(bs_gate_ptrs + k * stride_bs_k) + bs_up = tl.load(bs_up_ptrs + k * stride_bs_k) + + acc_gate += tl.dot(a, b_gate) * a_s * bs_gate[None, :] + acc_up += tl.dot(a, b_up) * a_s * bs_up[None, :] + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_gate_ptr = tl.advance(b_gate_ptr, (BLOCK_SIZE_K, 0)) + b_up_ptr = tl.advance(b_up_ptr, (BLOCK_SIZE_K, 0)) + + if SIMULATE_UNFUSED: + acc_gate = acc_gate.to(tl.bfloat16).to(tl.float32) + acc_up = acc_up.to(tl.bfloat16).to(tl.float32) + intermediate = (acc_gate * tl.sigmoid(acc_gate)).to(tl.bfloat16).to( + tl.float32 + ) * acc_up + intermediate = intermediate.to(tl.bfloat16).to(tl.float32) + else: + intermediate = acc_gate * tl.sigmoid(acc_gate) * acc_up + + inter_s = tl.max(tl.abs(intermediate)) / 448.0 + inter_fp8 = (intermediate / tl.maximum(inter_s, 1e-12)).to(tl.float8e4nv) + + inter_ptrs = ( + Inter + + batch_id * stride_im + + offs_bn[None, :] * stride_in + + offs_m[:, None] * 0 + ) + tl.store(inter_ptrs, inter_fp8) + + num_N_tiles = tl.cdiv(N_inter, BLOCK_SIZE_N) + tl.store(Inter_s + batch_id * num_N_tiles + pid_n, inter_s) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=w, num_stages=s) + for w in [4, 8, 16] + for s in [2, 3, 4] + ], + key=["N_inter", "hidden", "BLOCK_SIZE_M"], +) +@triton.jit +def batched_down_proj_kernel( + Inter, + Inter_s, + W_down, + Ws_down, + ExpertIds, + SampleWeights, + Out, + N_inter, + hidden, + stride_im, + stride_in, + stride_be, + stride_bk, + stride_bn, + stride_bs_e, + stride_bs_k, + stride_bs_n, + stride_om, + stride_oh, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + NUM_N_TILES: tl.constexpr, + SIMULATE_UNFUSED: tl.constexpr, +): + """Batched kernel 2: fp8 intermediate → down proj → output. Grid: (S, H-tiles).""" + batch_id = tl.program_id(axis=0) + pid_h = tl.program_id(axis=1) + + expert_id = tl.load(ExpertIds + batch_id).to(tl.int64) + offs_m = tl.arange(0, BLOCK_SIZE_M) + offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + + w_down_ptr = tl.make_block_ptr( + base=W_down + expert_id * stride_be, + shape=(N_inter, hidden), + strides=(stride_bk, stride_bn), + offsets=(0, pid_h * BLOCK_SIZE_H), + block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_H), + order=(0, 1), + ) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_H), dtype=tl.float32) + + for n_tile in range(0, NUM_N_TILES): + n_offs = n_tile * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + inter_ptrs = ( + Inter + + batch_id * stride_im + + n_offs[None, :] * stride_in + + offs_m[:, None] * 0 + ) + inter_fp8 = tl.load(inter_ptrs) + + inter_s = tl.load(Inter_s + batch_id * NUM_N_TILES + n_tile) + + w_down = tl.load(w_down_ptr) + ws_down = tl.load( + Ws_down + + expert_id * stride_bs_e + + pid_h * stride_bs_n + + n_tile * stride_bs_k + ) + + acc += tl.dot(inter_fp8, w_down) * inter_s * ws_down + w_down_ptr = tl.advance(w_down_ptr, (BLOCK_SIZE_N, 0)) + + if SIMULATE_UNFUSED: + acc = acc.to(tl.bfloat16).to(tl.float32) + routing_w = tl.load(SampleWeights + batch_id) + acc = acc * routing_w + + if Out.dtype.element_ty == tl.bfloat16: + c = acc.to(tl.bfloat16) + elif Out.dtype.element_ty == tl.float16: + c = acc.to(tl.float16) + else: + c = acc.to(tl.float32) + + c_ptrs = ( + Out + batch_id * stride_om + offs_h[None, :] * stride_oh + offs_m[:, None] * 0 + ) + tl.store(c_ptrs, c) + + +@triton_op("finegrained_fp8::moe_batched_fused", mutates_args=()) +def _moe_batched_fused( + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + gate_up_proj: torch.Tensor, + down_proj: torch.Tensor, + gate_up_proj_scale_inv: torch.Tensor, + down_proj_scale_inv: torch.Tensor, + block_size: list[int], + simulate_unfused: bool = False, +) -> torch.Tensor: + """Two-kernel batched fused MoE: deterministic, no sorting, no atomics. + + Pipeline: expand → kernel 1 (gate_up + SiLU + FP8 quant) → kernel 2 (down proj + routing) → reduce + """ + device = hidden_states.device + num_top_k = top_k_index.size(-1) + num_tokens = hidden_states.size(0) + hidden_dim = down_proj.shape[1] + intermediate_dim = down_proj.shape[2] + block_n, block_k = block_size + + token_idx = ( + torch.arange(num_tokens, device=device) + .unsqueeze(1) + .expand(-1, num_top_k) + .reshape(-1) + ) + sample_weights = top_k_weights.reshape(-1) + expert_ids = top_k_index.reshape(-1) + selected_hidden_states = hidden_states[token_idx] + S = expert_ids.size(0) + + num_N_tiles = triton.cdiv(intermediate_dim, block_n) + num_H_tiles = triton.cdiv(hidden_dim, block_n) + + inter_fp8 = torch.empty( + S, intermediate_dim, device=device, dtype=torch.float8_e4m3fn + ) + inter_scales = torch.empty(S, num_N_tiles, device=device, dtype=torch.float32) + + # Kernel 1: gate_up + SiLU + FP8 quant — grid (S, N-tiles) + BLOCK_SIZE_M = min( + max( + triton.next_power_of_2( + (S + gate_up_proj.shape[0] - 1) // gate_up_proj.shape[0] + ), + 16, + ), + 64, + ) + grid1 = (S, num_N_tiles) + with device_context(device): + wrap_triton(batched_gate_up_silu_kernel)[grid1]( + selected_hidden_states, + gate_up_proj, + gate_up_proj_scale_inv, + inter_fp8, + inter_scales, + expert_ids, + intermediate_dim, + hidden_states.shape[1], + selected_hidden_states.stride(0), + selected_hidden_states.stride(1), + gate_up_proj.stride(0), + gate_up_proj.stride(2), + gate_up_proj.stride(1), + gate_up_proj_scale_inv.stride(0), + gate_up_proj_scale_inv.stride(2), + gate_up_proj_scale_inv.stride(1), + inter_fp8.stride(0), + inter_fp8.stride(1), + BLOCK_SIZE_N=block_n, + BLOCK_SIZE_K=block_k, + BLOCK_SIZE_M=BLOCK_SIZE_M, + SIMULATE_UNFUSED=simulate_unfused, + ) + + # Kernel 2: down proj + routing — grid (S, H-tiles) + Out = selected_hidden_states.new_empty(S, hidden_dim) + grid2 = (S, num_H_tiles) + with device_context(device): + wrap_triton(batched_down_proj_kernel)[grid2]( + inter_fp8, + inter_scales, + down_proj, + down_proj_scale_inv, + expert_ids, + sample_weights, + Out, + intermediate_dim, + hidden_dim, + inter_fp8.stride(0), + inter_fp8.stride(1), + down_proj.stride(0), + down_proj.stride(2), + down_proj.stride(1), + down_proj_scale_inv.stride(0), + down_proj_scale_inv.stride(2), + down_proj_scale_inv.stride(1), + Out.stride(0), + Out.stride(1), + BLOCK_SIZE_N=block_n, + BLOCK_SIZE_H=block_n, + BLOCK_SIZE_M=BLOCK_SIZE_M, + NUM_N_TILES=num_N_tiles, + SIMULATE_UNFUSED=simulate_unfused, + ) + + final_hidden_states = Out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1) + return final_hidden_states.to(hidden_states.dtype) + + +def moe_batched_fused( + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + gate_up_proj: torch.Tensor, + down_proj: torch.Tensor, + gate_up_proj_scale_inv: torch.Tensor, + down_proj_scale_inv: torch.Tensor, + block_size: list[int], + simulate_unfused: bool = False, +) -> torch.Tensor: + """Two-kernel batched fused MoE: deterministic, no sorting, no atomics.""" + return torch.ops.finegrained_fp8.moe_batched_fused( + hidden_states, + top_k_index, + top_k_weights, + gate_up_proj, + down_proj, + gate_up_proj_scale_inv, + down_proj_scale_inv, + block_size, + simulate_unfused, + ) diff --git a/finegrained-fp8/torch-ext/finegrained_fp8/moe.py b/finegrained-fp8/torch-ext/finegrained_fp8/moe.py new file mode 100644 index 00000000..aa53f000 --- /dev/null +++ b/finegrained-fp8/torch-ext/finegrained_fp8/moe.py @@ -0,0 +1,189 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end MoE expert layer using unfused grouped GEMM primitives. + +Pipeline: + 1. Sort tokens by expert (histc + cumsum + argsort) + 2. Gate+up grouped GEMM → bf16 intermediate + 3. SiLU(gate) * up → bf16 intermediate + 4. Down grouped GEMM → bf16 output + 5. Apply routing weights + unsort + top_k reduce + +All operations are deterministic (no atomic_add). This serves as the +baseline implementation; see fused.py for the fused variant that +merges steps 2-3 into a single kernel. +""" + +import torch +import torch.nn.functional as F + +from .batched import w8a8_block_fp8_matmul_batched +from .grouped import w8a8_block_fp8_matmul_grouped + + +def moe_grouped( + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + gate_up_proj: torch.Tensor, + down_proj: torch.Tensor, + gate_up_proj_scale_inv: torch.Tensor, + down_proj_scale_inv: torch.Tensor, + block_size: list[int], +) -> torch.Tensor: + """End-to-end MoE expert layer using unfused grouped GEMM primitives. + + Input: unsorted hidden_states + router outputs (top_k_index, top_k_weights) + Output: (num_tokens, hidden) — accumulated across top_k experts + + Pipeline: sort → gate_up GEMM → SiLU → down GEMM → routing + unsort + reduce + Deterministic: no atomic_add, all operations have fixed execution order. + """ + + device = hidden_states.device + num_top_k = top_k_index.size(-1) + num_experts = gate_up_proj.size(0) + num_tokens = hidden_states.size(0) + hidden_dim = hidden_states.size(-1) + + # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) + token_idx = ( + torch.arange(num_tokens, device=device) + .unsqueeze(1) + .expand(-1, num_top_k) + .reshape(-1) + ) # (S,) + sample_weights = top_k_weights.reshape(-1) # (S,) + expert_ids = top_k_index.reshape(-1) # (S,) + + # Sort by expert for grouped processing + perm = torch.argsort(expert_ids) + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(perm.size(0), device=device) + + expert_ids_g = expert_ids[perm] + sample_weights_g = sample_weights[perm] + selected_hidden_states_g = hidden_states[token_idx[perm]] + + # Compute offsets for grouped processing. + # histc instead of bincount avoids cuda-graph issues; + # CPU requires float input, CUDA requires int input (deterministic mode). + histc_input = expert_ids_g.float() if device.type == "cpu" else expert_ids_g.int() + tokens_per_expert = torch.histc( + histc_input, bins=num_experts, min=0, max=num_experts - 1 + ) + offsets = torch.cumsum(tokens_per_expert, dim=0, dtype=torch.int32) + + # --- Gate+up projection per expert (FP8 grouped) --- + proj_out = w8a8_block_fp8_matmul_grouped( + selected_hidden_states_g, + gate_up_proj, + gate_up_proj_scale_inv, + offsets, + tokens_per_expert, + block_size, + ) # (S, 2 * intermediate_dim) + + # Apply SiLU gating + gate, up = proj_out.chunk(2, dim=-1) + proj_out = F.silu(gate) * up # (S, intermediate_dim) + + # --- Down projection per expert (FP8 grouped) --- + proj_out = w8a8_block_fp8_matmul_grouped( + proj_out, down_proj, down_proj_scale_inv, offsets, tokens_per_expert, block_size + ) # (S, hidden_dim) + + # Apply routing weights + weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze( + -1 + ) # (S, hidden_dim) + + # Restore original order + weighted_out = weighted_out[inv_perm] + + # Accumulate results using deterministic reshape+sum instead of index_add_ + # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd) + final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum( + dim=1 + ) + + return final_hidden_states.to(hidden_states.dtype) + + +def moe_batched( + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + gate_up_proj: torch.Tensor, + down_proj: torch.Tensor, + gate_up_proj_scale_inv: torch.Tensor, + down_proj_scale_inv: torch.Tensor, + block_size: list[int], +) -> torch.Tensor: + """End-to-end MoE expert layer using unfused batched GEMM primitives. + + Input: unsorted hidden_states + router outputs (top_k_index, top_k_weights) + Output: (num_tokens, hidden) — accumulated across top_k experts + + Pipeline: expand → batched_mm(gate_up) → SiLU → batched_mm(down) → routing + reduce + Deterministic: no sorting, no atomic_add. Each token processed independently. + """ + device = hidden_states.device + num_top_k = top_k_index.size(-1) + num_tokens = hidden_states.size(0) + hidden_dim = hidden_states.size(-1) + + # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) + token_idx = ( + torch.arange(num_tokens, device=device) + .unsqueeze(1) + .expand(-1, num_top_k) + .reshape(-1) + ) # (S,) + sample_weights = top_k_weights.reshape(-1) # (S,) + expert_ids = top_k_index.reshape(-1) # (S,) + + # Get current hidden states for selected samples (no sorting needed) + selected_hidden_states = hidden_states[token_idx] + + # --- Gate+up projection per expert (FP8 batched) --- + proj_out = w8a8_block_fp8_matmul_batched( + selected_hidden_states, + gate_up_proj, + gate_up_proj_scale_inv, + expert_ids, + block_size, + ) # (S, 2 * intermediate_dim) + + # Apply SiLU gating + gate, up = proj_out.chunk(2, dim=-1) + proj_out = F.silu(gate) * up # (S, intermediate_dim) + + # --- Down projection per expert (FP8 batched) --- + proj_out = w8a8_block_fp8_matmul_batched( + proj_out, down_proj, down_proj_scale_inv, expert_ids, block_size + ) # (S, hidden_dim) + + # Apply routing weights + weighted_out = proj_out * sample_weights.to(proj_out.dtype).unsqueeze( + -1 + ) # (S, hidden_dim) + + # Accumulate results using deterministic reshape+sum instead of index_add_ + final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum( + dim=1 + ) + + return final_hidden_states.to(hidden_states.dtype)