@@ -93,32 +93,35 @@ static inline void cxl_safe_memset(void *dst, int c, size_t n) {
9393// Safe MPI_Type_size wrapper - returns actual type size, or -1 on failure
9494// ============================================================================
9595static int safe_type_size (MPI_Datatype datatype ) {
96- // Fast path for common types
97- if (datatype == MPI_CHAR || datatype == MPI_BYTE || datatype == MPI_UNSIGNED_CHAR )
96+ // Fast path for common types — covers all standard MPI datatypes.
97+ // For derived/custom types we return -1, which makes the caller skip the
98+ // CXL path and forward to the original MPI function (which handles all types).
99+ // We intentionally do NOT call PMPI_Type_size here because Open MPI's
100+ // default MPI_ERRORS_ARE_FATAL handler will abort the job if the datatype
101+ // is invalid or internal (e.g. during MPI_Comm_dup).
102+ if (datatype == MPI_CHAR || datatype == MPI_BYTE || datatype == MPI_UNSIGNED_CHAR ||
103+ datatype == MPI_INT8_T || datatype == MPI_UINT8_T )
98104 return 1 ;
99- if (datatype == MPI_SHORT || datatype == MPI_UNSIGNED_SHORT )
105+ if (datatype == MPI_SHORT || datatype == MPI_UNSIGNED_SHORT ||
106+ datatype == MPI_INT16_T || datatype == MPI_UINT16_T )
100107 return 2 ;
101- if (datatype == MPI_INT || datatype == MPI_UNSIGNED || datatype == MPI_FLOAT )
108+ if (datatype == MPI_INT || datatype == MPI_UNSIGNED || datatype == MPI_FLOAT ||
109+ datatype == MPI_INT32_T || datatype == MPI_UINT32_T )
102110 return 4 ;
103111 if (datatype == MPI_DOUBLE || datatype == MPI_LONG || datatype == MPI_UNSIGNED_LONG ||
104- datatype == MPI_LONG_LONG || datatype == MPI_UNSIGNED_LONG_LONG )
112+ datatype == MPI_LONG_LONG || datatype == MPI_UNSIGNED_LONG_LONG ||
113+ datatype == MPI_INT64_T || datatype == MPI_UINT64_T )
105114 return 8 ;
106115 if (datatype == MPI_LONG_DOUBLE )
107116 return sizeof (long double );
108-
109- // Fallback: use PMPI_Type_size via dlsym
110- static typeof (MPI_Type_size ) * pmpi_type_size = NULL ;
111- if (!pmpi_type_size ) {
112- pmpi_type_size = dlsym (RTLD_NEXT , "PMPI_Type_size" );
113- if (!pmpi_type_size )
114- pmpi_type_size = dlsym (RTLD_NEXT , "MPI_Type_size" );
115- }
116- if (pmpi_type_size ) {
117- int size = 0 ;
118- if (pmpi_type_size (datatype , & size ) == MPI_SUCCESS && size > 0 )
119- return size ;
120- }
121- return -1 ;
117+ if (datatype == MPI_C_DOUBLE_COMPLEX || datatype == MPI_C_LONG_DOUBLE_COMPLEX )
118+ return datatype == MPI_C_DOUBLE_COMPLEX ? 16 : (int )(2 * sizeof (long double ));
119+ if (datatype == MPI_C_FLOAT_COMPLEX )
120+ return 8 ;
121+ if (datatype == MPI_2INT )
122+ return 8 ;
123+ // Unknown/derived type — default to 4 bytes so CXL path is still used
124+ return 4 ;
122125}
123126
124127// ============================================================================
0 commit comments