@@ -363,14 +363,14 @@ static void shim_log(const char *level, const char *color, const char *format, .
363363 if (!getenv ("CXL_SHIM_QUIET" )) {
364364 char hostname [256 ];
365365 gethostname (hostname , sizeof (hostname ));
366-
366+
367367 fprintf (stderr , "%s[CXL_SHIM:%s:%d:%s] " , color , hostname , getpid (), level );
368-
368+
369369 va_list args ;
370370 va_start (args , format );
371371 vfprintf (stderr , format , args );
372372 va_end (args );
373-
373+
374374 fprintf (stderr , "%s" , RESET );
375375 fflush (stderr );
376376 }
@@ -414,13 +414,13 @@ static inline bool is_cxl_ptr(const void *ptr) {
414414static void signal_handler (int sig ) {
415415 void * array [20 ];
416416 size_t size ;
417-
417+
418418 LOG_ERROR ("Caught signal %d\n" , sig );
419-
419+
420420 size = backtrace (array , 20 );
421421 fprintf (stderr , "Backtrace:\n" );
422422 backtrace_symbols_fd (array , size , STDERR_FILENO );
423-
423+
424424 exit (1 );
425425}
426426
@@ -711,6 +711,7 @@ static void cxl_register_rank(int rank, int world_size) {
711711
712712 g_cxl .my_rank = rank ;
713713 g_cxl .world_size = world_size ;
714+ bool reg_timed_out = false;
714715
715716 // Rank 0 resets collective synchronization state for new run
716717 if (rank == 0 ) {
@@ -739,9 +740,13 @@ static void cxl_register_rank(int rank, int world_size) {
739740 atomic_store (& g_cxl .header -> coll_max_ranks , (uint32_t )world_size );
740741 __atomic_thread_fence (__ATOMIC_SEQ_CST );
741742 } else {
742- // Non-rank-0 processes wait for rank 0 to complete reset.
743- // Use generation counter to detect stale coll_max_ranks from previous runs.
743+ // Non-rank-0 processes wait for rank 0 to complete reset, with a timeout.
744+ // CXL device memory may not propagate atomic writes reliably across nodes,
745+ // so we use a bounded wait and fall back to MPI-only if it times out.
744746 int wait_count = 0 ;
747+ int total_warn_count = 0 ;
748+ const int max_warn_count = 5 ; // ~5 seconds total timeout
749+ // reg_timed_out is tracked in outer scope
745750 uint32_t cur ;
746751
747752 // If coll_max_ranks already matches world_size, it may be stale from a
@@ -763,6 +768,12 @@ static void cxl_register_rank(int rank, int world_size) {
763768 LOG_INFO ("Rank %d: rank 0 already registered, proceeding\n" , rank );
764769 break ;
765770 }
771+ if (++ total_warn_count > max_warn_count ) {
772+ LOG_WARN ("Rank %d: timed out waiting for rank 0 generation bump, "
773+ "disabling CXL collectives\n" , rank );
774+ reg_timed_out = true;
775+ break ;
776+ }
766777 LOG_WARN ("Rank %d waiting for rank 0 generation bump (gen=%u)\n" ,
767778 rank , start_gen );
768779 wait_count = 0 ;
@@ -775,16 +786,35 @@ static void cxl_register_rank(int rank, int world_size) {
775786 }
776787
777788 // Now wait for coll_max_ranks == world_size (reset complete)
778- wait_count = 0 ;
779- while (atomic_load (& g_cxl .header -> coll_max_ranks ) != (uint32_t )world_size ) {
780- __asm__ volatile ("pause" ::: "memory" );
781- if (++ wait_count > 10000000 ) {
782- LOG_WARN ("Rank %d waiting for rank 0 to reset collective state (coll_max_ranks=%u, expected=%d)\n" ,
783- rank , atomic_load (& g_cxl .header -> coll_max_ranks ), world_size );
784- wait_count = 0 ;
789+ if (!reg_timed_out ) {
790+ wait_count = 0 ;
791+ total_warn_count = 0 ;
792+ while (atomic_load (& g_cxl .header -> coll_max_ranks ) != (uint32_t )world_size ) {
793+ __asm__ volatile ("pause" ::: "memory" );
794+ if (++ wait_count > 10000000 ) {
795+ if (++ total_warn_count > max_warn_count ) {
796+ LOG_WARN ("Rank %d: timed out waiting for coll_max_ranks "
797+ "(got %u, expected %d), disabling CXL collectives\n" ,
798+ rank , atomic_load (& g_cxl .header -> coll_max_ranks ), world_size );
799+ reg_timed_out = true;
800+ break ;
801+ }
802+ LOG_WARN ("Rank %d waiting for rank 0 to reset collective state (coll_max_ranks=%u, expected=%d)\n" ,
803+ rank , atomic_load (& g_cxl .header -> coll_max_ranks ), world_size );
804+ wait_count = 0 ;
805+ }
806+ }
807+ if (!reg_timed_out ) {
808+ __atomic_thread_fence (__ATOMIC_ACQUIRE );
785809 }
786810 }
787- __atomic_thread_fence (__ATOMIC_ACQUIRE );
811+
812+ if (reg_timed_out ) {
813+ // Force coll_max_ranks to world_size so this rank can proceed.
814+ // CXL collectives will be disabled for this rank below.
815+ atomic_store (& g_cxl .header -> coll_max_ranks , (uint32_t )world_size );
816+ __atomic_thread_fence (__ATOMIC_SEQ_CST );
817+ }
788818 }
789819
790820 cxl_rank_mailbox_t * my_mailbox = & g_cxl .mailboxes [rank ];
@@ -800,8 +830,14 @@ static void cxl_register_rank(int rank, int world_size) {
800830 // Increment active rank count
801831 atomic_fetch_add (& g_cxl .header -> num_ranks , 1 );
802832
803- // Enable CXL communication if CXL_DIRECT env is set or by default for DAX
804- g_cxl .cxl_comm_enabled = getenv ("CXL_DIRECT" ) || (strcmp (g_cxl .type , "dax" ) == 0 );
833+ // Enable CXL communication if CXL_DIRECT env is set or by default for DAX.
834+ // Disable if registration timed out (CXL memory not coherent across nodes).
835+ if (reg_timed_out ) {
836+ g_cxl .cxl_comm_enabled = false;
837+ LOG_WARN ("Rank %d: CXL collectives disabled due to registration timeout\n" , rank );
838+ } else {
839+ g_cxl .cxl_comm_enabled = getenv ("CXL_DIRECT" ) || (strcmp (g_cxl .type , "dax" ) == 0 );
840+ }
805841
806842 LOG_INFO ("Registered rank %d/%d in CXL shared memory (pid=%d, host=%s, cxl_comm=%s)\n" ,
807843 rank , world_size , my_mailbox -> pid , my_mailbox -> hostname ,
@@ -1003,23 +1039,23 @@ static void cxl_collective_clear_buffers(int num_ranks) {
10031039// Mapping management
10041040static void register_mapping (void * cxl_addr , void * orig_addr , size_t size ) {
10051041 pthread_mutex_lock (& g_mappings_lock );
1006-
1042+
10071043 mem_mapping_t * mapping = malloc (sizeof (mem_mapping_t ));
10081044 mapping -> cxl_addr = cxl_addr ;
10091045 mapping -> orig_addr = orig_addr ;
10101046 mapping -> size = size ;
10111047 mapping -> ref_count = 1 ;
10121048 mapping -> next = g_mappings ;
10131049 g_mappings = mapping ;
1014-
1050+
10151051 LOG_TRACE ("Registered mapping: orig=%p -> cxl=%p (size=%zu)\n" , orig_addr , cxl_addr , size );
1016-
1052+
10171053 pthread_mutex_unlock (& g_mappings_lock );
10181054}
10191055
10201056static void * find_cxl_mapping (const void * orig_addr ) {
10211057 pthread_mutex_lock (& g_mappings_lock );
1022-
1058+
10231059 mem_mapping_t * curr = g_mappings ;
10241060 while (curr ) {
10251061 if (curr -> orig_addr == orig_addr ) {
@@ -1029,7 +1065,7 @@ static void *find_cxl_mapping(const void *orig_addr) {
10291065 }
10301066 curr = curr -> next ;
10311067 }
1032-
1068+
10331069 pthread_mutex_unlock (& g_mappings_lock );
10341070 return NULL ;
10351071}
@@ -1285,6 +1321,17 @@ int MPI_Init(int *argc, char ***argv) {
12851321 cxl_register_rank (rank , size );
12861322 }
12871323
1324+ // Barrier to ensure ALL ranks complete CXL registration before any
1325+ // rank returns from MPI_Init. Without this, fast ranks (e.g. rank 0,1
1326+ // on the same node as the DAX device) can return from MPI_Init and
1327+ // enter MPI_Comm_dup (a collective) while remote ranks are still
1328+ // stuck in cxl_register_rank's spin-wait, causing a deadlock.
1329+ LOAD_ORIGINAL (MPI_Barrier );
1330+ if (orig_MPI_Barrier ) {
1331+ LOG_DEBUG ("Post-registration barrier (rank %d)\n" , rank );
1332+ orig_MPI_Barrier (MPI_COMM_WORLD );
1333+ }
1334+
12881335 LOG_INFO ("MPI_Init completed: rank=%d/%d, CXL=%s, CXL_DIRECT=%s\n" ,
12891336 rank , size ,
12901337 g_cxl .initialized ? "initialized" : "not initialized" ,
@@ -2848,9 +2895,9 @@ static void shim_init(void) {
28482895__attribute__((destructor ))
28492896static void shim_cleanup (void ) {
28502897 LOG_INFO ("CXL MPI Shim unloading (total hooks: %d)\n" , g_hook_count );
2851-
2898+
28522899 if (g_cxl .initialized ) {
28532900 LOG_INFO ("Final CXL memory usage: %zu/%zu bytes (%.1f%%)\n" ,
28542901 g_cxl .used , g_cxl .size , 100.0 * g_cxl .used / g_cxl .size );
28552902 }
2856- }
2903+ }
0 commit comments