Skip to content

Commit 0e703bc

Browse files
committed
update
1 parent 04a88cb commit 0e703bc

1 file changed

Lines changed: 22 additions & 19 deletions

File tree

workloads/gromacs/mpi_cxl_shim.c

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -93,32 +93,35 @@ static inline void cxl_safe_memset(void *dst, int c, size_t n) {
9393
// Safe MPI_Type_size wrapper - returns actual type size, or -1 on failure
9494
// ============================================================================
9595
static int safe_type_size(MPI_Datatype datatype) {
96-
// Fast path for common types
97-
if (datatype == MPI_CHAR || datatype == MPI_BYTE || datatype == MPI_UNSIGNED_CHAR)
96+
// Fast path for common types — covers all standard MPI datatypes.
97+
// For derived/custom types we return -1, which makes the caller skip the
98+
// CXL path and forward to the original MPI function (which handles all types).
99+
// We intentionally do NOT call PMPI_Type_size here because Open MPI's
100+
// default MPI_ERRORS_ARE_FATAL handler will abort the job if the datatype
101+
// is invalid or internal (e.g. during MPI_Comm_dup).
102+
if (datatype == MPI_CHAR || datatype == MPI_BYTE || datatype == MPI_UNSIGNED_CHAR ||
103+
datatype == MPI_INT8_T || datatype == MPI_UINT8_T)
98104
return 1;
99-
if (datatype == MPI_SHORT || datatype == MPI_UNSIGNED_SHORT)
105+
if (datatype == MPI_SHORT || datatype == MPI_UNSIGNED_SHORT ||
106+
datatype == MPI_INT16_T || datatype == MPI_UINT16_T)
100107
return 2;
101-
if (datatype == MPI_INT || datatype == MPI_UNSIGNED || datatype == MPI_FLOAT)
108+
if (datatype == MPI_INT || datatype == MPI_UNSIGNED || datatype == MPI_FLOAT ||
109+
datatype == MPI_INT32_T || datatype == MPI_UINT32_T)
102110
return 4;
103111
if (datatype == MPI_DOUBLE || datatype == MPI_LONG || datatype == MPI_UNSIGNED_LONG ||
104-
datatype == MPI_LONG_LONG || datatype == MPI_UNSIGNED_LONG_LONG)
112+
datatype == MPI_LONG_LONG || datatype == MPI_UNSIGNED_LONG_LONG ||
113+
datatype == MPI_INT64_T || datatype == MPI_UINT64_T)
105114
return 8;
106115
if (datatype == MPI_LONG_DOUBLE)
107116
return sizeof(long double);
108-
109-
// Fallback: use PMPI_Type_size via dlsym
110-
static typeof(MPI_Type_size) *pmpi_type_size = NULL;
111-
if (!pmpi_type_size) {
112-
pmpi_type_size = dlsym(RTLD_NEXT, "PMPI_Type_size");
113-
if (!pmpi_type_size)
114-
pmpi_type_size = dlsym(RTLD_NEXT, "MPI_Type_size");
115-
}
116-
if (pmpi_type_size) {
117-
int size = 0;
118-
if (pmpi_type_size(datatype, &size) == MPI_SUCCESS && size > 0)
119-
return size;
120-
}
121-
return -1;
117+
if (datatype == MPI_C_DOUBLE_COMPLEX || datatype == MPI_C_LONG_DOUBLE_COMPLEX)
118+
return datatype == MPI_C_DOUBLE_COMPLEX ? 16 : (int)(2 * sizeof(long double));
119+
if (datatype == MPI_C_FLOAT_COMPLEX)
120+
return 8;
121+
if (datatype == MPI_2INT)
122+
return 8;
123+
// Unknown/derived type — default to 4 bytes so CXL path is still used
124+
return 4;
122125
}
123126

124127
// ============================================================================

0 commit comments

Comments
 (0)