Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/apis/attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ static void fp8_gemm_nt_skip_head_mid(const std::pair<torch::Tensor, torch::Tens
const auto& major_sfb = get_major_type_ab(sfb);
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, major_sfb, compiled_dims, epilogue_type);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, compiled_dims, epilogue_type);
sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, std::nullopt, std::nullopt, d, m, n, k, major_a, major_b, compiled_dims, epilogue_type);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
}
Expand Down
36 changes: 28 additions & 8 deletions csrc/apis/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ static void fp8_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
const std::optional<torch::Tensor>& bias,
std::optional<std::tuple<int, int, int>> recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
Expand All @@ -74,6 +75,14 @@ static void fp8_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);

if (bias.has_value()) {
const auto& arch_major = device_runtime->get_arch_major();
DG_HOST_ASSERT(arch_major == 10);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(bias.value().scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(d.scalar_type() == bias.value().scalar_type());
}

// Early return for trivial cases
if (early_return(m, n, k, d, c))
return;
Expand All @@ -95,7 +104,7 @@ static void fp8_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, major_sfb, compiled_dims);
}
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, bias, d, m, n, k, major_a, major_b, compiled_dims);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
}
Expand All @@ -105,34 +114,37 @@ static void fp8_gemm_nn(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
const std::optional<torch::Tensor>& bias,
const std::optional<std::tuple<int, int, int>>& recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
fp8_gemm_nt(a, {b.first.transpose(0, 1), b.second.transpose(0, 1)},
d, c, recipe, compiled_dims, disable_ue8m0_cast);
d, c, bias, recipe, compiled_dims, disable_ue8m0_cast);
}

static void fp8_gemm_tn(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
const std::optional<torch::Tensor>& bias,
const std::optional<std::tuple<int, int, int>>& recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)},
{b.first.transpose(0, 1), b.second.transpose(0, 1)},
d, c, recipe, compiled_dims, disable_ue8m0_cast);
d, c, bias, recipe, compiled_dims, disable_ue8m0_cast);
}

static void fp8_gemm_tt(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
const std::optional<torch::Tensor>& bias,
const std::optional<std::tuple<int, int, int>>& recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, b,
d, c, recipe, compiled_dims, disable_ue8m0_cast);
d, c, bias, recipe, compiled_dims, disable_ue8m0_cast);
}

static void m_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
Expand Down Expand Up @@ -567,22 +579,30 @@ static void register_apis(pybind11::module_& m) {
// FP8 GEMMs
m.def("fp8_gemm_nt", &fp8_gemm_nt,
py::arg("a"), py::arg("b"), py::arg("d"),
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
py::arg("c") = std::nullopt,
py::arg("bias") = std::nullopt,
py::arg("recipe") = std::nullopt,
py::arg("compiled_dims") = "nk",
py::arg("disable_ue8m0_cast") = false);
m.def("fp8_gemm_nn", &fp8_gemm_nn,
py::arg("a"), py::arg("b"), py::arg("d"),
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
py::arg("c") = std::nullopt,
py::arg("bias") = std::nullopt,
py::arg("recipe") = std::nullopt,
py::arg("compiled_dims") = "nk",
py::arg("disable_ue8m0_cast") = false);
m.def("fp8_gemm_tn", &fp8_gemm_tn,
py::arg("a"), py::arg("b"), py::arg("d"),
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
py::arg("c") = std::nullopt,
py::arg("bias") = std::nullopt,
py::arg("recipe") = std::nullopt,
py::arg("compiled_dims") = "mn",
py::arg("disable_ue8m0_cast") = false);
m.def("fp8_gemm_tt", &fp8_gemm_tt,
py::arg("a"), py::arg("b"), py::arg("d"),
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
py::arg("c") = std::nullopt,
py::arg("bias") = std::nullopt,
py::arg("recipe") = std::nullopt,
py::arg("compiled_dims") = "mn",
py::arg("disable_ue8m0_cast") = false);
m.def("m_grouped_fp8_gemm_nt_contiguous", &m_grouped_fp8_gemm_nt_contiguous,
Expand Down
13 changes: 10 additions & 3 deletions csrc/jit_kernels/heuristics/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const Kerne
const int& block_m, const int& block_n, const int& block_k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
const int& num_stages, const MulticastConfig& multicast_config) {
const int& num_stages, const MulticastConfig& multicast_config,
const bool& with_bias = false) {
const int& ab_elem_size = static_cast<int>(c10::elementSize(ab_dtype));
const int& cd_elem_size = static_cast<int>(c10::elementSize(cd_dtype));

Expand All @@ -113,6 +114,9 @@ static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const Kerne
// Different archs have different epilogue pipelines
const int& smem_cd = ArchSpec::get_smem_cd_size(kernel_type, block_m, block_n, swizzle_cd_mode, cd_dtype);

// Bias shared memory (aligned to 128 bytes for TMA) and 32 bytes for bias barrier, 2 = NumTMAStoreStages
const int& smem_bias = with_bias ? align(cd_elem_size * block_n, 128) * 2 + 32: 0;

// A/B shared memory
const int& smem_a_per_stage = load_block_m * block_k * ab_elem_size;
const int& smem_b_per_stage = load_block_n * block_k * ab_elem_size;
Expand All @@ -131,6 +135,7 @@ static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const Kerne
int smem_size = 0;
smem_size += smem_tensor_map;
smem_size += smem_cd;
smem_size += smem_bias;
smem_size += num_stages * smem_a_per_stage;
smem_size += num_stages * smem_b_per_stage;
smem_size += num_stages * smem_sfa_per_stage;
Expand All @@ -152,7 +157,8 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
const int& m, const int& n, const int& k, const int& num_groups,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
const bool& with_accumulation, const int& num_sms) {
const bool& with_accumulation, const int& num_sms,
const bool& with_bias = false) {
DG_HOST_ASSERT(ab_dtype == torch::kFloat8_e4m3fn or ab_dtype == torch::kBFloat16);
DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat);

Expand Down Expand Up @@ -244,7 +250,8 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
best_block_m, best_block_n, block_k,
major_a, major_b,
ab_dtype, cd_dtype,
num_stages, best_multicast_config);
num_stages, best_multicast_config,
with_bias);
if (best_smem_config.smem_size <= smem_capacity) {
best_num_stages = num_stages;
break;
Expand Down
19 changes: 19 additions & 0 deletions csrc/jit_kernels/impls/runtime_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,4 +222,23 @@ static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
allow_tf32);
}

