From cfe29c47913ede8d1b9213063728489d35c882c6 Mon Sep 17 00:00:00 2001 From: xueweil Date: Tue, 2 Dec 2025 18:48:03 -0800 Subject: [PATCH 1/6] minor --- csrc/apis/gemm.hpp | 42 ++++++------- csrc/jit_kernels/heuristics/sm100.hpp | 3 + csrc/jit_kernels/impls/runtime_utils.hpp | 21 +++++++ .../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp | 18 ++++-- .../deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh | 61 +++++++++++++++++-- tests/generators.py | 32 +++++----- tests/test_fp8.py | 4 +- 7 files changed, 134 insertions(+), 47 deletions(-) diff --git a/csrc/apis/gemm.hpp b/csrc/apis/gemm.hpp index f50960f2..f9770c97 100644 --- a/csrc/apis/gemm.hpp +++ b/csrc/apis/gemm.hpp @@ -22,27 +22,27 @@ static bool early_return(const int& m, const int &n, const int& k, if (m == 0 or n == 0) return true; - // Checks - const bool& is_cd_same = c.has_value() and c->data_ptr() == d.data_ptr(); - if (is_cd_same) - DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides()); - if (c.has_value()) { - check_major_type_cd(c.value()); - DG_HOST_ASSERT((d.scalar_type() == torch::kFloat) or (d.scalar_type() == torch::kBFloat16)); - DG_HOST_ASSERT((c.value().scalar_type() == torch::kFloat) or (c.value().scalar_type() == torch::kBFloat16)); - DG_HOST_ASSERT(d.scalar_type() == c.value().scalar_type()); - } - - // No accumulation - if (k == 0) { - if (not is_cd_same) - c.has_value() ? d.copy_(c.value()) : d.zero_(); - return true; - } - - // With accumulation, do copy before GEMM (assuming the GEMM kernel does not support different C/D) - if (c.has_value() and not is_cd_same) - d.copy_(c.value()); + // // Checks + // const bool& is_cd_same = c.has_value() and c->data_ptr() == d.data_ptr(); + // if (is_cd_same) + // DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides()); + // if (c.has_value()) { + // check_major_type_cd(c.value()); + // DG_HOST_ASSERT((d.scalar_type() == torch::kFloat) or (d.scalar_type() == torch::kBFloat16)); + // DG_HOST_ASSERT((c.value().scalar_type() == torch::kFloat) or (c.value().scalar_type() == torch::kBFloat16)); + // DG_HOST_ASSERT(d.scalar_type() == c.value().scalar_type()); + // } + + // // No accumulation + // if (k == 0) { + // if (not is_cd_same) + // c.has_value() ? d.copy_(c.value()) : d.zero_(); + // return true; + // } + + // // With accumulation, do copy before GEMM (assuming the GEMM kernel does not support different C/D) + // if (c.has_value() and not is_cd_same) + // d.copy_(c.value()); return false; } diff --git a/csrc/jit_kernels/heuristics/sm100.hpp b/csrc/jit_kernels/heuristics/sm100.hpp index 0ac4cc28..0016166c 100644 --- a/csrc/jit_kernels/heuristics/sm100.hpp +++ b/csrc/jit_kernels/heuristics/sm100.hpp @@ -156,6 +156,9 @@ struct SM100ArchSpec { // TMA full/empty barriers, with-SF full barriers, tensor memory full/empty barriers // NOTES: some shapes may only have 1 epilogue stage, but we still allocate space for 2 stages // NOTES: the last barrier is for tensor core utilization control + // NOTES: add 2 * 8 for bias_barriers when with_accumulation is true (2 TMA store stages) + // const int bias_barrier_size = with_accumulation ? 2 * 8 : 0; + // return num_stages * 8 * 3 + 2 * 8 * 2 + 8 + bias_barrier_size; return num_stages * 8 * 3 + 2 * 8 * 2 + 8; } diff --git a/csrc/jit_kernels/impls/runtime_utils.hpp b/csrc/jit_kernels/impls/runtime_utils.hpp index 8f8504d5..ae8ab6b0 100644 --- a/csrc/jit_kernels/impls/runtime_utils.hpp +++ b/csrc/jit_kernels/impls/runtime_utils.hpp @@ -222,4 +222,25 @@ static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major, allow_tf32); } +static CUtensorMap make_tma_bias_desc(const cute::UMMA::Major& major, + const torch::Tensor& t, + int shape_mn, int shape_k, + const int& block_mn, const int& block_k, + const int& num_groups, + const int& swizzle_mode, const int& swizzle_base = 0, + const bool& allow_tf32 = false) { + DG_HOST_ASSERT(major == cute::UMMA::Major::MN); + + // TODO: maybe swizzle SF as well + DG_HOST_ASSERT(swizzle_mode == 0); + + shape_mn = get_tma_aligned_size(shape_mn, static_cast(t.element_size())); + return make_tma_2d_desc(t, + shape_mn, num_groups, + block_mn, 1, + shape_mn, + swizzle_mode, swizzle_base, + allow_tf32); +} + } // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp index 896c2485..1dac819c 100644 --- a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp +++ b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp @@ -31,6 +31,7 @@ class SM100FP8Gemm1D1DRuntime final: public LaunchRuntimeget_num_sms()); - const auto& cd = c.value_or(d); + // const auto& cd = d; const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), config.block_k, @@ -114,6 +115,14 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, config.block_n, config.block_k, 1, 0); + // Create tensor map for bias only if c has a value + CUtensorMap tensor_map_bias{}; + if (c.has_value()) { + tensor_map_bias = make_tma_bias_desc(cute::UMMA::Major::MN, c.value(), n, 1, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), 1, + 1, 0); + } + // Launch const SM100FP8Gemm1D1DRuntime::Args& args = { .m = m, .n = n, .k = aligned_k, @@ -122,14 +131,15 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa .epilogue_type = epilogue_type, .gemm_config = config, .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, - config.smem_config.smem_size, + 232448-128, config.multicast_config.num_multicast), .grouped_layout = nullptr, .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_sfa = tensor_map_sfa, .tensor_map_sfb = tensor_map_sfb, - .tensor_map_cd = tensor_map_cd + .tensor_map_cd = tensor_map_cd, + .tensor_map_bias = tensor_map_bias }; const auto& code = SM100FP8Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code); diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh index 886f5ffb..39132a5c 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh @@ -31,7 +31,8 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, const __grid_constant__ cute::TmaDescriptor tensor_map_b, const __grid_constant__ cute::TmaDescriptor tensor_map_sfa, const __grid_constant__ cute::TmaDescriptor tensor_map_sfb, - const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { + const __grid_constant__ cute::TmaDescriptor tensor_map_cd, + const __grid_constant__ cute::TmaDescriptor tensor_map_bias) { #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) using Barrier = cutlass::arch::ClusterTransactionBarrier; using Allocator = cute::conditional_t; @@ -84,6 +85,12 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, constexpr uint32_t SF_BLOCK_N = constexpr_align(BLOCK_N, kNumUTCCPAlignedElems); constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t); constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t); + + uint32_t SMEM_BIAS_SIZE_PER_STAGE = 0; + if constexpr (kWithAccumulation) { + SMEM_BIAS_SIZE_PER_STAGE = LOAD_BLOCK_N * sizeof(cd_dtype_t); + } + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages"); @@ -111,6 +118,9 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, cute::prefetch_tma_descriptor(&tensor_map_sfa); cute::prefetch_tma_descriptor(&tensor_map_sfb); cute::prefetch_tma_descriptor(&tensor_map_cd); + if constexpr (kWithAccumulation){ + cute::prefetch_tma_descriptor(&tensor_map_bias); + } } // D/A/B shared memory @@ -133,16 +143,24 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, return reinterpret_cast(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE); }); + auto bias_start_ptr = sf_start_ptr + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE); + auto smem_bias = PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(bias_start_ptr + i * SMEM_BIAS_SIZE_PER_STAGE); + }); + // Fill barriers auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + - kNumStages * (SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE)); + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE) + + (kWithAccumulation ? kNumStages * SMEM_BIAS_SIZE_PER_STAGE : 0)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); auto with_sf_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages + i); }); + // auto bias_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages * 2 + i); }); // Fill the tensor memory pointer auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2); @@ -165,6 +183,14 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, // Arrive only at the leader CTA tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads); } + + // // Initialize bias barriers + // if constexpr (kWithAccumulation) { + // #pragma unroll + // for (uint32_t i = 0; i < kNumTMAStoreStages; ++ i) { + // bias_barriers[i]->init(1); + // } + // } // Make initialized barrier visible in async proxy cutlass::arch::fence_barrier_init(); @@ -249,6 +275,11 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, num_arrival_bytes += (BLOCK_M + BLOCK_N) * sizeof(uint32_t); } + if (k_block_idx == 0 and kWithAccumulation){ + tma_copy(&tensor_map_bias, full_barriers[stage_idx], smem_bias[stage_idx], n_idx, 0, 1, 0); + num_arrival_bytes += LOAD_BLOCK_N * sizeof(cd_dtype_t); + } + // Arrive at full barriers full_barriers[stage_idx]->arrive_and_expect_tx(num_arrival_bytes); } @@ -457,6 +488,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M; const auto n_idx = epilogue_type_t::apply_index_n(n_block_idx * BLOCK_N + s * STORE_BLOCK_N); + // Store into shared memory #pragma unroll for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { @@ -493,9 +525,30 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, // For BF16 output, read, cast and store DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, - values[0], values[1], values[2], values[3], - values[4], values[5], values[6], values[7]); + values[0], values[1], values[2], values[3], values[4], values[5], values[6], values[7]); cutlass::arch::fence_view_async_tmem_load(); + + // Add bias (BF16 case - all lanes in same column read same bias value) + // Use 'i' (logical index) not 'col' (swizzled index) because bias has no swizzle + if constexpr (kWithAccumulation) { + uint32_t n_offset_in_block = i * kNumElemsPerBankGroup; + uint32_t store_block_offset = tma_stage_idx * STORE_BLOCK_N * sizeof(cd_dtype_t) ; + cd_dtype_t* bias_ptr = smem_bias[0] + store_block_offset + n_offset_in_block; + float bias_vals[8]; + #pragma unroll + for (int b = 0; b < 8; ++b) { + bias_vals[b] = static_cast(bias_ptr[b]); + } + values[0] = __float_as_uint(__uint_as_float(values[0]) + bias_vals[0]); + values[1] = __float_as_uint(__uint_as_float(values[1]) + bias_vals[1]); + values[2] = __float_as_uint(__uint_as_float(values[2]) + bias_vals[2]); + values[3] = __float_as_uint(__uint_as_float(values[3]) + bias_vals[3]); + values[4] = __float_as_uint(__uint_as_float(values[4]) + bias_vals[4]); + values[5] = __float_as_uint(__uint_as_float(values[5]) + bias_vals[5]); + values[6] = __float_as_uint(__uint_as_float(values[6]) + bias_vals[6]); + values[7] = __float_as_uint(__uint_as_float(values[7]) + bias_vals[7]); + } + st_shared(smem_ptr, cast_into_bf16_and_pack(values[0], values[1]), cast_into_bf16_and_pack(values[2], values[3]), diff --git a/tests/generators.py b/tests/generators.py index c4fdc69d..1ab877b4 100644 --- a/tests/generators.py +++ b/tests/generators.py @@ -64,8 +64,8 @@ def enumerate_normal(dtype: torch.dtype) -> Generator: assert dtype in (torch.float8_e4m3fn, torch.bfloat16) fp32_output_nk = [(256, 7168), (129280, 7168)] - bf16_output_nk = [(2112, 7168), (576, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048)] - m_fwd_list, m_bwd_list = [1, 128, 4096], [4096, ] + bf16_output_nk = [(5120, 5120), (5120, 13824), (13824, 5120)] + m_fwd_list, m_bwd_list = [9614], [4096, ] nk_list = list(bf16_output_nk) # Only BF16 GEMM needs FP32 outputs @@ -82,22 +82,22 @@ def enumerate_normal(dtype: torch.dtype) -> Generator: if dtype == torch.float8_e4m3fn and get_arch_major() == 10: yield kernel_type, m, n, k, MajorTypeAB.KMajor, MajorTypeAB.KMajor, True, out_dtype - # Backward - for m in m_bwd_list: - for n, k in nk_list: - override_major = MajorTypeAB.MNMajor - override_kernel_type = kernel_type - if get_arch_major() == 9 and dtype == torch.float8_e4m3fn: - override_major = MajorTypeAB.KMajor - override_kernel_type = KernelType.Kernel1D1D - yield kernel_type, m, k, n, MajorTypeAB.KMajor, override_major, False, torch.bfloat16 # Dgrad - yield override_kernel_type, n, m, k, override_major, override_major, True, torch.float # Wgrad - yield override_kernel_type, n, m, k, override_major, override_major, False, torch.bfloat16 # Wgrad + # # Backward + # for m in m_bwd_list: + # for n, k in nk_list: + # override_major = MajorTypeAB.MNMajor + # override_kernel_type = kernel_type + # if get_arch_major() == 9 and dtype == torch.float8_e4m3fn: + # override_major = MajorTypeAB.KMajor + # override_kernel_type = KernelType.Kernel1D1D + # yield kernel_type, m, k, n, MajorTypeAB.KMajor, override_major, False, torch.bfloat16 # Dgrad + # yield override_kernel_type, n, m, k, override_major, override_major, True, torch.float # Wgrad + # yield override_kernel_type, n, m, k, override_major, override_major, False, torch.bfloat16 # Wgrad def enumerate_m_grouped_contiguous(dtype: torch.dtype) -> Generator: for kernel_type in get_kernel_types(dtype): - for num_groups, expected_m_per_group, n, k in ((4, 8192, 4096, 7168), (4, 8192, 7168, 2048), (8, 4096, 4096, 7168), (8, 4096, 7168, 2048)): + for num_groups, expected_m_per_group, n, k in ((1, 9614, 5120, 5120), (1, 9614, 13824, 5120), (1, 9614, 5120, 13824)): for major_a, major_b in get_major_ab(False, get_arch_major() != 9 or dtype != torch.float8_e4m3fn): yield kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b @@ -156,7 +156,7 @@ def generate_normal(m: int, n: int, k: int, b = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) d = torch.randn((m, n), device='cuda', dtype=out_dtype) * 32 if accumulate else \ torch.empty((m, n), device='cuda', dtype=out_dtype) - c = d if accumulate else None + c = torch.randn((n), device='cuda', dtype=out_dtype) if accumulate else None ref_d = (a.float() @ b.float().t() + (c if accumulate else 0)).to(out_dtype) if use_bf16: @@ -175,7 +175,7 @@ def generate_normal(m: int, n: int, k: int, def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n: int, k: int, major_a: MajorTypeAB, major_b: MajorTypeAB, use_ue8m0: bool = False, use_bf16: bool = False): - actual_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)] + actual_ms = [int(expected_m_per_group ) for _ in range(num_groups)] aligned_ms = [align(actual_m, get_mk_alignment_for_contiguous_layout()) for actual_m in actual_ms] m = sum(aligned_ms) diff --git a/tests/test_fp8.py b/tests/test_fp8.py index 50d25c7c..95343953 100644 --- a/tests/test_fp8.py +++ b/tests/test_fp8.py @@ -171,5 +171,5 @@ def test_func(): test_gemm() test_m_grouped_gemm_contiguous() - test_m_grouped_gemm_masked() - test_k_grouped_gemm_contiguous() + # test_m_grouped_gemm_masked() + # test_k_grouped_gemm_contiguous() From d96986f627c0ddb0a2d315eb7816735308413555 Mon Sep 17 00:00:00 2001 From: xueweil Date: Sun, 14 Dec 2025 23:37:41 -0800 Subject: [PATCH 2/6] accumulate and bias true works --- .../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp | 6 ++- .../deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh | 41 +++++++++---------- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp index 1dac819c..75f06f8c 100644 --- a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp +++ b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp @@ -94,6 +94,10 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(), device_runtime->get_num_sms()); + // Debug: print block sizes + printf("DEBUG: block_m=%d, block_n=%d, block_k=%d, num_multicast=%d\n", + config.block_m, config.block_n, config.block_k, config.multicast_config.num_multicast); + // const auto& cd = d; const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), @@ -119,7 +123,7 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa CUtensorMap tensor_map_bias{}; if (c.has_value()) { tensor_map_bias = make_tma_bias_desc(cute::UMMA::Major::MN, c.value(), n, 1, - SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), 1, + config.block_n, 1, 1, 0); } diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh index 39132a5c..4ac88bca 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh @@ -22,7 +22,7 @@ template __global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) sm100_fp8_gemm_1d1d_impl(int* grouped_layout, @@ -37,6 +37,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, using Barrier = cutlass::arch::ClusterTransactionBarrier; using Allocator = cute::conditional_t; + constexpr bool kWithBias = true; // GEMM with accumulation must have FP32/BF16 output if constexpr (kWithAccumulation) DG_STATIC_ASSERT(cute::is_same_v or cute::is_same_v, "Invalid C/D data dtype"); @@ -48,6 +49,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, constexpr uint32_t kNumTMAStoreStages = 2; constexpr uint32_t kNumSFStagesPerLoad = sizeof(uint32_t) / sizeof(cutlass::float_ue8m0_t); constexpr uint32_t kNumUTCCPAlignedElems = 128; + constexpr uint32_t kNumBiasAlignBytes = 128; DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K"); DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0 and 2 % kNumMWaves == 0, "Invalid block M"); @@ -87,8 +89,9 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t); uint32_t SMEM_BIAS_SIZE_PER_STAGE = 0; - if constexpr (kWithAccumulation) { - SMEM_BIAS_SIZE_PER_STAGE = LOAD_BLOCK_N * sizeof(cd_dtype_t); + if constexpr (kWithBias) { + constexpr uint32_t BiasSizePerBlock = BLOCK_N * sizeof(cd_dtype_t); + SMEM_BIAS_SIZE_PER_STAGE = constexpr_align(BiasSizePerBlock, kNumBiasAlignBytes); } DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0, @@ -118,7 +121,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, cute::prefetch_tma_descriptor(&tensor_map_sfa); cute::prefetch_tma_descriptor(&tensor_map_sfb); cute::prefetch_tma_descriptor(&tensor_map_cd); - if constexpr (kWithAccumulation){ + if constexpr (kWithBias){ cute::prefetch_tma_descriptor(&tensor_map_bias); } } @@ -153,14 +156,13 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE) + - (kWithAccumulation ? kNumStages * SMEM_BIAS_SIZE_PER_STAGE : 0)); + (kWithBias ? kNumEpilogueStages * SMEM_BIAS_SIZE_PER_STAGE : 0)); auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); auto with_sf_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages + i); }); - // auto bias_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages * 2 + i); }); // Fill the tensor memory pointer auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2); @@ -183,14 +185,6 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, // Arrive only at the leader CTA tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads); } - - // // Initialize bias barriers - // if constexpr (kWithAccumulation) { - // #pragma unroll - // for (uint32_t i = 0; i < kNumTMAStoreStages; ++ i) { - // bias_barriers[i]->init(1); - // } - // } // Make initialized barrier visible in async proxy cutlass::arch::fence_barrier_init(); @@ -220,6 +214,8 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { // Wait consumer release empty_barriers[stage_idx]->wait(phase ^ 1); @@ -275,9 +271,9 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, num_arrival_bytes += (BLOCK_M + BLOCK_N) * sizeof(uint32_t); } - if (k_block_idx == 0 and kWithAccumulation){ - tma_copy(&tensor_map_bias, full_barriers[stage_idx], smem_bias[stage_idx], n_idx, 0, 1, 0); - num_arrival_bytes += LOAD_BLOCK_N * sizeof(cd_dtype_t); + if (k_block_idx == 0 and kWithBias){ + tma_copy(&tensor_map_bias, full_barriers[stage_idx], smem_bias[accum_stage_idx], n_block_idx * BLOCK_N, 0, 1, 0); + num_arrival_bytes += BLOCK_N * sizeof(cd_dtype_t); } // Arrive at full barriers @@ -530,15 +526,16 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, // Add bias (BF16 case - all lanes in same column read same bias value) // Use 'i' (logical index) not 'col' (swizzled index) because bias has no swizzle - if constexpr (kWithAccumulation) { + if constexpr (kWithBias) { uint32_t n_offset_in_block = i * kNumElemsPerBankGroup; - uint32_t store_block_offset = tma_stage_idx * STORE_BLOCK_N * sizeof(cd_dtype_t) ; - cd_dtype_t* bias_ptr = smem_bias[0] + store_block_offset + n_offset_in_block; + uint32_t store_block_offset = s * STORE_BLOCK_N; + cd_dtype_t* bias_ptr = smem_bias[accum_stage_idx] + store_block_offset + n_offset_in_block; float bias_vals[8]; #pragma unroll for (int b = 0; b < 8; ++b) { bias_vals[b] = static_cast(bias_ptr[b]); } + //TODO:fadd2 values[0] = __float_as_uint(__uint_as_float(values[0]) + bias_vals[0]); values[1] = __float_as_uint(__uint_as_float(values[1]) + bias_vals[1]); values[2] = __float_as_uint(__uint_as_float(values[2]) + bias_vals[2]); @@ -569,12 +566,12 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { if constexpr (kGemmType == GemmType::Batched) { - using cute_tma_t = cute::conditional_t; cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx, scheduler.current_group_idx); } else { - using cute_tma_t = cute::conditional_t; cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx); } From 6e38032647e4fc8dc6c0440ae4bea46f6504ce19 Mon Sep 17 00:00:00 2001 From: xueweil Date: Mon, 15 Dec 2025 03:10:55 -0800 Subject: [PATCH 3/6] minor --- csrc/apis/attention.hpp | 2 +- csrc/apis/gemm.hpp | 78 ++++++++++++------- .../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp | 11 ++- .../deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh | 36 +++++++-- tests/generators.py | 35 ++++++--- tests/test_fp8.py | 19 +++-- 6 files changed, 122 insertions(+), 59 deletions(-) diff --git a/csrc/apis/attention.hpp b/csrc/apis/attention.hpp index bf146147..f316172b 100644 --- a/csrc/apis/attention.hpp +++ b/csrc/apis/attention.hpp @@ -66,7 +66,7 @@ static void fp8_gemm_nt_skip_head_mid(const std::pairdata_ptr() == d.data_ptr(); - // if (is_cd_same) - // DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides()); - // if (c.has_value()) { - // check_major_type_cd(c.value()); - // DG_HOST_ASSERT((d.scalar_type() == torch::kFloat) or (d.scalar_type() == torch::kBFloat16)); - // DG_HOST_ASSERT((c.value().scalar_type() == torch::kFloat) or (c.value().scalar_type() == torch::kBFloat16)); - // DG_HOST_ASSERT(d.scalar_type() == c.value().scalar_type()); - // } - - // // No accumulation - // if (k == 0) { - // if (not is_cd_same) - // c.has_value() ? d.copy_(c.value()) : d.zero_(); - // return true; - // } - - // // With accumulation, do copy before GEMM (assuming the GEMM kernel does not support different C/D) - // if (c.has_value() and not is_cd_same) - // d.copy_(c.value()); + // Checks + const bool& is_cd_same = c.has_value() and c->data_ptr() == d.data_ptr(); + if (is_cd_same) + DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides()); + if (c.has_value()) { + check_major_type_cd(c.value()); + DG_HOST_ASSERT((d.scalar_type() == torch::kFloat) or (d.scalar_type() == torch::kBFloat16)); + DG_HOST_ASSERT((c.value().scalar_type() == torch::kFloat) or (c.value().scalar_type() == torch::kBFloat16)); + DG_HOST_ASSERT(d.scalar_type() == c.value().scalar_type()); + } + + // No accumulation + if (k == 0) { + if (not is_cd_same) + c.has_value() ? d.copy_(c.value()) : d.zero_(); + return true; + } + + // With accumulation, do copy before GEMM (assuming the GEMM kernel does not support different C/D) + if (c.has_value() and not is_cd_same) + d.copy_(c.value()); return false; } @@ -51,6 +51,7 @@ static void fp8_gemm_nt(const std::pair& a, const std::pair& b, const torch::Tensor& d, const std::optional& c, + const std::optional& bias, std::optional> recipe, const std::string& compiled_dims, const bool& disable_ue8m0_cast) { @@ -74,6 +75,14 @@ static void fp8_gemm_nt(const std::pair& a, DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn); DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat); + if (bias.has_value()) { + const auto& arch_major = device_runtime->get_arch_major(); + DG_HOST_ASSERT(arch_major == 10); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(bias.value().scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(d.scalar_type() == bias.value().scalar_type()); + } + // Early return for trivial cases if (early_return(m, n, k, d, c)) return; @@ -95,7 +104,7 @@ static void fp8_gemm_nt(const std::pair& a, sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, major_sfb, compiled_dims); } } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { - sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); + sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, bias, d, m, n, k, major_a, major_b, compiled_dims); } else { DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); } @@ -105,34 +114,37 @@ static void fp8_gemm_nn(const std::pair& a, const std::pair& b, const torch::Tensor& d, const std::optional& c, + const std::optional& bias, const std::optional>& recipe, const std::string& compiled_dims, const bool& disable_ue8m0_cast) { fp8_gemm_nt(a, {b.first.transpose(0, 1), b.second.transpose(0, 1)}, - d, c, recipe, compiled_dims, disable_ue8m0_cast); + d, c, bias, recipe, compiled_dims, disable_ue8m0_cast); } static void fp8_gemm_tn(const std::pair& a, const std::pair& b, const torch::Tensor& d, const std::optional& c, + const std::optional& bias, const std::optional>& recipe, const std::string& compiled_dims, const bool& disable_ue8m0_cast) { fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, {b.first.transpose(0, 1), b.second.transpose(0, 1)}, - d, c, recipe, compiled_dims, disable_ue8m0_cast); + d, c, bias, recipe, compiled_dims, disable_ue8m0_cast); } static void fp8_gemm_tt(const std::pair& a, const std::pair& b, const torch::Tensor& d, const std::optional& c, + const std::optional& bias, const std::optional>& recipe, const std::string& compiled_dims, const bool& disable_ue8m0_cast) { fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, b, - d, c, recipe, compiled_dims, disable_ue8m0_cast); + d, c, bias, recipe, compiled_dims, disable_ue8m0_cast); } static void m_grouped_fp8_gemm_nt_contiguous(const std::pair& a, @@ -567,22 +579,30 @@ static void register_apis(pybind11::module_& m) { // FP8 GEMMs m.def("fp8_gemm_nt", &fp8_gemm_nt, py::arg("a"), py::arg("b"), py::arg("d"), - py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, + py::arg("c") = std::nullopt, + py::arg("bias") = std::nullopt, + py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false); m.def("fp8_gemm_nn", &fp8_gemm_nn, py::arg("a"), py::arg("b"), py::arg("d"), - py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, + py::arg("c") = std::nullopt, + py::arg("bias") = std::nullopt, + py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false); m.def("fp8_gemm_tn", &fp8_gemm_tn, py::arg("a"), py::arg("b"), py::arg("d"), - py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, + py::arg("c") = std::nullopt, + py::arg("bias") = std::nullopt, + py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "mn", py::arg("disable_ue8m0_cast") = false); m.def("fp8_gemm_tt", &fp8_gemm_tt, py::arg("a"), py::arg("b"), py::arg("d"), - py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, + py::arg("c") = std::nullopt, + py::arg("bias") = std::nullopt, + py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "mn", py::arg("disable_ue8m0_cast") = false); m.def("m_grouped_fp8_gemm_nt_contiguous", &m_grouped_fp8_gemm_nt_contiguous, diff --git a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp index 75f06f8c..ff9f8963 100644 --- a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp +++ b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp @@ -19,6 +19,7 @@ class SM100FP8Gemm1D1DRuntime final: public LaunchRuntime& epilogue_type; @@ -51,7 +52,7 @@ static void __instantiate_kernel() {{ {}, {}, {}, {}, {}, - {}, {}, {}, + {}, {}, {}, {}, {} >); }}; @@ -65,7 +66,7 @@ static void __instantiate_kernel() {{ args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads, args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, args.gemm_config.num_sms, - to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype), + to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, args.with_bias, to_string(args.gemm_config.cd_dtype), get_default_epilogue_type(args.epilogue_type)); } @@ -82,6 +83,7 @@ static void __instantiate_kernel() {{ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, const torch::Tensor& b, const torch::Tensor& sfb, const std::optional& c, + const std::optional& bias, const torch::Tensor& d, const int& m, const int& n, const int& k, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, @@ -121,8 +123,8 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa // Create tensor map for bias only if c has a value CUtensorMap tensor_map_bias{}; - if (c.has_value()) { - tensor_map_bias = make_tma_bias_desc(cute::UMMA::Major::MN, c.value(), n, 1, + if (bias.has_value()) { + tensor_map_bias = make_tma_bias_desc(cute::UMMA::Major::MN, bias.value(), n, 1, config.block_n, 1, 1, 0); } @@ -131,6 +133,7 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa const SM100FP8Gemm1D1DRuntime::Args& args = { .m = m, .n = n, .k = aligned_k, .num_groups = 1, + .with_bias = bias.has_value(), .compiled_dims = compiled_dims, .epilogue_type = epilogue_type, .gemm_config = config, diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh index 4ac88bca..1a03ae47 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh @@ -22,7 +22,7 @@ template __global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) sm100_fp8_gemm_1d1d_impl(int* grouped_layout, @@ -37,7 +37,6 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, using Barrier = cutlass::arch::ClusterTransactionBarrier; using Allocator = cute::conditional_t; - constexpr bool kWithBias = true; // GEMM with accumulation must have FP32/BF16 output if constexpr (kWithAccumulation) DG_STATIC_ASSERT(cute::is_same_v or cute::is_same_v, "Invalid C/D data dtype"); @@ -271,9 +270,11 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, num_arrival_bytes += (BLOCK_M + BLOCK_N) * sizeof(uint32_t); } - if (k_block_idx == 0 and kWithBias){ - tma_copy(&tensor_map_bias, full_barriers[stage_idx], smem_bias[accum_stage_idx], n_block_idx * BLOCK_N, 0, 1, 0); - num_arrival_bytes += BLOCK_N * sizeof(cd_dtype_t); + if constexpr (kWithBias){ + if (k_block_idx == 0){ + tma_copy(&tensor_map_bias, full_barriers[stage_idx], smem_bias[accum_stage_idx], n_block_idx * BLOCK_N, 0, 1, 0); + num_arrival_bytes += BLOCK_N * sizeof(cd_dtype_t); + } } // Arrive at full barriers @@ -535,7 +536,6 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, for (int b = 0; b < 8; ++b) { bias_vals[b] = static_cast(bias_ptr[b]); } - //TODO:fadd2 values[0] = __float_as_uint(__uint_as_float(values[0]) + bias_vals[0]); values[1] = __float_as_uint(__uint_as_float(values[1]) + bias_vals[1]); values[2] = __float_as_uint(__uint_as_float(values[2]) + bias_vals[2]); @@ -544,6 +544,26 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, values[5] = __float_as_uint(__uint_as_float(values[5]) + bias_vals[5]); values[6] = __float_as_uint(__uint_as_float(values[6]) + bias_vals[6]); values[7] = __float_as_uint(__uint_as_float(values[7]) + bias_vals[7]); + + // float2 bias_vals[4]; + // #pragma unroll + // for (int b = 0; b < 4; ++b) { + // bias_vals[b] = make_float2(static_cast(bias_ptr[b * 2]), static_cast(bias_ptr[b * 2 + 1])); + // } + // //TODO:fadd2 + // bias_vals[0] = __fadd2_rd(make_float2(__uint_as_float(values[0]), __uint_as_float(values[1])), bias_vals[0]); + // bias_vals[1] = __fadd2_rd(make_float2(__uint_as_float(values[2]), __uint_as_float(values[3])), bias_vals[1]); + // bias_vals[2] = __fadd2_rd(make_float2(__uint_as_float(values[4]), __uint_as_float(values[5])), bias_vals[2]); + // bias_vals[3] = __fadd2_rd(make_float2(__uint_as_float(values[6]), __uint_as_float(values[7])), bias_vals[3]); + + // values[0] = __float_as_uint(bias_vals[0].x); + // values[1] = __float_as_uint(bias_vals[0].y); + // values[2] = __float_as_uint(bias_vals[1].x); + // values[3] = __float_as_uint(bias_vals[1].y); + // values[4] = __float_as_uint(bias_vals[2].x); + // values[5] = __float_as_uint(bias_vals[2].y); + // values[6] = __float_as_uint(bias_vals[3].x); + // values[7] = __float_as_uint(bias_vals[3].y); } st_shared(smem_ptr, @@ -566,12 +586,12 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { if constexpr (kGemmType == GemmType::Batched) { - using cute_tma_t = cute::conditional_t; cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx, scheduler.current_group_idx); } else { - using cute_tma_t = cute::conditional_t; cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx); } diff --git a/tests/generators.py b/tests/generators.py index 1ab877b4..2f7b884f 100644 --- a/tests/generators.py +++ b/tests/generators.py @@ -64,8 +64,8 @@ def enumerate_normal(dtype: torch.dtype) -> Generator: assert dtype in (torch.float8_e4m3fn, torch.bfloat16) fp32_output_nk = [(256, 7168), (129280, 7168)] - bf16_output_nk = [(5120, 5120), (5120, 13824), (13824, 5120)] - m_fwd_list, m_bwd_list = [9614], [4096, ] + bf16_output_nk = [(512, 512), (5120, 13824), (13824, 5120)] + m_fwd_list, m_bwd_list = [5120], [4096, ] nk_list = list(bf16_output_nk) # Only BF16 GEMM needs FP32 outputs @@ -78,9 +78,16 @@ def enumerate_normal(dtype: torch.dtype) -> Generator: for i in range(len(nk_list)): n, k = nk_list[i] out_dtype = torch.bfloat16 if i < len(bf16_output_nk) else torch.float - yield kernel_type, m, n, k, MajorTypeAB.KMajor, MajorTypeAB.KMajor, False, out_dtype + yield kernel_type, m, n, k, MajorTypeAB.KMajor, MajorTypeAB.KMajor, False, False, out_dtype if dtype == torch.float8_e4m3fn and get_arch_major() == 10: - yield kernel_type, m, n, k, MajorTypeAB.KMajor, MajorTypeAB.KMajor, True, out_dtype + # with accumulation, output = A[m,n] @ B[n,k] + C[m,n] + # with bias, output = A[m,n] @ B[n,k] + bias[n] + # With accumulation, no bias. + yield kernel_type, m, n, k, MajorTypeAB.KMajor, MajorTypeAB.KMajor, True, False, out_dtype + if out_dtype == torch.bfloat16: + # With bias, no accumulation. + yield kernel_type, m, n, k, MajorTypeAB.KMajor, MajorTypeAB.KMajor, False, True, out_dtype + # # Backward # for m in m_bwd_list: @@ -97,7 +104,7 @@ def enumerate_normal(dtype: torch.dtype) -> Generator: def enumerate_m_grouped_contiguous(dtype: torch.dtype) -> Generator: for kernel_type in get_kernel_types(dtype): - for num_groups, expected_m_per_group, n, k in ((1, 9614, 5120, 5120), (1, 9614, 13824, 5120), (1, 9614, 5120, 13824)): + for num_groups, expected_m_per_group, n, k in ((4, 8192, 4096, 7168), (4, 8192, 7168, 2048), (8, 4096, 4096, 7168), (8, 4096, 7168, 2048)): for major_a, major_b in get_major_ab(False, get_arch_major() != 9 or dtype != torch.float8_e4m3fn): yield kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b @@ -149,33 +156,39 @@ def enumerate_transpose(): def generate_normal(m: int, n: int, k: int, major_a: MajorTypeAB, major_b: MajorTypeAB, - accumulate: bool, out_dtype: torch.dtype, + accumulate: bool, with_bias: bool, out_dtype: torch.dtype, kernel_type: KernelType, use_ue8m0: bool = False, use_bf16: bool = False): a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) b = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) d = torch.randn((m, n), device='cuda', dtype=out_dtype) * 32 if accumulate else \ torch.empty((m, n), device='cuda', dtype=out_dtype) - c = torch.randn((n), device='cuda', dtype=out_dtype) if accumulate else None - ref_d = (a.float() @ b.float().t() + (c if accumulate else 0)).to(out_dtype) + c = d if accumulate else None + bias = torch.randn((n), device='cuda', dtype=out_dtype) * 16 if with_bias else None + if accumulate: + ref_d = (a.float() @ b.float().t() + c).to(out_dtype) + elif with_bias: + ref_d = (a.float() @ b.float().t() + bias).to(out_dtype) + else: + ref_d = (a.float() @ b.float().t()).to(out_dtype) if use_bf16: a = a if major_a.is_k_major() else a.T.contiguous().T b = b if major_b.is_k_major() else b.T.contiguous().T - return a, b, c, d, ref_d + return a, b, c, bias, d, ref_d a_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0) b_fp8 = per_token_cast_to_fp8(b, use_ue8m0=use_ue8m0) if kernel_type.is_1d1d() and accumulate \ else per_block_cast_to_fp8(b, use_ue8m0=use_ue8m0) a_fp8 = a_fp8 if major_a.is_k_major() else (a_fp8[0].T.contiguous().T, a_fp8[1]) b_fp8 = b_fp8 if major_b.is_k_major() else (b_fp8[0].T.contiguous().T, b_fp8[1]) - return a_fp8, b_fp8, c, d, ref_d + return a_fp8, b_fp8, c, bias, d, ref_d def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n: int, k: int, major_a: MajorTypeAB, major_b: MajorTypeAB, use_ue8m0: bool = False, use_bf16: bool = False): - actual_ms = [int(expected_m_per_group ) for _ in range(num_groups)] + actual_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)] aligned_ms = [align(actual_m, get_mk_alignment_for_contiguous_layout()) for actual_m in actual_ms] m = sum(aligned_ms) diff --git a/tests/test_fp8.py b/tests/test_fp8.py index 95343953..535bdac2 100644 --- a/tests/test_fp8.py +++ b/tests/test_fp8.py @@ -21,7 +21,10 @@ def test_gemm() -> None: print('Testing GEMM:') scores = [] - for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.float8_e4m3fn): + for kernel_type, m, n, k, major_a, major_b, accumulate, with_bias, out_dtype in enumerate_normal(torch.float8_e4m3fn): + if not((not accumulate ) and with_bias): + continue + major_opt = 'N' if major_a.is_k_major() else 'T' major_opt += 'T' if major_b.is_k_major() else 'N' out_opt = 'FP32' if out_dtype == torch.float else 'BF16' @@ -32,19 +35,23 @@ def test_gemm() -> None: recipe = (1, 1, 128) if kernel_type.is_1d1d() and accumulate else None for test_alias in (False, True): - a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0) + a, b, c, bias, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, with_bias, out_dtype, kernel_type, use_ue8m0=use_ue8m0) func_name = f'fp8_gemm_{major_opt.lower() if test_alias else "nt"}' if test_alias: a = a if major_a.is_k_major() else (a[0].T, a[1].T) b = b if major_b.is_k_major() else (b[0].T, b[1].T) assert a[0].is_contiguous() and b[0].is_contiguous() - getattr(deep_gemm, func_name)(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe) + getattr(deep_gemm, func_name)(a, b, d, c=c, bias=bias, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe) + # for i in range(m): + # print("line: ", i, "max diff: ", torch.max(torch.abs(c - d[i, :].reshape(n)))) + print(d) + print(ref_d) diff = calc_diff(d, ref_d) assert diff < 0.001, (f'{m=}, {n=}, {k=}, {kernel_opt}, {major_opt=}, {accumulate=}, {out_dtype=}, ' f'{diff:.5f}, alias={test_alias}') - a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0) - t = bench_kineto(lambda: deep_gemm.fp8_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe), + a, b, c, bias, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, with_bias, out_dtype, kernel_type, use_ue8m0=use_ue8m0) + t = bench_kineto(lambda: deep_gemm.fp8_gemm_nt(a, b, d, c=c, bias=bias, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe), 'fp8_gemm', suppress_kineto_output=True) cublas_t, split_k_t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a[0], b[0], d, c=c), ('nvjet', 'reduce'), suppress_kineto_output=True) print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}): ' @@ -170,6 +177,6 @@ def test_func(): print(f' > {deep_gemm.__path__}\n') test_gemm() - test_m_grouped_gemm_contiguous() + # test_m_grouped_gemm_contiguous() # test_m_grouped_gemm_masked() # test_k_grouped_gemm_contiguous() From 7dd9fbc719861dc5d8068e5813d2ca2a3b714d8f Mon Sep 17 00:00:00 2001 From: xueweil Date: Wed, 17 Dec 2025 01:14:46 -0800 Subject: [PATCH 4/6] update smem --- csrc/jit_kernels/heuristics/common.hpp | 13 +++++-- csrc/jit_kernels/heuristics/sm100.hpp | 3 -- csrc/jit_kernels/impls/runtime_utils.hpp | 6 +-- .../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp | 8 ++-- .../deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh | 2 - tests/generators.py | 39 ++++++++++++------- tests/test_attention.py | 2 +- tests/test_bf16.py | 10 ++--- tests/test_fp8.py | 15 +++---- 9 files changed, 50 insertions(+), 48 deletions(-) diff --git a/csrc/jit_kernels/heuristics/common.hpp b/csrc/jit_kernels/heuristics/common.hpp index be66454a..fef85abe 100644 --- a/csrc/jit_kernels/heuristics/common.hpp +++ b/csrc/jit_kernels/heuristics/common.hpp @@ -100,7 +100,8 @@ static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const Kerne const int& block_m, const int& block_n, const int& block_k, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, - const int& num_stages, const MulticastConfig& multicast_config) { + const int& num_stages, const MulticastConfig& multicast_config, + const bool& with_bias = false) { const int& ab_elem_size = static_cast(c10::elementSize(ab_dtype)); const int& cd_elem_size = static_cast(c10::elementSize(cd_dtype)); @@ -113,6 +114,9 @@ static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const Kerne // Different archs have different epilogue pipelines const int& smem_cd = ArchSpec::get_smem_cd_size(kernel_type, block_m, block_n, swizzle_cd_mode, cd_dtype); + // Bias shared memory (aligned to 128 bytes for TMA), 2 = NumTMAStoreStages + const int& smem_bias = with_bias ? align(cd_elem_size * block_n, 128) * 2 : 0; + // A/B shared memory const int& smem_a_per_stage = load_block_m * block_k * ab_elem_size; const int& smem_b_per_stage = load_block_n * block_k * ab_elem_size; @@ -131,6 +135,7 @@ static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const Kerne int smem_size = 0; smem_size += smem_tensor_map; smem_size += smem_cd; + smem_size += smem_bias; smem_size += num_stages * smem_a_per_stage; smem_size += num_stages * smem_b_per_stage; smem_size += num_stages * smem_sfa_per_stage; @@ -152,7 +157,8 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k const int& m, const int& n, const int& k, const int& num_groups, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, - const bool& with_accumulation, const int& num_sms) { + const bool& with_accumulation, const int& num_sms, + const bool& with_bias = false) { DG_HOST_ASSERT(ab_dtype == torch::kFloat8_e4m3fn or ab_dtype == torch::kBFloat16); DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat); @@ -244,7 +250,8 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k best_block_m, best_block_n, block_k, major_a, major_b, ab_dtype, cd_dtype, - num_stages, best_multicast_config); + num_stages, best_multicast_config, + with_bias); if (best_smem_config.smem_size <= smem_capacity) { best_num_stages = num_stages; break; diff --git a/csrc/jit_kernels/heuristics/sm100.hpp b/csrc/jit_kernels/heuristics/sm100.hpp index 0016166c..0ac4cc28 100644 --- a/csrc/jit_kernels/heuristics/sm100.hpp +++ b/csrc/jit_kernels/heuristics/sm100.hpp @@ -156,9 +156,6 @@ struct SM100ArchSpec { // TMA full/empty barriers, with-SF full barriers, tensor memory full/empty barriers // NOTES: some shapes may only have 1 epilogue stage, but we still allocate space for 2 stages // NOTES: the last barrier is for tensor core utilization control - // NOTES: add 2 * 8 for bias_barriers when with_accumulation is true (2 TMA store stages) - // const int bias_barrier_size = with_accumulation ? 2 * 8 : 0; - // return num_stages * 8 * 3 + 2 * 8 * 2 + 8 + bias_barrier_size; return num_stages * 8 * 3 + 2 * 8 * 2 + 8; } diff --git a/csrc/jit_kernels/impls/runtime_utils.hpp b/csrc/jit_kernels/impls/runtime_utils.hpp index ae8ab6b0..5de663a7 100644 --- a/csrc/jit_kernels/impls/runtime_utils.hpp +++ b/csrc/jit_kernels/impls/runtime_utils.hpp @@ -224,16 +224,14 @@ static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major, static CUtensorMap make_tma_bias_desc(const cute::UMMA::Major& major, const torch::Tensor& t, - int shape_mn, int shape_k, - const int& block_mn, const int& block_k, + int shape_mn, + const int& block_mn, const int& num_groups, const int& swizzle_mode, const int& swizzle_base = 0, const bool& allow_tf32 = false) { DG_HOST_ASSERT(major == cute::UMMA::Major::MN); - // TODO: maybe swizzle SF as well DG_HOST_ASSERT(swizzle_mode == 0); - shape_mn = get_tma_aligned_size(shape_mn, static_cast(t.element_size())); return make_tma_2d_desc(t, shape_mn, num_groups, diff --git a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp index ff9f8963..d14f4f37 100644 --- a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp +++ b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp @@ -94,13 +94,13 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa GemmType::Normal, KernelType::Kernel1D1D, m, n, k, 1, major_a, major_b, torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(), - device_runtime->get_num_sms()); + device_runtime->get_num_sms(), bias.has_value()); // Debug: print block sizes printf("DEBUG: block_m=%d, block_n=%d, block_k=%d, num_multicast=%d\n", config.block_m, config.block_n, config.block_k, config.multicast_config.num_multicast); - // const auto& cd = d; + const auto& cd = c.value_or(d); const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), config.block_k, @@ -124,9 +124,7 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa // Create tensor map for bias only if c has a value CUtensorMap tensor_map_bias{}; if (bias.has_value()) { - tensor_map_bias = make_tma_bias_desc(cute::UMMA::Major::MN, bias.value(), n, 1, - config.block_n, 1, - 1, 0); + tensor_map_bias = make_tma_bias_desc(cute::UMMA::Major::MN, bias.value(), n, config.block_n, 1, 0); } // Launch diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh index 1a03ae47..5437fb15 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh @@ -525,8 +525,6 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, values[0], values[1], values[2], values[3], values[4], values[5], values[6], values[7]); cutlass::arch::fence_view_async_tmem_load(); - // Add bias (BF16 case - all lanes in same column read same bias value) - // Use 'i' (logical index) not 'col' (swizzled index) because bias has no swizzle if constexpr (kWithBias) { uint32_t n_offset_in_block = i * kNumElemsPerBankGroup; uint32_t store_block_offset = s * STORE_BLOCK_N; diff --git a/tests/generators.py b/tests/generators.py index 2f7b884f..89f3b993 100644 --- a/tests/generators.py +++ b/tests/generators.py @@ -64,8 +64,8 @@ def enumerate_normal(dtype: torch.dtype) -> Generator: assert dtype in (torch.float8_e4m3fn, torch.bfloat16) fp32_output_nk = [(256, 7168), (129280, 7168)] - bf16_output_nk = [(512, 512), (5120, 13824), (13824, 5120)] - m_fwd_list, m_bwd_list = [5120], [4096, ] + bf16_output_nk = [(2112, 7168), (576, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048)] + m_fwd_list, m_bwd_list = [1, 128, 4096], [4096, ] nk_list = list(bf16_output_nk) # Only BF16 GEMM needs FP32 outputs @@ -89,17 +89,17 @@ def enumerate_normal(dtype: torch.dtype) -> Generator: yield kernel_type, m, n, k, MajorTypeAB.KMajor, MajorTypeAB.KMajor, False, True, out_dtype - # # Backward - # for m in m_bwd_list: - # for n, k in nk_list: - # override_major = MajorTypeAB.MNMajor - # override_kernel_type = kernel_type - # if get_arch_major() == 9 and dtype == torch.float8_e4m3fn: - # override_major = MajorTypeAB.KMajor - # override_kernel_type = KernelType.Kernel1D1D - # yield kernel_type, m, k, n, MajorTypeAB.KMajor, override_major, False, torch.bfloat16 # Dgrad - # yield override_kernel_type, n, m, k, override_major, override_major, True, torch.float # Wgrad - # yield override_kernel_type, n, m, k, override_major, override_major, False, torch.bfloat16 # Wgrad + # Backward + for m in m_bwd_list: + for n, k in nk_list: + override_major = MajorTypeAB.MNMajor + override_kernel_type = kernel_type + if get_arch_major() == 9 and dtype == torch.float8_e4m3fn: + override_major = MajorTypeAB.KMajor + override_kernel_type = KernelType.Kernel1D1D + yield kernel_type, m, k, n, MajorTypeAB.KMajor, override_major, False, False, torch.bfloat16 # Dgrad + yield override_kernel_type, n, m, k, override_major, override_major, True, False, torch.float # Wgrad + yield override_kernel_type, n, m, k, override_major, override_major, False, False, torch.bfloat16 # Wgrad def enumerate_m_grouped_contiguous(dtype: torch.dtype) -> Generator: @@ -163,8 +163,17 @@ def generate_normal(m: int, n: int, k: int, b = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) d = torch.randn((m, n), device='cuda', dtype=out_dtype) * 32 if accumulate else \ torch.empty((m, n), device='cuda', dtype=out_dtype) - c = d if accumulate else None - bias = torch.randn((n), device='cuda', dtype=out_dtype) * 16 if with_bias else None + + if accumulate: + if out_dtype == torch.bfloat16: + c = torch.ones_like(d) * 10 + else: + c = d + else: + c = None + + bias = torch.ones((n), device='cuda', dtype=out_dtype) * 10 if with_bias else None + if accumulate: ref_d = (a.float() @ b.float().t() + c).to(out_dtype) elif with_bias: diff --git a/tests/test_attention.py b/tests/test_attention.py index 1c8befc1..66dbb7d6 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -42,7 +42,7 @@ def test_gemm_skip_head_mid() -> None: use_ue8m0 = get_ue8m0_usage(kernel_type) disable_ue8m0_cast = not use_ue8m0 - a, b, _, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0) + a, b, _, _, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, False, out_dtype, kernel_type, use_ue8m0=use_ue8m0) d = apply_skip_head_mid(d, head_splits) ref_d = apply_skip_head_mid(ref_d, head_splits) diff --git a/tests/test_bf16.py b/tests/test_bf16.py index f2f41c4a..2b7c617d 100644 --- a/tests/test_bf16.py +++ b/tests/test_bf16.py @@ -18,7 +18,7 @@ def test_gemm() -> None: print('Testing GEMM:') scores = [] - for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.bfloat16): + for kernel_type, m, n, k, major_a, major_b, accumulate, _, out_dtype in enumerate_normal(torch.bfloat16): # TODO: support accumulation for SM90 BF16 GEMM if get_arch_major() == 9 and accumulate: continue @@ -29,7 +29,7 @@ def test_gemm() -> None: acc_opt = f'acc={int(accumulate)}' for test_alias in (False, True): - a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True) + a, b, c, _, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, False, out_dtype, kernel_type, use_bf16=True) func_name = f'bf16_gemm_{major_opt.lower() if test_alias else "nt"}' if test_alias: a = a if major_a.is_k_major() else a.T @@ -39,7 +39,7 @@ def test_gemm() -> None: diff = calc_diff(d, ref_d) assert diff < 1e-5, (f'{m=}, {n=}, {k=}, {major_opt=}, {accumulate=}, {out_dtype=}, ' f'{diff:.5f}, alias={test_alias}') - a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True) + a, b, c, _, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, False, out_dtype, kernel_type, use_bf16=True) t = bench_kineto(lambda: deep_gemm.bf16_gemm_nt(a, b, d, c=c), 'bf16_gemm', suppress_kineto_output=True) cublas_t, split_k_t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a, b, d, c=c), ('nvjet', 'reduce'), suppress_kineto_output=True) @@ -148,13 +148,13 @@ def test_func(): def test_cublaslt_gemm() -> None: print('Testing cuBLASLt GEMM:') - for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(dtype=torch.bfloat16): + for kernel_type, m, n, k, major_a, major_b, accumulate, _, out_dtype in enumerate_normal(dtype=torch.bfloat16): major_opt = 'N' if major_a.is_k_major() else 'T' major_opt += 'T' if major_b.is_k_major() else 'N' out_opt = 'FP32' if out_dtype == torch.float else 'BF16' acc_opt = f'acc={int(accumulate)}' - a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True) + a, b, c, _, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, False, out_dtype, kernel_type, use_bf16=True) deep_gemm.cublaslt_gemm_nt(a, b, d, c=c) diff = calc_diff(d, ref_d) assert diff < 6e-7, f'{diff=}, ({m=}, {n=}, {k=}, {major_opt=}, {accumulate=}, {out_dtype=})' diff --git a/tests/test_fp8.py b/tests/test_fp8.py index 535bdac2..ef3b136c 100644 --- a/tests/test_fp8.py +++ b/tests/test_fp8.py @@ -22,13 +22,12 @@ def test_gemm() -> None: print('Testing GEMM:') scores = [] for kernel_type, m, n, k, major_a, major_b, accumulate, with_bias, out_dtype in enumerate_normal(torch.float8_e4m3fn): - if not((not accumulate ) and with_bias): - continue major_opt = 'N' if major_a.is_k_major() else 'T' major_opt += 'T' if major_b.is_k_major() else 'N' out_opt = 'FP32' if out_dtype == torch.float else 'BF16' acc_opt = f'acc={int(accumulate)}' + bias_opt = f'bias={int(with_bias)}' kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' use_ue8m0 = get_ue8m0_usage(kernel_type) disable_ue8m0_cast = not use_ue8m0 @@ -42,10 +41,6 @@ def test_gemm() -> None: b = b if major_b.is_k_major() else (b[0].T, b[1].T) assert a[0].is_contiguous() and b[0].is_contiguous() getattr(deep_gemm, func_name)(a, b, d, c=c, bias=bias, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe) - # for i in range(m): - # print("line: ", i, "max diff: ", torch.max(torch.abs(c - d[i, :].reshape(n)))) - print(d) - print(ref_d) diff = calc_diff(d, ref_d) assert diff < 0.001, (f'{m=}, {n=}, {k=}, {kernel_opt}, {major_opt=}, {accumulate=}, {out_dtype=}, ' f'{diff:.5f}, alias={test_alias}') @@ -54,7 +49,7 @@ def test_gemm() -> None: t = bench_kineto(lambda: deep_gemm.fp8_gemm_nt(a, b, d, c=c, bias=bias, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe), 'fp8_gemm', suppress_kineto_output=True) cublas_t, split_k_t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a[0], b[0], d, c=c), ('nvjet', 'reduce'), suppress_kineto_output=True) - print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}): ' + print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}, {bias_opt}): ' f'{t * 1e6:6.1f} us | {2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s | ' f'{(cublas_t + split_k_t) / t:.2f}x cuBLAS') @@ -177,6 +172,6 @@ def test_func(): print(f' > {deep_gemm.__path__}\n') test_gemm() - # test_m_grouped_gemm_contiguous() - # test_m_grouped_gemm_masked() - # test_k_grouped_gemm_contiguous() + test_m_grouped_gemm_contiguous() + test_m_grouped_gemm_masked() + test_k_grouped_gemm_contiguous() From e314311edbf7423cdce64e13d1a820e1b6124b6b Mon Sep 17 00:00:00 2001 From: xueweil Date: Wed, 17 Dec 2025 02:46:44 -0800 Subject: [PATCH 5/6] minor --- .../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp | 4 ---- .../deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh | 20 ------------------- 2 files changed, 24 deletions(-) diff --git a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp index d14f4f37..454f9e2f 100644 --- a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp +++ b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp @@ -96,10 +96,6 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(), device_runtime->get_num_sms(), bias.has_value()); - // Debug: print block sizes - printf("DEBUG: block_m=%d, block_n=%d, block_k=%d, num_multicast=%d\n", - config.block_m, config.block_n, config.block_k, config.multicast_config.num_multicast); - const auto& cd = c.value_or(d); const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh index 5437fb15..8b72cb6c 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh @@ -542,26 +542,6 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, values[5] = __float_as_uint(__uint_as_float(values[5]) + bias_vals[5]); values[6] = __float_as_uint(__uint_as_float(values[6]) + bias_vals[6]); values[7] = __float_as_uint(__uint_as_float(values[7]) + bias_vals[7]); - - // float2 bias_vals[4]; - // #pragma unroll - // for (int b = 0; b < 4; ++b) { - // bias_vals[b] = make_float2(static_cast(bias_ptr[b * 2]), static_cast(bias_ptr[b * 2 + 1])); - // } - // //TODO:fadd2 - // bias_vals[0] = __fadd2_rd(make_float2(__uint_as_float(values[0]), __uint_as_float(values[1])), bias_vals[0]); - // bias_vals[1] = __fadd2_rd(make_float2(__uint_as_float(values[2]), __uint_as_float(values[3])), bias_vals[1]); - // bias_vals[2] = __fadd2_rd(make_float2(__uint_as_float(values[4]), __uint_as_float(values[5])), bias_vals[2]); - // bias_vals[3] = __fadd2_rd(make_float2(__uint_as_float(values[6]), __uint_as_float(values[7])), bias_vals[3]); - - // values[0] = __float_as_uint(bias_vals[0].x); - // values[1] = __float_as_uint(bias_vals[0].y); - // values[2] = __float_as_uint(bias_vals[1].x); - // values[3] = __float_as_uint(bias_vals[1].y); - // values[4] = __float_as_uint(bias_vals[2].x); - // values[5] = __float_as_uint(bias_vals[2].y); - // values[6] = __float_as_uint(bias_vals[3].x); - // values[7] = __float_as_uint(bias_vals[3].y); } st_shared(smem_ptr, From 3ae9b1a2fa4062cb73583250b7509bfae6d10e14 Mon Sep 17 00:00:00 2001 From: xueweil Date: Mon, 19 Jan 2026 18:28:47 -0800 Subject: [PATCH 6/6] bug fix --- csrc/jit_kernels/heuristics/common.hpp | 4 +- .../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp | 4 +- .../deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh | 52 +++++++++++++++---- tests/generators.py | 11 +--- 4 files changed, 49 insertions(+), 22 deletions(-) diff --git a/csrc/jit_kernels/heuristics/common.hpp b/csrc/jit_kernels/heuristics/common.hpp index fef85abe..12eaeaa0 100644 --- a/csrc/jit_kernels/heuristics/common.hpp +++ b/csrc/jit_kernels/heuristics/common.hpp @@ -114,8 +114,8 @@ static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const Kerne // Different archs have different epilogue pipelines const int& smem_cd = ArchSpec::get_smem_cd_size(kernel_type, block_m, block_n, swizzle_cd_mode, cd_dtype); - // Bias shared memory (aligned to 128 bytes for TMA), 2 = NumTMAStoreStages - const int& smem_bias = with_bias ? align(cd_elem_size * block_n, 128) * 2 : 0; + // Bias shared memory (aligned to 128 bytes for TMA) and 32 bytes for bias barrier, 2 = NumTMAStoreStages + const int& smem_bias = with_bias ? align(cd_elem_size * block_n, 128) * 2 + 32: 0; // A/B shared memory const int& smem_a_per_stage = load_block_m * block_k * ab_elem_size; diff --git a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp index 454f9e2f..dacb6df4 100644 --- a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp +++ b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp @@ -95,7 +95,7 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa m, n, k, 1, major_a, major_b, torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(), device_runtime->get_num_sms(), bias.has_value()); - + const auto& cd = c.value_or(d); const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), @@ -132,7 +132,7 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa .epilogue_type = epilogue_type, .gemm_config = config, .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, - 232448-128, + config.smem_config.smem_size, config.multicast_config.num_multicast), .grouped_layout = nullptr, .tensor_map_a = tensor_map_a, diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh index 8b72cb6c..ae465bc4 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh @@ -162,9 +162,15 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, auto with_sf_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages + i); }); + auto bias_full_barriers = PatternVisitor([=](const uint32_t& i) { if constexpr (! kWithBias) { std::abort(); } return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages * 2 + i); }); + auto bias_empty_barriers = PatternVisitor([=](const uint32_t& i) { if constexpr (! kWithBias) { std::abort(); } return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages * 3 + i); }); // Fill the tensor memory pointer - auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2); + uint32_t barrier_offset = kNumStages * 3 + kNumEpilogueStages * 2; + if constexpr (kWithBias) { + barrier_offset += kNumEpilogueStages * 2; + } + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + barrier_offset); DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); // Initialize barriers @@ -185,6 +191,14 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads); } + if constexpr (kWithBias) { + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + bias_empty_barriers[i]->init(kNumUMMAStoreThreads); + bias_full_barriers[i]->init(1); + } + } + // Make initialized barrier visible in async proxy cutlass::arch::fence_barrier_init(); } else if (warp_idx == 2) { @@ -270,18 +284,31 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, num_arrival_bytes += (BLOCK_M + BLOCK_N) * sizeof(uint32_t); } - if constexpr (kWithBias){ - if (k_block_idx == 0){ - tma_copy(&tensor_map_bias, full_barriers[stage_idx], smem_bias[accum_stage_idx], n_block_idx * BLOCK_N, 0, 1, 0); - num_arrival_bytes += BLOCK_N * sizeof(cd_dtype_t); - } - } - // Arrive at full barriers full_barriers[stage_idx]->arrive_and_expect_tx(num_arrival_bytes); } } - } else if (warp_idx == 1 and is_leader_cta) { + } else if (warp_idx == 3 and cute::elect_one_sync()) { + if constexpr (kWithBias) { + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + + uint32_t wait_bytes = BLOCK_N * sizeof(cd_dtype_t); + bias_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); + tma_copy(&tensor_map_bias, + bias_full_barriers[accum_stage_idx], smem_bias[accum_stage_idx], n_block_idx * BLOCK_N, 0, 1, 0); + bias_full_barriers[accum_stage_idx]->arrive_and_expect_tx(wait_bytes); + } + + const auto& iter_idx = scheduler.current_iter - 1; + if (kNumMulticast > 1 and iter_idx >= 0) { + const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1; + bias_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx); + } + } + } + else if (warp_idx == 1 and is_leader_cta) { // MMA issue warp // NOTES: only the leader CTA will do this // Make instruction descriptor @@ -465,6 +492,10 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); tcgen05_after_thread_sync(); + if constexpr (kWithBias) { + bias_full_barriers[accum_stage_idx]->wait(accum_phase_idx); + } + // Load from tensor memory into registers, and write shared memory with STSM DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough"); DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); @@ -557,6 +588,9 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { tcgen05_before_thread_sync(); tmem_empty_barriers[accum_stage_idx]->arrive(0u); + if constexpr (kWithBias) { + bias_empty_barriers[accum_stage_idx]->arrive(); + } } // Synchronize all threads and issue TMA diff --git a/tests/generators.py b/tests/generators.py index 89f3b993..6c146fb6 100644 --- a/tests/generators.py +++ b/tests/generators.py @@ -163,16 +163,9 @@ def generate_normal(m: int, n: int, k: int, b = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) d = torch.randn((m, n), device='cuda', dtype=out_dtype) * 32 if accumulate else \ torch.empty((m, n), device='cuda', dtype=out_dtype) + c = d if accumulate else None - if accumulate: - if out_dtype == torch.bfloat16: - c = torch.ones_like(d) * 10 - else: - c = d - else: - c = None - - bias = torch.ones((n), device='cuda', dtype=out_dtype) * 10 if with_bias else None + bias = torch.randn((n), device='cuda', dtype=out_dtype) * 10 if with_bias else None if accumulate: ref_d = (a.float() @ b.float().t() + c).to(out_dtype)