diff --git a/codegen/CMakeLists.txt b/codegen/CMakeLists.txt index 69a6a71de21..6c5e557fb7c 100644 --- a/codegen/CMakeLists.txt +++ b/codegen/CMakeLists.txt @@ -15,8 +15,6 @@ configure_file(${CK_ROOT}/include/ck/config.h.in ${CK_ROOT}/include/ck/config.h) find_package(ROCM) include(ROCMInstallTargets) include(ROCMTest) -list(APPEND CMAKE_PREFIX_PATH /opt/rocm $ENV{ROCM_PATH}) -find_package(hiprtc REQUIRED) rocm_setup_version(VERSION 1.0) @@ -24,15 +22,24 @@ list(APPEND CMAKE_MODULE_PATH ${CK_ROOT}/cmake) include(Embed) file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS ${CK_ROOT}/include/ck/*.hpp) - add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include) +# Embed CK Tile headers (ck_tile/*.hpp) for FMHA RTC API +file(GLOB_RECURSE CK_TILE_KERNEL_FILES CONFIGURE_DEPENDS + ${CK_ROOT}/include/ck_tile/*.hpp) +add_embed_library(ck_tile_headers ${CK_TILE_KERNEL_FILES} RELATIVE ${CK_ROOT}/include) + +# Embed codegen device headers (wrapper.hpp for FMHA RTC) +file(GLOB_RECURSE CK_CODEGEN_DEVICE_FILES CONFIGURE_DEPENDS + ${CMAKE_CURRENT_SOURCE_DIR}/include/ck/host/device_fmha_fwd/wrapper.hpp) +add_embed_library(ck_codegen_headers ${CK_CODEGEN_DEVICE_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/include) + add_compile_options(-std=c++20) file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp) # TODO: Use object library add_library(ck_host STATIC ${SOURCES}) -target_link_libraries(ck_host PRIVATE ck_headers hiprtc::hiprtc) +target_link_libraries(ck_host PRIVATE ck_headers ck_tile_headers ck_codegen_headers) set_target_properties(ck_host PROPERTIES LINKER_LANGUAGE CXX @@ -46,12 +53,12 @@ add_executable(ck-template-driver driver/main.cpp) target_link_libraries(ck-template-driver ck_host) rocm_install_targets( - TARGETS ck_host ck_headers + TARGETS ck_host ck_headers ck_tile_headers ck_codegen_headers EXPORT ck_host_targets INCLUDE include ) rocm_export_targets( - TARGETS ck_host ck_headers + TARGETS ck_host ck_headers ck_tile_headers ck_codegen_headers EXPORT ck_host_targets NAMESPACE composable_kernel:: ) diff --git a/codegen/include/ck/host/device_fmha_fwd/operation.hpp b/codegen/include/ck/host/device_fmha_fwd/operation.hpp new file mode 100644 index 00000000000..245c818e6ed --- /dev/null +++ b/codegen/include/ck/host/device_fmha_fwd/operation.hpp @@ -0,0 +1,83 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include "ck/host/types.hpp" +#include "ck/host/device_fmha_fwd/problem.hpp" + +namespace ck { +namespace host { +namespace device_fmha_fwd { + +// Derived from fmha_fwd.py FmhaFwdTileSize. +struct TileConfig +{ + // Block tile + std::size_t bm0; + std::size_t bn0; + std::size_t bk0; + std::size_t bn1; + std::size_t bk1; + std::size_t bk0max; + + // Gemm0 block warps + std::size_t rm0; + std::size_t rn0; + std::size_t rk0; + + // Gemm1 block warps + std::size_t rm1; + std::size_t rn1; + std::size_t rk1; + + // Gemm0 warp tile + std::size_t wm0; + std::size_t wn0; + std::size_t wk0; + + // Gemm1 warp tile + std::size_t wm1; + std::size_t wn1; + std::size_t wk1; +}; + +struct Operation +{ + TileConfig tile = {}; + + std::string pipeline = "qr_async"; + + bool is_causal = false; + bool is_v_rowmajor = true; + bool has_bias = false; + DataType dtype = DataType::Half; + + bool pad_m = true; // pad seqlen_q + bool pad_n = true; // pad seqlen_k + bool pad_k = true; // pad hdim_q + bool pad_o = true; // pad hdim_v + + static std::vector CreateOperations(const Problem& prob, const std::string& arch); + + Solution ToSolution() const; +}; + +struct HdimBucketResult +{ + std::size_t bucket_hdim = 0; + std::size_t bucket_hdim_v = 0; + std::vector tiles; +}; + +HdimBucketResult +GetTileConfigsForHdim(const std::string& arch, DataType dtype, std::size_t K, std::size_t O); + +bool IsSupportedArch(const std::string& arch); + +} // namespace device_fmha_fwd +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/device_fmha_fwd/problem.hpp b/codegen/include/ck/host/device_fmha_fwd/problem.hpp new file mode 100644 index 00000000000..c19e1bc90e5 --- /dev/null +++ b/codegen/include/ck/host/device_fmha_fwd/problem.hpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include "ck/host/types.hpp" + +namespace ck { +namespace host { +namespace device_fmha_fwd { + +struct Problem +{ + std::size_t M = 0; // seqlen_q + std::size_t N = 0; // seqlen_k + std::size_t K = 0; // hdim_q + std::size_t O = 0; // hdim_v + + std::size_t batch = 0; + std::size_t nhead = 0; // nhead_q == nhead_k + + DataType dtype = DataType::Half; + + bool is_v_rowmajor = true; // true=[N,O], false=[O,N] + bool is_causal = false; + bool has_bias = false; + + std::string GetIncludeHeader() const; + + std::vector GetSolutions(const std::string& arch) const; +}; + +} // namespace device_fmha_fwd +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/device_fmha_fwd/wrapper.hpp b/codegen/include/ck/host/device_fmha_fwd/wrapper.hpp new file mode 100644 index 00000000000..1c36f262531 --- /dev/null +++ b/codegen/include/ck/host/device_fmha_fwd/wrapper.hpp @@ -0,0 +1,268 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +// This header is designed to be embedded and used at RTC compilation time. + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha.hpp" +#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" + +namespace ck_tile { + +enum class FmhaPipelineTag +{ + QR, // BlockFmhaPipelineQRKSVS + QR_ASYNC, // BlockFmhaPipelineQRKSVSAsync + QR_ASYNC_TRLOAD // BlockFmhaPipelineQRKSVSAsyncTrload +}; + +template +struct FmhaFwdWrapper +{ + using BlockTile = sequence; + + using Gemm0BlockWarps = sequence; + using Gemm0WarpTile = sequence; + using Gemm1BlockWarps = sequence; + using Gemm1WarpTile = sequence; + + using FmhaShape = TileFmhaShape; + + static constexpr auto BiasEnum = + kHasBias ? BlockAttentionBiasEnum::ELEMENTWISE_BIAS : BlockAttentionBiasEnum::NO_BIAS; + + using FmhaTraits = TileFmhaTraits; // kHasSink + + using FmhaMask = std::conditional_t, + SimplifiedGenericAttentionMask>; + + static constexpr bool kUseTrLoad = (kPipelineTag == FmhaPipelineTag::QR_ASYNC_TRLOAD); + + using PipelineProblem = + BlockFmhaPipelineProblem, + FmhaMask, + kUseTrLoad, + FmhaTraits>; + + using Pipeline = + std::conditional_t, + std::conditional_t, + BlockFmhaPipelineQRKSVS>>; + + using Epilogue = Default2DEpilogue>; + + using Kernel = FmhaFwdKernel; + + // Innermost dimension is always contiguous (stride=1): + // + // K is stored as [batch, nhead, N, K] (not transposed). + // The kernel internally handles the transpose for Q @ K^T. + // + // Q: [batch, nhead, M, K] + // K: [batch, nhead, N, K] + // V: [batch, nhead, N, O] (rowmajor) or [batch, nhead, O, N] (colmajor) + // O: [batch, nhead, M, O] + // Bias: [batch, nhead, M, N] + struct Descriptor + { + index_t batch, nhead, M, K; + index_t q_stride_batch, q_stride_nhead, q_stride_m; + + index_t N; + index_t k_stride_batch, k_stride_nhead, k_stride_n; + + index_t O; + index_t v_stride_batch, v_stride_nhead, v_stride_n; + + index_t o_stride_batch, o_stride_nhead, o_stride_m; + + index_t bias_stride_batch, bias_stride_nhead, bias_stride_m; + + CK_TILE_HOST_DEVICE constexpr bool IsValid() const { return Kernel::kIsAvailable; } + }; + + // Each tensor is specified as (batch, nhead, dim0, dim1) and (stride0, stride1, stride2) + // Innermost stride is always 1 and not passed. + template + CK_TILE_HOST_DEVICE static constexpr auto make_descriptor(QDims q_dims, + QStrides q_strides, + KDims k_dims, + KStrides k_strides, + VDims v_dims, + VStrides v_strides, + ODims o_dims, + OStrides o_strides, + BiasDims bias_dims, + BiasStrides bias_strides) + { + return Descriptor{q_dims[number<0>{}], + q_dims[number<1>{}], + q_dims[number<2>{}], + q_dims[number<3>{}], + q_strides[number<0>{}], + q_strides[number<1>{}], + q_strides[number<2>{}], + // + k_dims[number<2>{}], + k_strides[number<0>{}], + k_strides[number<1>{}], + k_strides[number<2>{}], + // + v_dims[number<3>{}], + v_strides[number<0>{}], + v_strides[number<1>{}], + v_strides[number<2>{}], + // + o_strides[number<0>{}], + o_strides[number<1>{}], + o_strides[number<2>{}], + // + bias_strides[number<0>{}], + bias_strides[number<1>{}], + bias_strides[number<2>{}]}; + } + + CK_TILE_DEVICE static void Run(const Descriptor& desc, + float scale_s, + const DataType_* q_ptr, + const DataType_* k_ptr, + const DataType_* v_ptr, + const DataType_* bias_ptr, + DataType_* o_ptr) + { + using Kargs = typename Kernel::Kargs; + Kargs kargs{}; + + kargs.q_ptr = q_ptr; + kargs.k_ptr = k_ptr; + kargs.v_ptr = v_ptr; + kargs.o_ptr = o_ptr; + kargs.sink_ptr = nullptr; + + kargs.seqlen_q = desc.M; + kargs.seqlen_k = desc.N; + kargs.hdim_q = desc.K; + kargs.hdim_v = desc.O; + + kargs.num_head_q = desc.nhead; + kargs.nhead_ratio_qk = 1; // nhead_q == nhead_k + + kargs.scale_s = scale_s; + + kargs.stride_q = desc.q_stride_m; + kargs.stride_k = desc.k_stride_n; + kargs.stride_v = desc.v_stride_n; + kargs.stride_o = desc.o_stride_m; + + kargs.nhead_stride_q = desc.q_stride_nhead; + kargs.nhead_stride_k = desc.k_stride_nhead; + kargs.nhead_stride_v = desc.v_stride_nhead; + kargs.nhead_stride_o = desc.o_stride_nhead; + + if constexpr(kHasBias) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = desc.bias_stride_m; + kargs.nhead_stride_bias = desc.bias_stride_nhead; + kargs.batch_stride_bias = desc.bias_stride_batch; + } + + if constexpr(kIsCausal) + { + kargs.window_size_left = -1; + kargs.window_size_right = 0; + kargs.sink_size = 0; + kargs.mask_type = GenericAttentionMaskEnum::MASK_FROM_BOTTOM_RIGHT; + } + + kargs.batch_stride_q = desc.q_stride_batch; + kargs.batch_stride_k = desc.k_stride_batch; + kargs.batch_stride_v = desc.v_stride_batch; + kargs.batch_stride_o = desc.o_stride_batch; + + kargs.cu_seqlen_q_ptr = nullptr; + kargs.cu_seqlen_k_ptr = nullptr; + + Kernel{}(kargs); + } +}; + +} // namespace ck_tile diff --git a/codegen/include/ck/host/headers.hpp b/codegen/include/ck/host/headers.hpp index 571ad472ea8..11d1e555221 100644 --- a/codegen/include/ck/host/headers.hpp +++ b/codegen/include/ck/host/headers.hpp @@ -13,5 +13,7 @@ namespace host { std::unordered_map GetHeaders(); +std::unordered_map GetTileHeaders(); + } // namespace host } // namespace ck diff --git a/codegen/include/rocm-cxx/rocm/algorithm.hpp b/codegen/include/rocm-cxx/rocm/algorithm.hpp new file mode 100644 index 00000000000..7613a752d62 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/algorithm.hpp @@ -0,0 +1,31 @@ +#ifndef ROCM_GUARD_ROCM_ALGORITHM_HPP +#define ROCM_GUARD_ROCM_ALGORITHM_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#endif // ROCM_GUARD_ROCM_ALGORITHM_HPP diff --git a/codegen/include/rocm-cxx/rocm/algorithm/accumulate.hpp b/codegen/include/rocm-cxx/rocm/algorithm/accumulate.hpp new file mode 100644 index 00000000000..741b0453fa6 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/algorithm/accumulate.hpp @@ -0,0 +1,27 @@ +#ifndef ROCM_GUARD_ROCM_ALGORITHM_ACCUMULATE_HPP +#define ROCM_GUARD_ROCM_ALGORITHM_ACCUMULATE_HPP + +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr T accumulate(InputIt first, InputIt last, T init, BinaryOperation op) +{ + for(; first != last; ++first) + { + init = op(static_cast(init), *first); + } + return init; +} + +template +constexpr T accumulate(InputIt first, InputIt last, T init) +{ + return accumulate(first, last, init, [](auto x, auto y) { return x + y; }); +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ALGORITHM_ACCUMULATE_HPP diff --git a/codegen/include/rocm-cxx/rocm/algorithm/all_of.hpp b/codegen/include/rocm-cxx/rocm/algorithm/all_of.hpp new file mode 100644 index 00000000000..a8221427d8b --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/algorithm/all_of.hpp @@ -0,0 +1,18 @@ +#ifndef ROCM_GUARD_ROCM_ALGORITHM_ALL_OF_HPP +#define ROCM_GUARD_ROCM_ALGORITHM_ALL_OF_HPP + +#include +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr bool all_of(InputIt first, InputIt last, UnaryPredicate p) +{ + return none_of(first, last, [=](auto&& x) { return not p(x); }); +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ALGORITHM_ALL_OF_HPP diff --git a/codegen/include/rocm-cxx/rocm/algorithm/any_of.hpp b/codegen/include/rocm-cxx/rocm/algorithm/any_of.hpp new file mode 100644 index 00000000000..61f14e1f547 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/algorithm/any_of.hpp @@ -0,0 +1,18 @@ +#ifndef ROCM_GUARD_ROCM_ALGORITHM_ANY_OF_HPP +#define ROCM_GUARD_ROCM_ALGORITHM_ANY_OF_HPP + +#include +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr bool any_of(InputIt first, InputIt last, UnaryPredicate p) +{ + return find_if(first, last, p) != last; +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ALGORITHM_ANY_OF_HPP diff --git a/codegen/include/rocm-cxx/rocm/algorithm/copy.hpp b/codegen/include/rocm-cxx/rocm/algorithm/copy.hpp new file mode 100644 index 00000000000..451b0fba4b4 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/algorithm/copy.hpp @@ -0,0 +1,21 @@ +#ifndef ROCM_GUARD_ROCM_ALGORITHM_COPY_HPP +#define ROCM_GUARD_ROCM_ALGORITHM_COPY_HPP + +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first) +{ + while(first != last) + { + *d_first++ = *first++; + } + return d_first; +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ALGORITHM_COPY_HPP diff --git a/codegen/include/rocm-cxx/rocm/algorithm/copy_if.hpp b/codegen/include/rocm-cxx/rocm/algorithm/copy_if.hpp new file mode 100644 index 00000000000..08097116244 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/algorithm/copy_if.hpp @@ -0,0 +1,25 @@ +#ifndef ROCM_GUARD_ROCM_ALGORITHM_COPY_IF_HPP +#define ROCM_GUARD_ROCM_ALGORITHM_COPY_IF_HPP + +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr OutputIt copy_if(InputIt first, InputIt last, OutputIt d_first, UnaryPredicate pred) +{ + for(; first != last; ++first) + { + if(pred(*first)) + { + *d_first = *first; + ++d_first; + } + } + return d_first; +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ALGORITHM_COPY_IF_HPP diff --git a/codegen/include/rocm-cxx/rocm/algorithm/equal.hpp b/codegen/include/rocm-cxx/rocm/algorithm/equal.hpp new file mode 100644 index 00000000000..3c530d6981e --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/algorithm/equal.hpp @@ -0,0 +1,28 @@ +#ifndef ROCM_GUARD_ROCM_ALGORITHM_EQUAL_HPP +#define ROCM_GUARD_ROCM_ALGORITHM_EQUAL_HPP + +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr bool equal(Iterator1 first1, Iterator1 last1, Iterator2 first2, BinaryPred p) +{ + for(; first1 != last1; ++first1, ++first2) + if(not p(*first1, *first2)) + { + return false; + } + return true; +} + +template +constexpr bool equal(Iterator1 first1, Iterator1 last1, Iterator2 first2) +{ + return equal(first1, last1, first2, [](auto&& x, auto&& y) { return x == y; }); +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ALGORITHM_EQUAL_HPP diff --git a/codegen/include/rocm-cxx/rocm/algorithm/fill.hpp b/codegen/include/rocm-cxx/rocm/algorithm/fill.hpp new file mode 100644 index 00000000000..1cc512290d8 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/algorithm/fill.hpp @@ -0,0 +1,18 @@ +#ifndef ROCM_GUARD_ROCM_ALGORITHM_FILL_HPP +#define ROCM_GUARD_ROCM_ALGORITHM_FILL_HPP + +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr void fill(Iterator first, Iterator last, const T& value) +{ + for(; first != last; ++first) + *first = value; +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ALGORITHM_FILL_HPP diff --git a/codegen/include/rocm-cxx/rocm/algorithm/find.hpp b/codegen/include/rocm-cxx/rocm/algorithm/find.hpp new file mode 100644 index 00000000000..0ac3d0ddfbd --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/algorithm/find.hpp @@ -0,0 +1,18 @@ +#ifndef ROCM_GUARD_ROCM_ALGORITHM_FIND_HPP +#define ROCM_GUARD_ROCM_ALGORITHM_FIND_HPP + +#include +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr Iterator find(Iterator first, Iterator last, const T& value) +{ + return find_if(first, last, [&](const auto& x) { return x == value; }); +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ALGORITHM_FIND_HPP diff --git a/codegen/include/rocm-cxx/rocm/algorithm/find_if.hpp b/codegen/include/rocm-cxx/rocm/algorithm/find_if.hpp new file mode 100644 index 00000000000..db7815e679e --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/algorithm/find_if.hpp @@ -0,0 +1,24 @@ +#ifndef ROCM_GUARD_ROCM_ALGORITHM_FIND_IF_HPP +#define ROCM_GUARD_ROCM_ALGORITHM_FIND_IF_HPP + +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr Iterator find_if(Iterator first, Iterator last, Predicate p) +{ + for(; first != last; ++first) + { + if(p(*first)) + { + return first; + } + } + return last; +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ALGORITHM_FIND_IF_HPP diff --git a/codegen/include/rocm-cxx/rocm/algorithm/for_each.hpp b/codegen/include/rocm-cxx/rocm/algorithm/for_each.hpp new file mode 100644 index 00000000000..36871a0448e --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/algorithm/for_each.hpp @@ -0,0 +1,21 @@ +#ifndef ROCM_GUARD_ROCM_ALGORITHM_FOR_EACH_HPP +#define ROCM_GUARD_ROCM_ALGORITHM_FOR_EACH_HPP + +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr F for_each(Iterator first, Iterator last, F f) +{ + for(; first != last; ++first) + { + f(*first); + } + return f; +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ALGORITHM_FOR_EACH_HPP diff --git a/codegen/include/rocm-cxx/rocm/algorithm/inner_product.hpp b/codegen/include/rocm-cxx/rocm/algorithm/inner_product.hpp new file mode 100644 index 00000000000..474f7eb9d79 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/algorithm/inner_product.hpp @@ -0,0 +1,40 @@ +#ifndef ROCM_GUARD_ROCM_ALGORITHM_INNER_PRODUCT_HPP +#define ROCM_GUARD_ROCM_ALGORITHM_INNER_PRODUCT_HPP + +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr T inner_product(InputIt1 first1, + InputIt1 last1, + InputIt2 first2, + T init, + BinaryOperation1 op1, + BinaryOperation2 op2) +{ + while(first1 != last1) + { + init = op1(init, op2(*first1, *first2)); + ++first1; + ++first2; + } + return init; +} + +template +constexpr T inner_product(InputIt1 first1, InputIt1 last1, InputIt2 first2, T init) +{ + return inner_product( + first1, + last1, + first2, + init, + [](auto x, auto y) { return x + y; }, + [](auto x, auto y) { return x * y; }); +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ALGORITHM_INNER_PRODUCT_HPP diff --git a/codegen/include/rocm-cxx/rocm/algorithm/iota.hpp b/codegen/include/rocm-cxx/rocm/algorithm/iota.hpp new file mode 100644 index 00000000000..2701085b128 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/algorithm/iota.hpp @@ -0,0 +1,18 @@ +#ifndef ROCM_GUARD_ROCM_ALGORITHM_IOTA_HPP +#define ROCM_GUARD_ROCM_ALGORITHM_IOTA_HPP + +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr void iota(Iterator first, Iterator last, T value) +{ + for(; first != last; ++first, ++value) + *first = value; +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ALGORITHM_IOTA_HPP diff --git a/codegen/include/rocm-cxx/rocm/algorithm/is_sorted.hpp b/codegen/include/rocm-cxx/rocm/algorithm/is_sorted.hpp new file mode 100644 index 00000000000..5d96dbde59f --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/algorithm/is_sorted.hpp @@ -0,0 +1,25 @@ +#ifndef ROCM_GUARD_ROCM_ALGORITHM_IS_SORTED_HPP +#define ROCM_GUARD_ROCM_ALGORITHM_IS_SORTED_HPP + +#include +#include +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr bool is_sorted(Iterator first, Iterator last, Compare comp) +{ + return is_sorted_until(first, last, comp) == last; +} + +template +constexpr bool is_sorted(Iterator first, Iterator last) +{ + return is_sorted(first, last, less<>{}); +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ALGORITHM_IS_SORTED_HPP diff --git a/codegen/include/rocm-cxx/rocm/algorithm/is_sorted_until.hpp b/codegen/include/rocm-cxx/rocm/algorithm/is_sorted_until.hpp new file mode 100644 index 00000000000..1d2eda40936 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/algorithm/is_sorted_until.hpp @@ -0,0 +1,34 @@ +#ifndef ROCM_GUARD_ROCM_ALGORITHM_IS_SORTED_UNTIL_HPP +#define ROCM_GUARD_ROCM_ALGORITHM_IS_SORTED_UNTIL_HPP + +#include +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr Iterator is_sorted_until(Iterator first, Iterator last, Compare comp) +{ + if(first != last) + { + Iterator next = first; + while(++next != last) + { + if(comp(*next, *first)) + return next; + first = next; + } + } + return last; +} + +template +constexpr Iterator is_sorted_until(Iterator first, Iterator last) +{ + return is_sorted_until(first, last, less<>{}); +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ALGORITHM_IS_SORTED_UNTIL_HPP diff --git a/codegen/include/rocm-cxx/rocm/algorithm/iter_swap.hpp b/codegen/include/rocm-cxx/rocm/algorithm/iter_swap.hpp new file mode 100644 index 00000000000..7d1c8597f26 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/algorithm/iter_swap.hpp @@ -0,0 +1,20 @@ +#ifndef ROCM_GUARD_ROCM_ALGORITHM_ITER_SWAP_HPP +#define ROCM_GUARD_ROCM_ALGORITHM_ITER_SWAP_HPP + +#include +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr void iter_swap(Iterator1 a, Iterator2 b) +{ + if(a == b) + return; + swap(*a, *b); +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ALGORITHM_ITER_SWAP_HPP diff --git a/codegen/include/rocm-cxx/rocm/algorithm/lower_bound.hpp b/codegen/include/rocm-cxx/rocm/algorithm/lower_bound.hpp new file mode 100644 index 00000000000..e07d95cf7f1 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/algorithm/lower_bound.hpp @@ -0,0 +1,25 @@ +#ifndef ROCM_GUARD_ROCM_ALGORITHM_LOWER_BOUND_HPP +#define ROCM_GUARD_ROCM_ALGORITHM_LOWER_BOUND_HPP + +#include +#include +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr Iterator lower_bound(Iterator first, Iterator last, const T& value, Compare comp) +{ + return upper_bound(first, last, value, [&](auto&& a, auto&& b) { return not comp(b, a); }); +} + +template +constexpr Iterator lower_bound(Iterator first, Iterator last, const T& value) +{ + return lower_bound(first, last, value, less<>{}); +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ALGORITHM_LOWER_BOUND_HPP diff --git a/codegen/include/rocm-cxx/rocm/algorithm/max_element.hpp b/codegen/include/rocm-cxx/rocm/algorithm/max_element.hpp new file mode 100644 index 00000000000..f3c39a71fea --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/algorithm/max_element.hpp @@ -0,0 +1,25 @@ +#ifndef ROCM_GUARD_ROCM_ALGORITHM_MAX_ELEMENT_HPP +#define ROCM_GUARD_ROCM_ALGORITHM_MAX_ELEMENT_HPP + +#include +#include +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr Iterator max_element(Iterator first, Iterator last, Compare comp) +{ + return min_element(first, last, [&](auto&& a, auto&& b) { return comp(b, a); }); +} + +template +constexpr Iterator max_element(Iterator first, Iterator last) +{ + return max_element(first, last, less<>{}); +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ALGORITHM_MAX_ELEMENT_HPP diff --git a/codegen/include/rocm-cxx/rocm/algorithm/merge.hpp b/codegen/include/rocm-cxx/rocm/algorithm/merge.hpp new file mode 100644 index 00000000000..059da8a0c13 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/algorithm/merge.hpp @@ -0,0 +1,39 @@ +#ifndef ROCM_GUARD_ROCM_ALGORITHM_MERGE_HPP +#define ROCM_GUARD_ROCM_ALGORITHM_MERGE_HPP + +#include +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr OutputIterator merge(Iterator1 first1, + Iterator1 last1, + Iterator2 first2, + Iterator2 last2, + OutputIterator d_first, + Compare comp) +{ + for(; first1 != last1; ++d_first) + { + if(first2 == last2) + return copy(first1, last1, d_first); + + if(comp(*first2, *first1)) + { + *d_first = *first2; + ++first2; + } + else + { + *d_first = *first1; + ++first1; + } + } + return copy(first2, last2, d_first); +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ALGORITHM_MERGE_HPP diff --git a/codegen/include/rocm-cxx/rocm/algorithm/min_element.hpp b/codegen/include/rocm-cxx/rocm/algorithm/min_element.hpp new file mode 100644 index 00000000000..cad6cb3a4f7 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/algorithm/min_element.hpp @@ -0,0 +1,33 @@ +#ifndef ROCM_GUARD_ROCM_ALGORITHM_MIN_ELEMENT_HPP +#define ROCM_GUARD_ROCM_ALGORITHM_MIN_ELEMENT_HPP + +#include +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr Iterator min_element(Iterator first, Iterator last, Compare comp) +{ + if(first == last) + return last; + + Iterator smallest = first; + + while(++first != last) + if(comp(*first, *smallest)) + smallest = first; + + return smallest; +} + +template +constexpr Iterator min_element(Iterator first, Iterator last) +{ + return min_element(first, last, less<>{}); +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ALGORITHM_MIN_ELEMENT_HPP diff --git a/codegen/include/rocm-cxx/rocm/algorithm/none_of.hpp b/codegen/include/rocm-cxx/rocm/algorithm/none_of.hpp new file mode 100644 index 00000000000..c303f6f643f --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/algorithm/none_of.hpp @@ -0,0 +1,18 @@ +#ifndef ROCM_GUARD_ROCM_ALGORITHM_NONE_OF_HPP +#define ROCM_GUARD_ROCM_ALGORITHM_NONE_OF_HPP + +#include +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr bool none_of(InputIt first, InputIt last, UnaryPredicate p) +{ + return find_if(first, last, p) == last; +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ALGORITHM_NONE_OF_HPP diff --git a/codegen/include/rocm-cxx/rocm/algorithm/rotate.hpp b/codegen/include/rocm-cxx/rocm/algorithm/rotate.hpp new file mode 100644 index 00000000000..cb891803d26 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/algorithm/rotate.hpp @@ -0,0 +1,35 @@ +#ifndef ROCM_GUARD_ROCM_ALGORITHM_ROTATE_HPP +#define ROCM_GUARD_ROCM_ALGORITHM_ROTATE_HPP + +#include +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr Iterator rotate(Iterator first, Iterator middle, Iterator last) +{ + if(first == middle) + return last; + + if(middle == last) + return first; + + Iterator write = first; + Iterator next_read = first; + + for(Iterator read = middle; read != last; ++write, ++read) + { + if(write == next_read) + next_read = read; + iter_swap(write, read); + } + + rotate(write, next_read, last); + return write; +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ALGORITHM_ROTATE_HPP diff --git a/codegen/include/rocm-cxx/rocm/algorithm/search.hpp b/codegen/include/rocm-cxx/rocm/algorithm/search.hpp new file mode 100644 index 00000000000..9ca4c3b0d6a --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/algorithm/search.hpp @@ -0,0 +1,42 @@ +#ifndef ROCM_GUARD_ROCM_ALGORITHM_SEARCH_HPP +#define ROCM_GUARD_ROCM_ALGORITHM_SEARCH_HPP + +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr Iterator1 +search(Iterator1 first, Iterator1 last, Iterator2 s_first, Iterator2 s_last, BinaryPredicate pred) +{ + for(;; ++first) + { + Iterator1 it = first; + for(Iterator2 s_it = s_first;; ++it, ++s_it) + { + if(s_it == s_last) + { + return first; + } + if(it == last) + { + return last; + } + if(not pred(*it, *s_it)) + { + break; + } + } + } +} + +template +constexpr Iterator1 search(Iterator1 first, Iterator1 last, Iterator2 s_first, Iterator2 s_last) +{ + return search(first, last, s_first, s_last, [](auto&& x, auto&& y) { return x == y; }); +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ALGORITHM_SEARCH_HPP diff --git a/codegen/include/rocm-cxx/rocm/algorithm/sort.hpp b/codegen/include/rocm-cxx/rocm/algorithm/sort.hpp new file mode 100644 index 00000000000..47a74639bf2 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/algorithm/sort.hpp @@ -0,0 +1,32 @@ +#ifndef ROCM_GUARD_ROCM_ALGORITHM_SORT_HPP +#define ROCM_GUARD_ROCM_ALGORITHM_SORT_HPP + +#include +#include +#include +#include +#include +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr void sort(Iterator first, Iterator last, Compare comp) +{ + if(first == last) + return; + for(auto i = first; i != last - 1; ++i) + iter_swap(i, min_element(i, last, comp)); + ROCM_ASSERT(is_sorted(first, last, comp)); +} + +template +constexpr void sort(Iterator first, Iterator last) +{ + sort(first, last, less<>{}); +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ALGORITHM_SORT_HPP diff --git a/codegen/include/rocm-cxx/rocm/algorithm/stable_sort.hpp b/codegen/include/rocm-cxx/rocm/algorithm/stable_sort.hpp new file mode 100644 index 00000000000..92b25f1e086 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/algorithm/stable_sort.hpp @@ -0,0 +1,32 @@ +#ifndef ROCM_GUARD_ROCM_ALGORITHM_STABLE_SORT_HPP +#define ROCM_GUARD_ROCM_ALGORITHM_STABLE_SORT_HPP + +#include +#include +#include +#include +#include +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr void stable_sort(Iterator first, Iterator last, Compare comp) +{ + if(first == last) + return; + for(auto i = first; i != last; ++i) + rotate(upper_bound(first, i, *i, comp), i, i + 1); + ROCM_ASSERT(is_sorted(first, last, comp)); +} + +template +constexpr void stable_sort(Iterator first, Iterator last) +{ + stable_sort(first, last, less<>{}); +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ALGORITHM_STABLE_SORT_HPP diff --git a/codegen/include/rocm-cxx/rocm/algorithm/transform.hpp b/codegen/include/rocm-cxx/rocm/algorithm/transform.hpp new file mode 100644 index 00000000000..34317e28c82 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/algorithm/transform.hpp @@ -0,0 +1,31 @@ +#ifndef ROCM_GUARD_ROCM_ALGORITHM_TRANSFORM_HPP +#define ROCM_GUARD_ROCM_ALGORITHM_TRANSFORM_HPP + +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr OutputIterator +transform(Iterator first1, Iterator last1, OutputIterator out, UnaryOp unary_op) +{ + for(; first1 != last1; ++out, ++first1) + *out = unary_op(*first1); + + return out; +} + +template +constexpr OutputIterator transform( + Iterator1 first1, Iterator1 last1, Iterator2 first2, OutputIterator out, BinaryOp binary_op) +{ + for(; first1 != last1; ++out, ++first1, ++first2) + *out = binary_op(*first1, *first2); + + return out; +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ALGORITHM_TRANSFORM_HPP diff --git a/codegen/include/rocm-cxx/rocm/algorithm/upper_bound.hpp b/codegen/include/rocm-cxx/rocm/algorithm/upper_bound.hpp new file mode 100644 index 00000000000..914a64aac80 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/algorithm/upper_bound.hpp @@ -0,0 +1,42 @@ +#ifndef ROCM_GUARD_ROCM_ALGORITHM_UPPER_BOUND_HPP +#define ROCM_GUARD_ROCM_ALGORITHM_UPPER_BOUND_HPP + +#include +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr Iterator upper_bound(Iterator first, Iterator last, const T& value, Compare comp) +{ + auto count = last - first; + + while(count > 0) + { + // NOLINTNEXTLINE(readability-qualified-auto) + auto it = first; + auto step = count / 2; + it += step; + + if(not comp(value, *it)) + { + first = ++it; + count -= step + 1; + } + else + count = step; + } + + return first; +} + +template +constexpr Iterator upper_bound(Iterator first, Iterator last, const T& value) +{ + return upper_bound(first, last, value, less<>{}); +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ALGORITHM_UPPER_BOUND_HPP diff --git a/codegen/include/rocm-cxx/rocm/array.hpp b/codegen/include/rocm-cxx/rocm/array.hpp new file mode 100644 index 00000000000..e8dbb4e861a --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/array.hpp @@ -0,0 +1,218 @@ +#ifndef ROCM_GUARD_ROCM_ARRAY_HPP +#define ROCM_GUARD_ROCM_ARRAY_HPP + +#include +#include +#include +#include +#include +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +struct array +{ + using value_type = T; + using pointer = T*; + using const_pointer = const T*; + using reference = T&; + using const_reference = const T&; + using size_type = size_t; + using difference_type = ptrdiff_t; + using iterator = T*; + using const_iterator = const T*; + using reverse_iterator = rocm::reverse_iterator; + using const_reverse_iterator = rocm::reverse_iterator; + + T elems[N]; // NOLINT + + // fill + constexpr void fill(const T& u) + { + for(size_type i = 0; i < N; ++i) + elems[i] = u; + } + + // swap + constexpr void swap(array& other) noexcept + { + for(size_type i = 0; i < N; ++i) + rocm::swap(elems[i], other.elems[i]); + } + + // iterators + constexpr iterator begin() noexcept { return elems; } + constexpr const_iterator begin() const noexcept { return elems; } + constexpr iterator end() noexcept { return elems + N; } + constexpr const_iterator end() const noexcept { return elems + N; } + + constexpr reverse_iterator rbegin() noexcept { return reverse_iterator(end()); } + constexpr const_reverse_iterator rbegin() const noexcept + { + return const_reverse_iterator(end()); + } + constexpr reverse_iterator rend() noexcept { return reverse_iterator(begin()); } + constexpr const_reverse_iterator rend() const noexcept + { + return const_reverse_iterator(begin()); + } + + constexpr const_iterator cbegin() const noexcept { return begin(); } + constexpr const_iterator cend() const noexcept { return end(); } + constexpr const_reverse_iterator crbegin() const noexcept { return rbegin(); } + constexpr const_reverse_iterator crend() const noexcept { return rend(); } + + // capacity + constexpr size_type size() const noexcept { return N; } + constexpr size_type max_size() const noexcept { return N; } + constexpr bool empty() const noexcept { return N == 0; } + + // element access + constexpr reference operator[](size_type n) { return elems[n]; } + constexpr const_reference operator[](size_type n) const { return elems[n]; } + + constexpr reference at(size_type n) { return elems[n]; } + constexpr const_reference at(size_type n) const { return elems[n]; } + + constexpr reference front() { return elems[0]; } + constexpr const_reference front() const { return elems[0]; } + constexpr reference back() { return elems[N - 1]; } + constexpr const_reference back() const { return elems[N - 1]; } + + constexpr T* data() noexcept { return elems; } + constexpr const T* data() const noexcept { return elems; } + + // comparison operators + friend constexpr bool operator==(const array& x, const array& y) + { + for(size_type i = 0; i < N; ++i) + { + if(not(x.elems[i] == y.elems[i])) + return false; + } + return true; + } + + friend constexpr bool operator!=(const array& x, const array& y) { return not(x == y); } + + friend constexpr bool operator<(const array& x, const array& y) + { + for(size_type i = 0; i < N; ++i) + { + if(x.elems[i] < y.elems[i]) + return true; + if(y.elems[i] < x.elems[i]) + return false; + } + return false; + } + + friend constexpr bool operator>(const array& x, const array& y) { return y < x; } + + friend constexpr bool operator<=(const array& x, const array& y) { return not(y < x); } + + friend constexpr bool operator>=(const array& x, const array& y) { return not(x < y); } +}; + +// zero-size specialization +template +struct array +{ + using value_type = T; + using pointer = T*; + using const_pointer = const T*; + using reference = T&; + using const_reference = const T&; + using size_type = size_t; + using difference_type = ptrdiff_t; + using iterator = T*; + using const_iterator = const T*; + using reverse_iterator = rocm::reverse_iterator; + using const_reverse_iterator = rocm::reverse_iterator; + + constexpr void fill(const T&) {} + constexpr void swap(array&) noexcept {} + + constexpr iterator begin() noexcept { return nullptr; } + constexpr const_iterator begin() const noexcept { return nullptr; } + constexpr iterator end() noexcept { return nullptr; } + constexpr const_iterator end() const noexcept { return nullptr; } + + constexpr reverse_iterator rbegin() noexcept { return reverse_iterator(end()); } + constexpr const_reverse_iterator rbegin() const noexcept + { + return const_reverse_iterator(end()); + } + constexpr reverse_iterator rend() noexcept { return reverse_iterator(begin()); } + constexpr const_reverse_iterator rend() const noexcept + { + return const_reverse_iterator(begin()); + } + + constexpr const_iterator cbegin() const noexcept { return begin(); } + constexpr const_iterator cend() const noexcept { return end(); } + constexpr const_reverse_iterator crbegin() const noexcept { return rbegin(); } + constexpr const_reverse_iterator crend() const noexcept { return rend(); } + + constexpr size_type size() const noexcept { return 0; } + constexpr size_type max_size() const noexcept { return 0; } + constexpr bool empty() const noexcept { return true; } + + constexpr T* data() noexcept { return nullptr; } + constexpr const T* data() const noexcept { return nullptr; } + + friend constexpr bool operator==(const array&, const array&) { return true; } + friend constexpr bool operator!=(const array&, const array&) { return false; } + friend constexpr bool operator<(const array&, const array&) { return false; } + friend constexpr bool operator>(const array&, const array&) { return false; } + friend constexpr bool operator<=(const array&, const array&) { return true; } + friend constexpr bool operator>=(const array&, const array&) { return true; } +}; + +// CTAD +template +ROCM_HIP_HOST_DEVICE array(T, U...) -> array; + +// swap +template +constexpr void swap(array& x, array& y) noexcept(noexcept(x.swap(y))) +{ + x.swap(y); +} + +// to_array +namespace detail { + +template +constexpr array, N> to_array_lvalue(T (&a)[N], rocm::index_sequence) // NOLINT +{ + return {{a[Is]...}}; +} + +template +constexpr array, N> to_array_rvalue(T (&&a)[N], + rocm::index_sequence) // NOLINT +{ + return {{static_cast(a[Is])...}}; +} + +} // namespace detail + +template +constexpr array, N> to_array(T (&a)[N]) // NOLINT +{ + return detail::to_array_lvalue(a, rocm::make_index_sequence{}); +} + +template +constexpr array, N> to_array(T (&&a)[N]) // NOLINT +{ + return detail::to_array_rvalue(static_cast(a), + rocm::make_index_sequence{}); // NOLINT +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ARRAY_HPP diff --git a/codegen/include/rocm-cxx/rocm/assert.hpp b/codegen/include/rocm-cxx/rocm/assert.hpp new file mode 100644 index 00000000000..4ef82c09772 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/assert.hpp @@ -0,0 +1,152 @@ +#ifndef ROCM_GUARD_ROCM_ASSERT_HPP +#define ROCM_GUARD_ROCM_ASSERT_HPP + +#include +#include + +#ifndef __HIPCC_RTC__ +#include +#include +#endif + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +// Workaround hip's broken abort on device code +#ifdef __HIP_DEVICE_COMPILE__ +// NOLINTNEXTLINE +#define ROCM_HIP_NORETURN +#else +// NOLINTNEXTLINE +#define ROCM_HIP_NORETURN [[noreturn]] +#endif + +namespace debug { +struct swallow +{ + template + constexpr swallow(Ts&&...) + { + } +}; + +template +struct print_buffer +{ + char buffer[N + 1] = {0}; + char* pos = buffer; + + constexpr void append(char c) + { + if(c == 0) + return; + if(pos < buffer + N) + { + *pos = c; + pos++; + } + } + static constexpr void reverse(char* first, char* last) + { + if(first == last) + return; + last--; + while(first < last) + { + char tmp = *first; + *first = *last; + *last = tmp; + first++; + last--; + } + } + + template + constexpr void append(T i) + { + if(i < 0) + { + append('-'); + i = -i; + } + if(i == 0) + { + append('0'); + return; + } + char* start = pos; + while(i != 0) + { + char c = (i % 10) + '0'; + append(c); + i = i / 10; + } + reverse(start, pos); + } + + constexpr void append(const char* str) + { + if(str == nullptr) + return; + int i = 512; + while(*str != 0 and i > 0) + { + append(*str); + str++; + i--; + } + } + + template + constexpr void append(const char (&array)[M]) + { + for(int i = 0; i < M; i++) + append(array[i]); + } +}; + +template +ROCM_HIP_HOST_DEVICE void print(const Ts&... xs) +{ + print_buffer<1024> buffer; + swallow{(buffer.append(xs), 0)...}; + printf("%s", buffer.buffer); +} + +} // namespace debug + +// noreturn cannot be used on this function because abort in hip is broken +template +ROCM_HIP_NORETURN inline ROCM_HIP_HOST_DEVICE void +assert_fail(const T1& assertion, const T2& file, const T3& line, const T4& function) +{ + // printf is broken on hip with more than one argument, so use a simple print functions instead + debug::print(file, ":", line, ": ", function, ": assertion '", assertion, "' failed.\n"); + // printf("%s:%s: %s: assertion '%s' failed.\n", file, line, function, assertion); + abort(); +} + +// NOLINTNEXTLINE +#define ROCM_ASSERT_FAIL(cond, ...) \ + ((cond) ? void(0) : [](auto&&... private_migraphx_xs) { \ + assert_fail(private_migraphx_xs...); \ + }(__VA_ARGS__)) + +// NOLINTNEXTLINE +#define ROCM_CHECK(cond) ROCM_ASSERT_FAIL(cond, #cond, __FILE__, __LINE__, __PRETTY_FUNCTION__) + +#ifdef ROCM_DEBUG +// NOLINTNEXTLINE +#define ROCM_ASSERT ROCM_CHECK +#define ROCM_ASSUME ROCM_CHECK +#define ROCM_UNREACHABLE() ROCM_ASSERT(false) +#else +// NOLINTNEXTLINE +#define ROCM_ASSUME __builtin_assume +#define ROCM_UNREACHABLE __builtin_unreachable +#define ROCM_ASSERT(cond) +#endif + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ASSERT_HPP diff --git a/codegen/include/rocm-cxx/rocm/bit.hpp b/codegen/include/rocm-cxx/rocm/bit.hpp new file mode 100644 index 00000000000..3a5fca9419f --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/bit.hpp @@ -0,0 +1,107 @@ +#ifndef ROCM_GUARD_ROCM_BIT_HPP +#define ROCM_GUARD_ROCM_BIT_HPP + +#include +#include +#include +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template {} and + rocm::is_trivially_copyable{} and sizeof(To) == sizeof(From))> +constexpr To bit_cast(From fr) noexcept +{ + return __builtin_bit_cast(To, fr); +} + +template {})> +constexpr int countl_zero(T x) noexcept +{ + return __builtin_clzg(x, numeric_limits::digits); +} + +template {})> +constexpr int countl_one(T x) noexcept +{ + return countl_zero(T(~x)); +} + +template {})> +constexpr int countr_zero(T x) noexcept +{ + return __builtin_ctzg(x, numeric_limits::digits); +} + +template {})> +constexpr int countr_one(T x) noexcept +{ + return countr_zero(T(~x)); +} + +template {})> +constexpr int popcount(T x) noexcept +{ + return __builtin_popcountg(x); +} + +template {})> +constexpr int bit_width(T x) noexcept +{ + return numeric_limits::digits - countl_zero(x); +} + +template {})> +constexpr T bit_floor(T x) noexcept +{ + if(x != 0) + return T(1) << (bit_width(x) - 1); + return 0; +} + +template {})> +constexpr T bit_ceil(T x) noexcept +{ + if(x <= 1) + return 1; + auto e = bit_width(T(x - 1)); + ROCM_ASSERT(e < numeric_limits::digits); + if constexpr(is_same{}) + return T(1) << e; + constexpr int offset_for_ub = numeric_limits::digits - numeric_limits::digits; + return T(1u << (e + offset_for_ub) >> offset_for_ub); +} + +template {})> +constexpr bool has_single_bit(T x) noexcept +{ + return popcount(x) == 1; +} + +template {})> +constexpr T rotl(T x, int s) noexcept +{ + const int n = numeric_limits::digits; + int r = s % n; + + if(r == 0) + return x; + + if(r > 0) + return (x << r) | (x >> (n - r)); + + return (x >> -r) | (x << (n + r)); +} + +template {})> +constexpr T rotr(T x, int s) noexcept +{ + return rotl(x, -s); +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_BIT_HPP diff --git a/codegen/include/rocm-cxx/rocm/config.hpp b/codegen/include/rocm-cxx/rocm/config.hpp new file mode 100644 index 00000000000..ef6daa8f373 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/config.hpp @@ -0,0 +1,43 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +#ifndef ROCM_GUARD_ROCM_CONFIG_HPP +#define ROCM_GUARD_ROCM_CONFIG_HPP + +namespace rocm { + +#if !defined(ROCM_USE_CLANG_TIDY) +#define ROCM_INLINE_NS pocket_version_1 +#endif + +#ifdef __HIP_DEVICE_COMPILE__ +// NOLINTNEXTLINE +#define ROCM_HIP_HOST_DEVICE __host__ __device__ +#else +// NOLINTNEXTLINE +#define ROCM_HIP_HOST_DEVICE +#endif + +} // namespace rocm +#endif // ROCM_GUARD_ROCM_CONFIG_HPP diff --git a/codegen/include/rocm-cxx/rocm/functional.hpp b/codegen/include/rocm-cxx/rocm/functional.hpp new file mode 100644 index 00000000000..8e12f3bbe21 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/functional.hpp @@ -0,0 +1,12 @@ +#ifndef ROCM_GUARD_ROCM_FUNCTIONAL_HPP +#define ROCM_GUARD_ROCM_FUNCTIONAL_HPP + +#include +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_FUNCTIONAL_HPP diff --git a/codegen/include/rocm-cxx/rocm/functional/operations.hpp b/codegen/include/rocm-cxx/rocm/functional/operations.hpp new file mode 100644 index 00000000000..01b9f8e0a1c --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/functional/operations.hpp @@ -0,0 +1,70 @@ +#ifndef ROCM_GUARD_FUNCTIONAL_OPERATIONS_HPP +#define ROCM_GUARD_FUNCTIONAL_OPERATIONS_HPP + +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +#define ROCM_FUNCTIONAL_BINARY_OP(name, op, result) \ + template \ + struct name \ + { \ + constexpr result operator()(const T& x, const T& y) const { return x op y; } \ + }; \ + template <> \ + struct name \ + { \ + using is_transparent = void; \ + template \ + constexpr auto operator()(T&& x, U&& y) const \ + noexcept(noexcept(static_cast(x) op static_cast(y))) \ + -> decltype(static_cast(x) op static_cast(y)) \ + { \ + return static_cast(x) op static_cast(y); \ + } \ + }; + +#define ROCM_FUNCTIONAL_UNARY_OP(name, op, result) \ + template \ + struct name \ + { \ + constexpr result operator()(const T& x) const { return op x; } \ + }; \ + template <> \ + struct name \ + { \ + using is_transparent = void; \ + template \ + constexpr auto operator()(T&& x) const noexcept(noexcept(op static_cast(x))) \ + -> decltype(op static_cast(x)) \ + { \ + return op static_cast(x); \ + } \ + }; + +ROCM_FUNCTIONAL_BINARY_OP(plus, +, T) +ROCM_FUNCTIONAL_BINARY_OP(minus, -, T) +ROCM_FUNCTIONAL_BINARY_OP(multiplies, *, T) +ROCM_FUNCTIONAL_BINARY_OP(divides, /, T) +ROCM_FUNCTIONAL_BINARY_OP(modulus, %, T) +ROCM_FUNCTIONAL_BINARY_OP(bit_and, &, T) +ROCM_FUNCTIONAL_BINARY_OP(bit_or, |, T) +ROCM_FUNCTIONAL_BINARY_OP(bit_xor, ^, T) + +ROCM_FUNCTIONAL_BINARY_OP(equal_to, ==, bool) +ROCM_FUNCTIONAL_BINARY_OP(not_equal_to, !=, bool) +ROCM_FUNCTIONAL_BINARY_OP(greater, >, bool) +ROCM_FUNCTIONAL_BINARY_OP(less, <, bool) +ROCM_FUNCTIONAL_BINARY_OP(greater_equal, >=, bool) +ROCM_FUNCTIONAL_BINARY_OP(less_equal, <=, bool) +ROCM_FUNCTIONAL_BINARY_OP(logical_and, and, bool) +ROCM_FUNCTIONAL_BINARY_OP(logical_or, or, bool) + +ROCM_FUNCTIONAL_UNARY_OP(negate, -, T) +ROCM_FUNCTIONAL_UNARY_OP(logical_not, not, bool) +ROCM_FUNCTIONAL_UNARY_OP(bit_not, ~, T) + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // MIGRAPHX_GUARD_FUNCTIONAL_OPERATIONS_HPP diff --git a/codegen/include/rocm-cxx/rocm/integral_constant.hpp b/codegen/include/rocm-cxx/rocm/integral_constant.hpp new file mode 100644 index 00000000000..c41901ae452 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/integral_constant.hpp @@ -0,0 +1,94 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +#ifndef ROCM_GUARD_ROCM_INTEGRAL_CONSTANT_HPP +#define ROCM_GUARD_ROCM_INTEGRAL_CONSTANT_HPP + +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +struct integral_constant +{ + static constexpr T value = V; + using value_type = T; + using type = integral_constant; + constexpr operator value_type() const noexcept { return value; } + constexpr value_type operator()() const noexcept { return value; } + static constexpr type to() { return {}; } +}; + +// NOLINTNEXTLINE +#define ROCM_INTEGRAL_CONSTANT_BINARY_OP(op) \ + template \ + constexpr inline integral_constant operator op( \ + integral_constant, integral_constant) noexcept \ + { \ + return {}; \ + } + +// NOLINTNEXTLINE +#define ROCM_INTEGRAL_CONSTANT_UNARY_OP(op) \ + template \ + constexpr inline integral_constant operator op( \ + integral_constant) noexcept \ + { \ + return {}; \ + } + +ROCM_INTEGRAL_CONSTANT_BINARY_OP(+) +ROCM_INTEGRAL_CONSTANT_BINARY_OP(-) +ROCM_INTEGRAL_CONSTANT_BINARY_OP(*) +ROCM_INTEGRAL_CONSTANT_BINARY_OP(/) +ROCM_INTEGRAL_CONSTANT_BINARY_OP(%) +ROCM_INTEGRAL_CONSTANT_BINARY_OP(>>) +ROCM_INTEGRAL_CONSTANT_BINARY_OP(<<) +ROCM_INTEGRAL_CONSTANT_BINARY_OP(>) +ROCM_INTEGRAL_CONSTANT_BINARY_OP(<) +ROCM_INTEGRAL_CONSTANT_BINARY_OP(<=) +ROCM_INTEGRAL_CONSTANT_BINARY_OP(>=) +ROCM_INTEGRAL_CONSTANT_BINARY_OP(==) +ROCM_INTEGRAL_CONSTANT_BINARY_OP(!=) +ROCM_INTEGRAL_CONSTANT_BINARY_OP(&) +ROCM_INTEGRAL_CONSTANT_BINARY_OP(^) +ROCM_INTEGRAL_CONSTANT_BINARY_OP(|) +ROCM_INTEGRAL_CONSTANT_BINARY_OP(and) +ROCM_INTEGRAL_CONSTANT_BINARY_OP(or) + +ROCM_INTEGRAL_CONSTANT_UNARY_OP(not) +ROCM_INTEGRAL_CONSTANT_UNARY_OP(~) +ROCM_INTEGRAL_CONSTANT_UNARY_OP(+) +ROCM_INTEGRAL_CONSTANT_UNARY_OP(-) + +template +using bool_constant = integral_constant; + +using true_type = bool_constant; +using false_type = bool_constant; + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_INTEGRAL_CONSTANT_HPP diff --git a/codegen/include/rocm-cxx/rocm/iterator.hpp b/codegen/include/rocm-cxx/rocm/iterator.hpp new file mode 100644 index 00000000000..d7d1f3923dd --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/iterator.hpp @@ -0,0 +1,13 @@ +#ifndef ROCM_GUARD_ROCM_ITERATOR_HPP +#define ROCM_GUARD_ROCM_ITERATOR_HPP + +#include +#include +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_ITERATOR_HPP diff --git a/codegen/include/rocm-cxx/rocm/iterator/iterator_traits.hpp b/codegen/include/rocm-cxx/rocm/iterator/iterator_traits.hpp new file mode 100644 index 00000000000..87fa45ec51c --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/iterator/iterator_traits.hpp @@ -0,0 +1,53 @@ +#ifndef ROCM_GUARD_ITERATOR_ITERATOR_TRAITS_HPP +#define ROCM_GUARD_ITERATOR_ITERATOR_TRAITS_HPP + +#include +#include +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +struct input_iterator_tag +{ +}; + +struct output_iterator_tag +{ +}; + +struct forward_iterator_tag : input_iterator_tag +{ +}; + +struct bidirectional_iterator_tag : forward_iterator_tag +{ +}; + +struct random_access_iterator_tag : bidirectional_iterator_tag +{ +}; + +template +struct iterator_traits +{ + using difference_type = typename Iterator::difference_type; + using value_type = typename Iterator::value_type; + using pointer = typename Iterator::pointer; + using reference = typename Iterator::reference; + using iterator_category = typename Iterator::iterator_category; +}; + +template +struct iterator_traits +{ + using difference_type = ptrdiff_t; + using value_type = remove_cv_t; + using pointer = T*; + using reference = T&; + using iterator_category = random_access_iterator_tag; +}; + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ITERATOR_ITERATOR_TRAITS_HPP diff --git a/codegen/include/rocm-cxx/rocm/iterator/reverse_iterator.hpp b/codegen/include/rocm-cxx/rocm/iterator/reverse_iterator.hpp new file mode 100644 index 00000000000..782a1e284e5 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/iterator/reverse_iterator.hpp @@ -0,0 +1,167 @@ +#ifndef ROCM_GUARD_ITERATOR_REVERSE_ITERATOR_HPP +#define ROCM_GUARD_ITERATOR_REVERSE_ITERATOR_HPP + +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +struct reverse_iterator +{ + using iterator_type = Iterator; + using difference_type = typename iterator_traits::difference_type; + using value_type = typename iterator_traits::value_type; + using pointer = typename iterator_traits::pointer; + using reference = typename iterator_traits::reference; + using iterator_category = typename iterator_traits::iterator_category; + + iterator_type current; + + constexpr reverse_iterator() : current() {} + + constexpr explicit reverse_iterator(iterator_type it) : current(it) {} + + template + constexpr reverse_iterator(const reverse_iterator& other) : current(other.base()) + { + } + + template + constexpr reverse_iterator& operator=(const reverse_iterator& other) + { + current = other.base(); + return *this; + } + + constexpr iterator_type base() const { return current; } + + constexpr reference operator*() const + { + iterator_type tmp = current; + --tmp; + return *tmp; + } + + constexpr pointer operator->() const + { + iterator_type tmp = current; + --tmp; + return tmp; + } + + constexpr reference operator[](difference_type n) const { return *(*this + n); } + + constexpr reverse_iterator& operator++() + { + --current; + return *this; + } + + constexpr reverse_iterator& operator--() + { + ++current; + return *this; + } + + constexpr reverse_iterator operator++(int) // NOLINT + { + reverse_iterator tmp = *this; + --current; + return tmp; + } + + constexpr reverse_iterator operator--(int) // NOLINT + { + reverse_iterator tmp = *this; + ++current; + return tmp; + } + + constexpr reverse_iterator& operator+=(difference_type n) + { + current -= n; + return *this; + } + + constexpr reverse_iterator& operator-=(difference_type n) + { + current += n; + return *this; + } + + friend constexpr reverse_iterator operator+(reverse_iterator it, difference_type n) + { + return it += n; + } + + friend constexpr reverse_iterator operator+(difference_type n, reverse_iterator it) + { + return it += n; + } + + friend constexpr reverse_iterator operator-(reverse_iterator it, difference_type n) + { + return it -= n; + } + + template + friend constexpr auto operator-(const reverse_iterator& lhs, + const reverse_iterator& rhs) + -> decltype(rhs.base() - lhs.base()) + { + return rhs.base() - lhs.base(); + } + + template + friend constexpr bool operator==(const reverse_iterator& lhs, + const reverse_iterator& rhs) + { + return lhs.base() == rhs.base(); + } + + template + friend constexpr bool operator!=(const reverse_iterator& lhs, + const reverse_iterator& rhs) + { + return lhs.base() != rhs.base(); + } + + template + friend constexpr bool operator<(const reverse_iterator& lhs, + const reverse_iterator& rhs) + { + return lhs.base() > rhs.base(); + } + + template + friend constexpr bool operator>(const reverse_iterator& lhs, + const reverse_iterator& rhs) + { + return lhs.base() < rhs.base(); + } + + template + friend constexpr bool operator<=(const reverse_iterator& lhs, + const reverse_iterator& rhs) + { + return lhs.base() >= rhs.base(); + } + + template + friend constexpr bool operator>=(const reverse_iterator& lhs, + const reverse_iterator& rhs) + { + return lhs.base() <= rhs.base(); + } +}; + +template +constexpr reverse_iterator make_reverse_iterator(Iterator it) +{ + return reverse_iterator(it); +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ITERATOR_REVERSE_ITERATOR_HPP diff --git a/codegen/include/rocm-cxx/rocm/limits.hpp b/codegen/include/rocm-cxx/rocm/limits.hpp new file mode 100644 index 00000000000..b2076409aed --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/limits.hpp @@ -0,0 +1,307 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +#ifndef ROCM_GUARD_ROCM_LIMITS_HPP +#define ROCM_GUARD_ROCM_LIMITS_HPP + +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +enum float_round_style +{ + round_indeterminate = -1, + round_toward_zero = 0, + round_to_nearest = 1, + round_toward_infinity = 2, + round_toward_neg_infinity = 3 +}; + +namespace detail { + +constexpr unsigned long int_max(unsigned long n) +{ + // Note, left shift cannot be used to get the maximum value of int64_type or + // uint64_type because it is undefined behavior to left shift 64 bits for + // these types + if(n == 8) + return -1; + return (1ul << (n * 8)) - 1; +} + +template +struct numeric_limits_integer +{ + static constexpr const bool is_specialized = true; + + static constexpr const bool is_signed = T(-1) < T(0); + static constexpr const int digits = + static_cast(sizeof(T) * 8 - static_cast(is_signed)); + static constexpr const int digits10 = digits * 3 / 10; + static constexpr const int max_digits10 = 0; + static constexpr T min() noexcept + { + if constexpr(is_signed) + return -max() - 1; + return 0; + } + static constexpr T max() noexcept + { + if constexpr(is_signed) + return int_max(sizeof(T)) / 2; + return int_max(sizeof(T)); + } + static constexpr T lowest() noexcept { return min(); } + + static constexpr const bool is_integer = true; + static constexpr const bool is_exact = true; + static constexpr const int radix = 2; + static constexpr T epsilon() noexcept { return T(0); } + static constexpr T round_error() noexcept { return T(0); } + + static constexpr const int min_exponent = 0; + static constexpr const int min_exponent10 = 0; + static constexpr const int max_exponent = 0; + static constexpr const int max_exponent10 = 0; + + static constexpr const bool has_infinity = false; + // NOLINTNEXTLINE(readability-identifier-naming) + static constexpr const bool has_quiet_NaN = false; + // NOLINTNEXTLINE(readability-identifier-naming) + static constexpr const bool has_signaling_NaN = false; + static constexpr T infinity() noexcept { return T(0); } + // NOLINTNEXTLINE(readability-identifier-naming) + static constexpr T quiet_NaN() noexcept { return T(0); } + // NOLINTNEXTLINE(readability-identifier-naming) + static constexpr T signaling_NaN() noexcept { return T(0); } + static constexpr T denorm_min() noexcept { return T(0); } + + static constexpr const bool is_iec559 = false; + static constexpr const bool is_bounded = true; + static constexpr const bool is_modulo = not is_signed; + + static constexpr const bool traps = false; + static constexpr const bool tinyness_before = false; + static constexpr const float_round_style round_style = round_toward_zero; +}; + +template +struct numeric_limits_fp_mixin : Base +{ + static constexpr const bool is_specialized = true; + + static constexpr const bool is_signed = true; + static constexpr const int max_digits10 = 2 + (Base::digits * 30103L) / 100000L; + static constexpr const bool is_integer = false; + static constexpr const bool is_exact = false; + static constexpr const int radix = __FLT_RADIX__; + static constexpr typename Base::type round_error() noexcept { return 0.5F; } + static constexpr typename Base::type lowest() noexcept { return -Base::max(); } + static constexpr const bool has_infinity = true; + // NOLINTNEXTLINE(readability-identifier-naming) + static constexpr const bool has_quiet_NaN = true; + // NOLINTNEXTLINE(readability-identifier-naming) + static constexpr const bool has_signaling_NaN = true; + static constexpr const bool is_iec559 = true; + static constexpr const bool is_bounded = true; + static constexpr const bool is_modulo = false; + + static constexpr const bool traps = false; + static constexpr const bool tinyness_before = false; + static constexpr const float_round_style round_style = round_to_nearest; +}; + +struct numeric_limits_float +{ + using type = float; + static constexpr const int digits = __FLT_MANT_DIG__; + static constexpr const int digits10 = __FLT_DIG__; + static constexpr type min() noexcept { return __FLT_MIN__; } + static constexpr type max() noexcept { return __FLT_MAX__; } + + static constexpr type epsilon() noexcept { return __FLT_EPSILON__; } + + static constexpr const int min_exponent = __FLT_MIN_EXP__; + static constexpr const int min_exponent10 = __FLT_MIN_10_EXP__; + static constexpr const int max_exponent = __FLT_MAX_EXP__; + static constexpr const int max_exponent10 = __FLT_MAX_10_EXP__; + + static constexpr type infinity() noexcept { return __builtin_huge_valf(); } + // NOLINTNEXTLINE(readability-identifier-naming) + static constexpr type quiet_NaN() noexcept { return __builtin_nanf(""); } + // NOLINTNEXTLINE(readability-identifier-naming) + static constexpr type signaling_NaN() noexcept { return __builtin_nansf(""); } + static constexpr type denorm_min() noexcept { return __FLT_DENORM_MIN__; } +}; + +struct numeric_limits_double +{ + using type = double; + + static constexpr const int digits = __DBL_MANT_DIG__; + static constexpr const int digits10 = __DBL_DIG__; + static constexpr type min() noexcept { return __DBL_MIN__; } + static constexpr type max() noexcept { return __DBL_MAX__; } + + static constexpr const int radix = __FLT_RADIX__; + static constexpr type epsilon() noexcept { return __DBL_EPSILON__; } + + static constexpr const int min_exponent = __DBL_MIN_EXP__; + static constexpr const int min_exponent10 = __DBL_MIN_10_EXP__; + static constexpr const int max_exponent = __DBL_MAX_EXP__; + static constexpr const int max_exponent10 = __DBL_MAX_10_EXP__; + + static constexpr type infinity() noexcept { return __builtin_huge_val(); } + // NOLINTNEXTLINE(readability-identifier-naming) + static constexpr type quiet_NaN() noexcept { return __builtin_nan(""); } + // NOLINTNEXTLINE(readability-identifier-naming) + static constexpr type signaling_NaN() noexcept { return __builtin_nans(""); } + static constexpr type denorm_min() noexcept { return __DBL_DENORM_MIN__; } +}; + +#ifdef __FLT16_MAX__ +struct numeric_limits_fp16 +{ + using type = _Float16; + + static constexpr const int digits = __FLT16_MANT_DIG__; + static constexpr const int digits10 = __FLT16_DIG__; + static constexpr type min() noexcept { return __FLT16_MIN__; } + static constexpr type max() noexcept { return __FLT16_MAX__; } + + static constexpr const int radix = __FLT_RADIX__; + static constexpr type epsilon() noexcept { return __FLT16_EPSILON__; } + + static constexpr const int min_exponent = __FLT16_MIN_EXP__; + static constexpr const int min_exponent10 = __FLT16_MIN_10_EXP__; + static constexpr const int max_exponent = __FLT16_MAX_EXP__; + static constexpr const int max_exponent10 = __FLT16_MAX_10_EXP__; + + static constexpr type infinity() noexcept { return __builtin_huge_valf16(); } + // NOLINTNEXTLINE(readability-identifier-naming) + static constexpr type quiet_NaN() noexcept { return __builtin_nanf16(""); } + // NOLINTNEXTLINE(readability-identifier-naming) + static constexpr type signaling_NaN() noexcept { return __builtin_nansf16(""); } + static constexpr type denorm_min() noexcept { return __FLT16_DENORM_MIN__; } +}; +#endif + +} // namespace detail + +template +struct numeric_limits +{ + static constexpr const bool is_specialized = false; + static constexpr T min() noexcept { return T(); } + static constexpr T max() noexcept { return T(); } + static constexpr T lowest() noexcept { return T(); } + + static constexpr const int digits = 0; + static constexpr const int digits10 = 0; + static constexpr const int max_digits10 = 0; + static constexpr const bool is_signed = false; + static constexpr const bool is_integer = false; + static constexpr const bool is_exact = false; + static constexpr const int radix = 0; + static constexpr T epsilon() noexcept { return T(); } + static constexpr T round_error() noexcept { return T(); } + + static constexpr const int min_exponent = 0; + static constexpr const int min_exponent10 = 0; + static constexpr const int max_exponent = 0; + static constexpr const int max_exponent10 = 0; + + static constexpr const bool has_infinity = false; + // NOLINTNEXTLINE(readability-identifier-naming) + static constexpr const bool has_quiet_NaN = false; + // NOLINTNEXTLINE(readability-identifier-naming) + static constexpr const bool has_signaling_NaN = false; + static constexpr T infinity() noexcept { return T(); } + // NOLINTNEXTLINE(readability-identifier-naming) + static constexpr T quiet_NaN() noexcept { return T(); } + // NOLINTNEXTLINE(readability-identifier-naming) + static constexpr T signaling_NaN() noexcept { return T(); } + static constexpr T denorm_min() noexcept { return T(); } + + static constexpr const bool is_iec559 = false; + static constexpr const bool is_bounded = false; + static constexpr const bool is_modulo = false; + + static constexpr const bool traps = false; + static constexpr const bool tinyness_before = false; + static constexpr const float_round_style round_style = round_toward_zero; +}; + +template +struct numeric_limits : numeric_limits +{ +}; + +template +struct numeric_limits : numeric_limits +{ +}; + +template +struct numeric_limits : numeric_limits +{ +}; + +#define ROCM_DEFINE_NUMERIC_LIMITS_INT(T) \ + template <> \ + struct numeric_limits : detail::numeric_limits_integer \ + { \ + } + +ROCM_DEFINE_NUMERIC_LIMITS_INT(char); +ROCM_DEFINE_NUMERIC_LIMITS_INT(signed char); +ROCM_DEFINE_NUMERIC_LIMITS_INT(unsigned char); +ROCM_DEFINE_NUMERIC_LIMITS_INT(wchar_t); +ROCM_DEFINE_NUMERIC_LIMITS_INT(char16_t); +ROCM_DEFINE_NUMERIC_LIMITS_INT(char32_t); +ROCM_DEFINE_NUMERIC_LIMITS_INT(short); +ROCM_DEFINE_NUMERIC_LIMITS_INT(unsigned short); +ROCM_DEFINE_NUMERIC_LIMITS_INT(int); +ROCM_DEFINE_NUMERIC_LIMITS_INT(unsigned int); +ROCM_DEFINE_NUMERIC_LIMITS_INT(long); +ROCM_DEFINE_NUMERIC_LIMITS_INT(unsigned long); +ROCM_DEFINE_NUMERIC_LIMITS_INT(long long); +ROCM_DEFINE_NUMERIC_LIMITS_INT(unsigned long long); + +#define ROCM_DEFINE_NUMERIC_LIMITS_FLOAT(T, base) \ + template <> \ + struct numeric_limits : detail::numeric_limits_fp_mixin \ + { \ + } + +ROCM_DEFINE_NUMERIC_LIMITS_FLOAT(float, numeric_limits_float); +ROCM_DEFINE_NUMERIC_LIMITS_FLOAT(double, numeric_limits_double); +#ifdef __FLT16_MAX__ +ROCM_DEFINE_NUMERIC_LIMITS_FLOAT(_Float16, numeric_limits_fp16); +#endif + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_LIMITS_HPP diff --git a/codegen/include/rocm-cxx/rocm/stddef.hpp b/codegen/include/rocm-cxx/rocm/stddef.hpp new file mode 100644 index 00000000000..a62466a7136 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/stddef.hpp @@ -0,0 +1,16 @@ +#ifndef ROCM_GUARD_ROCM_STDDEF_HPP +#define ROCM_GUARD_ROCM_STDDEF_HPP + +#include +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +using nullptr_t = decltype(nullptr); +using size_t = uint64_t; +using ptrdiff_t = int64_t; + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_STDDEF_HPP diff --git a/codegen/include/rocm-cxx/rocm/stdint.hpp b/codegen/include/rocm-cxx/rocm/stdint.hpp new file mode 100644 index 00000000000..8cd3d075842 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/stdint.hpp @@ -0,0 +1,65 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +#ifndef ROCM_GUARD_ROCM_STDINT_HPP +#define ROCM_GUARD_ROCM_STDINT_HPP + +#include + +#ifndef __HIPCC_RTC__ +#include +#endif + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +#ifdef __HIPCC_RTC__ +using int8_t = __hip_int8_t; +using uint8_t = __hip_uint8_t; +using int16_t = __hip_int16_t; +using uint16_t = __hip_uint16_t; +using int32_t = __hip_int32_t; +using uint32_t = __hip_uint32_t; +using int64_t = __hip_int64_t; +using uint64_t = __hip_uint64_t; +#else +using int8_t = std::int8_t; +using uint8_t = std::uint8_t; +using int16_t = std::int16_t; +using uint16_t = std::uint16_t; +using int32_t = std::int32_t; +using uint32_t = std::uint32_t; +using int64_t = std::int64_t; +using uint64_t = std::uint64_t; +#endif + +using intptr_t = int64_t; +using uintptr_t = uint64_t; +using intmax_t = int64_t; +using uintmax_t = uint64_t; +using intmax_t = int64_t; + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_STDINT_HPP diff --git a/codegen/include/rocm-cxx/rocm/type_traits.hpp b/codegen/include/rocm-cxx/rocm/type_traits.hpp new file mode 100644 index 00000000000..3313aa53bbe --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/type_traits.hpp @@ -0,0 +1,317 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +#ifndef ROCM_GUARD_ROCM_TYPE_TRAITS_HPP +#define ROCM_GUARD_ROCM_TYPE_TRAITS_HPP + +#include +#include +#include +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +struct type_identity +{ + using type = T; +}; + +template +using type_identity_t = typename type_identity::type; + +template +struct enable_if +{ +}; + +template +struct enable_if +{ + using type = T; +}; + +template +using enable_if_t = typename enable_if::type; + +template +struct conditional +{ + using type = T; +}; + +template +struct conditional +{ + using type = F; +}; + +template +using conditional_t = typename conditional::type; + +template +struct add_cv +{ + using type = const volatile T; +}; +template +using add_cv_t = typename add_cv::type; + +template +struct add_const +{ + using type = const T; +}; +template +using add_const_t = typename add_const::type; + +template +struct add_volatile +{ + using type = volatile T; +}; +template +using add_volatile_t = typename add_volatile::type; + +template +struct remove_cv +{ + using type = T; +}; + +template +struct remove_cv : remove_cv +{ +}; + +template +struct remove_cv : remove_cv +{ +}; + +template +struct remove_cv : remove_cv +{ +}; + +template +using remove_cv_t = typename remove_cv::type; + +template +struct remove_const +{ + typedef T type; +}; +template +struct remove_const +{ + typedef T type; +}; +template +using remove_const_t = typename remove_const::type; + +template +struct remove_reference +{ + using type = T; +}; +template +struct remove_reference +{ + using type = T; +}; +template +struct remove_reference +{ + using type = T; +}; + +template +using remove_reference_t = typename remove_reference::type; + +template +struct remove_cvref : remove_cv> +{ +}; +template +using remove_cvref_t = typename remove_cvref::type; + +template +struct add_pointer : type_identity*> +{ +}; + +template +using add_pointer_t = typename add_pointer::type; + +template +struct remove_pointer +{ + using type = T; +}; +template +struct remove_pointer +{ + using type = T; +}; +template +struct remove_pointer +{ + using type = T; +}; +template +struct remove_pointer +{ + using type = T; +}; +template +struct remove_pointer +{ + using type = T; +}; + +template +using remove_pointer_t = typename remove_pointer::type; + +template +struct common_type; + +template +struct common_type +{ + using type = typename common_type::type; +}; + +template +struct common_type +{ + using type = remove_cv_t() : declval())>>; +}; + +template +struct common_type +{ + using type = typename common_type::type, Us...>::type; +}; + +template +using common_type_t = typename common_type::type; + +template +using void_t = void; + +// NOLINTNEXTLINE +#define ROCM_BUILTIN_TYPE_TRAIT1(name) \ + template \ + struct name : bool_constant<__##name(T)> \ + { \ + }; \ + template \ + inline constexpr bool name##_v = __##name(T) + +// NOLINTNEXTLINE +#define ROCM_BUILTIN_TYPE_TRAIT2(name) \ + template \ + struct name : bool_constant<__##name(T, U)> \ + { \ + }; \ + template \ + inline constexpr bool name##_v = __##name(T, U) + +// NOLINTNEXTLINE +#define ROCM_BUILTIN_TYPE_TRAITN(name) \ + template \ + struct name : bool_constant<__##name(Ts...)> \ + { \ + }; \ + template \ + inline constexpr bool name##_v = __##name(Ts...) + +// ROCM_BUILTIN_TYPE_TRAIT1(is_destructible); +// ROCM_BUILTIN_TYPE_TRAIT1(is_nothrow_destructible); +// ROCM_BUILTIN_TYPE_TRAIT1(is_scalar); +// ROCM_BUILTIN_TYPE_TRAIT1(is_signed); +// ROCM_BUILTIN_TYPE_TRAIT1(is_void); +ROCM_BUILTIN_TYPE_TRAIT1(is_abstract); +ROCM_BUILTIN_TYPE_TRAIT1(is_aggregate); +ROCM_BUILTIN_TYPE_TRAIT1(is_arithmetic); +ROCM_BUILTIN_TYPE_TRAIT1(is_array); +ROCM_BUILTIN_TYPE_TRAIT1(is_class); +ROCM_BUILTIN_TYPE_TRAIT1(is_compound); +ROCM_BUILTIN_TYPE_TRAIT1(is_const); +ROCM_BUILTIN_TYPE_TRAIT1(is_empty); +ROCM_BUILTIN_TYPE_TRAIT1(is_enum); +ROCM_BUILTIN_TYPE_TRAIT1(is_final); +ROCM_BUILTIN_TYPE_TRAIT1(is_floating_point); +ROCM_BUILTIN_TYPE_TRAIT1(is_function); +ROCM_BUILTIN_TYPE_TRAIT1(is_fundamental); +ROCM_BUILTIN_TYPE_TRAIT1(is_integral); +ROCM_BUILTIN_TYPE_TRAIT1(is_literal_type); +ROCM_BUILTIN_TYPE_TRAIT1(is_lvalue_reference); +ROCM_BUILTIN_TYPE_TRAIT1(is_member_function_pointer); +ROCM_BUILTIN_TYPE_TRAIT1(is_member_object_pointer); +ROCM_BUILTIN_TYPE_TRAIT1(is_member_pointer); +ROCM_BUILTIN_TYPE_TRAIT1(is_object); +ROCM_BUILTIN_TYPE_TRAIT1(is_pod); +ROCM_BUILTIN_TYPE_TRAIT1(is_pointer); +ROCM_BUILTIN_TYPE_TRAIT1(is_polymorphic); +ROCM_BUILTIN_TYPE_TRAIT1(is_reference); +ROCM_BUILTIN_TYPE_TRAIT1(is_rvalue_reference); +ROCM_BUILTIN_TYPE_TRAIT1(is_standard_layout); +ROCM_BUILTIN_TYPE_TRAIT1(is_trivial); +ROCM_BUILTIN_TYPE_TRAIT1(is_trivially_copyable); +ROCM_BUILTIN_TYPE_TRAIT1(is_trivially_destructible); +ROCM_BUILTIN_TYPE_TRAIT1(is_union); +ROCM_BUILTIN_TYPE_TRAIT1(is_unsigned); +ROCM_BUILTIN_TYPE_TRAIT1(is_volatile); +ROCM_BUILTIN_TYPE_TRAIT2(is_assignable); +ROCM_BUILTIN_TYPE_TRAIT2(is_base_of); +ROCM_BUILTIN_TYPE_TRAIT2(is_convertible); +ROCM_BUILTIN_TYPE_TRAIT2(is_nothrow_assignable); +ROCM_BUILTIN_TYPE_TRAIT2(is_same); +ROCM_BUILTIN_TYPE_TRAIT2(is_trivially_assignable); +ROCM_BUILTIN_TYPE_TRAITN(is_constructible); +ROCM_BUILTIN_TYPE_TRAITN(is_nothrow_constructible); +ROCM_BUILTIN_TYPE_TRAITN(is_trivially_constructible); + +template +struct is_void : is_same> +{ +}; +template +inline constexpr bool is_void_v = is_void::value; + +template +struct is_null_pointer : is_same> +{ +}; +template +inline constexpr bool is_null_pointer_v = is_null_pointer::value; + +#define ROCM_REQUIRES(...) class = enable_if_t<__VA_ARGS__> + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_TYPE_TRAITS_HPP diff --git a/codegen/include/rocm-cxx/rocm/utility.hpp b/codegen/include/rocm-cxx/rocm/utility.hpp new file mode 100644 index 00000000000..4aeb02d0a45 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/utility.hpp @@ -0,0 +1,16 @@ +#ifndef ROCM_GUARD_ROCM_UTILITY_HPP +#define ROCM_GUARD_ROCM_UTILITY_HPP + +#include +#include +#include +#include +#include +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_UTILITY_HPP diff --git a/codegen/include/rocm-cxx/rocm/utility/declval.hpp b/codegen/include/rocm-cxx/rocm/utility/declval.hpp new file mode 100644 index 00000000000..303da3defd3 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/utility/declval.hpp @@ -0,0 +1,44 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +#ifndef ROCM_GUARD_ROCM_DECLVAL_HPP +#define ROCM_GUARD_ROCM_DECLVAL_HPP + +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +U private_declval(int); + +template +T private_declval(long); + +template +auto declval() noexcept -> decltype(private_declval(0)); + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_ROCM_DECLVAL_HPP diff --git a/codegen/include/rocm-cxx/rocm/utility/forward.hpp b/codegen/include/rocm-cxx/rocm/utility/forward.hpp new file mode 100644 index 00000000000..b6c07494479 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/utility/forward.hpp @@ -0,0 +1,25 @@ +#ifndef ROCM_GUARD_UTILITY_FORWARD_HPP +#define ROCM_GUARD_UTILITY_FORWARD_HPP + +#include +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr T&& forward(remove_reference_t& x) noexcept +{ + return static_cast(x); +} + +template +constexpr T&& forward(remove_reference_t&& x) noexcept +{ + static_assert(not is_lvalue_reference{}, "can not forward an rvalue as an lvalue"); + return static_cast(x); +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_UTILITY_FORWARD_HPP diff --git a/codegen/include/rocm-cxx/rocm/utility/integer_sequence.hpp b/codegen/include/rocm-cxx/rocm/utility/integer_sequence.hpp new file mode 100644 index 00000000000..1ff52d35a9b --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/utility/integer_sequence.hpp @@ -0,0 +1,32 @@ +#ifndef ROCM_GUARD_UTILITY_INTEGER_SEQUENCE_HPP +#define ROCM_GUARD_UTILITY_INTEGER_SEQUENCE_HPP + +#include +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +struct integer_sequence +{ + using value_type = T; + + static constexpr size_t size() noexcept { return sizeof...(Ints); } +}; + +template +using index_sequence = integer_sequence; + +template +using make_integer_sequence = __make_integer_seq; + +template +using make_index_sequence = make_integer_sequence; + +template +using index_sequence_for = make_index_sequence; + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_UTILITY_INTEGER_SEQUENCE_HPP diff --git a/codegen/include/rocm-cxx/rocm/utility/move.hpp b/codegen/include/rocm-cxx/rocm/utility/move.hpp new file mode 100644 index 00000000000..f99c6dacaae --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/utility/move.hpp @@ -0,0 +1,18 @@ +#ifndef ROCM_GUARD_UTILITY_MOVE_HPP +#define ROCM_GUARD_UTILITY_MOVE_HPP + +#include +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr remove_reference_t&& move(T&& x) noexcept +{ + return static_cast&&>(x); +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_UTILITY_MOVE_HPP diff --git a/codegen/include/rocm-cxx/rocm/utility/swap.hpp b/codegen/include/rocm-cxx/rocm/utility/swap.hpp new file mode 100644 index 00000000000..afa6228e972 --- /dev/null +++ b/codegen/include/rocm-cxx/rocm/utility/swap.hpp @@ -0,0 +1,19 @@ +#ifndef ROCM_GUARD_UTILITY_SWAP_HPP +#define ROCM_GUARD_UTILITY_SWAP_HPP + +#include + +namespace rocm { +inline namespace ROCM_INLINE_NS { + +template +constexpr void swap(T& a, T& b) noexcept +{ + T tmp = static_cast(a); + a = static_cast(b); + b = static_cast(tmp); +} + +} // namespace ROCM_INLINE_NS +} // namespace rocm +#endif // ROCM_GUARD_UTILITY_SWAP_HPP diff --git a/codegen/src/device_fmha_fwd.cpp b/codegen/src/device_fmha_fwd.cpp new file mode 100644 index 00000000000..d0c07b63f3a --- /dev/null +++ b/codegen/src/device_fmha_fwd.cpp @@ -0,0 +1,42 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/host/device_fmha_fwd/problem.hpp" +#include "ck/host/device_fmha_fwd/operation.hpp" +#include + +namespace ck { +namespace host { +namespace device_fmha_fwd { + +// Based on factories defined in fmha_fwd.py +bool IsSupportedArch(const std::string& arch) +{ + // Match Python's get_factory() logic using prefix matching + if(arch.find("gfx950") == 0) + return true; + if(arch.find("gfx9") == 0) + return true; + if(arch.find("gfx12") == 0) + return true; + return false; +} + +std::string Problem::GetIncludeHeader() const { return "ck/host/device_fmha_fwd/wrapper.hpp"; } + +std::vector Problem::GetSolutions(const std::string& arch) const +{ + if(!IsSupportedArch(arch)) + return {}; + + auto ops = Operation::CreateOperations(*this, arch); + std::vector result; + std::transform(ops.begin(), ops.end(), std::back_inserter(result), [](const auto& op) { + return op.ToSolution(); + }); + return result; +} + +} // namespace device_fmha_fwd +} // namespace host +} // namespace ck diff --git a/codegen/src/device_fmha_fwd_operation.cpp b/codegen/src/device_fmha_fwd_operation.cpp new file mode 100644 index 00000000000..76d021354d5 --- /dev/null +++ b/codegen/src/device_fmha_fwd_operation.cpp @@ -0,0 +1,416 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/host/device_fmha_fwd/operation.hpp" +#include "ck/host/device_fmha_fwd/problem.hpp" +#include "ck/host/stringutils.hpp" +#include +#include +#include + +namespace ck { +namespace host { +namespace device_fmha_fwd { + +static const char* const FmhaFwdWrapperTemplate = + "ck_tile::FmhaFwdWrapper<${DataType}, " + "${BM0}, ${BN0}, ${BK0}, ${BN1}, ${BK1}, ${BK0Max}, " + "${RM0}, ${RN0}, ${RK0}, ${RM1}, ${RN1}, ${RK1}, " + "${WM0}, ${WN0}, ${WK0}, ${WM1}, ${WN1}, ${WK1}, " + "${IsCausal}, ${IsVRowMajor}, ${HasBias}, " + "${PadM}, ${PadN}, ${PadK}, ${PadO}, " + "ck_tile::FmhaPipelineTag::${PipelineTag}>"; + +static bool IsGfx950(const std::string& arch) { return arch.find("gfx950") == 0; } +static bool IsGfx12(const std::string& arch) { return arch.find("gfx12") == 0; } + +using TileMap = std::map, std::vector>; + +// gfx9 fp16/bf16 tile configurations +// +// Constraints that must be satisfied: +// - rn0 = rk0 = rn1 = rk1 = 1 (only M-dimension warp distribution supported) +// - rm0 == rm1 (BlockGemm requires identical thread buffer sizes between GEMM0/GEMM1) +// - bk0max >= 2 * bk0 (k0_loops >= 2 required for correct pipelining) +// - bk0 >= wk0 (block K must be at least warp K; for fp16 min wk0 is 16) +// - bn1 = hdim_v (output head dimension processed per block) +// - bk1 = 32 (fixed for softmax/attention score reduction pipelining) +// - (wm0, wn0, wk0) and (wm1, wn1, wk1) must be valid MFMA sizes for the dtype +// - rm0=8 not supported when bn1 is not power-of-2 (V tensor distribution alignment) +// +// Valid fp16 MFMA sizes: (32,32,16), (16,16,16), (16,16,32), (4,64,16), (64,4,16) +// However, not all are usable in this kernel: +// - (64,4,16), (4,64,16): warp_gemm_dispatcher has no template specialization +// - (32,32,8): produces invalid results (likely internal kernel issue) +// - (16,16,32): requires bk0 >= 2*wk0 (bk0 >= 64), only usable when bk0max >= 128 +// +// clang-format off +static const TileMap gfx9_fp16_tiles = { + // bm0, bn0, bk0, bn1, bk1,bk0max,rm0,rn0,rk0,rm1,rn1,rk1, wm0,wn0,wk0, wm1,wn1,wk1 + {{32, 32}, {{128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16}, + { 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}, + { 64, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16}, + { 32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16}, + { 16, 32, 16, 32, 32, 32, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16}, + {128, 64, 16, 32, 32, 32, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16}}}, + // + {{64, 64}, {{128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16}, + { 64, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16}, + { 32, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16}, + { 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}, + {128, 64, 32, 64, 32, 64, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16}, + { 16, 64, 32, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16}}}, + // + {{80, 96}, {{128, 128, 16, 96, 32, 80, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16}, + { 16, 128, 16, 96, 32, 80, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16}, + { 32, 128, 16, 96, 32, 80, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16}, + { 64, 128, 16, 96, 32, 80, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16}, + { 64, 128, 16, 96, 32, 80, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}}}, + // + {{96, 128}, {{128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16}, + { 16, 128, 32, 128, 32, 96, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16}, + { 32, 128, 32, 128, 32, 96, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16}, + { 64, 128, 32, 128, 32, 96, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16}, + { 64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}, + {128, 128, 32, 128, 32, 96, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16}}}, + // + {{128, 128}, {{ 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16}, + {128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16}, + {128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16}, + // MFMA 16x16x16 variants + { 32, 128, 32, 128, 32, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16}, + { 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}, + {128, 128, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16}, + // MFMA 32x32x16 variants + { 64, 128, 32, 128, 32, 128, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16}, + // MFMA 16x16x32 for GEMM0, 16x16x16 for GEMM1 (wk1=32 produces invalid results) + { 32, 128, 64, 128, 32, 128, 2, 1, 1, 2, 1, 1, 16, 16, 32, 16, 16, 16}, + { 64, 128, 64, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16}, + {128, 128, 64, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 32, 16, 16, 16}}}, + // + {{192, 128}, {{128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16}, + { 64, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16}, + // MFMA 16x16x16 variants (rm0=1,2,4,8) + { 16, 128, 32, 128, 32, 192, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16}, + { 32, 128, 32, 128, 32, 192, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16}, + { 64, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}, + {128, 128, 32, 128, 32, 192, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16}, + // MFMA 32x32x16 variants (rm0=2,8) + { 64, 128, 32, 128, 32, 192, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16}, + {256, 128, 32, 128, 32, 192, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16}, + // MFMA 16x16x32 for GEMM0, 16x16x16 for GEMM1 (bk0=32, rm0=2,8) + { 32, 128, 32, 128, 32, 192, 2, 1, 1, 2, 1, 1, 16, 16, 32, 16, 16, 16}, + {128, 128, 32, 128, 32, 192, 8, 1, 1, 8, 1, 1, 16, 16, 32, 16, 16, 16}, + // MFMA 16x16x32 for GEMM0, 16x16x16 for GEMM1 (bk0=64, rm0=2,4,8) + { 32, 128, 64, 128, 32, 192, 2, 1, 1, 2, 1, 1, 16, 16, 32, 16, 16, 16}, + { 64, 128, 64, 128, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16}, + {128, 128, 64, 128, 32, 192, 8, 1, 1, 8, 1, 1, 16, 16, 32, 16, 16, 16}}}, + // + {{192, 192}, {// MFMA 32x32x16 (original + rm0=2,8) + {128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16}, + { 64, 128, 32, 192, 32, 192, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16}, + {256, 128, 32, 192, 32, 192, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16}, + // MFMA 16x16x16 (rm0=1,2,4,8) + { 16, 128, 32, 192, 32, 192, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16}, + { 32, 128, 32, 192, 32, 192, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16}, + { 64, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}, + {128, 128, 32, 192, 32, 192, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16}, + // MFMA 16x16x32 for GEMM0, 16x16x16 for GEMM1 (bk0=32, rm0=2,8) + { 32, 128, 32, 192, 32, 192, 2, 1, 1, 2, 1, 1, 16, 16, 32, 16, 16, 16}, + {128, 128, 32, 192, 32, 192, 8, 1, 1, 8, 1, 1, 16, 16, 32, 16, 16, 16}}}, + // + {{256, 256}, {// MFMA 32x32x16 (original + rm0=2,8) + {128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16}, + { 64, 128, 32, 256, 32, 256, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16}, + {256, 128, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16}, + // MFMA 16x16x16 (rm0=1,2,4,8) + { 16, 128, 32, 256, 32, 256, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16}, + { 32, 128, 32, 256, 32, 256, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16}, + { 64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}, + {128, 128, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16}, + // MFMA 16x16x32 for GEMM0, 16x16x16 for GEMM1 (bk0=32, rm0=2,8) + { 32, 128, 32, 256, 32, 256, 2, 1, 1, 2, 1, 1, 16, 16, 32, 16, 16, 16}, + {128, 128, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 16, 16, 32, 16, 16, 16}}}, +}; + +// gfx12 fp16/bf16 tiles from KernelComponentFactoryGfx12::get_hdim_tile_size_dict +static const TileMap gfx12_fp16_tiles = { + // bm0, bn0, bk0, bn1, bk1,bk0max,rm0,rn0,rk0,rm1,rn1,rk1, wm0,wn0,wk0, wm1,wn1,wk1 + {{32, 32}, {{ 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}}}, + {{64, 64}, {{ 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}}}, + {{128, 128}, {{ 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}}}, + {{192, 128}, {{ 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}}}, + {{256, 256}, {{ 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}}}, +}; +// clang-format on + +HdimBucketResult +GetTileConfigsForHdim(const std::string& arch, DataType dtype, std::size_t K, std::size_t O) +{ + HdimBucketResult result; + + if(dtype != DataType::Half) + return result; + + const TileMap& tile_map = IsGfx12(arch) ? gfx12_fp16_tiles : gfx9_fp16_tiles; + + for(const auto& [key, tiles] : tile_map) + { + if(K <= key.first && O <= key.second) + { + result.bucket_hdim = key.first; + result.bucket_hdim_v = key.second; + result.tiles = tiles; + return result; + } + } + + return result; +} + +struct PipelineConfig +{ + std::string name; + bool pad_m; + bool pad_n; + bool pad_k; + bool pad_o; +}; + +static std::vector GetPipelinesGfx12() +{ + // QR pipeline is handled separately in CreateOperations with exact padding + return {}; +} + +static std::vector +GetPipelinesGfx9(std::size_t bucket_hdim, std::size_t bucket_hdim_v, bool has_bias) +{ + // QR pipeline is handled separately in CreateOperations with exact padding + std::vector configs; + + // QR_ASYNC pipeline requires pad_m=true, pad_k=true, pad_o=true (enforced by static_assert + // in BlockFmhaPipelineQRKSVSAsync). Only pad_n is variable, giving us two variants. + if(!has_bias) + { + configs.push_back({"qr_async", true, false, true, true}); // pad_n=false + configs.push_back({"qr_async", true, true, true, true}); // pad_n=true + } + + // Note: qr_async_trload requires gfx950+ (uses buffer_load_dwordx3/x4 instructions) + + return configs; +} + +static std::vector +GetPipelinesGfx950(std::size_t bucket_hdim, std::size_t bucket_hdim_v, bool has_bias) +{ + auto configs = GetPipelinesGfx9(bucket_hdim, bucket_hdim_v, has_bias); + + bool is_hdim_256 = (bucket_hdim == 256 && bucket_hdim_v == 256); + if(!is_hdim_256 && !has_bias) + { + configs.push_back({"qr_async_trload", false, false, false, false}); + configs.push_back({"qr_async_trload", false, false, true, true}); + } + + return configs; +} + +static std::vector GetPipelineConfigs(const std::string& arch, + std::size_t bucket_hdim, + std::size_t bucket_hdim_v, + bool has_bias) +{ + if(IsGfx12(arch)) + return GetPipelinesGfx12(); + if(IsGfx950(arch)) + return GetPipelinesGfx950(bucket_hdim, bucket_hdim_v, has_bias); + return GetPipelinesGfx9(bucket_hdim, bucket_hdim_v, has_bias); +} + +static bool IsPaddingCompatible(const PipelineConfig& config, + const Problem& prob, + const TileConfig& tile, + std::size_t bucket_hdim, + std::size_t bucket_hdim_v) +{ + bool needs_pad_m = (prob.M % tile.bm0 != 0); + bool needs_pad_n = (prob.N % tile.bn0 != 0); + bool needs_pad_k = (prob.K != bucket_hdim); + bool needs_pad_o = (prob.O != bucket_hdim_v); + + // +------------+----------+------------+ + // | config.pad | needs_pad| compatible | + // +------------+----------+------------+ + // | false | false | true | + // | false | true | false | + // | true | false | true | + // | true | true | true | + // +------------+----------+------------+ + // + return (config.pad_m || !needs_pad_m) && (config.pad_n || !needs_pad_n) && + (config.pad_k || !needs_pad_k) && (config.pad_o || !needs_pad_o); +} + +std::vector Operation::CreateOperations(const Problem& prob, const std::string& arch) +{ + std::vector result; + + auto bucket = GetTileConfigsForHdim(arch, prob.dtype, prob.K, prob.O); + auto pipelines = + GetPipelineConfigs(arch, bucket.bucket_hdim, bucket.bucket_hdim_v, prob.has_bias); + + for(const auto& tile : bucket.tiles) + { + // Compute exact padding needs for this tile + bool needs_pad_m = (prob.M % tile.bm0 != 0); + bool needs_pad_n = (prob.N % tile.bn0 != 0); + bool needs_pad_k = (prob.K != bucket.bucket_hdim); + bool needs_pad_o = (prob.O != bucket.bucket_hdim_v); + + // QR pipeline: create one operation with exact padding + { + Operation op; + op.tile = tile; + op.pipeline = "qr"; + op.is_causal = prob.is_causal; + op.is_v_rowmajor = prob.is_v_rowmajor; + op.has_bias = prob.has_bias; + op.dtype = prob.dtype; + op.pad_m = needs_pad_m; + op.pad_n = needs_pad_n; + op.pad_k = needs_pad_k; + op.pad_o = needs_pad_o; + result.push_back(op); + } + + // Async pipelines: use predefined configs with filters + for(const auto& pipeline : pipelines) + { + if(prob.dtype == DataType::Half && (prob.K % 8 != 0 || prob.O % 8 != 0)) + continue; + // Single-warp configs (rm0=1) produce incorrect results with async pipelines + if(tile.rm0 == 1) + continue; + // (96, 128) bucket: rm0 >= 4 with pad_n=false produces incorrect results + if(bucket.bucket_hdim == 96 && bucket.bucket_hdim_v == 128) + { + if(!pipeline.pad_n && tile.rm0 >= 4) + continue; + } + // (128, 128) bucket filters for async pipelines: + // - bn0=64, bk1=16 config produces invalid results + // - bk0=64 configs (MFMA 16x16x32) produce invalid results + if(bucket.bucket_hdim == 128 && bucket.bucket_hdim_v == 128) + { + if(tile.bn0 == 64 && tile.bk1 == 16) + continue; + if(tile.bk0 == 64) + continue; + } + // (192, 128) bucket filters for async pipelines + if(bucket.bucket_hdim == 192 && bucket.bucket_hdim_v == 128) + { + // bk0=64 configs produce invalid results + if(tile.bk0 == 64) + continue; + // pad_n=false fails for wm0=32 (MFMA 32x32x16) or rm0>=4 + if(!pipeline.pad_n && (tile.wm0 == 32 || tile.rm0 >= 4)) + continue; + } + // (192, 192) bucket filters for async pipelines + if(bucket.bucket_hdim == 192 && bucket.bucket_hdim_v == 192) + { + // rm0=8 with wm0=32 always fails (even with pad_n=true) + if(tile.rm0 == 8 && tile.wm0 == 32) + continue; + // pad_n=false fails except for (rm0=2, wm0=32) and (rm0=8, wk0=16) + if(!pipeline.pad_n) + { + bool is_valid = (tile.rm0 == 2 && tile.wm0 == 32) || + (tile.rm0 == 8 && tile.wk0 == 16); + if(!is_valid) + continue; + } + } + + if(!IsPaddingCompatible(pipeline, prob, tile, bucket.bucket_hdim, bucket.bucket_hdim_v)) + continue; + + Operation op; + op.tile = tile; + op.pipeline = pipeline.name; + op.is_causal = prob.is_causal; + op.is_v_rowmajor = prob.is_v_rowmajor; + op.has_bias = prob.has_bias; + op.dtype = prob.dtype; + op.pad_m = pipeline.pad_m; + op.pad_n = pipeline.pad_n; + op.pad_k = pipeline.pad_k; + op.pad_o = pipeline.pad_o; + result.push_back(op); + } + } + + return result; +} + +static std::string ToDataTypeString(DataType dtype) +{ + switch(dtype) + { + case DataType::Half: return "ck_tile::fp16_t"; + case DataType::Float: return "float"; + default: return "ck_tile::fp16_t"; + } +} + +Solution Operation::ToSolution() const +{ + std::unordered_map values = { + {"DataType", ToDataTypeString(dtype)}, + + {"BM0", std::to_string(tile.bm0)}, + {"BN0", std::to_string(tile.bn0)}, + {"BK0", std::to_string(tile.bk0)}, + {"BN1", std::to_string(tile.bn1)}, + {"BK1", std::to_string(tile.bk1)}, + {"BK0Max", std::to_string(tile.bk0max)}, + + {"RM0", std::to_string(tile.rm0)}, + {"RN0", std::to_string(tile.rn0)}, + {"RK0", std::to_string(tile.rk0)}, + + {"RM1", std::to_string(tile.rm1)}, + {"RN1", std::to_string(tile.rn1)}, + {"RK1", std::to_string(tile.rk1)}, + + {"WM0", std::to_string(tile.wm0)}, + {"WN0", std::to_string(tile.wn0)}, + {"WK0", std::to_string(tile.wk0)}, + + {"WM1", std::to_string(tile.wm1)}, + {"WN1", std::to_string(tile.wn1)}, + {"WK1", std::to_string(tile.wk1)}, + + {"IsCausal", is_causal ? "true" : "false"}, + {"IsVRowMajor", is_v_rowmajor ? "true" : "false"}, + {"HasBias", has_bias ? "true" : "false"}, + + {"PadM", pad_m ? "true" : "false"}, + {"PadN", pad_n ? "true" : "false"}, + {"PadK", pad_k ? "true" : "false"}, + {"PadO", pad_o ? "true" : "false"}, + + {"PipelineTag", + pipeline == "qr_async_trload" ? "QR_ASYNC_TRLOAD" + : (pipeline == "qr_async" ? "QR_ASYNC" : "QR")}, + }; + + return Solution{InterpolateString(FmhaFwdWrapperTemplate, values), std::move(values)}; +} + +} // namespace device_fmha_fwd +} // namespace host +} // namespace ck diff --git a/codegen/src/headers.cpp b/codegen/src/headers.cpp index 6c79bd9fd81..a7acf46384c 100644 --- a/codegen/src/headers.cpp +++ b/codegen/src/headers.cpp @@ -3,6 +3,8 @@ #include "ck/host/headers.hpp" #include "ck_headers.hpp" +#include "ck_tile_headers.hpp" +#include "ck_codegen_headers.hpp" namespace ck { namespace host { @@ -19,5 +21,13 @@ std::unordered_map GetHeaders() return headers; } +std::unordered_map GetTileHeaders() +{ + auto headers = ck_tile_headers(); + auto codegen_hdrs = ck_codegen_headers(); + headers.insert(codegen_hdrs.begin(), codegen_hdrs.end()); + return headers; +} + } // namespace host } // namespace ck diff --git a/codegen/test/CMakeLists.txt b/codegen/test/CMakeLists.txt index ad9743ff83d..6d7256d47c2 100644 --- a/codegen/test/CMakeLists.txt +++ b/codegen/test/CMakeLists.txt @@ -1,6 +1,13 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT +option(USE_HIPRTC_FOR_CODEGEN_TESTS "Whether to enable hipRTC for codegen tests." ON) + +find_package(hip REQUIRED) +if(USE_HIPRTC_FOR_CODEGEN_TESTS) + find_package(hiprtc REQUIRED) +endif() + list(APPEND CMAKE_PREFIX_PATH /opt/rocm) add_subdirectory(rtc) file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp) diff --git a/codegen/test/fmha_fwd.cpp b/codegen/test/fmha_fwd.cpp new file mode 100644 index 00000000000..deda5b62ac9 --- /dev/null +++ b/codegen/test/fmha_fwd.cpp @@ -0,0 +1,1220 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/host/device_fmha_fwd/problem.hpp" +#include "ck/host/device_fmha_fwd/operation.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include "ck/host/headers.hpp" +#include "common.hpp" +#include "fmha_fwd_ref.hpp" +#include +#include +#include +#include +#include +#include +#include + +using ck::host::Solution; +using ck::host::device_fmha_fwd::cpu_attention_ref; +using ck::host::device_fmha_fwd::FmhaFwdRefParams; +using ck::host::device_fmha_fwd::Problem; + +using half = _Float16; + +const std::string kernel_template = R"__ck__( +#include <${include}> + +using KernelType = ${template}; + +extern "C" __launch_bounds__(KernelType::Kernel::kBlockSize, KernelType::Kernel::kBlockPerCu) +__global__ void f(const ${dtype}* q, const ${dtype}* k, const ${dtype}* v, const ${dtype}* bias, ${dtype}* o) { + + constexpr float scale_s = ${scale_s}; + + using Kernel = KernelType; + + constexpr auto desc = Kernel::make_descriptor( + // Q + ck_tile::make_tuple(${batch}, ${nhead}, ${m}, ${k}), + ck_tile::make_tuple(${q_stride_batch}, ${q_stride_nhead}, ${q_stride_m}), + // K + ck_tile::make_tuple(${batch}, ${nhead}, ${n}, ${k}), + ck_tile::make_tuple(${k_stride_batch}, ${k_stride_nhead}, ${k_stride_n}), + // V + ck_tile::make_tuple(${batch}, ${nhead}, ${n}, ${o}), + ck_tile::make_tuple(${v_stride_batch}, ${v_stride_nhead}, ${v_stride_n}), + // O + ck_tile::make_tuple(${batch}, ${nhead}, ${m}, ${o}), + ck_tile::make_tuple(${o_stride_batch}, ${o_stride_nhead}, ${o_stride_m}), + // Bias + ck_tile::make_tuple(${batch}, ${nhead}, ${m}, ${n}), + ck_tile::make_tuple(${bias_stride_batch}, ${bias_stride_nhead}, ${bias_stride_m})); + + static_assert(desc.IsValid(), "Invalid FMHA kernel configuration"); + + Kernel::Run(desc, scale_s, q, k, v, bias, o); +} +)__ck__"; + +std::string make_kernel_source(const Problem& prob, + const Solution& solution, + const FmhaFwdRefParams& ref_params) +{ + auto template_string = solution.ToTemplateString(); + //std::cout << template_string << std::endl; + return ck::host::InterpolateString( + kernel_template, + {{"include", prob.GetIncludeHeader()}, + {"template", solution.ToTemplateString()}, + {"dtype", "ck_tile::fp16_t"}, + {"batch", std::to_string(ref_params.batch)}, + {"nhead", std::to_string(ref_params.nhead)}, + {"m", std::to_string(ref_params.M)}, + {"n", std::to_string(ref_params.N)}, + {"k", std::to_string(ref_params.K)}, + {"o", std::to_string(ref_params.O)}, + {"q_stride_batch", std::to_string(ref_params.q_stride_batch)}, + {"q_stride_nhead", std::to_string(ref_params.q_stride_nhead)}, + {"q_stride_m", std::to_string(ref_params.q_stride_m)}, + {"k_stride_batch", std::to_string(ref_params.k_stride_batch)}, + {"k_stride_nhead", std::to_string(ref_params.k_stride_nhead)}, + {"k_stride_n", std::to_string(ref_params.k_stride_n)}, + {"v_stride_batch", std::to_string(ref_params.v_stride_batch)}, + {"v_stride_nhead", std::to_string(ref_params.v_stride_nhead)}, + {"v_stride_n", std::to_string(ref_params.v_stride_n)}, + {"o_stride_batch", std::to_string(ref_params.o_stride_batch)}, + {"o_stride_nhead", std::to_string(ref_params.o_stride_nhead)}, + {"o_stride_m", std::to_string(ref_params.o_stride_m)}, + {"bias_stride_batch", std::to_string(ref_params.bias_stride_batch)}, + {"bias_stride_nhead", std::to_string(ref_params.bias_stride_nhead)}, + {"bias_stride_m", std::to_string(ref_params.bias_stride_m)}, + {"scale_s", std::to_string(ref_params.scale_s) + "f"}}); +} + +FmhaFwdRefParams make_ref_params(const Problem& prob, float scale_s) +{ + FmhaFwdRefParams p; + p.batch = prob.batch; + p.nhead = prob.nhead; + p.M = prob.M; + p.N = prob.N; + p.K = prob.K; + p.O = prob.O; + p.scale_s = scale_s; + + // Q - [batch, nhead, M, K] + p.q_stride_m = prob.K; + p.q_stride_nhead = prob.M * prob.K; + p.q_stride_batch = prob.nhead * prob.M * prob.K; + + // K - [batch, nhead, N, K] + p.k_stride_n = prob.K; + p.k_stride_nhead = prob.N * prob.K; + p.k_stride_batch = prob.nhead * prob.N * prob.K; + + // V - [batch, nhead, N, O] + p.v_stride_n = prob.O; + p.v_stride_nhead = prob.N * prob.O; + p.v_stride_batch = prob.nhead * prob.N * prob.O; + + // O - [batch, nhead, M, O] contiguous + p.o_stride_m = prob.O; + p.o_stride_nhead = prob.M * prob.O; + p.o_stride_batch = prob.nhead * prob.M * prob.O; + + return p; +} + +std::pair get_launch_dims(const Solution& solution, const Problem& prob) +{ + // Block tile sizes (from TileFmhaShape BlockTile sequence) + auto bm0 = solution.GetTemplateParameter("BM0"); + auto bn1 = solution.GetTemplateParameter("BN1"); + + // Block warps for Gemm0 - sequence + auto rm0 = solution.GetTemplateParameter("RM0"); + auto rn0 = solution.GetTemplateParameter("RN0"); + auto rk0 = solution.GetTemplateParameter("RK0"); + + // Block warps for Gemm1 - sequence + auto rm1 = solution.GetTemplateParameter("RM1"); + auto rn1 = solution.GetTemplateParameter("RN1"); + auto rk1 = solution.GetTemplateParameter("RK1"); + + const std::size_t warp_size = 64; // gfx9 + const std::size_t num_warps = std::max(rm0 * rn0 * rk0, rm1 * rn1 * rk1); + const std::size_t block_size = num_warps * warp_size; + + // Grid dimensions: (nhead, num_m_tiles * num_o_tiles, batch) + const auto grid_m = ck::host::integer_divide_ceil(prob.M, bm0); + const auto grid_o = ck::host::integer_divide_ceil(prob.O, bn1); + + dim3 grid(prob.nhead, grid_m * grid_o, prob.batch); + dim3 block(block_size, 1, 1); + + return {grid, block}; +} + +TEST_CASE(test_fmha_fwd_simple_validation) +{ + ck::host::device_fmha_fwd::Problem prob; + prob.M = 24; // seqlen_q + prob.N = 32; // seqlen_k + prob.K = 8; // hdim_q (must be multiple of 8) + prob.O = 16; // hdim_v + prob.batch = 2; + prob.nhead = 1; + prob.dtype = ck::host::DataType::Half; + prob.is_v_rowmajor = true; + prob.is_causal = false; + prob.has_bias = false; + + const float scale_s = 1.0f; + + auto solutions = prob.GetSolutions("gfx90a"); + std::cout << "Number of solutions: " << solutions.size() << std::endl; + + EXPECT(!solutions.empty()); + + const std::vector q_data = { + -0.125460f, 0.450714f, 0.231994f, 0.098658f, -0.343981f, -0.344005f, -0.441916f, + 0.366176f, 0.101115f, 0.208073f, -0.479416f, 0.469910f, 0.332443f, -0.287661f, + -0.318175f, -0.316595f, -0.195758f, 0.024756f, -0.068055f, -0.208771f, 0.111853f, + -0.360506f, -0.207855f, -0.133638f, -0.043930f, 0.285176f, -0.300326f, 0.014234f, + 0.092415f, -0.453550f, 0.107545f, -0.329476f, -0.434948f, 0.448886f, 0.465632f, + 0.308397f, -0.195386f, -0.402328f, 0.184233f, -0.059848f, -0.377962f, -0.004823f, + -0.465611f, 0.409320f, -0.241220f, 0.162522f, -0.188289f, 0.020068f, 0.046710f, + -0.315146f, 0.469585f, 0.275133f, 0.439499f, 0.394827f, 0.097900f, 0.421874f, + -0.411507f, -0.304017f, -0.454773f, -0.174670f, -0.111323f, -0.228651f, 0.328737f, + -0.143247f, -0.219065f, 0.042696f, -0.359076f, 0.302197f, -0.425449f, 0.486887f, + 0.272245f, -0.301284f, -0.494478f, 0.315461f, 0.206857f, 0.229007f, 0.271270f, + -0.425955f, -0.141534f, -0.384131f, 0.363103f, 0.123298f, -0.169102f, -0.436442f, + -0.189018f, -0.174817f, 0.229606f, 0.137557f, 0.387213f, -0.027785f, -0.380406f, + 0.213245f, 0.260785f, 0.061277f, 0.270967f, -0.006204f, 0.022733f, -0.072459f, + -0.474581f, -0.392109f, -0.468571f, 0.136410f, -0.185644f, 0.008571f, 0.407566f, + -0.250708f, -0.089617f, 0.255551f, -0.271202f, -0.423020f, -0.210249f, -0.338779f, + 0.429698f, 0.308120f, 0.133404f, 0.371461f, 0.303672f, -0.313430f, 0.392559f, + 0.039342f, 0.307440f, 0.396091f, -0.181997f, -0.389948f, -0.272065f, -0.072892f, + 0.318015f, 0.360731f, -0.493048f, 0.010747f, -0.082589f, -0.277892f, -0.380135f, + -0.162385f, 0.442910f, -0.176797f, 0.018791f, 0.203019f, -0.136370f, 0.471782f, + 0.462447f, -0.248218f, -0.002751f, -0.199122f, -0.215160f, -0.463113f, 0.109564f, + 0.002679f, -0.448521f, -0.221354f, 0.408266f, -0.260438f, -0.355105f, -0.010547f, + 0.485650f, -0.257945f, 0.172136f, 0.261620f, -0.262362f, 0.228216f, -0.132217f, + 0.132306f, 0.133530f, 0.035775f, -0.409710f, 0.335303f, -0.179220f, -0.313481f, + -0.459225f, 0.090893f, 0.177564f, -0.483412f, 0.012093f, -0.273504f, 0.145173f, + -0.325634f, 0.190938f, -0.113265f, 0.436730f, -0.362479f, -0.158934f, -0.386526f, + 0.424694f, 0.377339f, -0.242058f, 0.159984f, 0.317222f, 0.055201f, 0.029651f, + -0.258148f, -0.406897f, 0.397216f, 0.400418f, 0.133101f, -0.160970f, -0.150790f, + 0.225956f, 0.397110f, 0.387086f, 0.279876f, 0.142032f, -0.415860f, -0.338371f, + 0.398554f, 0.106429f, -0.490803f, -0.398528f, 0.163502f, -0.494938f, -0.339192f, + 0.048734f, 0.191895f, 0.151961f, -0.275731f, 0.212179f, -0.262751f, -0.174600f, + 0.246491f, 0.149633f, 0.349223f, 0.157613f, 0.068309f, -0.406325f, -0.132284f, + -0.234798f, -0.256010f, 0.473011f, -0.106902f, 0.392047f, 0.131139f, 0.294811f, + 0.002637f, 0.076904f, -0.007482f, -0.304757f, 0.222452f, -0.219228f, -0.475684f, + 0.145472f, -0.322889f, 0.440459f, 0.453929f, 0.414864f, -0.129841f, -0.484543f, + 0.428319f, -0.071816f, 0.466655f, 0.463620f, 0.353009f, -0.205551f, -0.114902f, + 0.351137f, -0.183078f, -0.330507f, 0.056801f, 0.436155f, 0.196030f, 0.070061f, + -0.402824f, 0.115007f, 0.490054f, -0.359916f, 0.018330f, 0.377373f, 0.240769f, + 0.197016f, 0.202484f, -0.140509f, -0.206408f, 0.309361f, 0.310113f, 0.367072f, + 0.413241f, 0.011342f, 0.001516f, 0.298295f, 0.149964f, 0.201967f, 0.295793f, + 0.390005f, -0.162005f, -0.124417f, -0.406018f, 0.078280f, -0.464058f, -0.034402f, + 0.042645f, -0.213459f, 0.090833f, -0.469500f, -0.462652f, 0.322601f, -0.139809f, + -0.372939f, 0.022243f, 0.269994f, -0.284179f, 0.122890f, -0.414653f, -0.448318f, + 0.031355f, 0.040635f, 0.137430f, 0.226091f, 0.475852f, 0.016300f, -0.177044f, + 0.295186f, -0.229168f, -0.061029f, -0.421544f, -0.474649f, 0.462648f, 0.335980f, + 0.195974f, -0.091047f, -0.326706f, -0.343563f, -0.249757f, 0.049227f, 0.214596f, + 0.160197f, -0.220066f, 0.454865f, 0.237897f, 0.054354f, 0.111721f, -0.080400f, + -0.252269f, -0.144027f, 0.257846f, -0.485607f, -0.383927f, -0.453997f, -0.459271f, + 0.355461f, 0.203658f, -0.025826f, -0.402166f, -0.008384f, -0.026528f, -0.326798f, + -0.066148f, -0.101495f, 0.115850f, 0.135094f, -0.454696f, -0.125387f, 0.125860f, + 0.003136f, 0.356490f, 0.158694f, -0.337066f, -0.429431f, 0.142419f, -0.473489f, + 0.085776f, 0.440230f, 0.075474f, -0.111830f, 0.143288f, -0.041747f, 0.045617f, + 0.441465f, -0.113897f, 0.461191f, 0.405351f, -0.304209f, -0.430639f, -0.399222f, + -0.481778f, -0.405557f, 0.183007f, -0.428811f, -0.181024f, 0.344875f, -0.476728f, + 0.314468f, -0.218145f, -0.381835f, 0.196737f, 0.128943f, 0.377472f, + }; + + const std::vector k_data = { + 0.235071f, 0.303481f, -0.217965f, -0.322560f, 0.250615f, 0.306835f, 0.490505f, + -0.087382f, -0.127982f, 0.276413f, -0.159196f, 0.430757f, 0.358413f, -0.071006f, + 0.250871f, 0.254543f, -0.396876f, 0.402553f, 0.005252f, 0.326457f, -0.179950f, + 0.395523f, -0.110798f, -0.489162f, 0.405382f, -0.408713f, -0.180686f, 0.450062f, + 0.450607f, 0.073438f, 0.131837f, -0.051554f, -0.206789f, -0.171335f, 0.172518f, + 0.252375f, 0.291579f, 0.289618f, -0.408794f, -0.005580f, -0.442441f, 0.049529f, + -0.058470f, 0.387704f, -0.149085f, -0.382933f, -0.357008f, 0.261511f, 0.118218f, + -0.398877f, -0.415893f, 0.200969f, -0.427237f, 0.321860f, 0.206242f, -0.418651f, + -0.415162f, 0.486640f, -0.125729f, -0.129358f, 0.312800f, 0.447249f, 0.486001f, + 0.253378f, -0.123740f, -0.416499f, 0.277147f, 0.058404f, -0.075778f, 0.406354f, + -0.388803f, -0.007375f, -0.488646f, -0.031339f, -0.443697f, -0.381182f, -0.382474f, + 0.149210f, 0.246045f, 0.083369f, 0.462173f, -0.125129f, -0.214288f, 0.368599f, + -0.276404f, 0.463223f, -0.487846f, 0.469879f, -0.456840f, 0.391143f, 0.027701f, + 0.492965f, -0.426203f, 0.053854f, 0.469303f, 0.023098f, 0.129399f, 0.195749f, + -0.045459f, 0.127558f, 0.084314f, 0.401158f, -0.454554f, -0.219037f, 0.450411f, + 0.390264f, -0.044343f, 0.120133f, -0.222619f, -0.311879f, -0.036302f, -0.146648f, + 0.083656f, -0.422265f, 0.474395f, 0.486211f, 0.198162f, 0.036096f, -0.190472f, + 0.313795f, 0.184731f, -0.337383f, 0.410927f, 0.322537f, 0.449800f, 0.225720f, + 0.113415f, -0.081757f, 0.432728f, 0.366064f, -0.454781f, -0.473633f, -0.123537f, + 0.310553f, 0.487276f, -0.349583f, 0.094131f, -0.119109f, 0.469914f, 0.342119f, + 0.338329f, -0.031307f, -0.085180f, -0.226593f, -0.443624f, 0.364722f, 0.312901f, + 0.499718f, 0.496637f, 0.055432f, 0.268987f, 0.444766f, 0.349647f, -0.252652f, + -0.049456f, -0.370841f, 0.454051f, 0.106175f, -0.271357f, 0.171701f, 0.118128f, + -0.141837f, -0.386442f, 0.171573f, 0.020308f, 0.272318f, 0.020164f, 0.352181f, + 0.051907f, 0.060938f, 0.376654f, -0.096517f, -0.365985f, -0.471217f, 0.255137f, + 0.120310f, 0.204080f, -0.287036f, -0.363629f, -0.485455f, -0.149412f, 0.089918f, + -0.107756f, -0.062525f, 0.404159f, -0.151745f, 0.013989f, 0.283653f, -0.103457f, + 0.122087f, 0.362364f, 0.449521f, -0.352927f, 0.426588f, -0.007884f, -0.241756f, + -0.040864f, 0.480033f, -0.007382f, -0.171248f, 0.133401f, -0.259854f, -0.424137f, + -0.371120f, -0.371954f, -0.348097f, -0.361173f, 0.140875f, -0.318120f, -0.154333f, + 0.396788f, -0.026038f, 0.167558f, -0.327680f, -0.307711f, -0.459131f, -0.331065f, + -0.221410f, -0.322990f, -0.411297f, -0.379364f, -0.039221f, -0.293666f, -0.135730f, + 0.003417f, 0.190395f, -0.460688f, 0.299410f, 0.127900f, -0.418241f, 0.373579f, + 0.420872f, -0.438922f, -0.223122f, 0.306201f, 0.248260f, -0.315479f, -0.290651f, + -0.129528f, -0.015477f, 0.118255f, -0.131086f, -0.037465f, 0.247471f, -0.463317f, + -0.247563f, 0.213350f, 0.395207f, 0.011677f, 0.032113f, -0.392828f, -0.052588f, + 0.032617f, -0.257529f, -0.230757f, -0.122716f, -0.479929f, -0.177921f, -0.288552f, + -0.172503f, -0.380238f, 0.390527f, 0.093592f, 0.179102f, 0.289171f, -0.001558f, + -0.413080f, 0.037107f, 0.086841f, 0.245439f, -0.068340f, -0.372420f, -0.216224f, + -0.136918f, 0.145917f, 0.070778f, -0.143903f, 0.486515f, 0.105775f, -0.262773f, + -0.398218f, -0.347141f, -0.254042f, -0.339319f, -0.313433f, -0.214905f, -0.326626f, + 0.396765f, -0.419766f, 0.024511f, -0.089603f, 0.482379f, -0.387961f, -0.102144f, + 0.469470f, 0.365507f, 0.317072f, -0.242097f, -0.329112f, 0.168643f, 0.429376f, + 0.056763f, 0.071613f, -0.220021f, 0.269493f, -0.312956f, -0.176321f, -0.074564f, + 0.007610f, -0.257590f, -0.385163f, 0.110620f, -0.211369f, 0.081238f, -0.345637f, + -0.018860f, 0.032589f, -0.448176f, -0.163396f, -0.365585f, -0.436625f, 0.489960f, + -0.177646f, 0.309874f, -0.245359f, 0.181503f, 0.260228f, 0.095639f, -0.028424f, + -0.088159f, -0.151132f, 0.429529f, 0.330619f, 0.465027f, -0.375703f, 0.230867f, + 0.438340f, -0.318767f, -0.433504f, 0.241121f, 0.074473f, 0.341829f, -0.360228f, + 0.295267f, -0.298373f, -0.336344f, -0.335734f, 0.314575f, 0.165197f, 0.023065f, + -0.141170f, 0.377201f, -0.107555f, 0.316599f, -0.060865f, -0.123056f, -0.037320f, + -0.198622f, 0.247609f, 0.002720f, -0.267787f, 0.399575f, -0.116109f, 0.043553f, + 0.406472f, 0.124238f, -0.383102f, 0.439832f, 0.127708f, -0.165094f, -0.360728f, + 0.294025f, 0.120073f, 0.033461f, 0.393893f, 0.288597f, -0.348325f, -0.188278f, + -0.251511f, 0.243946f, -0.466468f, 0.069890f, 0.262459f, 0.376766f, -0.157918f, + 0.321257f, -0.389368f, 0.346452f, -0.372511f, -0.102713f, 0.297295f, -0.350083f, + -0.270749f, 0.222253f, 0.220037f, 0.141148f, 0.193948f, 0.042724f, -0.248201f, + -0.154304f, -0.318402f, 0.408451f, 0.083392f, -0.099149f, -0.037994f, 0.447283f, + -0.346649f, 0.086230f, 0.005889f, 0.111454f, -0.481890f, 0.372124f, 0.432118f, + 0.065133f, 0.196651f, 0.422499f, 0.207239f, -0.347461f, 0.076288f, 0.106715f, + -0.075869f, 0.236444f, 0.434367f, 0.425569f, -0.049161f, -0.386762f, 0.484841f, + 0.338898f, -0.375337f, 0.420842f, 0.369896f, 0.018838f, 0.091275f, -0.100997f, + -0.445238f, -0.164803f, 0.302853f, -0.495368f, -0.166501f, -0.101831f, 0.037396f, + 0.419856f, -0.153654f, -0.153047f, 0.237501f, -0.047782f, -0.275395f, -0.047560f, + -0.359143f, -0.323613f, -0.001632f, -0.081075f, 0.414846f, -0.137606f, 0.080588f, + 0.132264f, -0.486906f, 0.163537f, -0.321964f, 0.461070f, -0.351337f, -0.085376f, + -0.414650f, 0.496874f, 0.002195f, 0.095385f, -0.432924f, 0.249960f, -0.290094f, + 0.398054f, -0.294860f, -0.309312f, -0.463450f, -0.027933f, 0.064841f, -0.434291f, + 0.275528f, -0.046711f, 0.024390f, -0.059237f, -0.099237f, 0.059640f, -0.344760f, + -0.318072f, 0.361786f, 0.446115f, -0.126691f, -0.229255f, 0.144000f, -0.091266f, + -0.474614f, -0.343847f, 0.215972f, 0.158924f, -0.472904f, -0.278028f, -0.268925f, + 0.171893f, -0.480289f, -0.395891f, 0.299916f, -0.321455f, 0.152746f, -0.261817f, + -0.400559f, -0.256828f, 0.222267f, 0.355696f, 0.330220f, -0.102816f, 0.168085f, + -0.295016f, + }; + + const std::vector v_data = { + -0.206852f, 0.396336f, -0.486998f, -0.414491f, -0.292114f, -0.473468f, -0.318565f, + 0.083042f, -0.078575f, 0.392672f, 0.317444f, -0.158183f, -0.240577f, -0.120308f, + 0.090295f, -0.231936f, 0.124149f, -0.090588f, 0.052047f, -0.063873f, -0.205534f, + 0.448453f, 0.263606f, -0.359887f, 0.368468f, -0.012569f, 0.394552f, 0.299855f, + -0.074786f, -0.477531f, -0.231323f, 0.041634f, 0.133478f, -0.242112f, -0.360644f, + 0.334930f, 0.484402f, 0.025690f, -0.328321f, -0.227693f, -0.481609f, 0.414299f, + -0.382249f, 0.076516f, -0.225945f, 0.054178f, 0.151420f, 0.329742f, -0.293579f, + -0.489004f, -0.363114f, 0.400019f, 0.373890f, 0.097413f, 0.100517f, 0.165037f, + -0.324629f, 0.414412f, -0.081229f, -0.116861f, 0.018918f, -0.453034f, -0.333717f, + 0.238034f, -0.417201f, 0.103152f, -0.254651f, -0.110704f, -0.211306f, -0.144327f, + 0.219046f, -0.202878f, 0.066405f, -0.023950f, 0.163671f, 0.436830f, 0.232572f, + -0.285060f, -0.468817f, -0.237736f, 0.095078f, -0.448574f, -0.003634f, 0.096843f, + -0.165756f, 0.270912f, -0.393402f, -0.424862f, 0.228189f, -0.004509f, 0.188402f, + -0.065173f, -0.253598f, 0.319102f, 0.299416f, 0.194696f, -0.227855f, 0.090231f, + -0.139026f, -0.408418f, 0.417314f, -0.363181f, 0.450237f, -0.053994f, -0.314867f, + 0.041901f, 0.372946f, 0.232225f, 0.306561f, 0.158783f, 0.192277f, 0.349196f, + -0.250332f, -0.010575f, -0.278791f, 0.487668f, 0.444059f, -0.460573f, 0.205575f, + 0.425248f, -0.319425f, 0.067945f, 0.415488f, -0.466054f, 0.197420f, -0.202651f, + 0.424396f, 0.471058f, 0.444266f, -0.025786f, 0.362043f, 0.344549f, -0.180900f, + 0.328915f, -0.462992f, 0.096270f, -0.269991f, -0.379433f, -0.423047f, 0.196289f, + -0.160125f, 0.224767f, -0.434644f, -0.184710f, 0.039491f, 0.290723f, -0.181248f, + 0.125891f, 0.385978f, 0.115863f, -0.267041f, -0.475599f, 0.370099f, -0.478731f, + 0.374702f, 0.028937f, 0.439068f, 0.298783f, 0.497934f, -0.149288f, 0.267188f, + -0.098069f, -0.020124f, 0.127505f, 0.373677f, 0.484083f, 0.268273f, -0.082233f, + -0.078643f, 0.237582f, -0.261223f, -0.389526f, -0.145378f, -0.212761f, -0.203692f, + -0.266392f, -0.457907f, -0.482126f, 0.487722f, -0.072227f, -0.115673f, 0.179647f, + -0.281746f, 0.449961f, 0.286345f, -0.410589f, -0.082419f, 0.379118f, 0.444732f, + -0.032598f, 0.113411f, -0.332966f, 0.491169f, -0.268328f, 0.442732f, 0.149647f, + 0.107737f, 0.012689f, -0.269330f, -0.323472f, -0.279514f, -0.313562f, 0.279584f, + -0.149875f, -0.442157f, 0.469103f, 0.383786f, 0.427752f, 0.494908f, -0.326105f, + -0.103758f, 0.258238f, 0.196021f, -0.346104f, 0.315833f, -0.275559f, -0.276182f, + 0.036974f, 0.092940f, 0.080086f, -0.408513f, 0.377461f, -0.234400f, -0.370485f, + 0.388748f, 0.455651f, 0.362128f, 0.309516f, 0.155242f, 0.050857f, -0.413013f, + -0.091547f, -0.127311f, -0.240246f, 0.223420f, -0.004124f, -0.418954f, -0.279817f, + 0.183259f, -0.423869f, 0.351207f, -0.004853f, -0.019413f, 0.092408f, 0.324681f, + -0.152191f, 0.178016f, 0.065732f, -0.232972f, 0.378630f, 0.297426f, 0.158452f, + 0.350582f, 0.367294f, 0.208363f, 0.337013f, 0.197471f, 0.180141f, 0.118611f, + 0.252717f, -0.341395f, 0.380871f, 0.371844f, -0.470753f, 0.325817f, -0.371130f, + -0.164881f, 0.243508f, -0.339240f, 0.317967f, 0.332134f, 0.007468f, -0.493614f, + -0.212962f, 0.116927f, 0.481186f, 0.131814f, -0.240196f, 0.134006f, 0.039985f, + 0.279845f, -0.393019f, 0.261028f, 0.041267f, 0.462992f, -0.158128f, 0.132622f, + 0.432028f, -0.397490f, 0.437229f, 0.187886f, -0.432163f, -0.199036f, 0.208172f, + -0.432649f, 0.082170f, -0.154117f, 0.120916f, -0.454258f, 0.371537f, 0.473489f, + 0.468878f, 0.249652f, -0.369914f, 0.258263f, -0.475413f, -0.477876f, -0.176390f, + -0.011357f, 0.270407f, 0.183295f, -0.054097f, -0.226373f, 0.497124f, -0.073819f, + -0.048613f, -0.336376f, 0.294810f, 0.193682f, -0.279230f, -0.417619f, 0.180499f, + 0.154511f, -0.226740f, 0.450864f, -0.348942f, -0.067665f, 0.443616f, -0.080273f, + 0.138526f, -0.102406f, -0.225785f, 0.483978f, -0.090666f, 0.394099f, -0.270045f, + -0.286895f, -0.468866f, 0.151667f, -0.131474f, 0.364358f, -0.026790f, 0.468193f, + -0.314474f, 0.368623f, 0.276597f, 0.270922f, 0.344783f, 0.261024f, 0.126220f, + -0.368755f, -0.467474f, 0.420848f, 0.116650f, 0.296537f, -0.018478f, -0.382692f, + -0.374814f, 0.185565f, -0.069694f, -0.299475f, -0.008405f, -0.435791f, 0.081971f, + -0.231007f, 0.297559f, -0.189638f, -0.044780f, -0.488379f, -0.427553f, -0.107506f, + -0.020061f, 0.100021f, -0.208337f, 0.194982f, 0.360122f, 0.279851f, -0.460381f, + -0.019493f, -0.395070f, -0.257955f, 0.486663f, -0.357504f, -0.001112f, 0.118156f, + 0.202465f, 0.059649f, -0.490229f, -0.173539f, 0.017712f, -0.412134f, -0.149373f, + -0.466797f, -0.421421f, -0.103077f, -0.367284f, 0.067541f, 0.189465f, 0.300587f, + -0.299850f, -0.332517f, -0.395432f, 0.136430f, 0.206476f, -0.468414f, 0.436212f, + -0.448029f, 0.041296f, 0.209061f, 0.370969f, 0.214087f, 0.301728f, -0.160550f, + 0.314825f, -0.419885f, 0.394817f, 0.047592f, 0.317298f, -0.047682f, 0.143578f, + 0.026403f, 0.231590f, -0.418370f, -0.439648f, -0.252897f, -0.340455f, 0.371784f, + -0.280786f, 0.475865f, -0.163104f, -0.317882f, 0.289699f, 0.158708f, -0.001804f, + 0.055364f, 0.219202f, -0.271545f, 0.496334f, 0.474793f, 0.150326f, -0.300458f, + 0.180228f, -0.427802f, -0.469348f, -0.242317f, -0.037377f, 0.368273f, 0.227169f, + 0.242707f, -0.074507f, -0.154065f, -0.128961f, 0.487650f, -0.459891f, 0.367031f, + 0.078675f, -0.061385f, 0.225258f, -0.013331f, 0.373423f, 0.400702f, -0.078279f, + -0.223172f, 0.092350f, 0.412363f, -0.289338f, 0.122967f, 0.131560f, 0.233113f, + -0.368432f, 0.215825f, 0.409033f, -0.320317f, -0.262457f, 0.471395f, -0.319023f, + 0.354385f, -0.007722f, -0.252769f, 0.370750f, -0.054695f, 0.014817f, -0.140767f, + 0.092951f, -0.336476f, -0.108918f, 0.469412f, -0.241867f, 0.156737f, -0.174810f, + 0.273473f, -0.369126f, 0.469821f, -0.046210f, -0.263950f, -0.426503f, -0.330242f, + 0.019774f, -0.162997f, 0.328883f, -0.069112f, -0.251286f, 0.117145f, 0.206777f, + -0.332958f, -0.332381f, -0.463329f, 0.236402f, 0.163805f, -0.025369f, 0.344170f, + 0.305670f, 0.085354f, 0.368271f, -0.294159f, -0.388080f, -0.230250f, -0.442913f, + 0.031170f, 0.436606f, -0.460656f, -0.377890f, -0.047801f, 0.433875f, -0.183844f, + 0.007235f, -0.458427f, -0.351657f, 0.486630f, 0.465119f, -0.495060f, 0.451812f, + 0.139120f, 0.367918f, -0.045260f, 0.015596f, -0.011153f, 0.166864f, -0.360349f, + -0.470026f, -0.192070f, 0.204681f, -0.298147f, 0.173432f, 0.469912f, -0.406099f, + 0.172602f, -0.056250f, 0.368142f, -0.322850f, 0.192626f, 0.338115f, 0.444614f, + 0.183248f, -0.002825f, 0.117847f, 0.368905f, 0.070610f, -0.469613f, 0.430949f, + 0.189527f, 0.176513f, -0.284325f, 0.158885f, -0.106136f, 0.151233f, -0.393407f, + 0.157845f, 0.499414f, -0.451788f, 0.477174f, -0.093092f, 0.370753f, 0.282385f, + 0.067016f, 0.238449f, 0.378516f, -0.095860f, -0.172967f, 0.167593f, 0.307846f, + 0.262285f, 0.297814f, -0.064417f, 0.317834f, -0.379791f, 0.044489f, -0.494241f, + -0.175414f, -0.133538f, -0.103827f, 0.195467f, -0.111442f, -0.051306f, -0.262456f, + -0.126748f, -0.272730f, -0.426804f, 0.103449f, 0.168213f, 0.119490f, -0.036506f, + -0.120214f, 0.363334f, 0.019082f, -0.020818f, -0.474358f, -0.158752f, -0.119804f, + -0.101177f, 0.080172f, 0.033603f, 0.107905f, 0.264883f, 0.312986f, 0.218123f, + 0.455524f, -0.481767f, -0.304222f, -0.492437f, 0.147475f, 0.398031f, -0.256518f, + 0.427035f, -0.439733f, 0.434436f, -0.148377f, -0.398579f, -0.014128f, -0.243223f, + -0.215127f, -0.192710f, 0.303026f, 0.039161f, -0.188692f, 0.110334f, 0.216151f, + -0.227376f, -0.086451f, -0.378114f, -0.318851f, 0.181118f, -0.318562f, 0.025163f, + 0.209046f, -0.393123f, 0.067312f, -0.243437f, 0.462927f, -0.016454f, 0.305993f, + 0.050227f, -0.456587f, 0.133151f, 0.451403f, 0.101612f, 0.319189f, 0.384206f, + -0.271920f, -0.287955f, 0.110981f, -0.088972f, 0.339861f, 0.400023f, -0.146579f, + -0.263129f, 0.280526f, -0.225194f, 0.322614f, -0.076262f, 0.167550f, -0.404465f, + 0.123859f, -0.048232f, 0.086608f, -0.331986f, 0.236874f, 0.362797f, -0.283260f, + -0.404285f, -0.476361f, 0.141971f, 0.107094f, 0.046697f, -0.268053f, -0.109094f, + 0.094476f, -0.003233f, 0.487786f, -0.363560f, 0.195145f, -0.095681f, -0.071800f, + 0.217598f, 0.192436f, 0.491256f, -0.371606f, -0.395890f, 0.224339f, 0.078387f, + -0.225839f, -0.420581f, -0.414342f, 0.394191f, -0.308133f, -0.176628f, -0.273344f, + -0.145004f, -0.430576f, 0.019060f, -0.432387f, 0.300357f, -0.266288f, 0.040012f, + 0.380079f, 0.150877f, 0.032958f, -0.175666f, -0.166998f, 0.169487f, 0.494139f, + 0.161839f, 0.057783f, 0.230651f, -0.034794f, -0.439858f, 0.062297f, 0.457625f, + -0.324697f, 0.190005f, -0.299066f, 0.035828f, -0.403324f, -0.049629f, 0.256163f, + -0.152428f, 0.164912f, 0.295450f, 0.427178f, -0.265358f, -0.100684f, -0.347584f, + 0.492483f, 0.427001f, 0.039957f, 0.342033f, 0.020958f, 0.123586f, -0.410876f, + 0.255270f, -0.372287f, 0.326068f, 0.282028f, 0.208745f, -0.463840f, -0.196872f, + -0.236887f, -0.139864f, -0.412357f, 0.436958f, 0.053802f, -0.194476f, -0.103018f, + -0.052797f, 0.100594f, 0.015679f, 0.419392f, -0.003037f, 0.492158f, 0.351425f, + -0.291489f, 0.430595f, -0.383634f, 0.317450f, -0.119377f, 0.377974f, 0.368057f, + 0.305925f, 0.290030f, -0.195321f, -0.419081f, -0.097020f, -0.326475f, 0.194951f, + -0.153900f, 0.475610f, 0.140972f, 0.322481f, -0.367475f, 0.362014f, 0.422757f, + -0.012938f, 0.106253f, 0.264810f, -0.325161f, 0.002566f, -0.101337f, -0.353626f, + -0.132466f, -0.431828f, -0.474188f, -0.364834f, 0.463115f, 0.049530f, 0.465822f, + -0.067502f, -0.188184f, 0.006142f, -0.060488f, -0.394335f, 0.140826f, -0.283962f, + 0.119588f, 0.150201f, -0.347975f, -0.438650f, 0.280762f, -0.040200f, -0.441836f, + 0.494866f, -0.442219f, 0.195035f, 0.483679f, -0.260820f, -0.357751f, -0.378615f, + -0.196725f, -0.398954f, 0.192161f, -0.437708f, 0.009422f, 0.496697f, 0.313970f, + 0.115219f, -0.193746f, 0.123896f, 0.027041f, -0.073917f, -0.369290f, 0.386604f, + -0.050215f, -0.305377f, -0.132241f, -0.085870f, 0.327538f, 0.233614f, 0.269305f, + -0.488969f, -0.083846f, -0.018656f, -0.480808f, -0.240187f, 0.260290f, -0.362890f, + 0.035310f, -0.284798f, -0.487879f, -0.258799f, 0.475874f, 0.301537f, 0.459577f, + -0.012146f, -0.390264f, 0.047959f, -0.045623f, 0.344357f, -0.401917f, -0.011759f, + -0.349951f, -0.175324f, 0.237357f, -0.023982f, -0.124112f, -0.105524f, -0.040553f, + 0.285017f, 0.392085f, 0.455335f, 0.286903f, -0.184593f, 0.188135f, -0.062397f, + -0.245329f, 0.340872f, -0.461574f, 0.401762f, -0.038523f, 0.137201f, 0.159354f, + 0.395118f, 0.136670f, 0.113934f, -0.433348f, 0.018408f, -0.349831f, 0.237434f, + 0.012222f, 0.180228f, -0.458327f, -0.415208f, 0.216323f, -0.427916f, -0.428743f, + -0.487892f, 0.456501f, 0.237508f, -0.146749f, -0.203464f, -0.150297f, 0.274654f, + 0.161371f, -0.314804f, -0.325891f, -0.401604f, 0.160303f, 0.264373f, -0.234954f, + -0.479055f, -0.417828f, 0.467860f, -0.204555f, 0.269223f, 0.124664f, -0.118060f, + -0.294313f, -0.378614f, 0.115013f, 0.274634f, 0.143904f, 0.030302f, -0.458049f, + 0.468489f, 0.298714f, -0.207178f, 0.479970f, 0.101882f, 0.082423f, 0.248073f, + 0.311770f, 0.156479f, -0.371904f, -0.161732f, 0.428084f, -0.275384f, -0.127833f, + -0.067923f, -0.060595f, 0.112940f, 0.443076f, -0.259307f, -0.378499f, -0.302530f, + 0.386925f, 0.145811f, -0.214093f, 0.315947f, 0.361370f, 0.346514f, 0.418927f, + -0.247759f, 0.255042f, -0.039461f, 0.341999f, 0.228491f, 0.276447f, 0.156162f, + -0.322571f, 0.045027f, 0.484670f, 0.437388f, -0.456826f, -0.335185f, -0.368271f, + 0.225980f, 0.317785f, -0.286489f, 0.005853f, 0.340703f, 0.232802f, 0.042237f, + 0.090348f, 0.008361f, -0.202452f, 0.065022f, 0.188885f, 0.373323f, 0.136291f, + 0.261122f, -0.339928f, -0.038443f, -0.490668f, -0.253321f, 0.226462f, 0.491810f, + -0.400822f, -0.098506f, 0.300071f, -0.295964f, 0.055085f, 0.233071f, 0.115985f, + -0.311975f, -0.144615f, 0.283792f, 0.054227f, -0.494770f, 0.260991f, -0.464689f, + 0.245734f, -0.297519f, 0.458073f, -0.132059f, -0.173068f, -0.351112f, -0.194396f, + 0.376651f, 0.496334f, -0.131690f, -0.051389f, 0.222071f, 0.386196f, 0.093044f, + -0.108474f, -0.087378f, + }; + + const std::vector numpy_expected = { + 0.007383f, -0.085425f, 0.011838f, 0.062971f, 0.043929f, 0.007666f, 0.008439f, + -0.046630f, -0.058420f, -0.034030f, 0.050607f, 0.002766f, 0.056086f, 0.071142f, + 0.003148f, -0.008505f, 0.002715f, -0.076216f, -0.014847f, 0.068649f, 0.058922f, + -0.008740f, 0.021790f, -0.043732f, -0.082332f, -0.014314f, 0.041560f, 0.015328f, + 0.045330f, 0.052070f, 0.014844f, 0.026025f, 0.007508f, -0.065677f, -0.006289f, + 0.065917f, 0.036876f, 0.000431f, 0.013452f, -0.047478f, -0.076925f, -0.027326f, + 0.047549f, 0.003660f, 0.052550f, 0.068205f, 0.015890f, 0.019385f, -0.002520f, + -0.068157f, -0.014357f, 0.059441f, 0.046273f, -0.015606f, 0.029188f, -0.047057f, + -0.067481f, -0.025480f, 0.048960f, 0.016361f, 0.055688f, 0.066174f, 0.022904f, + 0.016228f, -0.017850f, -0.077436f, 0.015345f, 0.052739f, 0.056457f, -0.008167f, + -0.002618f, -0.035080f, -0.054646f, -0.047784f, 0.064118f, 0.021038f, 0.098352f, + 0.061559f, 0.014207f, -0.006122f, -0.002099f, -0.067341f, -0.000756f, 0.057148f, + 0.059963f, 0.001503f, 0.010144f, -0.032881f, -0.075191f, -0.032237f, 0.037420f, + 0.001029f, 0.060923f, 0.060398f, 0.030673f, 0.012808f, -0.006748f, -0.047749f, + 0.000415f, 0.060475f, 0.069737f, -0.008651f, 0.004705f, -0.012828f, -0.077261f, + -0.017083f, 0.051994f, 0.003326f, 0.062779f, 0.048019f, 0.008298f, -0.012594f, + -0.007749f, -0.055491f, -0.012014f, 0.053954f, 0.045582f, -0.010534f, 0.030729f, + -0.036889f, -0.063309f, -0.032229f, 0.049988f, 0.004904f, 0.070313f, 0.069882f, + 0.033285f, 0.018283f, -0.009560f, -0.056328f, -0.007101f, 0.047559f, 0.067232f, + -0.013676f, 0.019708f, -0.032811f, -0.078113f, -0.040424f, 0.039800f, 0.003230f, + 0.060881f, 0.069153f, 0.049097f, 0.012857f, -0.003914f, -0.063199f, 0.001035f, + 0.065549f, 0.052037f, -0.002653f, -0.013828f, -0.048785f, -0.080286f, -0.041294f, + 0.059457f, 0.014830f, 0.082938f, 0.054519f, 0.019383f, 0.025542f, 0.001185f, + -0.064716f, -0.015948f, 0.052071f, 0.032986f, -0.014907f, 0.051420f, -0.044499f, + -0.053381f, -0.017821f, 0.042237f, 0.002952f, 0.031287f, 0.084531f, 0.017001f, + -0.008584f, -0.010784f, -0.064312f, -0.024903f, 0.052547f, 0.063267f, -0.024236f, + 0.046386f, -0.025896f, -0.068553f, -0.006001f, 0.044032f, 0.006031f, 0.043641f, + 0.056054f, 0.016689f, 0.004116f, 0.014393f, -0.058293f, -0.004851f, 0.058634f, + 0.027928f, 0.008397f, 0.033760f, -0.046834f, -0.072747f, -0.025939f, 0.024793f, + -0.008613f, 0.026162f, 0.088906f, 0.032530f, 0.011598f, 0.010774f, -0.087746f, + -0.002402f, 0.076286f, 0.052772f, -0.007808f, 0.042321f, -0.044525f, -0.074307f, + -0.020356f, 0.050978f, 0.005467f, 0.041848f, 0.067021f, -0.013176f, 0.016990f, + -0.018131f, -0.073032f, -0.014444f, 0.052988f, 0.066205f, -0.028847f, 0.041022f, + -0.028227f, -0.053479f, -0.012696f, 0.059475f, 0.020471f, 0.064025f, 0.053843f, + 0.002226f, -0.009378f, -0.006675f, -0.061330f, -0.016546f, 0.045374f, 0.038021f, + -0.019298f, 0.049954f, -0.040340f, -0.044663f, -0.022905f, 0.044510f, 0.000977f, + 0.038488f, 0.082866f, 0.025464f, -0.019278f, -0.009946f, -0.056392f, -0.003774f, + 0.051014f, 0.046133f, -0.009736f, 0.021107f, -0.040785f, -0.057193f, -0.047951f, + 0.055886f, 0.003465f, 0.078724f, 0.075681f, 0.040318f, 0.006164f, -0.009899f, + -0.067255f, -0.012504f, 0.061307f, 0.063530f, -0.014937f, 0.016265f, -0.035016f, + -0.074253f, -0.016603f, 0.052519f, 0.019856f, 0.065436f, 0.046476f, 0.014571f, + 0.015569f, -0.005469f, -0.070110f, 0.003504f, 0.058781f, 0.054405f, -0.013541f, + 0.035046f, -0.035151f, -0.061428f, -0.041955f, 0.064034f, 0.004731f, 0.079533f, + 0.069533f, 0.006321f, 0.009739f, 0.009868f, -0.046759f, 0.003892f, 0.060610f, + 0.044778f, 0.004380f, -0.013117f, -0.035925f, -0.088403f, -0.036423f, 0.046171f, + -0.005440f, 0.057470f, 0.064779f, 0.022364f, 0.000553f, 0.014907f, -0.062145f, + 0.003694f, 0.063011f, 0.053007f, 0.000731f, 0.003884f, -0.046303f, -0.090317f, + -0.042867f, 0.037300f, -0.004294f, 0.044668f, 0.074411f, 0.030016f, 0.013970f, + 0.002469f, -0.050964f, -0.006501f, 0.059326f, 0.037477f, -0.004060f, 0.006490f, + -0.050532f, -0.076494f, -0.042087f, 0.054995f, -0.000966f, 0.067863f, 0.072168f, + 0.032234f, 0.017786f, -0.011112f, -0.075392f, -0.003143f, 0.052040f, 0.047606f, + -0.019149f, 0.046299f, -0.035092f, -0.041081f, -0.022936f, 0.065448f, 0.005120f, + 0.065054f, 0.074648f, -0.008866f, -0.023949f, 0.005304f, -0.069631f, 0.009495f, + 0.062978f, 0.044818f, 0.007730f, -0.001488f, -0.040640f, -0.066867f, -0.031884f, + 0.052568f, 0.003658f, 0.061925f, 0.062329f, 0.004855f, -0.003895f, 0.114486f, + 0.079661f, -0.115023f, 0.025315f, -0.000117f, -0.070439f, -0.009776f, 0.115430f, + 0.047095f, -0.020249f, 0.001512f, -0.006185f, -0.036645f, 0.003067f, -0.048612f, + -0.035854f, 0.100041f, 0.077085f, -0.109820f, 0.015464f, -0.021206f, -0.063925f, + -0.009368f, 0.121258f, 0.055209f, -0.034103f, 0.008018f, -0.008480f, -0.026955f, + -0.004989f, -0.046626f, -0.017247f, 0.099377f, 0.074604f, -0.113369f, 0.007508f, + -0.004265f, -0.095276f, -0.026419f, 0.115797f, 0.092064f, -0.031276f, 0.007216f, + 0.010462f, -0.008152f, -0.001692f, -0.045870f, -0.039668f, 0.090610f, 0.070008f, + -0.095893f, 0.036339f, -0.001674f, -0.076347f, -0.010227f, 0.120200f, 0.066155f, + -0.008440f, 0.010495f, -0.005206f, -0.039893f, -0.013893f, -0.045189f, -0.045747f, + 0.101430f, 0.064898f, -0.104166f, 0.004913f, 0.012145f, -0.097956f, -0.028537f, + 0.107966f, 0.079144f, -0.029408f, 0.001147f, 0.011118f, 0.002440f, 0.012128f, + -0.048582f, -0.051894f, 0.096293f, 0.093928f, -0.130731f, 0.027540f, -0.020008f, + -0.071251f, -0.015406f, 0.119832f, 0.084302f, -0.018117f, 0.013128f, 0.001765f, + -0.034212f, -0.008191f, -0.050701f, -0.026755f, 0.091574f, 0.058934f, -0.109406f, + 0.031684f, 0.013173f, -0.073829f, -0.022562f, 0.122142f, 0.038862f, -0.031264f, + 0.031566f, -0.011584f, -0.034398f, 0.001449f, -0.050027f, -0.034705f, 0.093171f, + 0.092271f, -0.107576f, 0.039219f, -0.015123f, -0.054276f, -0.009520f, 0.109212f, + 0.061468f, 0.000427f, -0.008970f, -0.002040f, -0.047295f, -0.001064f, -0.047281f, + -0.044790f, 0.096935f, 0.078937f, -0.092781f, 0.036182f, 0.018153f, -0.056738f, + -0.020583f, 0.109120f, 0.059436f, 0.001769f, -0.000911f, -0.003321f, -0.044719f, + 0.010452f, -0.055386f, -0.059634f, 0.102150f, 0.071935f, -0.123576f, 0.025914f, + -0.014051f, -0.072845f, -0.011868f, 0.121021f, 0.055033f, -0.033752f, 0.019387f, + -0.010922f, -0.028995f, -0.004246f, -0.047819f, -0.017015f, 0.105048f, 0.077451f, + -0.111607f, 0.034564f, -0.009339f, -0.068584f, -0.006664f, 0.115148f, 0.050124f, + -0.015425f, 0.001799f, -0.009177f, -0.041747f, -0.004707f, -0.044247f, -0.034380f, + 0.089478f, 0.095989f, -0.120383f, 0.017656f, -0.012592f, -0.064598f, -0.025977f, + 0.108387f, 0.079686f, -0.023188f, -0.005437f, 0.007509f, -0.017324f, 0.016442f, + -0.047355f, -0.041292f, 0.088404f, 0.096468f, -0.106369f, 0.030468f, -0.002639f, + -0.071193f, -0.031953f, 0.110465f, 0.079732f, -0.007768f, -0.008830f, 0.009252f, + -0.039151f, 0.001257f, -0.033481f, -0.059676f, 0.097944f, 0.077180f, -0.121401f, + 0.012640f, 0.007468f, -0.074098f, -0.035356f, 0.119044f, 0.065360f, -0.039596f, + 0.019601f, 0.003659f, -0.011478f, 0.017890f, -0.054258f, -0.035000f, 0.084231f, + 0.097995f, -0.112847f, 0.040351f, -0.010664f, -0.064514f, -0.018276f, 0.105636f, + 0.089613f, 0.009693f, -0.009866f, 0.010465f, -0.039897f, -0.002313f, -0.049904f, + -0.058403f, 0.070641f, 0.071798f, -0.100271f, 0.039201f, -0.008508f, -0.083737f, + -0.027235f, 0.118331f, 0.089070f, -0.008805f, 0.014908f, 0.001467f, -0.036324f, + -0.016604f, -0.039218f, -0.046911f, 0.098816f, 0.076214f, -0.100256f, 0.028171f, + 0.004775f, -0.077152f, -0.019596f, 0.110189f, 0.066893f, -0.010315f, -0.006468f, + 0.000782f, -0.030826f, 0.003657f, -0.043033f, -0.054865f, 0.086425f, 0.084144f, + -0.118041f, 0.027978f, -0.008248f, -0.070637f, -0.022207f, 0.122238f, 0.085020f, + -0.018110f, 0.021525f, 0.001915f, -0.029623f, -0.005897f, -0.053233f, -0.031485f, + 0.087198f, 0.088724f, -0.105858f, 0.035310f, 0.000956f, -0.058345f, -0.024886f, + 0.109864f, 0.074043f, -0.003251f, -0.000581f, 0.002690f, -0.039096f, 0.008599f, + -0.051477f, -0.051773f, 0.091308f, 0.072021f, -0.109828f, 0.022452f, 0.006161f, + -0.081970f, -0.037871f, 0.119763f, 0.065898f, -0.032224f, 0.013211f, 0.000856f, + -0.025350f, 0.007624f, -0.039163f, -0.044446f, 0.118941f, 0.081076f, -0.135352f, + 0.014605f, -0.005394f, -0.076476f, -0.015260f, 0.130009f, 0.054056f, -0.041352f, + 0.026264f, -0.004701f, -0.031809f, -0.001052f, -0.052770f, -0.014205f, 0.108131f, + 0.074984f, -0.117471f, 0.021073f, -0.014417f, -0.085287f, -0.012516f, 0.116266f, + 0.060459f, -0.033259f, 0.000532f, -0.005881f, -0.027825f, -0.007172f, -0.034112f, + -0.028664f, 0.093271f, 0.087344f, -0.111297f, 0.019828f, 0.012828f, -0.075017f, + -0.041729f, 0.122014f, 0.080034f, -0.025663f, 0.017053f, 0.010779f, -0.029446f, + 0.010609f, -0.048868f, -0.050144f, 0.108120f, 0.065707f, -0.121392f, 0.001176f, + 0.013212f, -0.079736f, -0.031379f, 0.120274f, 0.045055f, -0.054592f, 0.025725f, + 0.000130f, 0.001714f, 0.019266f, -0.056784f, -0.029767f, + }; + + const std::size_t o_size = prob.batch * prob.nhead * prob.M * prob.O; + std::vector o_ref(o_size); + auto ref_params = make_ref_params(prob, scale_s); + cpu_attention_ref(q_data, k_data, v_data, o_ref, ref_params); + + EXPECT(allclose(o_ref, numpy_expected, 0.0001, 0.0001)); + + for(std::size_t sol_idx = 0; sol_idx < solutions.size(); ++sol_idx) + { + auto&& solution = solutions[sol_idx]; + std::cout << "Testing solution " << (sol_idx + 1) << "/" << solutions.size() << std::endl; + + auto srcs = get_tile_headers_for_test(); + srcs.push_back({"main.cpp", make_kernel_source(prob, solution, ref_params)}); + + rtc::compile_options options; + options.kernel_name = "f"; + auto kernel = rtc::compile_kernel(srcs, options); + + auto [grid, block] = get_launch_dims(solution, prob); + + rtc::buffer o_host(o_size); + std::fill(o_host.begin(), o_host.end(), half(0.0f)); + auto o_device = to_gpu(o_host); + const auto make_device_buff = [&](const std::vector& data) { + rtc::buffer host(data.size()); + std::transform( + data.begin(), data.end(), host.begin(), [](float val) { return half(val); }); + return to_gpu(host); + }; + auto q_device = make_device_buff(q_data); + auto k_device = make_device_buff(k_data); + auto v_device = make_device_buff(v_data); + + kernel.launch(nullptr, grid, block)(q_device.data(), + k_device.data(), + v_device.data(), + static_cast(nullptr), + o_device.data()); + o_host = rtc::from_gpu(o_device); + std::vector result(o_size); + std::transform(o_host.begin(), o_host.end(), result.begin(), [](half v) { + return static_cast(v); + }); + CHECK(allclose(result, o_ref, 0.0001, 0.0001)); + } +} + +TEST_CASE(test_fmha_fwd_large_dimensions) +{ + ck::host::device_fmha_fwd::Problem prob; + prob.M = 128; // seqlen_q + prob.N = 256; // seqlen_k + prob.K = 64; // hdim_q + prob.O = 64; // hdim_v + prob.batch = 4; + prob.nhead = 8; + prob.dtype = ck::host::DataType::Half; + prob.is_v_rowmajor = true; + prob.is_causal = false; + prob.has_bias = false; + + const float scale_s = 1.0f / std::sqrt(static_cast(prob.K)); + + auto solutions = prob.GetSolutions("gfx90a"); + std::cout << "Number of solutions: " << solutions.size() << std::endl; + + EXPECT(!solutions.empty()); + + const std::size_t q_size = prob.batch * prob.nhead * prob.M * prob.K; + const std::size_t k_size = prob.batch * prob.nhead * prob.N * prob.K; + const std::size_t v_size = prob.batch * prob.nhead * prob.N * prob.O; + const std::size_t o_size = prob.batch * prob.nhead * prob.M * prob.O; + + std::mt19937 rng(42); + std::uniform_real_distribution dist(-0.5f, 0.5f); + + rtc::buffer q_host(q_size), k_host(k_size), v_host(v_size); + std::vector q_ref(q_size), k_ref(k_size), v_ref(v_size), o_ref(o_size); + + auto fill_buffers = [&](auto& host, auto& ref) { + for(std::size_t i = 0; i < host.size(); ++i) + { + float val = dist(rng); + host[i] = half(val); + ref[i] = val; + } + }; + fill_buffers(q_host, q_ref); + fill_buffers(k_host, k_ref); + fill_buffers(v_host, v_ref); + + auto ref_params = make_ref_params(prob, scale_s); + cpu_attention_ref(q_ref, k_ref, v_ref, o_ref, ref_params); + + for(std::size_t sol_idx = 0; sol_idx < solutions.size(); ++sol_idx) + { + auto&& solution = solutions[sol_idx]; + std::cout << "Testing solution " << (sol_idx + 1) << "/" << solutions.size() << std::endl; + + auto srcs = get_tile_headers_for_test(); + srcs.push_back({"main.cpp", make_kernel_source(prob, solution, ref_params)}); + + rtc::compile_options options; + options.kernel_name = "f"; + auto kernel = rtc::compile_kernel(srcs, options); + + auto [grid, block] = get_launch_dims(solution, prob); + + rtc::buffer o_host(o_size); + std::fill(o_host.begin(), o_host.end(), half(0.0f)); + auto o_device = to_gpu(o_host); + auto q_device = to_gpu(q_host); + auto k_device = to_gpu(k_host); + auto v_device = to_gpu(v_host); + kernel.launch(nullptr, grid, block)(q_device.data(), + k_device.data(), + v_device.data(), + static_cast(nullptr), + o_device.data()); + o_host = rtc::from_gpu(o_device); + std::vector result(o_size); + std::transform(o_host.begin(), o_host.end(), result.begin(), [](half v) { + return static_cast(v); + }); + + CHECK(allclose(o_ref, result, 0.0001, 0.0001)); + } +} + +TEST_CASE(test_fmha_fwd_512x512_hdim32) +{ + ck::host::device_fmha_fwd::Problem prob; + prob.M = 512; // seqlen_q + prob.N = 512; // seqlen_k + prob.K = 32; // hdim_q + prob.O = 32; // hdim_v + prob.batch = 2; + prob.nhead = 4; + prob.dtype = ck::host::DataType::Half; + prob.is_v_rowmajor = true; + prob.is_causal = false; + prob.has_bias = false; + + const float scale_s = 1.0f / std::sqrt(static_cast(prob.K)); + + auto solutions = prob.GetSolutions("gfx90a"); + std::cout << "Number of solutions: " << solutions.size() << std::endl; + + EXPECT(!solutions.empty()); + + const std::size_t q_size = prob.batch * prob.nhead * prob.M * prob.K; + const std::size_t k_size = prob.batch * prob.nhead * prob.N * prob.K; + const std::size_t v_size = prob.batch * prob.nhead * prob.N * prob.O; + const std::size_t o_size = prob.batch * prob.nhead * prob.M * prob.O; + + std::mt19937 rng(44); + std::uniform_real_distribution dist(-0.5f, 0.5f); + + rtc::buffer q_host(q_size), k_host(k_size), v_host(v_size); + std::vector q_ref(q_size), k_ref(k_size), v_ref(v_size), o_ref(o_size); + + auto fill_buffers = [&](auto& host, auto& ref) { + for(std::size_t i = 0; i < host.size(); ++i) + { + float val = dist(rng); + host[i] = half(val); + ref[i] = val; + } + }; + fill_buffers(q_host, q_ref); + fill_buffers(k_host, k_ref); + fill_buffers(v_host, v_ref); + + auto ref_params = make_ref_params(prob, scale_s); + cpu_attention_ref(q_ref, k_ref, v_ref, o_ref, ref_params); + + for(std::size_t sol_idx = 0; sol_idx < solutions.size(); ++sol_idx) + { + auto&& solution = solutions[sol_idx]; + std::cout << "Testing solution " << (sol_idx + 1) << "/" << solutions.size() << std::endl; + + auto srcs = get_tile_headers_for_test(); + srcs.push_back({"main.cpp", make_kernel_source(prob, solution, ref_params)}); + + rtc::compile_options options; + options.kernel_name = "f"; + auto kernel = rtc::compile_kernel(srcs, options); + + auto [grid, block] = get_launch_dims(solution, prob); + + rtc::buffer o_host(o_size); + std::fill(o_host.begin(), o_host.end(), half(0.0f)); + auto o_device = to_gpu(o_host); + auto q_device = to_gpu(q_host); + auto k_device = to_gpu(k_host); + auto v_device = to_gpu(v_host); + kernel.launch(nullptr, grid, block)(q_device.data(), + k_device.data(), + v_device.data(), + static_cast(nullptr), + o_device.data()); + o_host = rtc::from_gpu(o_device); + std::vector result(o_size); + std::transform(o_host.begin(), o_host.end(), result.begin(), [](half v) { + return static_cast(v); + }); + + CHECK(allclose(o_ref, result, 0.0001, 0.0001)); + } +} + +TEST_CASE(test_fmha_fwd_with_bias) +{ + ck::host::device_fmha_fwd::Problem prob; + prob.M = 64; // seqlen_q + prob.N = 128; // seqlen_k + prob.K = 32; // hdim_q + prob.O = 32; // hdim_v + prob.batch = 2; + prob.nhead = 4; + prob.dtype = ck::host::DataType::Half; + prob.is_v_rowmajor = true; + prob.is_causal = false; + prob.has_bias = true; + + const float scale_s = 1.0f / std::sqrt(static_cast(prob.K)); + + auto solutions = prob.GetSolutions("gfx90a"); + std::cout << "Number of solutions: " << solutions.size() << std::endl; + + EXPECT(!solutions.empty()); + + const std::size_t q_size = prob.batch * prob.nhead * prob.M * prob.K; + const std::size_t k_size = prob.batch * prob.nhead * prob.N * prob.K; + const std::size_t v_size = prob.batch * prob.nhead * prob.N * prob.O; + const std::size_t o_size = prob.batch * prob.nhead * prob.M * prob.O; + const std::size_t bias_size = prob.M * prob.N; // Only [M, N], broadcast across batch/nhead + + std::mt19937 rng(43); + std::uniform_real_distribution dist(-0.5f, 0.5f); + std::uniform_real_distribution bias_dist(-0.1f, 0.1f); + + rtc::buffer q_host(q_size), k_host(k_size), v_host(v_size), bias_host(bias_size); + std::vector q_ref(q_size), k_ref(k_size), v_ref(v_size), bias_ref(bias_size), + o_ref(o_size); + auto fill_buffers = [&](auto& host, auto& ref, auto& dist) { + for(std::size_t i = 0; i < host.size(); ++i) + { + float val = dist(rng); + host[i] = half(val); + ref[i] = val; + } + }; + fill_buffers(q_host, q_ref, dist); + fill_buffers(k_host, k_ref, dist); + fill_buffers(v_host, v_ref, dist); + fill_buffers(bias_host, bias_ref, bias_dist); + + auto ref_params = make_ref_params(prob, scale_s); + ref_params.bias_stride_m = prob.N; + ref_params.bias_stride_nhead = 0; + ref_params.bias_stride_batch = 0; + cpu_attention_ref(q_ref, k_ref, v_ref, o_ref, &bias_ref, ref_params); + + for(std::size_t sol_idx = 0; sol_idx < solutions.size(); ++sol_idx) + { + auto&& solution = solutions[sol_idx]; + std::cout << "Testing solution " << (sol_idx + 1) << "/" << solutions.size() << std::endl; + + auto srcs = get_tile_headers_for_test(); + srcs.push_back({"main.cpp", make_kernel_source(prob, solution, ref_params)}); + + rtc::compile_options options; + options.kernel_name = "f"; + auto kernel = rtc::compile_kernel(srcs, options); + + auto [grid, block] = get_launch_dims(solution, prob); + + rtc::buffer o_host(o_size); + std::fill(o_host.begin(), o_host.end(), half(0.0f)); + auto o_device = to_gpu(o_host); + auto q_device = to_gpu(q_host); + auto k_device = to_gpu(k_host); + auto v_device = to_gpu(v_host); + auto bias_device = to_gpu(bias_host); + kernel.launch(nullptr, grid, block)( + q_device.data(), k_device.data(), v_device.data(), bias_device.data(), o_device.data()); + o_host = rtc::from_gpu(o_device); + std::vector result(o_size); + std::transform(o_host.begin(), o_host.end(), result.begin(), [](half v) { + return static_cast(v); + }); + + CHECK(allclose(result, o_ref, 0.0001, 0.0001)); + } +} + +TEST_CASE(benchmark_fmha_fwd) +{ + // Benchmark configuration - matches common example settings + ck::host::device_fmha_fwd::Problem prob; + prob.M = 1024; // seqlen_q + prob.N = 512; // seqlen_k + prob.K = 128; // hdim_q + prob.O = 256; // hdim_v + prob.batch = 2; + prob.nhead = 4; + prob.dtype = ck::host::DataType::Half; + prob.is_v_rowmajor = true; + prob.is_causal = false; + prob.has_bias = false; + + const float scale_s = 1.0f / std::sqrt(static_cast(prob.K)); + + constexpr int warmup_iters = 1; + constexpr int bench_iters = 1000; + + auto solutions = prob.GetSolutions("gfx90a"); + std::cout << "Number of solutions: " << solutions.size() << std::endl; + + EXPECT(!solutions.empty()); + + const std::size_t q_size = prob.batch * prob.nhead * prob.M * prob.K; + const std::size_t k_size = prob.batch * prob.nhead * prob.N * prob.K; + const std::size_t v_size = prob.batch * prob.nhead * prob.N * prob.O; + const std::size_t o_size = prob.batch * prob.nhead * prob.M * prob.O; + + // Initialize with random data and create reference buffers + std::mt19937 rng(42); + std::uniform_real_distribution dist(-0.5f, 0.5f); + + rtc::buffer q_host(q_size), k_host(k_size), v_host(v_size); + std::vector q_ref(q_size), k_ref(k_size), v_ref(v_size), o_ref(o_size); + + auto fill_buffers = [&](auto& host, auto& ref) { + for(std::size_t i = 0; i < host.size(); ++i) + { + float val = dist(rng); + host[i] = half(val); + ref[i] = val; + } + }; + fill_buffers(q_host, q_ref); + fill_buffers(k_host, k_ref); + fill_buffers(v_host, v_ref); + + auto ref_params = make_ref_params(prob, scale_s); + + // Compute reference output + cpu_attention_ref(q_ref, k_ref, v_ref, o_ref, ref_params); + + // Calculate FLOPs for FMHA: + // Gemm0: Q @ K^T = [batch, nhead, M, K] @ [batch, nhead, K, N] -> [batch, nhead, M, N] + // FLOPs = 2 * batch * nhead * M * N * K + // Gemm1: softmax(Gemm0) @ V = [batch, nhead, M, N] @ [batch, nhead, N, O] -> [batch, nhead, M, O] + // FLOPs = 2 * batch * nhead * M * N * O + const double flops = + 2.0 * prob.batch * prob.nhead * prob.M * prob.N * prob.K + + 2.0 * prob.batch * prob.nhead * prob.M * prob.N * prob.O; + + std::cout << "\n=== FMHA Forward Benchmark ===" << std::endl; + std::cout << "Problem: batch=" << prob.batch << ", nhead=" << prob.nhead + << ", M=" << prob.M << ", N=" << prob.N + << ", K=" << prob.K << ", O=" << prob.O << std::endl; + std::cout << "Warmup: " << warmup_iters << ", Iterations: " << bench_iters << std::endl; + std::cout << "FLOPs per forward: " << flops / 1e9 << " GFLOPs\n" << std::endl; + + // Create HIP events for timing + hipEvent_t start_evt, stop_evt; + (void)hipEventCreate(&start_evt); + (void)hipEventCreate(&stop_evt); + + std::vector timing_results(solutions.size(), std::numeric_limits::max()); + for(std::size_t sol_idx = 0; sol_idx < solutions.size(); ++sol_idx) + { + auto&& solution = solutions[sol_idx]; + std::cout << "Solution " << (sol_idx + 1) << "/" << solutions.size() << ": " + << solution.ToTemplateString() << std::endl; + + auto srcs = get_tile_headers_for_test(); + srcs.push_back({"main.cpp", make_kernel_source(prob, solution, ref_params)}); + + rtc::compile_options options; + options.kernel_name = "f"; + + try { + auto kernel = rtc::compile_kernel(srcs, options); + + auto [grid, block] = get_launch_dims(solution, prob); + + rtc::buffer o_host(o_size); + std::fill(o_host.begin(), o_host.end(), half(0.0f)); + auto o_device = to_gpu(o_host); + auto q_device = to_gpu(q_host); + auto k_device = to_gpu(k_host); + auto v_device = to_gpu(v_host); + + // Warmup + for(int i = 0; i < warmup_iters; ++i) + { + kernel.launch(nullptr, grid, block)(q_device.data(), + k_device.data(), + v_device.data(), + static_cast(nullptr), + o_device.data()); + } + (void)hipDeviceSynchronize(); + + // Validate result after warmup + o_host = rtc::from_gpu(o_device); + std::vector result(o_size); + std::transform(o_host.begin(), o_host.end(), result.begin(), [](half v) { + return static_cast(v); + }); + bool valid = allclose(o_ref, result, 0.0001, 0.0001); + + // Benchmark + (void)hipEventRecord(start_evt, nullptr); + for(int i = 0; i < bench_iters; ++i) + { + kernel.launch(nullptr, grid, block)(q_device.data(), + k_device.data(), + v_device.data(), + static_cast(nullptr), + o_device.data()); + } + (void)hipEventRecord(stop_evt, nullptr); + (void)hipEventSynchronize(stop_evt); + + float total_ms = 0.0f; + (void)hipEventElapsedTime(&total_ms, start_evt, stop_evt); + float avg_ms = total_ms / bench_iters; + if(valid) { + timing_results[sol_idx] = avg_ms; + } + double tflops = flops / (avg_ms * 1e-3) / 1e12; + + std::cout << " Time: " << avg_ms << " ms (avg over " << bench_iters << " iters)" + << std::endl; + // std::cout << " Throughput: " << tflops << " TFLOPs/s" << std::endl; + std::cout << " Valid: " << (valid ? "yes" : "NO") << std::endl; + std::cout << std::endl; + + CHECK(valid); + } catch (...) { + std::cout << "COMPILE ERROR" << std::endl; + CHECK(false); + continue; + } + } + auto it = std::min_element(timing_results.begin(), timing_results.end()); + std::size_t best_idx = std::distance(timing_results.begin(), it); + std::cout << "Best solution: " << *it << "ms, " << solutions[best_idx].ToTemplateString() << std::endl; + + (void)hipEventDestroy(start_evt); + (void)hipEventDestroy(stop_evt); +} + +TEST_CASE(sweep_fmha_fwd) +{ + std::vector seqlens_q{512, 1024, 2048, 4096}; + std::vector seqlens_k{512, 1024, 2048, 4096}; + std::vector hdims_q{32, 48, 64, 80, 96, 128, 192, 256}; + std::vector hdims_v{32, 48, 64, 80, 96, 128, 192, 256}; + + constexpr int batch_size = 2; + constexpr int num_heads = 4; + + int total_configs = 0; + int total_solutions = 0; + int total_passed = 0; + int total_failed = 0; + int total_compile_errors = 0; + int seed_counter = 0; + + struct FailureInfo + { + std::size_t M, N, K, O; + std::string solution; + std::string reason; + }; + std::vector failures; + + for(std::size_t M : seqlens_q) + { + for(std::size_t N : seqlens_k) + { + for(std::size_t K : hdims_q) + { + for(std::size_t O : hdims_v) + { + total_configs++; + + ck::host::device_fmha_fwd::Problem prob; + prob.M = M; + prob.N = N; + prob.K = K; + prob.O = O; + prob.batch = batch_size; + prob.nhead = num_heads; + prob.dtype = ck::host::DataType::Half; + prob.is_v_rowmajor = true; + prob.is_causal = false; + prob.has_bias = false; + + auto solutions = prob.GetSolutions("gfx90a"); + if(solutions.empty()) + { + std::cout << "Config M=" << M << ", N=" << N << ", K=" << K << ", O=" << O + << ": No solutions available" << std::endl; + continue; + } + + std::cout << "\n=== Config M=" << M << ", N=" << N << ", K=" << K << ", O=" << O + << " (" << solutions.size() << " solutions) ===" << std::endl; + + const float scale_s = 1.0f / std::sqrt(static_cast(K)); + + const std::size_t q_size = batch_size * num_heads * M * K; + const std::size_t k_size = batch_size * num_heads * N * K; + const std::size_t v_size = batch_size * num_heads * N * O; + const std::size_t o_size = batch_size * num_heads * M * O; + + std::mt19937 rng(42 + seed_counter++); + std::uniform_real_distribution dist(-0.5f, 0.5f); + + rtc::buffer q_host(q_size), k_host(k_size), v_host(v_size); + std::vector q_ref(q_size), k_ref(k_size), v_ref(v_size), o_ref(o_size); + + auto fill_buffers = [&](auto& host, auto& ref) { + for(std::size_t i = 0; i < host.size(); ++i) + { + float val = dist(rng); + host[i] = half(val); + ref[i] = val; + } + }; + fill_buffers(q_host, q_ref); + fill_buffers(k_host, k_ref); + fill_buffers(v_host, v_ref); + + auto ref_params = make_ref_params(prob, scale_s); + cpu_attention_ref(q_ref, k_ref, v_ref, o_ref, ref_params); + + auto q_device = to_gpu(q_host); + auto k_device = to_gpu(k_host); + auto v_device = to_gpu(v_host); + + for(std::size_t sol_idx = 0; sol_idx < solutions.size(); ++sol_idx) + { + total_solutions++; + auto&& solution = solutions[sol_idx]; + std::string sol_str = solution.ToTemplateString(); + + std::cout << " [" << (sol_idx + 1) << "/" << solutions.size() << "] "; + + try + { + auto srcs = get_tile_headers_for_test(); + srcs.push_back({"main.cpp", make_kernel_source(prob, solution, ref_params)}); + + rtc::compile_options options; + options.kernel_name = "f"; + auto kernel = rtc::compile_kernel(srcs, options); + + auto [grid, block] = get_launch_dims(solution, prob); + + rtc::buffer o_host(o_size); + std::fill(o_host.begin(), o_host.end(), half(0.0f)); + auto o_device = to_gpu(o_host); + + kernel.launch(nullptr, grid, block)(q_device.data(), + k_device.data(), + v_device.data(), + static_cast(nullptr), + o_device.data()); + (void)hipDeviceSynchronize(); + + o_host = rtc::from_gpu(o_device); + std::vector result(o_size); + std::transform(o_host.begin(), o_host.end(), result.begin(), [](half v) { + return static_cast(v); + }); + + bool valid = allclose(o_ref, result, 0.0001, 0.0001); + if(valid) + { + std::cout << "PASS" << std::endl; + total_passed++; + } + else + { + std::cout << "FAIL (incorrect result)" << std::endl; + total_failed++; + failures.push_back({M, N, K, O, sol_str, "incorrect result"}); + } + } + catch(const std::exception& e) + { + std::cout << "COMPILE ERROR: " << e.what() << std::endl; + total_compile_errors++; + failures.push_back({M, N, K, O, sol_str, std::string("compile error: ") + e.what()}); + } + catch(...) + { + std::cout << "COMPILE ERROR (unknown)" << std::endl; + total_compile_errors++; + failures.push_back({M, N, K, O, sol_str, "compile error: unknown"}); + } + } + } + } + } + } + + std::cout << "\n========================================" << std::endl; + std::cout << " SWEEP SUMMARY" << std::endl; + std::cout << "========================================" << std::endl; + std::cout << "Total configs tested: " << total_configs << std::endl; + std::cout << "Total solutions tested: " << total_solutions << std::endl; + std::cout << "Passed: " << total_passed << std::endl; + std::cout << "Failed (incorrect result): " << total_failed << std::endl; + std::cout << "Compile errors: " << total_compile_errors << std::endl; + + if(!failures.empty()) + { + std::cout << "\n========================================" << std::endl; + std::cout << " FAILURES" << std::endl; + std::cout << "========================================" << std::endl; + for(const auto& f : failures) + { + std::cout << "M=" << f.M << ", N=" << f.N << ", K=" << f.K << ", O=" << f.O << std::endl; + std::cout << " Solution: " << f.solution << std::endl; + std::cout << " Reason: " << f.reason << std::endl; + } + } + + EXPECT(total_failed == 0 && total_compile_errors == 0); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/codegen/test/fmha_fwd_ref.hpp b/codegen/test/fmha_fwd_ref.hpp new file mode 100644 index 00000000000..25644e5f5dd --- /dev/null +++ b/codegen/test/fmha_fwd_ref.hpp @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +namespace ck { +namespace host { +namespace device_fmha_fwd { + +struct FmhaFwdRefParams +{ + std::size_t batch; + std::size_t nhead; + std::size_t M; // seqlen_q + std::size_t N; // seqlen_k + std::size_t K; // hdim_q + std::size_t O; // hdim_v + + float scale_s; + + std::size_t q_stride_batch; + std::size_t q_stride_nhead; + std::size_t q_stride_m; + + std::size_t k_stride_batch; + std::size_t k_stride_nhead; + std::size_t k_stride_n; + + std::size_t v_stride_batch; + std::size_t v_stride_nhead; + std::size_t v_stride_n; + + std::size_t o_stride_batch; + std::size_t o_stride_nhead; + std::size_t o_stride_m; + + std::size_t bias_stride_batch = 0; + std::size_t bias_stride_nhead = 0; + std::size_t bias_stride_m = 0; +}; + +// O = softmax(Q @ K^T * scale_s + bias) @ V +// bias is optional (nullptr = no bias) +inline void cpu_attention_ref(const std::vector& q, + const std::vector& k, + const std::vector& v, + std::vector& o, + const std::vector* bias, + const FmhaFwdRefParams& p) +{ + for(std::size_t b = 0; b < p.batch; ++b) + { + for(std::size_t h = 0; h < p.nhead; ++h) + { + const float* q_ptr = q.data() + b * p.q_stride_batch + h * p.q_stride_nhead; + const float* k_ptr = k.data() + b * p.k_stride_batch + h * p.k_stride_nhead; + const float* v_ptr = v.data() + b * p.v_stride_batch + h * p.v_stride_nhead; + const float* bias_ptr = + bias ? (bias->data() + b * p.bias_stride_batch + h * p.bias_stride_nhead) : nullptr; + float* o_ptr = o.data() + b * p.o_stride_batch + h * p.o_stride_nhead; + + for(std::size_t m = 0; m < p.M; ++m) + { + // Q[m,:] @ K^T -> [N] + std::vector scores(p.N); + for(std::size_t n = 0; n < p.N; ++n) + { + float dot = 0.0f; + for(std::size_t kk = 0; kk < p.K; ++kk) + { + dot += q_ptr[m * p.q_stride_m + kk] * k_ptr[n * p.k_stride_n + kk]; + } + scores[n] = dot * p.scale_s; + + if(bias_ptr) + { + scores[n] += bias_ptr[m * p.bias_stride_m + n]; + } + } + + // Softmax + float max_score = *std::max_element(scores.begin(), scores.end()); + float sum_exp = 0.0f; + for(std::size_t n = 0; n < p.N; ++n) + { + scores[n] = std::exp(scores[n] - max_score); + sum_exp += scores[n]; + } + for(std::size_t n = 0; n < p.N; ++n) + { + scores[n] /= sum_exp; + } + + // Output: attn @ V -> [O] + for(std::size_t oo = 0; oo < p.O; ++oo) + { + float val = 0.0f; + for(std::size_t n = 0; n < p.N; ++n) + { + val += scores[n] * v_ptr[n * p.v_stride_n + oo]; + } + o_ptr[m * p.o_stride_m + oo] = val; + } + } + } + } +} + +inline void cpu_attention_ref(const std::vector& q, + const std::vector& k, + const std::vector& v, + std::vector& o, + const FmhaFwdRefParams& p) +{ + cpu_attention_ref(q, k, v, o, nullptr, p); +} + +} // namespace device_fmha_fwd +} // namespace host +} // namespace ck diff --git a/codegen/test/gen_fmha_ref.py b/codegen/test/gen_fmha_ref.py new file mode 100644 index 00000000000..be54b4765aa --- /dev/null +++ b/codegen/test/gen_fmha_ref.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +"""Generate reference values for FMHA forward test.""" + +import numpy as np +from scipy.special import softmax + +# Test parameters +BATCH = 2 # batch size +NHEAD = 1 # number of heads +M = 24 # seqlen_q +N = 32 # seqlen_k +K = 8 # hdim_q +O = 16 # hdim_v (different from K) +scale_s = 1.0 +SEED = 42 # Fixed seed for reproducibility + +# Generate random inputs with fixed seed +# Layout: [batch, nhead, seqlen, hdim] +np.random.seed(SEED) +q = np.random.uniform(-0.5, 0.5, (BATCH, NHEAD, M, K)).astype(np.float32) +k = np.random.uniform(-0.5, 0.5, (BATCH, NHEAD, N, K)).astype(np.float32) +v = np.random.uniform(-0.5, 0.5, (BATCH, NHEAD, N, O)).astype(np.float32) + +print("Q shape:", q.shape) +print("K shape:", k.shape) +print("V shape:", v.shape) + +# Compute attention per batch/head: O = softmax(Q @ K^T * scale_s) @ V +output = np.zeros((BATCH, NHEAD, M, O), dtype=np.float32) +for b in range(BATCH): + for h in range(NHEAD): + scores = np.matmul(q[b, h], k[b, h].T) * scale_s # [M, N] + attn = softmax(scores, axis=-1) # [M, N] + output[b, h] = np.matmul(attn, v[b, h]) # [M, O] + +def print_cpp_array(name, arr): + flat = arr.flatten() + print(f"\n// C++ array ({len(flat)} values):") + print(f"const float {name}[] = {{") + for i in range(0, len(flat), 8): + line = ", ".join(f"{flat[i+j]:.6f}f" for j in range(min(8, len(flat) - i))) + print(f" {line},") + print("};") + +# Print input arrays for C++ +print_cpp_array("q_data", q) +print_cpp_array("k_data", k) +print_cpp_array("v_data", v) +print_cpp_array("numpy_expected", output) + +# Print first/last for quick comparison +print("\nFirst 5 outputs:", output.flatten()[:5]) +print("Last 5 outputs:", output.flatten()[-5:]) + +# Print test parameters for reference +print(f"\n// Problem dimensions:") +print(f"// BATCH={BATCH}, NHEAD={NHEAD}, M={M}, N={N}, K={K}, O={O}") diff --git a/codegen/test/include/common.hpp b/codegen/test/include/common.hpp index 2cf8bec430b..1e802e879e1 100644 --- a/codegen/test/include/common.hpp +++ b/codegen/test/include/common.hpp @@ -34,6 +34,26 @@ inline const std::vector& get_headers_for_test() return headers; } +inline std::vector create_tile_headers_for_test() +{ + auto headers = ck::host::GetTileHeaders(); + std::vector result; + std::transform(headers.begin(), headers.end(), std::back_inserter(result), [](auto& p) { + std::string content; + content.reserve(p.second.size() + 1); + content.push_back(' '); // We need a whitespace before the content for hipRTC to work + content.append(p.second.data(), p.second.size()); + return rtc::src_file{p.first, std::move(content)}; + }); + return result; +} + +inline const std::vector& get_tile_headers_for_test() +{ + static const std::vector headers = create_tile_headers_for_test(); + return headers; +} + template std::size_t GetSize(V mLens, V mStrides) { diff --git a/codegen/test/rtc/CMakeLists.txt b/codegen/test/rtc/CMakeLists.txt index 68b43d0dd93..c8c40ee7b46 100644 --- a/codegen/test/rtc/CMakeLists.txt +++ b/codegen/test/rtc/CMakeLists.txt @@ -1,15 +1,22 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -find_package(hip) +option(USE_HIPRTC_FOR_CODEGEN_TESTS "Whether to enable hipRTC for codegen tests." ON) + +find_package(hip REQUIRED) +if(USE_HIPRTC_FOR_CODEGEN_TESTS) + find_package(hiprtc REQUIRED) +endif() + file(GLOB RTC_SOURCES CONFIGURE_DEPENDS src/*.cpp) add_library(ck_rtc ${RTC_SOURCES}) target_include_directories(ck_rtc PUBLIC include) -target_link_libraries(ck_rtc PUBLIC hip::host) -target_link_libraries(ck_rtc PUBLIC -lstdc++fs) +target_link_libraries(ck_rtc PUBLIC hip::host -lstdc++fs) -option(USE_HIPRTC_FOR_CODEGEN_TESTS "Whether to enable hipRTC for codegen tests." ON) if(USE_HIPRTC_FOR_CODEGEN_TESTS) + target_link_libraries(ck_rtc PUBLIC hiprtc::hiprtc) target_compile_definitions(ck_rtc PUBLIC HIPRTC_FOR_CODEGEN_TESTS) - message(STATUS "CK compiled with USE_HIPRTC_FOR_CODEGEN_TESTS set to ${USE_HIPRTC_FOR_CODEGEN_TESTS}") -endif() + message(STATUS "CK codegen tests: hipRTC enabled") +else() + message(STATUS "CK codegen tests: hipRTC disabled") +endif() \ No newline at end of file diff --git a/codegen/test/rtc/include/rtc/kernel.hpp b/codegen/test/rtc/include/rtc/kernel.hpp index 9fcb0501091..f980092e3de 100644 --- a/codegen/test/rtc/include/rtc/kernel.hpp +++ b/codegen/test/rtc/include/rtc/kernel.hpp @@ -49,6 +49,11 @@ struct kernel std::size_t local, std::vector args) const; + void launch(hipStream_t stream, + dim3 grid, + dim3 block, + const std::vector& args) const; + template auto launch(hipStream_t stream, std::size_t global, std::size_t local, Ts... zs) const { @@ -57,6 +62,14 @@ struct kernel }; } + template + auto launch(hipStream_t stream, dim3 grid, dim3 block, Ts... zs) const + { + return [=, this](auto&&... xs) { + launch(stream, grid, block, std::vector{xs...}, zs...); + }; + } + private: std::shared_ptr impl; }; diff --git a/codegen/test/rtc/src/compile_kernel.cpp b/codegen/test/rtc/src/compile_kernel.cpp index a779409278f..5389b5685a0 100644 --- a/codegen/test/rtc/src/compile_kernel.cpp +++ b/codegen/test/rtc/src/compile_kernel.cpp @@ -95,6 +95,8 @@ kernel clang_compile_kernel(const std::vector& srcs, compile_options o tmp_dir td{"compile"}; options.flags += " -I. -O3"; options.flags += " -std=c++20"; + // options.flags += " -DCK_TILE_FMHA_FWD_FAST_EXP2=1"; + // options.flags += " -fgpu-flush-denormals-to-zero"; options.flags += " --offload-arch=" + get_device_name(); std::string out; @@ -280,6 +282,8 @@ static kernel hiprtc_compile_kernel(const std::vector& srcs, compile_o options.flags += " -I. -O3"; options.flags += " -std=c++20"; options.flags += " -DCK_CODE_GEN_RTC"; + // options.flags += " -DCK_TILE_FMHA_FWD_FAST_EXP2=1"; + // options.flags += " -fgpu-flush-denormals-to-zero"; options.flags += " --offload-arch=" + get_device_name(); auto cos = compile_hip_src_with_hiprtc(srcs, options); if(cos.size() != 1) diff --git a/codegen/test/rtc/src/kernel.cpp b/codegen/test/rtc/src/kernel.cpp index 1dbd677a86d..a92839ebb3f 100644 --- a/codegen/test/rtc/src/kernel.cpp +++ b/codegen/test/rtc/src/kernel.cpp @@ -122,4 +122,44 @@ void kernel::launch(hipStream_t stream, launch_kernel(impl->fun, stream, global, local, kernargs.data(), size); } +static void launch_kernel_3d( + hipFunction_t fun, hipStream_t stream, dim3 grid, dim3 block, void* kernargs, std::size_t size) +{ + assert(grid.x > 0 && grid.y > 0 && grid.z > 0); + assert(block.x > 0 && block.y > 0 && block.z > 0); + void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, + kernargs, + HIP_LAUNCH_PARAM_BUFFER_SIZE, + &size, + HIP_LAUNCH_PARAM_END}; + + auto status = hipExtModuleLaunchKernel(fun, + grid.x * block.x, + grid.y * block.y, + grid.z * block.z, + block.x, + block.y, + block.z, + 0, + stream, + nullptr, + reinterpret_cast(&config), + nullptr, + nullptr); + if(status != hipSuccess) + throw std::runtime_error("Failed to launch kernel: " + hip_error(status)); +} + +void kernel::launch(hipStream_t stream, + dim3 grid, + dim3 block, + const std::vector& args) const +{ + assert(impl != nullptr); + std::vector kernargs = pack_args(args); + std::size_t size = kernargs.size(); + + launch_kernel_3d(impl->fun, stream, grid, block, kernargs.data(), size); +} + } // namespace rtc