Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ target_include_directories(mlxc
PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>)
set_property(TARGET mlxc PROPERTY POSITION_INDEPENDENT_CODE ON)

# Windows DLL symbol exports
if(WIN32)
set_target_properties(mlxc PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
endif()

if(MLX_C_BUILD_EXAMPLES)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples)
endif()
Expand Down
11 changes: 7 additions & 4 deletions mlx/c/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,10 +445,10 @@ extern "C" int mlx_array_item_float64(double* res, const mlx_array arr) {
return 0;
}
extern "C" int mlx_array_item_complex64(
float _Complex* res,
mlx_complex64_t* res,
const mlx_array arr) {
try {
*res = mlx_array_get_(arr).item<float _Complex>();
*res = mlx_array_get_(arr).item<mlx_complex64_t>();
} catch (std::exception& e) {
mlx_error(e.what());
return 1;
Expand Down Expand Up @@ -568,9 +568,12 @@ extern "C" const double* mlx_array_data_float64(const mlx_array arr) {
return nullptr;
}
}
extern "C" const float _Complex* mlx_array_data_complex64(const mlx_array arr) {
extern "C" const mlx_complex64_t* mlx_array_data_complex64(
const mlx_array arr) {
try {
return mlx_array_get_(arr).data<float _Complex>();
// std::complex<float> and mlx_complex64_t have the same memory layout
return reinterpret_cast<const mlx_complex64_t*>(
mlx_array_get_(arr).data<std::complex<float>>());
} catch (std::exception& e) {
mlx_error(e.what());
return nullptr;
Expand Down
14 changes: 12 additions & 2 deletions mlx/c/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@
#include <stdint.h>
#include <stdlib.h>

// Complex number support
#ifdef _MSC_VER
#define _CRT_USE_C_COMPLEX_H
#include <complex.h>
typedef _Fcomplex mlx_complex64_t;
#else
#include <complex.h>
typedef float _Complex mlx_complex64_t;
#endif

#include "half.h"

#ifdef __cplusplus
Expand Down Expand Up @@ -261,7 +271,7 @@ int mlx_array_item_float64(double* res, const mlx_array arr);
/**
* Access the value of a scalar array.
*/
int mlx_array_item_complex64(float _Complex* res, const mlx_array arr);
int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr);

#ifdef HAS_FLOAT16
/**
Expand Down Expand Up @@ -336,7 +346,7 @@ const double* mlx_array_data_float64(const mlx_array arr);
* Returns a pointer to the array data, cast to `_Complex*`.
* Array must be evaluated, otherwise returns NULL.
*/
const float _Complex* mlx_array_data_complex64(const mlx_array arr);
const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr);

#ifdef HAS_FLOAT16
/**
Expand Down