diff --git a/cuda_bindings/cuda/bindings/driver.pxd.in b/cuda_bindings/cuda/bindings/driver.pxd.in index ed992b8bd0..ab152a0b6d 100644 --- a/cuda_bindings/cuda/bindings/driver.pxd.in +++ b/cuda_bindings/cuda/bindings/driver.pxd.in @@ -3211,8 +3211,9 @@ cdef class CUtensorMap_st: getPtr() Get memory address of class instance """ - cdef cydriver.CUtensorMap_st _pvt_val + cdef void* _pvt_buf cdef cydriver.CUtensorMap_st* _pvt_ptr + cdef bint _owns_buf {{endif}} {{if 'CUDA_POINTER_ATTRIBUTE_P2P_TOKENS_st' in found_struct}} diff --git a/cuda_bindings/cuda/bindings/driver.pyx.in b/cuda_bindings/cuda/bindings/driver.pyx.in index 60f510dde2..795d4be0c0 100644 --- a/cuda_bindings/cuda/bindings/driver.pyx.in +++ b/cuda_bindings/cuda/bindings/driver.pyx.in @@ -6,6 +6,7 @@ from typing import Any, Optional import cython import ctypes from libc.stdlib cimport calloc, malloc, free +from libc.string cimport memset from libc cimport string from libc.stdint cimport int32_t, uint32_t, int64_t, uint64_t, uintptr_t from libc.stddef cimport wchar_t @@ -16,6 +17,8 @@ from cpython.bytes cimport PyBytes_FromStringAndSize from ._internal._fast_enum import FastEnum as _FastEnum import cuda.bindings.driver from libcpp.map cimport map +cdef extern from "" nogil: + void *aligned_alloc(size_t alignment, size_t size) _driver = globals() include "_lib/utils.pxi" @@ -18133,13 +18136,23 @@ cdef class CUtensorMap_st: """ def __cinit__(self, void_ptr _ptr = 0): if _ptr == 0: - self._pvt_ptr = &self._pvt_val - else: + self._pvt_buf = aligned_alloc(64, sizeof(cydriver.CUtensorMap_st)) + if self._pvt_buf is NULL: + raise MemoryError("Failed to allocate 64-byte aligned CUtensorMap") + memset(self._pvt_buf, 0, sizeof(cydriver.CUtensorMap_st)) + self._pvt_ptr = self._pvt_buf + self._owns_buf = True + else: + self._pvt_buf = NULL self._pvt_ptr = _ptr + self._owns_buf = False def __init__(self, void_ptr _ptr = 0): pass def __dealloc__(self): - pass + if self._owns_buf and self._pvt_buf is not NULL: + free(self._pvt_buf) + self._pvt_buf = NULL + self._pvt_ptr = NULL def getPtr(self): return self._pvt_ptr def __repr__(self):