Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions codegen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,31 @@ configure_file(${CK_ROOT}/include/ck/config.h.in ${CK_ROOT}/include/ck/config.h)
find_package(ROCM)
include(ROCMInstallTargets)
include(ROCMTest)
list(APPEND CMAKE_PREFIX_PATH /opt/rocm $ENV{ROCM_PATH})
find_package(hiprtc REQUIRED)

rocm_setup_version(VERSION 1.0)

list(APPEND CMAKE_MODULE_PATH ${CK_ROOT}/cmake)
include(Embed)
file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS
${CK_ROOT}/include/ck/*.hpp)

add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include)

# Embed CK Tile headers (ck_tile/*.hpp) for FMHA RTC API
file(GLOB_RECURSE CK_TILE_KERNEL_FILES CONFIGURE_DEPENDS
${CK_ROOT}/include/ck_tile/*.hpp)
add_embed_library(ck_tile_headers ${CK_TILE_KERNEL_FILES} RELATIVE ${CK_ROOT}/include)

# Embed codegen device headers (wrapper.hpp for FMHA RTC)
file(GLOB_RECURSE CK_CODEGEN_DEVICE_FILES CONFIGURE_DEPENDS
${CMAKE_CURRENT_SOURCE_DIR}/include/ck/host/device_fmha_fwd/wrapper.hpp)
add_embed_library(ck_codegen_headers ${CK_CODEGEN_DEVICE_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/include)

add_compile_options(-std=c++20)

file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp)
# TODO: Use object library
add_library(ck_host STATIC ${SOURCES})
target_link_libraries(ck_host PRIVATE ck_headers hiprtc::hiprtc)
target_link_libraries(ck_host PRIVATE ck_headers ck_tile_headers ck_codegen_headers)

set_target_properties(ck_host PROPERTIES
LINKER_LANGUAGE CXX
Expand All @@ -46,12 +53,12 @@ add_executable(ck-template-driver driver/main.cpp)
target_link_libraries(ck-template-driver ck_host)

rocm_install_targets(
TARGETS ck_host ck_headers
TARGETS ck_host ck_headers ck_tile_headers ck_codegen_headers
EXPORT ck_host_targets
INCLUDE include
)
rocm_export_targets(
TARGETS ck_host ck_headers
TARGETS ck_host ck_headers ck_tile_headers ck_codegen_headers
EXPORT ck_host_targets
NAMESPACE composable_kernel::
)
Expand Down
83 changes: 83 additions & 0 deletions codegen/include/ck/host/device_fmha_fwd/operation.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#pragma once

#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
#include "ck/host/device_fmha_fwd/problem.hpp"

namespace ck {
namespace host {
namespace device_fmha_fwd {

// Derived from fmha_fwd.py FmhaFwdTileSize.
struct TileConfig
{
// Block tile
std::size_t bm0;
std::size_t bn0;
std::size_t bk0;
std::size_t bn1;
std::size_t bk1;
std::size_t bk0max;

// Gemm0 block warps
std::size_t rm0;
std::size_t rn0;
std::size_t rk0;

// Gemm1 block warps
std::size_t rm1;
std::size_t rn1;
std::size_t rk1;

// Gemm0 warp tile
std::size_t wm0;
std::size_t wn0;
std::size_t wk0;

// Gemm1 warp tile
std::size_t wm1;
std::size_t wn1;
std::size_t wk1;
};

struct Operation
{
TileConfig tile = {};

std::string pipeline = "qr_async";

bool is_causal = false;
bool is_v_rowmajor = true;
bool has_bias = false;
DataType dtype = DataType::Half;

bool pad_m = true; // pad seqlen_q
bool pad_n = true; // pad seqlen_k
bool pad_k = true; // pad hdim_q
bool pad_o = true; // pad hdim_v

static std::vector<Operation> CreateOperations(const Problem& prob, const std::string& arch);

Solution ToSolution() const;
};

struct HdimBucketResult
{
std::size_t bucket_hdim = 0;
std::size_t bucket_hdim_v = 0;
std::vector<TileConfig> tiles;
};

HdimBucketResult
GetTileConfigsForHdim(const std::string& arch, DataType dtype, std::size_t K, std::size_t O);

bool IsSupportedArch(const std::string& arch);

} // namespace device_fmha_fwd
} // namespace host
} // namespace ck
38 changes: 38 additions & 0 deletions codegen/include/ck/host/device_fmha_fwd/problem.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#pragma once

#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"

namespace ck {
namespace host {
namespace device_fmha_fwd {

struct Problem
{
std::size_t M = 0; // seqlen_q
std::size_t N = 0; // seqlen_k
std::size_t K = 0; // hdim_q
std::size_t O = 0; // hdim_v

std::size_t batch = 0;
std::size_t nhead = 0; // nhead_q == nhead_k

DataType dtype = DataType::Half;

bool is_v_rowmajor = true; // true=[N,O], false=[O,N]
bool is_causal = false;
bool has_bias = false;

std::string GetIncludeHeader() const;

std::vector<Solution> GetSolutions(const std::string& arch) const;
};

} // namespace device_fmha_fwd
} // namespace host
} // namespace ck
Loading
Loading