Skip to content

Commit 8cd06ac

Browse files
committed
update
1 parent 47f1386 commit 8cd06ac

File tree

1 file changed

+109
-9
lines changed

1 file changed

+109
-9
lines changed

workloads/gromacs/mpi_cxl_shim.c

Lines changed: 109 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,66 @@ static inline bool is_cxl_ptr(const void *ptr) {
585585
(uintptr_t)ptr < (uintptr_t)g_cxl.base + g_cxl.size);
586586
}
587587

588+
// ============================================================================
589+
// Global memset/memcpy/memmove interceptors
590+
// ============================================================================
591+
// Open MPI (and other libraries) may call libc memset/memcpy on CXL memory.
592+
// Glibc's AVX-512 optimized versions trigger SIGILL on CXL Type-3 device pages.
593+
// We intercept these globally and route CXL-targeted calls to safe (non-SIMD)
594+
// implementations. The overhead for non-CXL calls is a single pointer range check.
595+
596+
static typeof(memset) *__real_memset = NULL;
597+
static typeof(memcpy) *__real_memcpy = NULL;
598+
static typeof(memmove) *__real_memmove = NULL;
599+
static __thread int __in_dlsym = 0; // Guard against dlsym calling memset/memcpy
600+
601+
static void __resolve_mem_originals(void) {
602+
if (__real_memset) return;
603+
__in_dlsym = 1;
604+
__real_memset = dlsym(RTLD_NEXT, "memset");
605+
__real_memcpy = dlsym(RTLD_NEXT, "memcpy");
606+
__real_memmove = dlsym(RTLD_NEXT, "memmove");
607+
__in_dlsym = 0;
608+
}
609+
610+
void *memset(void *s, int c, size_t n) {
611+
if (__in_dlsym || is_cxl_ptr(s)) {
612+
cxl_safe_memset(s, c, n);
613+
return s;
614+
}
615+
if (__builtin_expect(!__real_memset, 0))
616+
__resolve_mem_originals();
617+
return __real_memset(s, c, n);
618+
}
619+
620+
void *memcpy(void *dst, const void *src, size_t n) {
621+
if (__in_dlsym || is_cxl_ptr(dst) || is_cxl_ptr(src)) {
622+
cxl_safe_memcpy(dst, src, n);
623+
return dst;
624+
}
625+
if (__builtin_expect(!__real_memcpy, 0))
626+
__resolve_mem_originals();
627+
return __real_memcpy(dst, src, n);
628+
}
629+
630+
void *memmove(void *dst, const void *src, size_t n) {
631+
if (__in_dlsym || is_cxl_ptr(dst) || is_cxl_ptr(src)) {
632+
// Safe byte-by-byte with overlap handling
633+
if (dst < src || (char *)dst >= (char *)src + n) {
634+
cxl_safe_memcpy(dst, src, n);
635+
} else {
636+
// Overlapping, copy backwards
637+
volatile unsigned char *d = (volatile unsigned char *)dst + n;
638+
const volatile unsigned char *s = (const volatile unsigned char *)src + n;
639+
while (n--) *--d = *--s;
640+
}
641+
return dst;
642+
}
643+
if (__builtin_expect(!__real_memmove, 0))
644+
__resolve_mem_originals();
645+
return __real_memmove(dst, src, n);
646+
}
647+
588648
// Signal handler for debugging
589649
static void signal_handler(int sig) {
590650
void *array[20];
@@ -1204,18 +1264,29 @@ static int cxl_send(const void *buf, size_t data_size, int dest, int tag, int so
12041264

12051265
cxl_rank_mailbox_t *dest_mailbox = &g_cxl.mailboxes[dest];
12061266

1207-
// Atomically claim a slot in destination's queue
1208-
uint64_t head = atomic_fetch_add(&dest_mailbox->head, 1);
1267+
// Check if queue is full BEFORE claiming a slot.
1268+
// Previously the head was advanced unconditionally, causing unbounded growth
1269+
// when the receiver can't drain (e.g. cross-VM cache incoherence).
1270+
#ifdef CXL_CACHE_COHERENCE
1271+
// Invalidate tail to see if receiver has drained the queue.
1272+
cxl_invalidate_range(&dest_mailbox->tail, sizeof(uint64_t));
1273+
#endif
12091274
uint64_t tail = atomic_load(&dest_mailbox->tail);
1275+
uint64_t head = atomic_load(&dest_mailbox->head);
12101276

1211-
// Check if queue is full (undo if so)
12121277
if ((head - tail) >= CXL_MSG_QUEUE_SIZE) {
1213-
LOG_WARN("CXL send: destination %d queue full (head=%lu, tail=%lu)\n", dest, head, tail);
1214-
// Cannot easily undo atomic increment, but the slot will be in EMPTY state
1215-
// and the receiver will skip it, so this is safe (just wastes a slot)
1278+
static _Atomic int queue_full_warns = 0;
1279+
int warns = atomic_fetch_add(&queue_full_warns, 1);
1280+
if (warns < 5)
1281+
LOG_WARN("CXL send: destination %d queue full (head=%lu, tail=%lu)\n", dest, head, tail);
1282+
else if (warns == 5)
1283+
LOG_WARN("CXL send: suppressing further queue-full warnings\n");
12161284
return -1;
12171285
}
12181286

1287+
// Atomically claim a slot in destination's queue
1288+
head = atomic_fetch_add(&dest_mailbox->head, 1);
1289+
12191290
uint64_t slot = head % CXL_MSG_QUEUE_SIZE;
12201291
cxl_msg_t *msg = &dest_mailbox->messages[slot];
12211292

@@ -1269,6 +1340,12 @@ static int cxl_send(const void *buf, size_t data_size, int dest, int tag, int so
12691340
// Mark message as ready (head was already advanced atomically above)
12701341
atomic_store(&msg->state, CXL_MSG_READY);
12711342

1343+
#ifdef CXL_CACHE_COHERENCE
1344+
// Flush the entire message (including state) so the receiver on another VM
1345+
// can see it via cache invalidation.
1346+
cxl_flush_range(msg, sizeof(*msg));
1347+
#endif
1348+
12721349
LOG_DEBUG("CXL send: sent %zu bytes from rank %d to rank %d (tag=%d, slot=%lu, inline=%d)\n",
12731350
data_size, source_rank, dest, tag, slot, msg->is_inline);
12741351

@@ -1283,6 +1360,12 @@ static int cxl_recv(void *buf, size_t max_size, int source, int tag, size_t *act
12831360

12841361
cxl_rank_mailbox_t *my_mailbox = &g_cxl.mailboxes[g_cxl.my_rank];
12851362

1363+
#ifdef CXL_CACHE_COHERENCE
1364+
// Invalidate head/tail so we see the latest values from the sender VM.
1365+
cxl_invalidate_range(&my_mailbox->head, sizeof(uint64_t));
1366+
cxl_invalidate_range(&my_mailbox->tail, sizeof(uint64_t));
1367+
#endif
1368+
12861369
uint64_t tail = atomic_load(&my_mailbox->tail);
12871370
uint64_t head = atomic_load(&my_mailbox->head);
12881371

@@ -1296,6 +1379,11 @@ static int cxl_recv(void *buf, size_t max_size, int source, int tag, size_t *act
12961379
uint64_t slot = i % CXL_MSG_QUEUE_SIZE;
12971380
cxl_msg_t *msg = &my_mailbox->messages[slot];
12981381

1382+
#ifdef CXL_CACHE_COHERENCE
1383+
// Invalidate message state to see writes from sender VM.
1384+
cxl_invalidate_range(&msg->state, CACHELINE_SIZE);
1385+
#endif
1386+
12991387
// Check state
13001388
if (atomic_load(&msg->state) != CXL_MSG_READY) {
13011389
continue;
@@ -1364,6 +1452,10 @@ static int cxl_recv(void *buf, size_t max_size, int source, int tag, size_t *act
13641452
}
13651453
}
13661454
atomic_store(&my_mailbox->tail, tail);
1455+
#ifdef CXL_CACHE_COHERENCE
1456+
// Flush tail so the sender VM can see the queue has been drained.
1457+
cxl_flush_range(&my_mailbox->tail, sizeof(uint64_t));
1458+
#endif
13671459
}
13681460

13691461
return 0;
@@ -1644,7 +1736,11 @@ int MPI_Send(const void *buf, int count, MPI_Datatype datatype, int dest, int ta
16441736
// the CXL message for fast-path data and drains the MPI message for
16451737
// correctness. This avoids channel mismatch regardless of receiver path
16461738
// (MPI_Recv, MPI_Irecv, MPI_ANY_SOURCE, sub-communicators).
1647-
if (g_cxl.cxl_comm_enabled && comm == MPI_COMM_WORLD && cxl_rank_available(dest)) {
1739+
// Only use CXL mailbox when all ranks share the same DAX device; cross-VM
1740+
// cache incoherence on CXL Type-3 prevents the receiver from seeing queue
1741+
// state updates, causing the mailbox to fill up and never drain.
1742+
if (g_cxl.cxl_comm_enabled && g_cxl.all_ranks_local &&
1743+
comm == MPI_COMM_WORLD && cxl_rank_available(dest)) {
16481744
if (cxl_send(buf, total_size, dest, tag, g_cxl.my_rank) == 0) {
16491745
atomic_fetch_add(&cxl_send_count, 1);
16501746
atomic_fetch_add(&g_stats.send_cxl, 1);
@@ -1689,7 +1785,9 @@ int MPI_Recv(void *buf, int count, MPI_Datatype datatype, int source, int tag,
16891785
// This keeps the CXL mailbox clean and tracks CXL recv stats.
16901786
// The sender always ALSO sends via MPI, so we always drain via
16911787
// orig_MPI_Recv below for guaranteed correctness.
1692-
if (g_cxl.cxl_comm_enabled && comm == MPI_COMM_WORLD && g_cxl.my_rank >= 0) {
1788+
// Only when all ranks are local (cross-VM mailbox lacks cache coherence).
1789+
if (g_cxl.cxl_comm_enabled && g_cxl.all_ranks_local &&
1790+
comm == MPI_COMM_WORLD && g_cxl.my_rank >= 0) {
16931791
size_t actual_size = 0;
16941792
int cxl_source = source; // communicator rank == world rank for COMM_WORLD
16951793
if (cxl_recv(buf, max_size, cxl_source, tag, &actual_size) == 0) {
@@ -1726,7 +1824,9 @@ int MPI_Isend(const void *buf, int count, MPI_Datatype datatype, int dest, int t
17261824
LOAD_ORIGINAL(MPI_Isend);
17271825

17281826
// Dual-send: CXL bonus (best-effort) + always orig_MPI_Isend
1729-
if (g_cxl.cxl_comm_enabled && comm == MPI_COMM_WORLD && cxl_rank_available(dest)) {
1827+
// Only for same-node peers (cross-VM mailbox lacks cache coherence).
1828+
if (g_cxl.cxl_comm_enabled && g_cxl.all_ranks_local &&
1829+
comm == MPI_COMM_WORLD && cxl_rank_available(dest)) {
17301830
if (cxl_send(buf, total_size, dest, tag, g_cxl.my_rank) == 0) {
17311831
atomic_fetch_add(&cxl_isend_count, 1);
17321832
atomic_fetch_add(&g_stats.isend_cxl, 1);

0 commit comments

Comments
 (0)