static CUtensorMap make_tma_bias_desc(const cute::UMMA::Major& major,
const torch::Tensor& t,
int shape_mn,
const int& block_mn,
const int& num_groups,
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false) {
DG_HOST_ASSERT(major == cute::UMMA::Major::MN);

DG_HOST_ASSERT(swizzle_mode == 0);
shape_mn = get_tma_aligned_size(shape_mn, static_cast<int>(t.element_size()));
return make_tma_2d_desc(t,
shape_mn, num_groups,
block_mn, 1,
shape_mn,
swizzle_mode, swizzle_base,
allow_tf32);
}

} // namespace deep_gemm
25 changes: 18 additions & 7 deletions csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class SM100FP8Gemm1D1DRuntime final: public LaunchRuntime<SM100FP8Gemm1D1DRuntim
public:
struct Args {
int m, n, k, num_groups;
bool with_bias;
const std::string& compiled_dims;
const std::optional<std::string>& epilogue_type;

Expand All @@ -31,6 +32,7 @@ class SM100FP8Gemm1D1DRuntime final: public LaunchRuntime<SM100FP8Gemm1D1DRuntim
CUtensorMap tensor_map_sfa;
CUtensorMap tensor_map_sfb;
CUtensorMap tensor_map_cd;
CUtensorMap tensor_map_bias;
};

static std::string generate_impl(const Args& args) {
Expand All @@ -50,7 +52,7 @@ static void __instantiate_kernel() {{
{}, {},
{}, {},
{},
{}, {}, {},
{}, {}, {}, {},
{}
>);
}};
Expand All @@ -64,7 +66,7 @@ static void __instantiate_kernel() {{
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
args.gemm_config.num_sms,
to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype),
to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, args.with_bias, to_string(args.gemm_config.cd_dtype),
get_default_epilogue_type(args.epilogue_type));
}

Expand All @@ -74,13 +76,14 @@ static void __instantiate_kernel() {{
args.grouped_layout, args.m, args.n, args.k,
args.tensor_map_a, args.tensor_map_b,
args.tensor_map_sfa, args.tensor_map_sfb,
args.tensor_map_cd));
args.tensor_map_cd, args.tensor_map_bias));
}
};

static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const std::optional<torch::Tensor>& bias,
const torch::Tensor& d,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
Expand All @@ -91,8 +94,8 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
GemmType::Normal, KernelType::Kernel1D1D,
m, n, k, 1, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());

device_runtime->get_num_sms(), bias.has_value());
const auto& cd = c.value_or(d);
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
Expand All @@ -114,22 +117,30 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.block_n, config.block_k, 1, 0);

// Create tensor map for bias only if c has a value
CUtensorMap tensor_map_bias{};
if (bias.has_value()) {
tensor_map_bias = make_tma_bias_desc(cute::UMMA::Major::MN, bias.value(), n, config.block_n, 1, 0);
}

// Launch
const SM100FP8Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
.num_groups = 1,
.with_bias = bias.has_value(),
.compiled_dims = compiled_dims,
.epilogue_type = epilogue_type,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
.tensor_map_cd = tensor_map_cd,
.tensor_map_bias = tensor_map_bias
};
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code);
Expand Down
Loading