From 96cdfa63adef2036af42f450ef70de9a6f56ccf5 Mon Sep 17 00:00:00 2001 From: QiuYuan Han Date: Wed, 10 Dec 2025 13:52:50 +0800 Subject: [PATCH 01/28] Add the way to set the target evenif we use load_by_name --- core/iwasm/common/wasm_native.c | 38 ++++ core/iwasm/common/wasm_runtime_common.c | 180 ++++++++++++++++++ core/iwasm/common/wasm_runtime_common.h | 74 +++++++ core/iwasm/include/wasm_export.h | 51 +++++ core/iwasm/interpreter/wasm_runtime.c | 12 ++ .../wasi-nn/include/wasi_ephemeral_nn.h | 4 +- .../iwasm/libraries/wasi-nn/include/wasi_nn.h | 2 +- .../libraries/wasi-nn/include/wasi_nn_types.h | 3 +- core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 119 +++++++++++- .../libraries/wasi-nn/src/wasi_nn_backend.h | 3 +- .../libraries/wasi-nn/src/wasi_nn_llamacpp.c | 3 +- .../libraries/wasi-nn/src/wasi_nn_onnx.cpp | 3 +- .../libraries/wasi-nn/src/wasi_nn_openvino.c | 3 +- .../libraries/wasi-nn/src/wasi_nn_private.h | 3 +- .../wasi-nn/src/wasi_nn_tensorflowlite.cpp | 6 +- .../libraries/wasi-nn/test/requirements.txt | 2 +- .../libraries/wasi-nn/test/test_tensorflow.c | 66 +++---- .../wasi-nn/test/test_tensorflow_quantized.c | 26 +-- core/iwasm/libraries/wasi-nn/test/utils.c | 104 ++++++---- core/iwasm/libraries/wasi-nn/test/utils.h | 23 +-- product-mini/platforms/posix/main.c | 62 ++++++ 21 files changed, 659 insertions(+), 128 deletions(-) diff --git a/core/iwasm/common/wasm_native.c b/core/iwasm/common/wasm_native.c index 42aa55db28..8938524db8 100644 --- a/core/iwasm/common/wasm_native.c +++ b/core/iwasm/common/wasm_native.c @@ -25,6 +25,10 @@ static NativeSymbolsList g_native_symbols_list = NULL; static void *g_wasi_context_key; #endif /* WASM_ENABLE_LIBC_WASI */ +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +static void *g_wasi_nn_context_key; +#endif + uint32 get_libc_builtin_export_apis(NativeSymbol **p_libc_builtin_apis); @@ -473,6 +477,31 @@ wasi_context_dtor(WASMModuleInstanceCommon *inst, void *ctx) } #endif /* end of WASM_ENABLE_LIBC_WASI */ +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +WASINNGlobalContext * +wasm_runtime_get_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm) +{ + return wasm_native_get_context(module_inst_comm, g_wasi_nn_context_key); +} + +void +wasm_runtime_set_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm, + WASINNGlobalContext *wasi_nn_ctx) +{ + wasm_native_set_context(module_inst_comm, g_wasi_nn_context_key, wasi_nn_ctx); +} + +static void +wasi_nn_context_dtor(WASMModuleInstanceCommon *inst, void *ctx) +{ + if (ctx == NULL) { + return; + } + + wasm_runtime_destroy_wasi_nn_global_ctx(inst); +} +#endif + #if WASM_ENABLE_QUICK_AOT_ENTRY != 0 static bool quick_aot_entry_init(void); @@ -582,6 +611,11 @@ wasm_native_init() #endif /* WASM_ENABLE_LIB_RATS */ #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + g_wasi_nn_context_key = wasm_native_create_context_key(wasi_nn_context_dtor); + if (g_wasi_nn_context_key == NULL) { + goto fail; + } + if (!wasi_nn_initialize()) goto fail; @@ -648,6 +682,10 @@ wasm_native_destroy() #endif #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + if (g_wasi_nn_context_key != NULL) { + wasm_native_destroy_context_key(g_wasi_nn_context_key); + g_wasi_nn_context_key = NULL; + } wasi_nn_destroy(); #endif diff --git a/core/iwasm/common/wasm_runtime_common.c b/core/iwasm/common/wasm_runtime_common.c index 259816e0b9..312c4b9c7b 100644 --- a/core/iwasm/common/wasm_runtime_common.c +++ b/core/iwasm/common/wasm_runtime_common.c @@ -1696,6 +1696,67 @@ wasm_runtime_instantiation_args_destroy(struct InstantiationArgs2 *p) wasm_runtime_free(p); } +#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0) +struct wasi_nn_graph_registry; + +void +wasm_runtime_wasi_nn_graph_registry_args_set_defaults(struct wasi_nn_graph_registry *args) +{ + memset(args, 0, sizeof(*args)); +} + +bool +wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, const char* encoding, + const char* target, uint32_t n_graphs, + const char** graph_paths) +{ + if (!registry || !encoding || !target || !graph_paths) + { + return false; + } + registry->encoding = strdup(encoding); + registry->target = strdup(target); + registry->n_graphs = n_graphs; + registry->graph_paths = (uint32_t**)malloc(sizeof(uint32_t*) * n_graphs); + memset(registry->graph_paths, 0, sizeof(uint32_t*) * n_graphs); + for (uint32_t i = 0; i < registry->n_graphs; i++) + registry->graph_paths[i] = strdup(graph_paths[i]); + + return true; +} + +int +wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp) +{ + struct wasi_nn_graph_registry *args = wasm_runtime_malloc(sizeof(*args)); + if (args == NULL) { + return false; + } + wasm_runtime_wasi_nn_graph_registry_args_set_defaults(args); + *registryp = args; + return 0; +} + +void +wasi_nn_graph_registry_destroy(struct wasi_nn_graph_registry *registry) +{ + if (registry) + { + for (uint32_t i = 0; i < registry->n_graphs; i++) + if (registry->graph_paths[i]) + { + // wasi_nn_graph_registry_unregister_graph(registry, registry->name[i]); + free(registry->graph_paths[i]); + } + if (registry->encoding) + free(registry->encoding); + if (registry->target) + free(registry->target); + free(registry); + } +} +#endif + void wasm_runtime_instantiation_args_set_default_stack_size( struct InstantiationArgs2 *p, uint32 v) @@ -1794,6 +1855,14 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool( wasi_args->set_by_user = true; } #endif /* WASM_ENABLE_LIBC_WASI != 0 */ +#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0) +void +wasm_runtime_instantiation_args_set_wasi_nn_graph_registry( + struct InstantiationArgs2 *p, struct wasi_nn_graph_registry *registry) +{ + p->nn_registry = *registry; +} +#endif WASMModuleInstanceCommon * wasm_runtime_instantiate_ex2(WASMModuleCommon *module, @@ -8080,3 +8149,114 @@ wasm_runtime_check_and_update_last_used_shared_heap( return false; } #endif + +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +bool +wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, + const char* encoding, const char* target, + const uint32_t n_graphs, char* graph_paths[], + char *error_buf, uint32_t error_buf_size) +{ + WASINNGlobalContext *ctx; + bool ret = false; + + ctx = runtime_malloc(sizeof(*ctx), module_inst, error_buf, error_buf_size); + if (!ctx) + return false; + + ctx->encoding = strdup(encoding); + ctx->target = strdup(target); + ctx->n_graphs = n_graphs; + ctx->loaded = (uint32_t*)malloc(sizeof(uint32_t) * n_graphs); + memset(ctx->loaded, 0, sizeof(uint32_t) * n_graphs); + + ctx->graph_paths = (uint32_t**)malloc(sizeof(uint32_t*) * n_graphs); + memset(ctx->graph_paths, 0, sizeof(uint32_t*) * n_graphs); + for (uint32_t i = 0; i < n_graphs; i++) + { + ctx->graph_paths[i] = strdup(graph_paths[i]); + } + + wasm_runtime_set_wasi_nn_global_ctx(module_inst, ctx); + + ret = true; + + return ret; +} + +void +wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst) +{ + WASINNGlobalContext *wasi_nn_global_ctx = wasm_runtime_get_wasi_nn_global_ctx(module_inst); + + for (uint32 i = 0; i < wasi_nn_global_ctx->n_graphs; i++) + { + // All graphs will be unregistered in deinit() + if (wasi_nn_global_ctx->graph_paths[i]) + free(wasi_nn_global_ctx->graph_paths[i]); + } + free(wasi_nn_global_ctx->encoding); + free(wasi_nn_global_ctx->target); + free(wasi_nn_global_ctx->loaded); + free(wasi_nn_global_ctx->graph_paths); + + if (wasi_nn_global_ctx) { + wasm_runtime_free(wasi_nn_global_ctx); + } +} + +uint32_t +wasm_runtime_get_wasi_nn_global_ctx_ngraphs(WASINNGlobalContext *wasi_nn_global_ctx) +{ + if (wasi_nn_global_ctx) + return wasi_nn_global_ctx->n_graphs; + + return -1; +} + +char * +wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx) +{ + if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) + return wasi_nn_global_ctx->graph_paths[idx]; + + return NULL; +} + +uint32_t +wasm_runtime_get_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx) +{ + if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) + return wasi_nn_global_ctx->loaded[idx]; + + return -1; +} + +uint32_t +wasm_runtime_set_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value) +{ + if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) + wasi_nn_global_ctx->loaded[idx] = value; + + return 0; +} + +char* +wasm_runtime_get_wasi_nn_global_ctx_encoding(WASINNGlobalContext *wasi_nn_global_ctx) +{ + if (wasi_nn_global_ctx) + return wasi_nn_global_ctx->encoding; + + return NULL; +} + +char* +wasm_runtime_get_wasi_nn_global_ctx_target(WASINNGlobalContext *wasi_nn_global_ctx) +{ + if (wasi_nn_global_ctx) + return wasi_nn_global_ctx->target; + + return NULL; +} + +#endif diff --git a/core/iwasm/common/wasm_runtime_common.h b/core/iwasm/common/wasm_runtime_common.h index 88f23485e8..8d002bedcc 100644 --- a/core/iwasm/common/wasm_runtime_common.h +++ b/core/iwasm/common/wasm_runtime_common.h @@ -545,6 +545,17 @@ typedef struct WASMModuleInstMemConsumption { uint32 exports_size; } WASMModuleInstMemConsumption; +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +typedef struct WASINNGlobalContext { + char* encoding; + char* target; + + uint32_t n_graphs; + uint32_t *loaded; + char** graph_paths; +} WASINNGlobalContext; +#endif + #if WASM_ENABLE_LIBC_WASI != 0 #if WASM_ENABLE_UVWASI == 0 typedef struct WASIContext { @@ -612,11 +623,30 @@ WASMExecEnv * wasm_runtime_get_exec_env_tls(void); #endif +#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0) +struct wasi_nn_graph_registry { + char* encoding; + char* target; + + char** graph_paths; + uint32_t n_graphs; +}; + +WASM_RUNTIME_API_EXTERN int +wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp); + +WASM_RUNTIME_API_EXTERN void +wasi_nn_graph_registry_destroy(struct wasi_nn_graph_registry *registry); +#endif + struct InstantiationArgs2 { InstantiationArgs v1; #if WASM_ENABLE_LIBC_WASI != 0 WASIArguments wasi; #endif +#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0) + struct wasi_nn_graph_registry nn_registry; +#endif }; void @@ -775,6 +805,17 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool( struct InstantiationArgs2 *p, const char *ns_lookup_pool[], uint32 ns_lookup_pool_size); +#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0) +WASM_RUNTIME_API_EXTERN void +wasm_runtime_instantiation_args_set_wasi_nn_graph_registry( + struct InstantiationArgs2 *p, struct wasi_nn_graph_registry *registry); + +WASM_RUNTIME_API_EXTERN bool +wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, const char* encoding, + const char* target, uint32_t n_graphs, + const char** graph_paths); +#endif + /* See wasm_export.h for description */ WASM_RUNTIME_API_EXTERN WASMModuleInstanceCommon * wasm_runtime_instantiate_ex2(WASMModuleCommon *module, @@ -1427,6 +1468,39 @@ wasm_runtime_check_and_update_last_used_shared_heap( uint8 **shared_heap_base_addr_adj_p, bool is_memory64); #endif +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +WASM_RUNTIME_API_EXTERN bool +wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, + const char* encoding, const char* target, + const uint32_t n_graphs, char* graph_paths[], + char *error_buf, uint32_t error_buf_size); + +WASM_RUNTIME_API_EXTERN void +wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst); + +WASM_RUNTIME_API_EXTERN void +wasm_runtime_set_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, + WASINNGlobalContext *wasi_ctx); + +WASM_RUNTIME_API_EXTERN uint32_t +wasm_runtime_get_wasi_nn_global_ctx_ngraphs(WASINNGlobalContext *wasi_nn_global_ctx); + +WASM_RUNTIME_API_EXTERN char * +wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); + +WASM_RUNTIME_API_EXTERN uint32_t +wasm_runtime_get_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); + +WASM_RUNTIME_API_EXTERN uint32_t +wasm_runtime_set_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value); + +WASM_RUNTIME_API_EXTERN char* +wasm_runtime_get_wasi_nn_global_ctx_encoding(WASINNGlobalContext *wasi_nn_global_ctx); + +WASM_RUNTIME_API_EXTERN char* +wasm_runtime_get_wasi_nn_global_ctx_target(WASINNGlobalContext *wasi_nn_global_ctx); +#endif + #ifdef __cplusplus } #endif diff --git a/core/iwasm/include/wasm_export.h b/core/iwasm/include/wasm_export.h index 44a45dedfc..50263f1823 100644 --- a/core/iwasm/include/wasm_export.h +++ b/core/iwasm/include/wasm_export.h @@ -290,6 +290,8 @@ typedef struct InstantiationArgs { #endif /* INSTANTIATION_ARGS_OPTION_DEFINED */ struct InstantiationArgs2; +struct WASINNGlobalContext; +typedef struct WASINNGlobalContext *wasi_nn_global_context; #ifndef WASM_VALKIND_T_DEFINED #define WASM_VALKIND_T_DEFINED @@ -796,6 +798,55 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool( struct InstantiationArgs2 *p, const char *ns_lookup_pool[], uint32_t ns_lookup_pool_size); +// WASM_RUNTIME_API_EXTERN int +// wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp); + +// WASM_RUNTIME_API_EXTERN void +// wasi_nn_graph_registry_destroy(struct wasi_nn_graph_registry *registry); + +// WASM_RUNTIME_API_EXTERN void +// wasm_runtime_instantiation_args_set_wasi_nn_graph_registry( +// struct InstantiationArgs2 *p, struct wasi_nn_graph_registry *registry); + +// WASM_RUNTIME_API_EXTERN bool +// wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, const char* encoding, +// const char* target, uint32_t n_graphs, +// const char** graph_paths); + +WASM_RUNTIME_API_EXTERN bool +wasm_runtime_init_wasi_nn_global_ctx(wasm_module_inst_t module_inst, + const char* encoding, const char* target, + const uint32_t n_graphs, char* graph_paths[], + char *error_buf, uint32_t error_buf_size); + +WASM_RUNTIME_API_EXTERN void +wasm_runtime_destroy_wasi_nn_global_ctx(wasm_module_inst_t module_inst); + +WASM_RUNTIME_API_EXTERN void +wasm_runtime_set_wasi_nn_global_ctx(wasm_module_inst_t module_inst, + wasi_nn_global_context wasi_ctx); + +WASM_RUNTIME_API_EXTERN wasi_nn_global_context +wasm_runtime_get_wasi_nn_global_ctx(const wasm_module_inst_t module_inst); + +WASM_RUNTIME_API_EXTERN uint32_t +wasm_runtime_get_wasi_nn_global_ctx_ngraphs(wasi_nn_global_context wasi_nn_global_ctx); + +WASM_RUNTIME_API_EXTERN char * +wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx); + +WASM_RUNTIME_API_EXTERN uint32_t +wasm_runtime_get_wasi_nn_global_ctx_loaded_i(wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx); + +WASM_RUNTIME_API_EXTERN uint32_t +wasm_runtime_set_wasi_nn_global_ctx_loaded_i(wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx, uint32_t value); + +WASM_RUNTIME_API_EXTERN char* +wasm_runtime_get_wasi_nn_global_ctx_encoding(wasi_nn_global_context wasi_nn_global_ctx); + +WASM_RUNTIME_API_EXTERN char* +wasm_runtime_get_wasi_nn_global_ctx_target(wasi_nn_global_context wasi_nn_global_ctx); + /** * Instantiate a WASM module, with specified instantiation arguments * diff --git a/core/iwasm/interpreter/wasm_runtime.c b/core/iwasm/interpreter/wasm_runtime.c index a59bc9257b..79d4c73c2e 100644 --- a/core/iwasm/interpreter/wasm_runtime.c +++ b/core/iwasm/interpreter/wasm_runtime.c @@ -3300,6 +3300,18 @@ wasm_instantiate(WASMModule *module, WASMModuleInstance *parent, } #endif +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + /* Store graphs' path into ctx. Graphs will be loaded until user app calls load_by_name */ + // Do not consider load() for now + struct wasi_nn_graph_registry *nn_registry = &args->nn_registry; + if (!wasm_runtime_init_wasi_nn_global_ctx( + (WASMModuleInstanceCommon *)module_inst, nn_registry->encoding, + nn_registry->target, nn_registry->n_graphs, nn_registry->graph_paths, + error_buf, error_buf_size)) { + goto fail; + } +#endif + #if WASM_ENABLE_DEBUG_INTERP != 0 if (!is_sub_inst) { /* Add module instance into module's instance list */ diff --git a/core/iwasm/libraries/wasi-nn/include/wasi_ephemeral_nn.h b/core/iwasm/libraries/wasi-nn/include/wasi_ephemeral_nn.h index f76295a1ee..83beba98f5 100644 --- a/core/iwasm/libraries/wasi-nn/include/wasi_ephemeral_nn.h +++ b/core/iwasm/libraries/wasi-nn/include/wasi_ephemeral_nn.h @@ -8,5 +8,5 @@ #include "wasi_nn.h" -#undef WASM_ENABLE_WASI_EPHEMERAL_NN -#undef WASI_NN_NAME +// #undef WASM_ENABLE_WASI_EPHEMERAL_NN +// #undef WASI_NN_NAME diff --git a/core/iwasm/libraries/wasi-nn/include/wasi_nn.h b/core/iwasm/libraries/wasi-nn/include/wasi_nn.h index cda26324eb..d76de3ffc0 100644 --- a/core/iwasm/libraries/wasi-nn/include/wasi_nn.h +++ b/core/iwasm/libraries/wasi-nn/include/wasi_nn.h @@ -21,7 +21,7 @@ #else #define WASI_NN_IMPORT(name) \ __attribute__((import_module("wasi_nn"), import_name(name))) -#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It's deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.) +#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It is deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.) #endif /** diff --git a/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h b/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h index 952fb65e28..d77fe9a6cb 100644 --- a/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h +++ b/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h @@ -27,7 +27,7 @@ extern "C" { #define WASI_NN_TYPE_NAME(name) WASI_NN_NAME(type_##name) #define WASI_NN_ENCODING_NAME(name) WASI_NN_NAME(encoding_##name) #define WASI_NN_TARGET_NAME(name) WASI_NN_NAME(target_##name) -#define WASI_NN_ERROR_TYPE WASI_NN_NAME(error); +#define WASI_NN_ERROR_TYPE WASI_NN_NAME(error) #endif /** @@ -169,6 +169,7 @@ typedef enum WASI_NN_NAME(execution_target) { WASI_NN_TARGET_NAME(cpu) = 0, WASI_NN_TARGET_NAME(gpu), WASI_NN_TARGET_NAME(tpu), + WASI_NN_TARGET_NAME(unsupported_target), } WASI_NN_NAME(execution_target); // Bind a `graph` to the input and output tensors for an inference. diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 2282534b0f..9e3e741b69 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -21,7 +21,7 @@ #include "wasm_export.h" #if WASM_ENABLE_WASI_EPHEMERAL_NN == 0 -#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It's deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.) +#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It is deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.) #endif #define HASHMAP_INITIAL_SIZE 20 @@ -35,6 +35,8 @@ #define LLAMACPP_BACKEND_LIB "libwasi_nn_llamacpp" LIB_EXTENTION #define ONNX_BACKEND_LIB "libwasi_nn_onnx" LIB_EXTENTION +#define MAX_GLOBAL_GRAPHS_PER_INST 4 // ONNX only allows 4 graphs per instances + /* Global variables */ static korp_mutex wasi_nn_lock; /* @@ -208,6 +210,44 @@ wasi_nn_destroy() * - model file format * - on device ML framework */ +static graph_encoding str2encoding(char* str_encoding) +{ + if (!str_encoding) { + NN_ERR_PRINTF("Got empty string encoding"); + return -1; + } + + if (!strcmp(str_encoding, "openvino")) + return openvino; + else if (!strcmp(str_encoding, "tensorflowlite")) + return tensorflowlite; + else if (!strcmp(str_encoding, "ggml")) + return ggml; + else if (!strcmp(str_encoding, "onnx")) + return onnx; + else + return unknown_backend; + // return autodetect; +} + +static execution_target str2target(char* str_target) +{ + if (!str_target) { + NN_ERR_PRINTF("Got empty string target"); + return -1; + } + + if (!strcmp(str_target, "cpu")) + return cpu; + else if (!strcmp(str_target, "gpu")) + return gpu; + else if (!strcmp(str_target, "tpu")) + return tpu; + else + return unsupported_target; + // return autodetect; +} + static graph_encoding choose_a_backend() { @@ -565,17 +605,82 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, goto fail; } - res = ensure_backend(instance, autodetect, wasi_nn_ctx); - if (res != success) + wasi_nn_global_context wasi_nn_global_ctx = wasm_runtime_get_wasi_nn_global_ctx(instance); + if (!wasi_nn_global_ctx) { + NN_ERR_PRINTF("global context is invalid"); + res = not_found; goto fail; + } + graph_encoding encoding = str2encoding(wasm_runtime_get_wasi_nn_global_ctx_encoding(wasi_nn_global_ctx)); + execution_target target = str2target(wasm_runtime_get_wasi_nn_global_ctx_target(wasi_nn_global_ctx)); - call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res, - wasi_nn_ctx->backend_ctx, nul_terminated_name, name_len, - g); + // res = ensure_backend(instance, autodetect, wasi_nn_ctx); + res = ensure_backend(instance, encoding, wasi_nn_ctx); if (res != success) goto fail; + + bool is_loaded = false; + uint32 model_idx = 0; + char *global_model_path_i; + // Assume filename got from user wasm app : max; sum; average; ... + // Assume file path got from user cmd opt: /your/path1/max.tflite; /your/path2/sum.tflite; ...... + for (model_idx = 0; model_idx < wasm_runtime_get_wasi_nn_global_ctx_ngraphs(wasi_nn_global_ctx); model_idx++) + { + // Extract filename from file path + global_model_path_i = wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(wasi_nn_global_ctx, model_idx); + char *model_file_name; + const char *slash = strrchr(global_model_path_i, '/'); + if (slash != NULL) { + model_file_name = (char*)(slash + 1); + } + else + model_file_name = global_model_path_i; + + // Extract modelname from filename + char* model_name = NULL; + size_t model_name_len = 0; + char* dot = strrchr(model_file_name, '.'); + if (dot) + { + model_name_len = dot - model_file_name; + model_name = malloc(model_name_len + 1); + strncpy(model_name, model_file_name, model_name_len); + model_name[model_name_len] = '\0'; + } + + if (model_name && strcmp(nul_terminated_name, model_name) == 0) { + is_loaded = wasm_runtime_get_wasi_nn_global_ctx_loaded_i(wasi_nn_global_ctx, model_idx); + break; + } + } - res = success; + if (!is_loaded && (model_idx < MAX_GLOBAL_GRAPHS_PER_INST)) + { + NN_DBG_PRINTF("Model is not yet loaded, will add to global context"); + call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res, + wasi_nn_ctx->backend_ctx, global_model_path_i, strlen(global_model_path_i), + encoding, target, g); + if (res != success) + goto fail; + + wasm_runtime_set_wasi_nn_global_ctx_loaded_i(wasi_nn_global_ctx, model_idx, 1); + res = success; + } + else + { + if (is_loaded) + { + NN_DBG_PRINTF("Model is already loaded"); + res = success; + } + else if (model_idx >= MAX_GLOBAL_GRAPHS_PER_INST) + { + // No enlarge for now + NN_ERR_PRINTF("No enough space for new model"); + res = too_large; + } + goto fail; + } fail: if (nul_terminated_name != NULL) { wasm_runtime_free(nul_terminated_name); diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h b/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h index 8cd03f1214..3108f2eef0 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h @@ -17,7 +17,8 @@ load(void *ctx, graph_builder_array *builder, graph_encoding encoding, execution_target target, graph *g); __attribute__((visibility("default"))) wasi_nn_error -load_by_name(void *tflite_ctx, const char *name, uint32_t namelen, graph *g); +load_by_name(void *tflite_ctx, const char *name, uint32_t namelen, + graph_encoding encoding, execution_target target, graph *g); __attribute__((visibility("default"))) wasi_nn_error load_by_name_with_config(void *ctx, const char *name, uint32_t namelen, diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c index 2e1e649365..fd09c2be08 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c @@ -338,7 +338,8 @@ __load_by_name_with_configuration(void *ctx, const char *filename, graph *g) } __attribute__((visibility("default"))) wasi_nn_error -load_by_name(void *ctx, const char *filename, uint32_t filename_len, graph *g) +load_by_name(void *ctx, const char *filename, uint32_t filename_len, + graph_encoding encoding, execution_target target, graph *g) { struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx; diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp index 88587f68bc..e2283df0f3 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp @@ -334,7 +334,8 @@ load(void *onnx_ctx, graph_builder_array *builder, graph_encoding encoding, } __attribute__((visibility("default"))) wasi_nn_error -load_by_name(void *onnx_ctx, const char *name, uint32_t filename_len, graph *g) +load_by_name(void *onnx_ctx, const char *name, uint32_t filename_len, + graph_encoding encoding, execution_target target, graph *g) { if (!onnx_ctx) { return runtime_error; diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_openvino.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn_openvino.c index 899e06ee39..eec4f8190b 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_openvino.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_openvino.c @@ -306,7 +306,8 @@ load(void *ctx, graph_builder_array *builder, graph_encoding encoding, } __attribute__((visibility("default"))) wasi_nn_error -load_by_name(void *ctx, const char *filename, uint32_t filename_len, graph *g) +load_by_name(void *ctx, const char *filename, uint32_t filename_len, + graph_encoding encoding, execution_target target, graph *g) { OpenVINOContext *ov_ctx = (OpenVINOContext *)ctx; struct OpenVINOGraph *graph; diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h b/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h index 1bff2c514d..5dcb173f42 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h @@ -21,7 +21,8 @@ typedef struct { typedef wasi_nn_error (*LOAD)(void *, graph_builder_array *, graph_encoding, execution_target, graph *); -typedef wasi_nn_error (*LOAD_BY_NAME)(void *, const char *, uint32_t, graph *); +typedef wasi_nn_error (*LOAD_BY_NAME)(void *, const char *, uint32_t, graph_encoding, + execution_target, graph *); typedef wasi_nn_error (*LOAD_BY_NAME_WITH_CONFIG)(void *, const char *, uint32_t, void *, uint32_t, graph *); diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp index 9ac54e6644..eb56a42f23 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp @@ -164,8 +164,8 @@ load(void *tflite_ctx, graph_builder_array *builder, graph_encoding encoding, } __attribute__((visibility("default"))) wasi_nn_error -load_by_name(void *tflite_ctx, const char *filename, uint32_t filename_len, - graph *g) +load_by_name(void *tflite_ctx, const char *filename, uint32_t filename_len, + graph_encoding encoding, execution_target target,graph *g) { TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx; @@ -183,7 +183,7 @@ load_by_name(void *tflite_ctx, const char *filename, uint32_t filename_len, } // Use CPU as default - tfl_ctx->models[*g].target = cpu; + tfl_ctx->models[*g].target = target; return success; } diff --git a/core/iwasm/libraries/wasi-nn/test/requirements.txt b/core/iwasm/libraries/wasi-nn/test/requirements.txt index 1643b91b00..0c80fd6b12 100644 --- a/core/iwasm/libraries/wasi-nn/test/requirements.txt +++ b/core/iwasm/libraries/wasi-nn/test/requirements.txt @@ -1,2 +1,2 @@ -tensorflow==2.12.1 +tensorflow==2.14.0 numpy==1.24.4 diff --git a/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c b/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c index 6a9e20702f..b3d6ba8037 100644 --- a/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c +++ b/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c @@ -13,16 +13,16 @@ #include "logger.h" void -test_sum(execution_target target) +test_sum() { int dims[] = { 1, 5, 5, 1 }; input_info input = create_input(dims); uint32_t output_size = 0; - float *output = run_inference(target, input.input_tensor, input.dim, - &output_size, "./models/sum.tflite", 1); + float *output = run_inference(input.input_tensor, input.dim, + &output_size, "sum", 1); - assert(output_size == 1); + assert((output_size / sizeof(float)) == 1); assert(fabs(output[0] - 300.0) < EPSILON); free(input.dim); @@ -31,16 +31,16 @@ test_sum(execution_target target) } void -test_max(execution_target target) +test_max() { int dims[] = { 1, 5, 5, 1 }; input_info input = create_input(dims); uint32_t output_size = 0; - float *output = run_inference(target, input.input_tensor, input.dim, - &output_size, "./models/max.tflite", 1); + float *output = run_inference(input.input_tensor, input.dim, + &output_size, "max", 1); - assert(output_size == 1); + assert((output_size / sizeof(float)) == 1); assert(fabs(output[0] - 24.0) < EPSILON); NN_INFO_PRINTF("Result: max is %f", output[0]); @@ -50,16 +50,16 @@ test_max(execution_target target) } void -test_average(execution_target target) +test_average() { int dims[] = { 1, 5, 5, 1 }; input_info input = create_input(dims); uint32_t output_size = 0; - float *output = run_inference(target, input.input_tensor, input.dim, - &output_size, "./models/average.tflite", 1); + float *output = run_inference(input.input_tensor, input.dim, + &output_size, "average", 1); - assert(output_size == 1); + assert((output_size / sizeof(float)) == 1); assert(fabs(output[0] - 12.0) < EPSILON); NN_INFO_PRINTF("Result: average is %f", output[0]); @@ -69,16 +69,16 @@ test_average(execution_target target) } void -test_mult_dimensions(execution_target target) +test_mult_dimensions() { int dims[] = { 1, 3, 3, 1 }; input_info input = create_input(dims); uint32_t output_size = 0; - float *output = run_inference(target, input.input_tensor, input.dim, - &output_size, "./models/mult_dim.tflite", 1); + float *output = run_inference(input.input_tensor, input.dim, + &output_size, "mult_dim", 1); - assert(output_size == 9); + assert((output_size / sizeof(float)) == 9); for (int i = 0; i < 9; i++) assert(fabs(output[i] - i) < EPSILON); @@ -88,16 +88,16 @@ test_mult_dimensions(execution_target target) } void -test_mult_outputs(execution_target target) +test_mult_outputs() { int dims[] = { 1, 4, 4, 1 }; input_info input = create_input(dims); uint32_t output_size = 0; - float *output = run_inference(target, input.input_tensor, input.dim, - &output_size, "./models/mult_out.tflite", 2); + float *output = run_inference(input.input_tensor, input.dim, + &output_size, "mult_out", 2); - assert(output_size == 8); + assert((output_size / sizeof(float)) == 8); // first tensor check for (int i = 0; i < 4; i++) assert(fabs(output[i] - (i * 4 + 24)) < EPSILON); @@ -113,30 +113,18 @@ test_mult_outputs(execution_target target) int main() { - char *env = getenv("TARGET"); - if (env == NULL) { - NN_INFO_PRINTF("Usage:\n--env=\"TARGET=[cpu|gpu]\""); - return 1; - } - execution_target target; - if (strcmp(env, "cpu") == 0) - target = cpu; - else if (strcmp(env, "gpu") == 0) - target = gpu; - else { - NN_ERR_PRINTF("Wrong target!"); - return 1; - } + NN_INFO_PRINTF("Usage:\niwasm --native-lib=./libwasi_nn_tflite.so --wasi-nn-graph=encoding:target:model_path1:model_path2:...:model_pathn test_tensorflow.wasm\""); + NN_INFO_PRINTF("################### Testing sum..."); - test_sum(target); + test_sum(); NN_INFO_PRINTF("################### Testing max..."); - test_max(target); + test_max(); NN_INFO_PRINTF("################### Testing average..."); - test_average(target); + test_average(); NN_INFO_PRINTF("################### Testing multiple dimensions..."); - test_mult_dimensions(target); + test_mult_dimensions(); NN_INFO_PRINTF("################### Testing multiple outputs..."); - test_mult_outputs(target); + test_mult_outputs(); NN_INFO_PRINTF("Tests: passed!"); return 0; diff --git a/core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c b/core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c index 3ed7c751e3..0898c7ae2a 100644 --- a/core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c +++ b/core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c @@ -16,15 +16,15 @@ #define EPSILON 1e-2 void -test_average_quantized(execution_target target) +test_average_quantized() { int dims[] = { 1, 5, 5, 1 }; input_info input = create_input(dims); uint32_t output_size = 0; float *output = - run_inference(target, input.input_tensor, input.dim, &output_size, - "./models/quantized_model.tflite", 1); + run_inference(input.input_tensor, input.dim, &output_size, + "quantized_model", 1); NN_INFO_PRINTF("Output size: %d", output_size); NN_INFO_PRINTF("Result: average is %f", output[0]); @@ -39,24 +39,10 @@ test_average_quantized(execution_target target) int main() { - char *env = getenv("TARGET"); - if (env == NULL) { - NN_INFO_PRINTF("Usage:\n--env=\"TARGET=[cpu|gpu|tpu]\""); - return 1; - } - execution_target target; - if (strcmp(env, "cpu") == 0) - target = cpu; - else if (strcmp(env, "gpu") == 0) - target = gpu; - else if (strcmp(env, "tpu") == 0) - target = tpu; - else { - NN_ERR_PRINTF("Wrong target!"); - return 1; - } + NN_INFO_PRINTF("Usage:\niwasm --native-lib=./libwasi_nn_tflite.so --wasi-nn-graph=encoding:target:model_path1:model_path2:...:model_pathn test_tensorflow.wasm\""); + NN_INFO_PRINTF("################### Testing quantized model..."); - test_average_quantized(target); + test_average_quantized(); NN_INFO_PRINTF("Tests: passed!"); return 0; diff --git a/core/iwasm/libraries/wasi-nn/test/utils.c b/core/iwasm/libraries/wasi-nn/test/utils.c index 690c37f0e7..97ed08378e 100644 --- a/core/iwasm/libraries/wasi-nn/test/utils.c +++ b/core/iwasm/libraries/wasi-nn/test/utils.c @@ -5,17 +5,15 @@ #include "utils.h" #include "logger.h" -#include "wasi_nn.h" - #include #include -wasi_nn_error -wasm_load(char *model_name, graph *g, execution_target target) +WASI_NN_ERROR_TYPE +wasm_load(char *model_name, WASI_NN_NAME(graph) *g, WASI_NN_NAME(execution_target) target) { FILE *pFile = fopen(model_name, "r"); if (pFile == NULL) - return invalid_argument; + return WASI_NN_ERROR_NAME(invalid_argument); uint8_t *buffer; size_t result; @@ -24,20 +22,29 @@ wasm_load(char *model_name, graph *g, execution_target target) buffer = (uint8_t *)malloc(sizeof(uint8_t) * MAX_MODEL_SIZE); if (buffer == NULL) { fclose(pFile); - return too_large; + return WASI_NN_ERROR_NAME(too_large); } result = fread(buffer, 1, MAX_MODEL_SIZE, pFile); if (result <= 0) { fclose(pFile); free(buffer); - return too_large; + return WASI_NN_ERROR_NAME(too_large); } - graph_builder_array arr; +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + WASI_NN_NAME(graph_builder) arr; + + arr.buf = buffer; + arr.size = result; + + WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load)(&arr, result, WASI_NN_ENCODING_NAME(tensorflowlite), target, g); + // WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load)(&arr, 1, WASI_NN_ENCODING_NAME(tensorflowlite), target, g); +#else + WASI_NN_NAME(graph_builder_array) arr; arr.size = 1; - arr.buf = (graph_builder *)malloc(sizeof(graph_builder)); + arr.buf = (WASI_NN_NAME(graph_builder) *)malloc(sizeof(WASI_NN_NAME(graph_builder))); if (arr.buf == NULL) { fclose(pFile); free(buffer); @@ -47,7 +54,8 @@ wasm_load(char *model_name, graph *g, execution_target target) arr.buf[0].size = result; arr.buf[0].buf = buffer; - wasi_nn_error res = load(&arr, tensorflowlite, target, g); + WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load)(&arr, WASI_NN_ENCODING_NAME(tensorflowlite), target, g); +#endif fclose(pFile); free(buffer); @@ -55,77 +63,97 @@ wasm_load(char *model_name, graph *g, execution_target target) return res; } -wasi_nn_error -wasm_load_by_name(const char *model_name, graph *g) +WASI_NN_ERROR_TYPE +wasm_load_by_name(const char *model_name, WASI_NN_NAME(graph) *g) { - wasi_nn_error res = load_by_name(model_name, strlen(model_name), g); + WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load_by_name)(model_name, strlen(model_name), g); return res; } -wasi_nn_error -wasm_init_execution_context(graph g, graph_execution_context *ctx) +WASI_NN_ERROR_TYPE +wasm_init_execution_context(WASI_NN_NAME(graph) g, WASI_NN_NAME(graph_execution_context) *ctx) { - return init_execution_context(g, ctx); + return WASI_NN_NAME(init_execution_context)(g, ctx); } -wasi_nn_error -wasm_set_input(graph_execution_context ctx, float *input_tensor, uint32_t *dim) +WASI_NN_ERROR_TYPE +wasm_set_input(WASI_NN_NAME(graph_execution_context) ctx, float *input_tensor, uint32_t *dim) { - tensor_dimensions dims; + WASI_NN_NAME(tensor_dimensions) dims; dims.size = INPUT_TENSOR_DIMS; dims.buf = (uint32_t *)malloc(dims.size * sizeof(uint32_t)); if (dims.buf == NULL) - return too_large; - - tensor tensor; + return WASI_NN_ERROR_NAME(too_large); + + WASI_NN_NAME(tensor) tensor; +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + tensor.dimensions = dims; + for (int i = 0; i < tensor.dimensions.size; ++i) + tensor.dimensions.buf[i] = dim[i]; + tensor.type = WASI_NN_TYPE_NAME(fp32); + tensor.data.buf = (uint8_t *)input_tensor; + + uint32_t tmp_size = 1; + if (dim) + for (int i = 0; i < INPUT_TENSOR_DIMS; ++i) + tmp_size *= dim[i]; + + tensor.data.size = (tmp_size * sizeof(float)); +#else tensor.dimensions = &dims; for (int i = 0; i < tensor.dimensions->size; ++i) tensor.dimensions->buf[i] = dim[i]; - tensor.type = fp32; + tensor.type = WASI_NN_TYPE_NAME(fp32); tensor.data = (uint8_t *)input_tensor; - wasi_nn_error err = set_input(ctx, 0, &tensor); +#endif + + WASI_NN_ERROR_TYPE err = WASI_NN_NAME(set_input)(ctx, 0, &tensor); free(dims.buf); return err; } -wasi_nn_error -wasm_compute(graph_execution_context ctx) +WASI_NN_ERROR_TYPE +wasm_compute(WASI_NN_NAME(graph_execution_context) ctx) { - return compute(ctx); + return WASI_NN_NAME(compute)(ctx); } -wasi_nn_error -wasm_get_output(graph_execution_context ctx, uint32_t index, float *out_tensor, +WASI_NN_ERROR_TYPE +wasm_get_output(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index, float *out_tensor, uint32_t *out_size) { - return get_output(ctx, index, (uint8_t *)out_tensor, out_size); +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + return WASI_NN_NAME(get_output)(ctx, index, (uint8_t *)out_tensor, MAX_OUTPUT_TENSOR_SIZE, out_size); +#else + return WASI_NN_NAME(get_output)(ctx, index, (uint8_t *)out_tensor, out_size); +#endif } float * -run_inference(execution_target target, float *input, uint32_t *input_size, +run_inference(float *input, uint32_t *input_size, uint32_t *output_size, char *model_name, uint32_t num_output_tensors) { - graph graph; + WASI_NN_NAME(graph) graph; - if (wasm_load_by_name(model_name, &graph) != success) { + if (wasm_load_by_name(model_name, &graph) != WASI_NN_ERROR_NAME(success)) { NN_ERR_PRINTF("Error when loading model."); exit(1); } - graph_execution_context ctx; - if (wasm_init_execution_context(graph, &ctx) != success) { + WASI_NN_NAME(graph_execution_context) ctx; + if (wasm_init_execution_context(graph, &ctx) != WASI_NN_ERROR_NAME(success)) { NN_ERR_PRINTF("Error when initialixing execution context."); exit(1); } - if (wasm_set_input(ctx, input, input_size) != success) { + if (wasm_set_input(ctx, input, input_size) != WASI_NN_ERROR_NAME(success)) { NN_ERR_PRINTF("Error when setting input tensor."); exit(1); } - if (wasm_compute(ctx) != success) { + if (wasm_compute(ctx) != WASI_NN_ERROR_NAME(success)) { NN_ERR_PRINTF("Error when running inference."); exit(1); } @@ -140,7 +168,7 @@ run_inference(execution_target target, float *input, uint32_t *input_size, for (int i = 0; i < num_output_tensors; ++i) { *output_size = MAX_OUTPUT_TENSOR_SIZE - *output_size; if (wasm_get_output(ctx, i, &out_tensor[offset], output_size) - != success) { + != WASI_NN_ERROR_NAME(success)) { NN_ERR_PRINTF("Error when getting index %d.", i); break; } diff --git a/core/iwasm/libraries/wasi-nn/test/utils.h b/core/iwasm/libraries/wasi-nn/test/utils.h index e0d2417724..45ba156a0f 100644 --- a/core/iwasm/libraries/wasi-nn/test/utils.h +++ b/core/iwasm/libraries/wasi-nn/test/utils.h @@ -8,6 +8,7 @@ #include +#include "wasi_ephemeral_nn.h" #include "wasi_nn_types.h" #define MAX_MODEL_SIZE 85000000 @@ -23,26 +24,26 @@ typedef struct { /* wasi-nn wrappers */ -wasi_nn_error -wasm_load(char *model_name, graph *g, execution_target target); +WASI_NN_ERROR_TYPE +wasm_load(char *model_name, WASI_NN_NAME(graph) *g, WASI_NN_NAME(execution_target) target); -wasi_nn_error -wasm_init_execution_context(graph g, graph_execution_context *ctx); +WASI_NN_ERROR_TYPE +wasm_init_execution_context(WASI_NN_NAME(graph) g, WASI_NN_NAME(graph_execution_context) *ctx); -wasi_nn_error -wasm_set_input(graph_execution_context ctx, float *input_tensor, uint32_t *dim); +WASI_NN_ERROR_TYPE +wasm_set_input(WASI_NN_NAME(graph_execution_context) ctx, float *input_tensor, uint32_t *dim); -wasi_nn_error -wasm_compute(graph_execution_context ctx); +WASI_NN_ERROR_TYPE +wasm_compute(WASI_NN_NAME(graph_execution_context) ctx); -wasi_nn_error -wasm_get_output(graph_execution_context ctx, uint32_t index, float *out_tensor, +WASI_NN_ERROR_TYPE +wasm_get_output(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index, float *out_tensor, uint32_t *out_size); /* Utils */ float * -run_inference(execution_target target, float *input, uint32_t *input_size, +run_inference(float *input, uint32_t *input_size, uint32_t *output_size, char *model_name, uint32_t num_output_tensors); diff --git a/product-mini/platforms/posix/main.c b/product-mini/platforms/posix/main.c index 2d7d3afeb8..ef99f2a842 100644 --- a/product-mini/platforms/posix/main.c +++ b/product-mini/platforms/posix/main.c @@ -18,6 +18,10 @@ #include "../common/libc_wasi.c" #endif +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +#include "wasi_ephemeral_nn.h" +#endif + #include "../common/wasm_proposal.c" #if BH_HAS_DLFCN @@ -115,6 +119,12 @@ print_help(void) #endif #if WASM_ENABLE_STATIC_PGO != 0 printf(" --gen-prof-file= Generate LLVM PGO (Profile-Guided Optimization) profile file\n"); +#endif +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + printf(" --wasi-nn-graph=encoding:target:::...:\n"); + printf(" Set encoding, target and model_paths for wasi-nn. target can be\n"); + printf(" cpu|gpu|tpu, encoding can be tensorflowlite|openvino|llama|onnx|\n"); + printf(" tensorflow|pytorch|ggml|autodetect\n"); #endif printf(" --version Show version information\n"); return 1; @@ -635,6 +645,13 @@ main(int argc, char *argv[]) int timeout_ms = -1; #endif +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + struct wasi_nn_graph_registry *nn_registry; + char *encoding, *target; + uint32_t n_models = 0; + char **model_paths; +#endif + #if WASM_ENABLE_LIBC_WASI != 0 memset(&wasi_parse_ctx, 0, sizeof(wasi_parse_ctx)); #endif @@ -825,6 +842,37 @@ main(int argc, char *argv[]) wasm_proposal_print_status(); return 0; } +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + else if (!strncmp(argv[0], "--wasi-nn-graph=", 16)) { + char *token; + char *saveptr = NULL; + int token_count = 0; + char *tokens[12] = {0}; + + // encoding:tensorflowlite|openvino|llama target:cpu|gpu|tpu + // --wasi-nn-graph=encoding:target:model_file_path1:model_file_path2:model_file_path3:...... + token = strtok_r(argv[0] + 16, ":", &saveptr); + while (token) { + tokens[token_count] = token; + token_count++; + token = strtok_r(NULL, ":", &saveptr); + } + + if (token_count < 2) { + return print_help(); + } + + n_models = token_count - 2; + encoding = strdup(tokens[0]); + target = strdup(tokens[1]); + model_paths = malloc(n_models * sizeof(void*)); + for (int i = 0; i < n_models; i++) { + model_paths[i] = strdup(tokens[i + 2]); + } + if (token) + free(token); + } +#endif else { #if WASM_ENABLE_LIBC_WASI != 0 libc_wasi_parse_result_t result = @@ -974,6 +1022,11 @@ main(int argc, char *argv[]) libc_wasi_set_init_args(inst_args, argc, argv, &wasi_parse_ctx); #endif +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + wasi_nn_graph_registry_create(&nn_registry); + wasi_nn_graph_registry_set_args(nn_registry, encoding, target, n_models, model_paths); + wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(inst_args, nn_registry); +#endif /* instantiate the module */ wasm_module_inst = wasm_runtime_instantiate_ex2( wasm_module, inst_args, error_buf, sizeof(error_buf)); @@ -1092,6 +1145,15 @@ main(int argc, char *argv[]) #endif #if WASM_ENABLE_DEBUG_INTERP != 0 fail4: +#endif +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + wasi_nn_graph_registry_destroy(nn_registry); + for (uint32_t i = 0; i < n_models; i++) + if (model_paths[i]) + free(model_paths[i]); + free(model_paths); + free(encoding); + free(target); #endif /* destroy the module instance */ wasm_runtime_deinstantiate(wasm_module_inst); From 60a80118992cdad2ab9e1548cf0b8589acbf189f Mon Sep 17 00:00:00 2001 From: QiuYuan Han Date: Wed, 10 Dec 2025 17:19:49 +0800 Subject: [PATCH 02/28] Add a new error check for wasi_nn_load_by_name --- core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 9e3e741b69..1afb07df07 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -622,9 +622,10 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, bool is_loaded = false; uint32 model_idx = 0; char *global_model_path_i; + uint32_t global_n_graphs = wasm_runtime_get_wasi_nn_global_ctx_ngraphs(wasi_nn_global_ctx); // Assume filename got from user wasm app : max; sum; average; ... // Assume file path got from user cmd opt: /your/path1/max.tflite; /your/path2/sum.tflite; ...... - for (model_idx = 0; model_idx < wasm_runtime_get_wasi_nn_global_ctx_ngraphs(wasi_nn_global_ctx); model_idx++) + for (model_idx = 0; model_idx < global_n_graphs; model_idx++) { // Extract filename from file path global_model_path_i = wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(wasi_nn_global_ctx, model_idx); @@ -654,7 +655,9 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, } } - if (!is_loaded && (model_idx < MAX_GLOBAL_GRAPHS_PER_INST)) + if (!is_loaded && \ + (model_idx < MAX_GLOBAL_GRAPHS_PER_INST) && \ + (model_idx < global_n_graphs)) { NN_DBG_PRINTF("Model is not yet loaded, will add to global context"); call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res, @@ -679,6 +682,12 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, NN_ERR_PRINTF("No enough space for new model"); res = too_large; } + else if (model_idx >= global_n_graphs) + { + NN_ERR_PRINTF("Cannot find model %s, you should pass its path through --wasi-nn-graph", + nul_terminated_name); + res = not_found; + } goto fail; } fail: From 6dc9d01d5f8b84f80726caf32e0adaa187a6130a Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Thu, 11 Dec 2025 10:11:21 +0800 Subject: [PATCH 03/28] Use clang-format-18 to format source files --- core/iwasm/common/wasm_native.c | 8 +- core/iwasm/common/wasm_runtime_common.c | 73 ++++++++-------- core/iwasm/common/wasm_runtime_common.h | 49 ++++++----- core/iwasm/include/wasm_export.h | 36 ++++---- core/iwasm/interpreter/wasm_runtime.c | 7 +- core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 83 ++++++++++--------- .../libraries/wasi-nn/src/wasi_nn_backend.h | 2 +- .../libraries/wasi-nn/src/wasi_nn_llamacpp.c | 4 +- .../libraries/wasi-nn/src/wasi_nn_onnx.cpp | 4 +- .../libraries/wasi-nn/src/wasi_nn_openvino.c | 2 +- .../libraries/wasi-nn/src/wasi_nn_private.h | 5 +- .../wasi-nn/src/wasi_nn_tensorflowlite.cpp | 4 +- .../libraries/wasi-nn/test/test_tensorflow.c | 24 +++--- .../wasi-nn/test/test_tensorflow_quantized.c | 9 +- core/iwasm/libraries/wasi-nn/test/utils.c | 44 ++++++---- core/iwasm/libraries/wasi-nn/test/utils.h | 18 ++-- product-mini/platforms/posix/main.c | 16 ++-- 17 files changed, 215 insertions(+), 173 deletions(-) diff --git a/core/iwasm/common/wasm_native.c b/core/iwasm/common/wasm_native.c index 8938524db8..7781843914 100644 --- a/core/iwasm/common/wasm_native.c +++ b/core/iwasm/common/wasm_native.c @@ -486,9 +486,10 @@ wasm_runtime_get_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm) void wasm_runtime_set_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm, - WASINNGlobalContext *wasi_nn_ctx) + WASINNGlobalContext *wasi_nn_ctx) { - wasm_native_set_context(module_inst_comm, g_wasi_nn_context_key, wasi_nn_ctx); + wasm_native_set_context(module_inst_comm, g_wasi_nn_context_key, + wasi_nn_ctx); } static void @@ -611,7 +612,8 @@ wasm_native_init() #endif /* WASM_ENABLE_LIB_RATS */ #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - g_wasi_nn_context_key = wasm_native_create_context_key(wasi_nn_context_dtor); + g_wasi_nn_context_key = + wasm_native_create_context_key(wasi_nn_context_dtor); if (g_wasi_nn_context_key == NULL) { goto fail; } diff --git a/core/iwasm/common/wasm_runtime_common.c b/core/iwasm/common/wasm_runtime_common.c index 312c4b9c7b..685c7de045 100644 --- a/core/iwasm/common/wasm_runtime_common.c +++ b/core/iwasm/common/wasm_runtime_common.c @@ -1700,25 +1700,25 @@ wasm_runtime_instantiation_args_destroy(struct InstantiationArgs2 *p) struct wasi_nn_graph_registry; void -wasm_runtime_wasi_nn_graph_registry_args_set_defaults(struct wasi_nn_graph_registry *args) +wasm_runtime_wasi_nn_graph_registry_args_set_defaults( + struct wasi_nn_graph_registry *args) { memset(args, 0, sizeof(*args)); } bool -wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, const char* encoding, - const char* target, uint32_t n_graphs, - const char** graph_paths) +wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, + const char *encoding, const char *target, + uint32_t n_graphs, const char **graph_paths) { - if (!registry || !encoding || !target || !graph_paths) - { + if (!registry || !encoding || !target || !graph_paths) { return false; } registry->encoding = strdup(encoding); registry->target = strdup(target); registry->n_graphs = n_graphs; - registry->graph_paths = (uint32_t**)malloc(sizeof(uint32_t*) * n_graphs); - memset(registry->graph_paths, 0, sizeof(uint32_t*) * n_graphs); + registry->graph_paths = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); + memset(registry->graph_paths, 0, sizeof(uint32_t *) * n_graphs); for (uint32_t i = 0; i < registry->n_graphs; i++) registry->graph_paths[i] = strdup(graph_paths[i]); @@ -1740,12 +1740,11 @@ wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp) void wasi_nn_graph_registry_destroy(struct wasi_nn_graph_registry *registry) { - if (registry) - { + if (registry) { for (uint32_t i = 0; i < registry->n_graphs; i++) - if (registry->graph_paths[i]) - { - // wasi_nn_graph_registry_unregister_graph(registry, registry->name[i]); + if (registry->graph_paths[i]) { + // wasi_nn_graph_registry_unregister_graph(registry, + // registry->name[i]); free(registry->graph_paths[i]); } if (registry->encoding) @@ -8153,9 +8152,10 @@ wasm_runtime_check_and_update_last_used_shared_heap( #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 bool wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, - const char* encoding, const char* target, - const uint32_t n_graphs, char* graph_paths[], - char *error_buf, uint32_t error_buf_size) + const char *encoding, const char *target, + const uint32_t n_graphs, + char *graph_paths[], char *error_buf, + uint32_t error_buf_size) { WASINNGlobalContext *ctx; bool ret = false; @@ -8163,17 +8163,16 @@ wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, ctx = runtime_malloc(sizeof(*ctx), module_inst, error_buf, error_buf_size); if (!ctx) return false; - + ctx->encoding = strdup(encoding); ctx->target = strdup(target); ctx->n_graphs = n_graphs; - ctx->loaded = (uint32_t*)malloc(sizeof(uint32_t) * n_graphs); + ctx->loaded = (uint32_t *)malloc(sizeof(uint32_t) * n_graphs); memset(ctx->loaded, 0, sizeof(uint32_t) * n_graphs); - - ctx->graph_paths = (uint32_t**)malloc(sizeof(uint32_t*) * n_graphs); - memset(ctx->graph_paths, 0, sizeof(uint32_t*) * n_graphs); - for (uint32_t i = 0; i < n_graphs; i++) - { + + ctx->graph_paths = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); + memset(ctx->graph_paths, 0, sizeof(uint32_t *) * n_graphs); + for (uint32_t i = 0; i < n_graphs; i++) { ctx->graph_paths[i] = strdup(graph_paths[i]); } @@ -8187,10 +8186,10 @@ wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, void wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst) { - WASINNGlobalContext *wasi_nn_global_ctx = wasm_runtime_get_wasi_nn_global_ctx(module_inst); + WASINNGlobalContext *wasi_nn_global_ctx = + wasm_runtime_get_wasi_nn_global_ctx(module_inst); - for (uint32 i = 0; i < wasi_nn_global_ctx->n_graphs; i++) - { + for (uint32 i = 0; i < wasi_nn_global_ctx->n_graphs; i++) { // All graphs will be unregistered in deinit() if (wasi_nn_global_ctx->graph_paths[i]) free(wasi_nn_global_ctx->graph_paths[i]); @@ -8206,7 +8205,8 @@ wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst) } uint32_t -wasm_runtime_get_wasi_nn_global_ctx_ngraphs(WASINNGlobalContext *wasi_nn_global_ctx) +wasm_runtime_get_wasi_nn_global_ctx_ngraphs( + WASINNGlobalContext *wasi_nn_global_ctx) { if (wasi_nn_global_ctx) return wasi_nn_global_ctx->n_graphs; @@ -8215,7 +8215,8 @@ wasm_runtime_get_wasi_nn_global_ctx_ngraphs(WASINNGlobalContext *wasi_nn_global_ } char * -wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx) +wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i( + WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx) { if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) return wasi_nn_global_ctx->graph_paths[idx]; @@ -8224,7 +8225,8 @@ wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(WASINNGlobalContext *wasi_nn_g } uint32_t -wasm_runtime_get_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx) +wasm_runtime_get_wasi_nn_global_ctx_loaded_i( + WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx) { if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) return wasi_nn_global_ctx->loaded[idx]; @@ -8233,7 +8235,8 @@ wasm_runtime_get_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global } uint32_t -wasm_runtime_set_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value) +wasm_runtime_set_wasi_nn_global_ctx_loaded_i( + WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value) { if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) wasi_nn_global_ctx->loaded[idx] = value; @@ -8241,8 +8244,9 @@ wasm_runtime_set_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global return 0; } -char* -wasm_runtime_get_wasi_nn_global_ctx_encoding(WASINNGlobalContext *wasi_nn_global_ctx) +char * +wasm_runtime_get_wasi_nn_global_ctx_encoding( + WASINNGlobalContext *wasi_nn_global_ctx) { if (wasi_nn_global_ctx) return wasi_nn_global_ctx->encoding; @@ -8250,8 +8254,9 @@ wasm_runtime_get_wasi_nn_global_ctx_encoding(WASINNGlobalContext *wasi_nn_global return NULL; } -char* -wasm_runtime_get_wasi_nn_global_ctx_target(WASINNGlobalContext *wasi_nn_global_ctx) +char * +wasm_runtime_get_wasi_nn_global_ctx_target( + WASINNGlobalContext *wasi_nn_global_ctx) { if (wasi_nn_global_ctx) return wasi_nn_global_ctx->target; diff --git a/core/iwasm/common/wasm_runtime_common.h b/core/iwasm/common/wasm_runtime_common.h index 8d002bedcc..98ea3b68e3 100644 --- a/core/iwasm/common/wasm_runtime_common.h +++ b/core/iwasm/common/wasm_runtime_common.h @@ -547,12 +547,12 @@ typedef struct WASMModuleInstMemConsumption { #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 typedef struct WASINNGlobalContext { - char* encoding; - char* target; + char *encoding; + char *target; uint32_t n_graphs; uint32_t *loaded; - char** graph_paths; + char **graph_paths; } WASINNGlobalContext; #endif @@ -625,10 +625,10 @@ wasm_runtime_get_exec_env_tls(void); #if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0) struct wasi_nn_graph_registry { - char* encoding; - char* target; + char *encoding; + char *target; - char** graph_paths; + char **graph_paths; uint32_t n_graphs; }; @@ -811,9 +811,9 @@ wasm_runtime_instantiation_args_set_wasi_nn_graph_registry( struct InstantiationArgs2 *p, struct wasi_nn_graph_registry *registry); WASM_RUNTIME_API_EXTERN bool -wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, const char* encoding, - const char* target, uint32_t n_graphs, - const char** graph_paths); +wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, + const char *encoding, const char *target, + uint32_t n_graphs, const char **graph_paths); #endif /* See wasm_export.h for description */ @@ -1471,34 +1471,41 @@ wasm_runtime_check_and_update_last_used_shared_heap( #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 WASM_RUNTIME_API_EXTERN bool wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, - const char* encoding, const char* target, - const uint32_t n_graphs, char* graph_paths[], - char *error_buf, uint32_t error_buf_size); + const char *encoding, const char *target, + const uint32_t n_graphs, + char *graph_paths[], char *error_buf, + uint32_t error_buf_size); WASM_RUNTIME_API_EXTERN void wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst); WASM_RUNTIME_API_EXTERN void wasm_runtime_set_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, - WASINNGlobalContext *wasi_ctx); + WASINNGlobalContext *wasi_ctx); WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_get_wasi_nn_global_ctx_ngraphs(WASINNGlobalContext *wasi_nn_global_ctx); +wasm_runtime_get_wasi_nn_global_ctx_ngraphs( + WASINNGlobalContext *wasi_nn_global_ctx); WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); +wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i( + WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_get_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); +wasm_runtime_get_wasi_nn_global_ctx_loaded_i( + WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_set_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value); +wasm_runtime_set_wasi_nn_global_ctx_loaded_i( + WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value); -WASM_RUNTIME_API_EXTERN char* -wasm_runtime_get_wasi_nn_global_ctx_encoding(WASINNGlobalContext *wasi_nn_global_ctx); +WASM_RUNTIME_API_EXTERN char * +wasm_runtime_get_wasi_nn_global_ctx_encoding( + WASINNGlobalContext *wasi_nn_global_ctx); -WASM_RUNTIME_API_EXTERN char* -wasm_runtime_get_wasi_nn_global_ctx_target(WASINNGlobalContext *wasi_nn_global_ctx); +WASM_RUNTIME_API_EXTERN char * +wasm_runtime_get_wasi_nn_global_ctx_target( + WASINNGlobalContext *wasi_nn_global_ctx); #endif #ifdef __cplusplus diff --git a/core/iwasm/include/wasm_export.h b/core/iwasm/include/wasm_export.h index 50263f1823..16a9ad54bc 100644 --- a/core/iwasm/include/wasm_export.h +++ b/core/iwasm/include/wasm_export.h @@ -291,7 +291,7 @@ typedef struct InstantiationArgs { struct InstantiationArgs2; struct WASINNGlobalContext; -typedef struct WASINNGlobalContext *wasi_nn_global_context; +typedef struct WASINNGlobalContext *wasi_nn_global_context; #ifndef WASM_VALKIND_T_DEFINED #define WASM_VALKIND_T_DEFINED @@ -809,43 +809,51 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool( // struct InstantiationArgs2 *p, struct wasi_nn_graph_registry *registry); // WASM_RUNTIME_API_EXTERN bool -// wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, const char* encoding, +// wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, +// const char* encoding, // const char* target, uint32_t n_graphs, // const char** graph_paths); WASM_RUNTIME_API_EXTERN bool wasm_runtime_init_wasi_nn_global_ctx(wasm_module_inst_t module_inst, - const char* encoding, const char* target, - const uint32_t n_graphs, char* graph_paths[], - char *error_buf, uint32_t error_buf_size); + const char *encoding, const char *target, + const uint32_t n_graphs, + char *graph_paths[], char *error_buf, + uint32_t error_buf_size); WASM_RUNTIME_API_EXTERN void wasm_runtime_destroy_wasi_nn_global_ctx(wasm_module_inst_t module_inst); WASM_RUNTIME_API_EXTERN void wasm_runtime_set_wasi_nn_global_ctx(wasm_module_inst_t module_inst, - wasi_nn_global_context wasi_ctx); + wasi_nn_global_context wasi_ctx); WASM_RUNTIME_API_EXTERN wasi_nn_global_context wasm_runtime_get_wasi_nn_global_ctx(const wasm_module_inst_t module_inst); WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_get_wasi_nn_global_ctx_ngraphs(wasi_nn_global_context wasi_nn_global_ctx); +wasm_runtime_get_wasi_nn_global_ctx_ngraphs( + wasi_nn_global_context wasi_nn_global_ctx); WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx); +wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i( + wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx); WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_get_wasi_nn_global_ctx_loaded_i(wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx); +wasm_runtime_get_wasi_nn_global_ctx_loaded_i( + wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx); WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_set_wasi_nn_global_ctx_loaded_i(wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx, uint32_t value); +wasm_runtime_set_wasi_nn_global_ctx_loaded_i( + wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx, uint32_t value); -WASM_RUNTIME_API_EXTERN char* -wasm_runtime_get_wasi_nn_global_ctx_encoding(wasi_nn_global_context wasi_nn_global_ctx); +WASM_RUNTIME_API_EXTERN char * +wasm_runtime_get_wasi_nn_global_ctx_encoding( + wasi_nn_global_context wasi_nn_global_ctx); -WASM_RUNTIME_API_EXTERN char* -wasm_runtime_get_wasi_nn_global_ctx_target(wasi_nn_global_context wasi_nn_global_ctx); +WASM_RUNTIME_API_EXTERN char * +wasm_runtime_get_wasi_nn_global_ctx_target( + wasi_nn_global_context wasi_nn_global_ctx); /** * Instantiate a WASM module, with specified instantiation arguments diff --git a/core/iwasm/interpreter/wasm_runtime.c b/core/iwasm/interpreter/wasm_runtime.c index 79d4c73c2e..6c8f92975c 100644 --- a/core/iwasm/interpreter/wasm_runtime.c +++ b/core/iwasm/interpreter/wasm_runtime.c @@ -3301,13 +3301,14 @@ wasm_instantiate(WASMModule *module, WASMModuleInstance *parent, #endif #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - /* Store graphs' path into ctx. Graphs will be loaded until user app calls load_by_name */ + /* Store graphs' path into ctx. Graphs will be loaded until user app calls + * load_by_name */ // Do not consider load() for now struct wasi_nn_graph_registry *nn_registry = &args->nn_registry; if (!wasm_runtime_init_wasi_nn_global_ctx( (WASMModuleInstanceCommon *)module_inst, nn_registry->encoding, - nn_registry->target, nn_registry->n_graphs, nn_registry->graph_paths, - error_buf, error_buf_size)) { + nn_registry->target, nn_registry->n_graphs, + nn_registry->graph_paths, error_buf, error_buf_size)) { goto fail; } #endif diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 1afb07df07..519c799454 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -35,7 +35,7 @@ #define LLAMACPP_BACKEND_LIB "libwasi_nn_llamacpp" LIB_EXTENTION #define ONNX_BACKEND_LIB "libwasi_nn_onnx" LIB_EXTENTION -#define MAX_GLOBAL_GRAPHS_PER_INST 4 // ONNX only allows 4 graphs per instances +#define MAX_GLOBAL_GRAPHS_PER_INST 4 // ONNX only allows 4 graphs per instance /* Global variables */ static korp_mutex wasi_nn_lock; @@ -210,7 +210,8 @@ wasi_nn_destroy() * - model file format * - on device ML framework */ -static graph_encoding str2encoding(char* str_encoding) +static graph_encoding +str2encoding(char *str_encoding) { if (!str_encoding) { NN_ERR_PRINTF("Got empty string encoding"); @@ -227,10 +228,11 @@ static graph_encoding str2encoding(char* str_encoding) return onnx; else return unknown_backend; - // return autodetect; + // return autodetect; } -static execution_target str2target(char* str_target) +static execution_target +str2target(char *str_target) { if (!str_target) { NN_ERR_PRINTF("Got empty string target"); @@ -245,7 +247,7 @@ static execution_target str2target(char* str_target) return tpu; else return unsupported_target; - // return autodetect; + // return autodetect; } static graph_encoding @@ -605,87 +607,88 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, goto fail; } - wasi_nn_global_context wasi_nn_global_ctx = wasm_runtime_get_wasi_nn_global_ctx(instance); + wasi_nn_global_context wasi_nn_global_ctx = + wasm_runtime_get_wasi_nn_global_ctx(instance); if (!wasi_nn_global_ctx) { NN_ERR_PRINTF("global context is invalid"); res = not_found; goto fail; } - graph_encoding encoding = str2encoding(wasm_runtime_get_wasi_nn_global_ctx_encoding(wasi_nn_global_ctx)); - execution_target target = str2target(wasm_runtime_get_wasi_nn_global_ctx_target(wasi_nn_global_ctx)); + graph_encoding encoding = str2encoding( + wasm_runtime_get_wasi_nn_global_ctx_encoding(wasi_nn_global_ctx)); + execution_target target = str2target( + wasm_runtime_get_wasi_nn_global_ctx_target(wasi_nn_global_ctx)); // res = ensure_backend(instance, autodetect, wasi_nn_ctx); res = ensure_backend(instance, encoding, wasi_nn_ctx); if (res != success) goto fail; - + bool is_loaded = false; uint32 model_idx = 0; char *global_model_path_i; - uint32_t global_n_graphs = wasm_runtime_get_wasi_nn_global_ctx_ngraphs(wasi_nn_global_ctx); + uint32_t global_n_graphs = + wasm_runtime_get_wasi_nn_global_ctx_ngraphs(wasi_nn_global_ctx); // Assume filename got from user wasm app : max; sum; average; ... - // Assume file path got from user cmd opt: /your/path1/max.tflite; /your/path2/sum.tflite; ...... - for (model_idx = 0; model_idx < global_n_graphs; model_idx++) - { + // Assume file path got from user cmd opt: /your/path1/max.tflite; + // /your/path2/sum.tflite; ...... + for (model_idx = 0; model_idx < global_n_graphs; model_idx++) { // Extract filename from file path - global_model_path_i = wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(wasi_nn_global_ctx, model_idx); + global_model_path_i = wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i( + wasi_nn_global_ctx, model_idx); char *model_file_name; - const char *slash = strrchr(global_model_path_i, '/'); + const char *slash = strrchr(global_model_path_i, '/'); if (slash != NULL) { - model_file_name = (char*)(slash + 1); + model_file_name = (char *)(slash + 1); } else model_file_name = global_model_path_i; // Extract modelname from filename - char* model_name = NULL; + char *model_name = NULL; size_t model_name_len = 0; - char* dot = strrchr(model_file_name, '.'); - if (dot) - { + char *dot = strrchr(model_file_name, '.'); + if (dot) { model_name_len = dot - model_file_name; model_name = malloc(model_name_len + 1); strncpy(model_name, model_file_name, model_name_len); model_name[model_name_len] = '\0'; } - + if (model_name && strcmp(nul_terminated_name, model_name) == 0) { - is_loaded = wasm_runtime_get_wasi_nn_global_ctx_loaded_i(wasi_nn_global_ctx, model_idx); + is_loaded = wasm_runtime_get_wasi_nn_global_ctx_loaded_i( + wasi_nn_global_ctx, model_idx); break; } } - if (!is_loaded && \ - (model_idx < MAX_GLOBAL_GRAPHS_PER_INST) && \ - (model_idx < global_n_graphs)) - { + if (!is_loaded && (model_idx < MAX_GLOBAL_GRAPHS_PER_INST) + && (model_idx < global_n_graphs)) { NN_DBG_PRINTF("Model is not yet loaded, will add to global context"); call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res, - wasi_nn_ctx->backend_ctx, global_model_path_i, strlen(global_model_path_i), - encoding, target, g); + wasi_nn_ctx->backend_ctx, global_model_path_i, + strlen(global_model_path_i), encoding, target, g); if (res != success) goto fail; - - wasm_runtime_set_wasi_nn_global_ctx_loaded_i(wasi_nn_global_ctx, model_idx, 1); + + wasm_runtime_set_wasi_nn_global_ctx_loaded_i(wasi_nn_global_ctx, + model_idx, 1); res = success; } - else - { - if (is_loaded) - { + else { + if (is_loaded) { NN_DBG_PRINTF("Model is already loaded"); res = success; } - else if (model_idx >= MAX_GLOBAL_GRAPHS_PER_INST) - { + else if (model_idx >= MAX_GLOBAL_GRAPHS_PER_INST) { // No enlarge for now NN_ERR_PRINTF("No enough space for new model"); res = too_large; } - else if (model_idx >= global_n_graphs) - { - NN_ERR_PRINTF("Cannot find model %s, you should pass its path through --wasi-nn-graph", - nul_terminated_name); + else if (model_idx >= global_n_graphs) { + NN_ERR_PRINTF("Cannot find model %s, you should pass its path " + "through --wasi-nn-graph", + nul_terminated_name); res = not_found; } goto fail; diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h b/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h index 3108f2eef0..344e66550b 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h @@ -18,7 +18,7 @@ load(void *ctx, graph_builder_array *builder, graph_encoding encoding, __attribute__((visibility("default"))) wasi_nn_error load_by_name(void *tflite_ctx, const char *name, uint32_t namelen, - graph_encoding encoding, execution_target target, graph *g); + graph_encoding encoding, execution_target target, graph *g); __attribute__((visibility("default"))) wasi_nn_error load_by_name_with_config(void *ctx, const char *name, uint32_t namelen, diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c index fd09c2be08..7042affa70 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c @@ -338,8 +338,8 @@ __load_by_name_with_configuration(void *ctx, const char *filename, graph *g) } __attribute__((visibility("default"))) wasi_nn_error -load_by_name(void *ctx, const char *filename, uint32_t filename_len, - graph_encoding encoding, execution_target target, graph *g) +load_by_name(void *ctx, const char *filename, uint32_t filename_len, + graph_encoding encoding, execution_target target, graph *g) { struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx; diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp index e2283df0f3..947fa558e3 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp @@ -334,8 +334,8 @@ load(void *onnx_ctx, graph_builder_array *builder, graph_encoding encoding, } __attribute__((visibility("default"))) wasi_nn_error -load_by_name(void *onnx_ctx, const char *name, uint32_t filename_len, - graph_encoding encoding, execution_target target, graph *g) +load_by_name(void *onnx_ctx, const char *name, uint32_t filename_len, + graph_encoding encoding, execution_target target, graph *g) { if (!onnx_ctx) { return runtime_error; diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_openvino.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn_openvino.c index eec4f8190b..4739953605 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_openvino.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_openvino.c @@ -307,7 +307,7 @@ load(void *ctx, graph_builder_array *builder, graph_encoding encoding, __attribute__((visibility("default"))) wasi_nn_error load_by_name(void *ctx, const char *filename, uint32_t filename_len, - graph_encoding encoding, execution_target target, graph *g) + graph_encoding encoding, execution_target target, graph *g) { OpenVINOContext *ov_ctx = (OpenVINOContext *)ctx; struct OpenVINOGraph *graph; diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h b/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h index 5dcb173f42..7ea76eddb1 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h @@ -21,8 +21,9 @@ typedef struct { typedef wasi_nn_error (*LOAD)(void *, graph_builder_array *, graph_encoding, execution_target, graph *); -typedef wasi_nn_error (*LOAD_BY_NAME)(void *, const char *, uint32_t, graph_encoding, - execution_target, graph *); +typedef wasi_nn_error (*LOAD_BY_NAME)(void *, const char *, uint32_t, + graph_encoding, execution_target, + graph *); typedef wasi_nn_error (*LOAD_BY_NAME_WITH_CONFIG)(void *, const char *, uint32_t, void *, uint32_t, graph *); diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp index eb56a42f23..2b4832dc41 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp @@ -164,8 +164,8 @@ load(void *tflite_ctx, graph_builder_array *builder, graph_encoding encoding, } __attribute__((visibility("default"))) wasi_nn_error -load_by_name(void *tflite_ctx, const char *filename, uint32_t filename_len, - graph_encoding encoding, execution_target target,graph *g) +load_by_name(void *tflite_ctx, const char *filename, uint32_t filename_len, + graph_encoding encoding, execution_target target, graph *g) { TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx; diff --git a/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c b/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c index b3d6ba8037..a34c3be4bc 100644 --- a/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c +++ b/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c @@ -19,8 +19,8 @@ test_sum() input_info input = create_input(dims); uint32_t output_size = 0; - float *output = run_inference(input.input_tensor, input.dim, - &output_size, "sum", 1); + float *output = + run_inference(input.input_tensor, input.dim, &output_size, "sum", 1); assert((output_size / sizeof(float)) == 1); assert(fabs(output[0] - 300.0) < EPSILON); @@ -37,8 +37,8 @@ test_max() input_info input = create_input(dims); uint32_t output_size = 0; - float *output = run_inference(input.input_tensor, input.dim, - &output_size, "max", 1); + float *output = + run_inference(input.input_tensor, input.dim, &output_size, "max", 1); assert((output_size / sizeof(float)) == 1); assert(fabs(output[0] - 24.0) < EPSILON); @@ -56,8 +56,8 @@ test_average() input_info input = create_input(dims); uint32_t output_size = 0; - float *output = run_inference(input.input_tensor, input.dim, - &output_size, "average", 1); + float *output = run_inference(input.input_tensor, input.dim, &output_size, + "average", 1); assert((output_size / sizeof(float)) == 1); assert(fabs(output[0] - 12.0) < EPSILON); @@ -75,8 +75,8 @@ test_mult_dimensions() input_info input = create_input(dims); uint32_t output_size = 0; - float *output = run_inference(input.input_tensor, input.dim, - &output_size, "mult_dim", 1); + float *output = run_inference(input.input_tensor, input.dim, &output_size, + "mult_dim", 1); assert((output_size / sizeof(float)) == 9); for (int i = 0; i < 9; i++) @@ -94,8 +94,8 @@ test_mult_outputs() input_info input = create_input(dims); uint32_t output_size = 0; - float *output = run_inference(input.input_tensor, input.dim, - &output_size, "mult_out", 2); + float *output = run_inference(input.input_tensor, input.dim, &output_size, + "mult_out", 2); assert((output_size / sizeof(float)) == 8); // first tensor check @@ -113,7 +113,9 @@ test_mult_outputs() int main() { - NN_INFO_PRINTF("Usage:\niwasm --native-lib=./libwasi_nn_tflite.so --wasi-nn-graph=encoding:target:model_path1:model_path2:...:model_pathn test_tensorflow.wasm\""); + NN_INFO_PRINTF("Usage:\niwasm --native-lib=./libwasi_nn_tflite.so " + "--wasi-nn-graph=encoding:target:model_path1:model_path2:..." + ":model_pathN test_tensorflow.wasm\""); NN_INFO_PRINTF("################### Testing sum..."); test_sum(); diff --git a/core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c b/core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c index 0898c7ae2a..a55dc9bb89 100644 --- a/core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c +++ b/core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c @@ -22,9 +22,8 @@ test_average_quantized() input_info input = create_input(dims); uint32_t output_size = 0; - float *output = - run_inference(input.input_tensor, input.dim, &output_size, - "quantized_model", 1); + float *output = run_inference(input.input_tensor, input.dim, &output_size, + "quantized_model", 1); NN_INFO_PRINTF("Output size: %d", output_size); NN_INFO_PRINTF("Result: average is %f", output[0]); @@ -39,7 +38,9 @@ test_average_quantized() int main() { - NN_INFO_PRINTF("Usage:\niwasm --native-lib=./libwasi_nn_tflite.so --wasi-nn-graph=encoding:target:model_path1:model_path2:...:model_pathn test_tensorflow.wasm\""); + NN_INFO_PRINTF("Usage:\niwasm --native-lib=./libwasi_nn_tflite.so " + "--wasi-nn-graph=encoding:target:model_path1:model_path2:..." + ":model_pathN test_tensorflow.wasm\""); NN_INFO_PRINTF("################### Testing quantized model..."); test_average_quantized(); diff --git a/core/iwasm/libraries/wasi-nn/test/utils.c b/core/iwasm/libraries/wasi-nn/test/utils.c index 97ed08378e..0a99a95e50 100644 --- a/core/iwasm/libraries/wasi-nn/test/utils.c +++ b/core/iwasm/libraries/wasi-nn/test/utils.c @@ -9,7 +9,8 @@ #include WASI_NN_ERROR_TYPE -wasm_load(char *model_name, WASI_NN_NAME(graph) *g, WASI_NN_NAME(execution_target) target) +wasm_load(char *model_name, WASI_NN_NAME(graph) * g, + WASI_NN_NAME(execution_target) target) { FILE *pFile = fopen(model_name, "r"); if (pFile == NULL) @@ -38,13 +39,14 @@ wasm_load(char *model_name, WASI_NN_NAME(graph) *g, WASI_NN_NAME(execution_targe arr.buf = buffer; arr.size = result; - WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load)(&arr, result, WASI_NN_ENCODING_NAME(tensorflowlite), target, g); - // WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load)(&arr, 1, WASI_NN_ENCODING_NAME(tensorflowlite), target, g); + WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load)( + &arr, result, WASI_NN_ENCODING_NAME(tensorflowlite), target, g); #else WASI_NN_NAME(graph_builder_array) arr; arr.size = 1; - arr.buf = (WASI_NN_NAME(graph_builder) *)malloc(sizeof(WASI_NN_NAME(graph_builder))); + arr.buf = (WASI_NN_NAME(graph_builder) *)malloc( + sizeof(WASI_NN_NAME(graph_builder))); if (arr.buf == NULL) { fclose(pFile); free(buffer); @@ -54,7 +56,8 @@ wasm_load(char *model_name, WASI_NN_NAME(graph) *g, WASI_NN_NAME(execution_targe arr.buf[0].size = result; arr.buf[0].buf = buffer; - WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load)(&arr, WASI_NN_ENCODING_NAME(tensorflowlite), target, g); + WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load)( + &arr, WASI_NN_ENCODING_NAME(tensorflowlite), target, g); #endif fclose(pFile); @@ -64,20 +67,23 @@ wasm_load(char *model_name, WASI_NN_NAME(graph) *g, WASI_NN_NAME(execution_targe } WASI_NN_ERROR_TYPE -wasm_load_by_name(const char *model_name, WASI_NN_NAME(graph) *g) +wasm_load_by_name(const char *model_name, WASI_NN_NAME(graph) * g) { - WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load_by_name)(model_name, strlen(model_name), g); + WASI_NN_ERROR_TYPE res = + WASI_NN_NAME(load_by_name)(model_name, strlen(model_name), g); return res; } WASI_NN_ERROR_TYPE -wasm_init_execution_context(WASI_NN_NAME(graph) g, WASI_NN_NAME(graph_execution_context) *ctx) +wasm_init_execution_context(WASI_NN_NAME(graph) g, + WASI_NN_NAME(graph_execution_context) * ctx) { return WASI_NN_NAME(init_execution_context)(g, ctx); } WASI_NN_ERROR_TYPE -wasm_set_input(WASI_NN_NAME(graph_execution_context) ctx, float *input_tensor, uint32_t *dim) +wasm_set_input(WASI_NN_NAME(graph_execution_context) ctx, float *input_tensor, + uint32_t *dim) { WASI_NN_NAME(tensor_dimensions) dims; dims.size = INPUT_TENSOR_DIMS; @@ -103,7 +109,7 @@ wasm_set_input(WASI_NN_NAME(graph_execution_context) ctx, float *input_tensor, u tensor.dimensions = &dims; for (int i = 0; i < tensor.dimensions->size; ++i) tensor.dimensions->buf[i] = dim[i]; - tensor.type = WASI_NN_TYPE_NAME(fp32); + tensor.type = WASI_NN_TYPE_NAME(fp32); tensor.data = (uint8_t *)input_tensor; #endif @@ -120,20 +126,21 @@ wasm_compute(WASI_NN_NAME(graph_execution_context) ctx) } WASI_NN_ERROR_TYPE -wasm_get_output(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index, float *out_tensor, - uint32_t *out_size) +wasm_get_output(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index, + float *out_tensor, uint32_t *out_size) { #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - return WASI_NN_NAME(get_output)(ctx, index, (uint8_t *)out_tensor, MAX_OUTPUT_TENSOR_SIZE, out_size); + return WASI_NN_NAME(get_output)(ctx, index, (uint8_t *)out_tensor, + MAX_OUTPUT_TENSOR_SIZE, out_size); #else - return WASI_NN_NAME(get_output)(ctx, index, (uint8_t *)out_tensor, out_size); + return WASI_NN_NAME(get_output)(ctx, index, (uint8_t *)out_tensor, + out_size); #endif } float * -run_inference(float *input, uint32_t *input_size, - uint32_t *output_size, char *model_name, - uint32_t num_output_tensors) +run_inference(float *input, uint32_t *input_size, uint32_t *output_size, + char *model_name, uint32_t num_output_tensors) { WASI_NN_NAME(graph) graph; @@ -143,7 +150,8 @@ run_inference(float *input, uint32_t *input_size, } WASI_NN_NAME(graph_execution_context) ctx; - if (wasm_init_execution_context(graph, &ctx) != WASI_NN_ERROR_NAME(success)) { + if (wasm_init_execution_context(graph, &ctx) + != WASI_NN_ERROR_NAME(success)) { NN_ERR_PRINTF("Error when initialixing execution context."); exit(1); } diff --git a/core/iwasm/libraries/wasi-nn/test/utils.h b/core/iwasm/libraries/wasi-nn/test/utils.h index 45ba156a0f..5a5c03c3d7 100644 --- a/core/iwasm/libraries/wasi-nn/test/utils.h +++ b/core/iwasm/libraries/wasi-nn/test/utils.h @@ -25,27 +25,29 @@ typedef struct { /* wasi-nn wrappers */ WASI_NN_ERROR_TYPE -wasm_load(char *model_name, WASI_NN_NAME(graph) *g, WASI_NN_NAME(execution_target) target); +wasm_load(char *model_name, WASI_NN_NAME(graph) * g, + WASI_NN_NAME(execution_target) target); WASI_NN_ERROR_TYPE -wasm_init_execution_context(WASI_NN_NAME(graph) g, WASI_NN_NAME(graph_execution_context) *ctx); +wasm_init_execution_context(WASI_NN_NAME(graph) g, + WASI_NN_NAME(graph_execution_context) * ctx); WASI_NN_ERROR_TYPE -wasm_set_input(WASI_NN_NAME(graph_execution_context) ctx, float *input_tensor, uint32_t *dim); +wasm_set_input(WASI_NN_NAME(graph_execution_context) ctx, float *input_tensor, + uint32_t *dim); WASI_NN_ERROR_TYPE wasm_compute(WASI_NN_NAME(graph_execution_context) ctx); WASI_NN_ERROR_TYPE -wasm_get_output(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index, float *out_tensor, - uint32_t *out_size); +wasm_get_output(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index, + float *out_tensor, uint32_t *out_size); /* Utils */ float * -run_inference(float *input, uint32_t *input_size, - uint32_t *output_size, char *model_name, - uint32_t num_output_tensors); +run_inference(float *input, uint32_t *input_size, uint32_t *output_size, + char *model_name, uint32_t num_output_tensors); input_info create_input(int *dims); diff --git a/product-mini/platforms/posix/main.c b/product-mini/platforms/posix/main.c index ef99f2a842..7226f5b507 100644 --- a/product-mini/platforms/posix/main.c +++ b/product-mini/platforms/posix/main.c @@ -847,9 +847,9 @@ main(int argc, char *argv[]) char *token; char *saveptr = NULL; int token_count = 0; - char *tokens[12] = {0}; + char *tokens[12] = { 0 }; - // encoding:tensorflowlite|openvino|llama target:cpu|gpu|tpu + // encoding:tensorflowlite|openvino|llama target:cpu|gpu|tpu // --wasi-nn-graph=encoding:target:model_file_path1:model_file_path2:model_file_path3:...... token = strtok_r(argv[0] + 16, ":", &saveptr); while (token) { @@ -865,11 +865,11 @@ main(int argc, char *argv[]) n_models = token_count - 2; encoding = strdup(tokens[0]); target = strdup(tokens[1]); - model_paths = malloc(n_models * sizeof(void*)); + model_paths = malloc(n_models * sizeof(void *)); for (int i = 0; i < n_models; i++) { model_paths[i] = strdup(tokens[i + 2]); } - if (token) + if (token) free(token); } #endif @@ -1024,8 +1024,10 @@ main(int argc, char *argv[]) #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 wasi_nn_graph_registry_create(&nn_registry); - wasi_nn_graph_registry_set_args(nn_registry, encoding, target, n_models, model_paths); - wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(inst_args, nn_registry); + wasi_nn_graph_registry_set_args(nn_registry, encoding, target, n_models, + model_paths); + wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(inst_args, + nn_registry); #endif /* instantiate the module */ wasm_module_inst = wasm_runtime_instantiate_ex2( @@ -1149,7 +1151,7 @@ main(int argc, char *argv[]) #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 wasi_nn_graph_registry_destroy(nn_registry); for (uint32_t i = 0; i < n_models; i++) - if (model_paths[i]) + if (model_paths[i]) free(model_paths[i]); free(model_paths); free(encoding); From d42581bb0b92e11b0ea2e25abccf17fac0c99476 Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Thu, 11 Dec 2025 17:31:46 +0800 Subject: [PATCH 04/28] Free model_name --- core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 519c799454..7b39ca541c 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -658,8 +658,10 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, if (model_name && strcmp(nul_terminated_name, model_name) == 0) { is_loaded = wasm_runtime_get_wasi_nn_global_ctx_loaded_i( wasi_nn_global_ctx, model_idx); + free(model_name); break; } + free(model_name); } if (!is_loaded && (model_idx < MAX_GLOBAL_GRAPHS_PER_INST) From 1cca0b8789479f8c2749ad4faf94fa8f7aa98f33 Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Tue, 16 Dec 2025 10:07:36 +0800 Subject: [PATCH 05/28] Add new errno for new test cases --- .../libraries/wasi-nn/include/wasi_nn_types.h | 1 + core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 4 +- .../libraries/wasi-nn/test/test_tensorflow.c | 46 +++++++++++-------- core/iwasm/libraries/wasi-nn/test/utils.c | 10 +++- 4 files changed, 40 insertions(+), 21 deletions(-) diff --git a/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h b/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h index d77fe9a6cb..aea6554b8d 100644 --- a/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h +++ b/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h @@ -48,6 +48,7 @@ typedef enum { WASI_NN_ERROR_NAME(unsupported_operation), WASI_NN_ERROR_NAME(too_large), WASI_NN_ERROR_NAME(not_found), + WASI_NN_ERROR_NAME(not_loaded), // for WasmEdge-wasi-nn WASI_NN_ERROR_NAME(end_of_sequence) = 100, // End of Sequence Found. diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 7b39ca541c..8bb7e57cd6 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -688,10 +688,10 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, res = too_large; } else if (model_idx >= global_n_graphs) { - NN_ERR_PRINTF("Cannot find model %s, you should pass its path " + NN_ERR_PRINTF("Model %s is not loaded, you should pass its path " "through --wasi-nn-graph", nul_terminated_name); - res = not_found; + res = not_loaded; } goto fail; } diff --git a/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c b/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c index a34c3be4bc..d5147b2fbd 100644 --- a/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c +++ b/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c @@ -22,8 +22,10 @@ test_sum() float *output = run_inference(input.input_tensor, input.dim, &output_size, "sum", 1); - assert((output_size / sizeof(float)) == 1); - assert(fabs(output[0] - 300.0) < EPSILON); + if (output) { + assert((output_size / sizeof(float)) == 1); + assert(fabs(output[0] - 300.0) < EPSILON); + } free(input.dim); free(input.input_tensor); @@ -40,9 +42,11 @@ test_max() float *output = run_inference(input.input_tensor, input.dim, &output_size, "max", 1); - assert((output_size / sizeof(float)) == 1); - assert(fabs(output[0] - 24.0) < EPSILON); - NN_INFO_PRINTF("Result: max is %f", output[0]); + if (output) { + assert((output_size / sizeof(float)) == 1); + assert(fabs(output[0] - 24.0) < EPSILON); + NN_INFO_PRINTF("Result: max is %f", output[0]); + } free(input.dim); free(input.input_tensor); @@ -59,9 +63,11 @@ test_average() float *output = run_inference(input.input_tensor, input.dim, &output_size, "average", 1); - assert((output_size / sizeof(float)) == 1); - assert(fabs(output[0] - 12.0) < EPSILON); - NN_INFO_PRINTF("Result: average is %f", output[0]); + if (output) { + assert((output_size / sizeof(float)) == 1); + assert(fabs(output[0] - 12.0) < EPSILON); + NN_INFO_PRINTF("Result: average is %f", output[0]); + } free(input.dim); free(input.input_tensor); @@ -78,9 +84,11 @@ test_mult_dimensions() float *output = run_inference(input.input_tensor, input.dim, &output_size, "mult_dim", 1); - assert((output_size / sizeof(float)) == 9); - for (int i = 0; i < 9; i++) - assert(fabs(output[i] - i) < EPSILON); + if (output) { + assert((output_size / sizeof(float)) == 9); + for (int i = 0; i < 9; i++) + assert(fabs(output[i] - i) < EPSILON); + } free(input.dim); free(input.input_tensor); @@ -97,13 +105,15 @@ test_mult_outputs() float *output = run_inference(input.input_tensor, input.dim, &output_size, "mult_out", 2); - assert((output_size / sizeof(float)) == 8); - // first tensor check - for (int i = 0; i < 4; i++) - assert(fabs(output[i] - (i * 4 + 24)) < EPSILON); - // second tensor check - for (int i = 0; i < 4; i++) - assert(fabs(output[i + 4] - (i + 6)) < EPSILON); + if (output) { + assert((output_size / sizeof(float)) == 8); + // first tensor check + for (int i = 0; i < 4; i++) + assert(fabs(output[i] - (i * 4 + 24)) < EPSILON); + // second tensor check + for (int i = 0; i < 4; i++) + assert(fabs(output[i + 4] - (i + 6)) < EPSILON); + } free(input.dim); free(input.input_tensor); diff --git a/core/iwasm/libraries/wasi-nn/test/utils.c b/core/iwasm/libraries/wasi-nn/test/utils.c index 0a99a95e50..f6f9bd0961 100644 --- a/core/iwasm/libraries/wasi-nn/test/utils.c +++ b/core/iwasm/libraries/wasi-nn/test/utils.c @@ -144,7 +144,15 @@ run_inference(float *input, uint32_t *input_size, uint32_t *output_size, { WASI_NN_NAME(graph) graph; - if (wasm_load_by_name(model_name, &graph) != WASI_NN_ERROR_NAME(success)) { + WASI_NN_ERROR_TYPE res = wasm_load_by_name(model_name, &graph); + + if (res == WASI_NN_ERROR_NAME(not_loaded)) { + NN_INFO_PRINTF("Model %s is not loaded, you should pass its path " + "through --wasi-nn-graph", + model_name); + return NULL; + } + else if (res != WASI_NN_ERROR_NAME(success)) { NN_ERR_PRINTF("Error when loading model."); exit(1); } From 9e039707c2b2ea9e040704e126ae25f5e2c845c6 Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Wed, 17 Dec 2025 14:59:26 +0800 Subject: [PATCH 06/28] Fix bugs --- core/iwasm/common/wasm_runtime_common.c | 4 ++-- core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 2 +- core/iwasm/libraries/wasi-nn/test/requirements.txt | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/core/iwasm/common/wasm_runtime_common.c b/core/iwasm/common/wasm_runtime_common.c index 685c7de045..51ecf23f95 100644 --- a/core/iwasm/common/wasm_runtime_common.c +++ b/core/iwasm/common/wasm_runtime_common.c @@ -1725,12 +1725,12 @@ wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, return true; } -int +static int wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp) { struct wasi_nn_graph_registry *args = wasm_runtime_malloc(sizeof(*args)); if (args == NULL) { - return false; + return -1; } wasm_runtime_wasi_nn_graph_registry_args_set_defaults(args); *registryp = args; diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 8bb7e57cd6..892ca5dd79 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -35,7 +35,7 @@ #define LLAMACPP_BACKEND_LIB "libwasi_nn_llamacpp" LIB_EXTENTION #define ONNX_BACKEND_LIB "libwasi_nn_onnx" LIB_EXTENTION -#define MAX_GLOBAL_GRAPHS_PER_INST 4 // ONNX only allows 4 graphs per instance +#define MAX_GLOBAL_GRAPHS_PER_INST 4 /* Global variables */ static korp_mutex wasi_nn_lock; diff --git a/core/iwasm/libraries/wasi-nn/test/requirements.txt b/core/iwasm/libraries/wasi-nn/test/requirements.txt index 0c80fd6b12..2145e3736a 100644 --- a/core/iwasm/libraries/wasi-nn/test/requirements.txt +++ b/core/iwasm/libraries/wasi-nn/test/requirements.txt @@ -1,2 +1,2 @@ -tensorflow==2.14.0 +tensorflow==2.12.0 numpy==1.24.4 From cd1c7f9134989ec7075e1e5b6261489a9c18deed Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Fri, 19 Dec 2025 14:47:38 +0800 Subject: [PATCH 07/28] Rename some parameters --- core/iwasm/common/wasm_native.c | 4 +-- core/iwasm/common/wasm_runtime_common.c | 27 +++++++++--------- core/iwasm/common/wasm_runtime_common.h | 28 +++++++++---------- core/iwasm/include/wasm_export.h | 16 ----------- core/iwasm/interpreter/wasm_runtime.c | 5 ++-- .../wasi-nn/include/wasi_ephemeral_nn.h | 3 -- core/iwasm/libraries/wasi-nn/test/build.sh | 2 ++ .../libraries/wasi-nn/test/test_tensorflow.c | 20 +++++++++++++ core/iwasm/libraries/wasi-nn/test/utils.h | 4 +++ product-mini/platforms/posix/main.c | 14 ++++++---- 10 files changed, 65 insertions(+), 58 deletions(-) diff --git a/core/iwasm/common/wasm_native.c b/core/iwasm/common/wasm_native.c index 7781843914..b8430520af 100644 --- a/core/iwasm/common/wasm_native.c +++ b/core/iwasm/common/wasm_native.c @@ -25,7 +25,7 @@ static NativeSymbolsList g_native_symbols_list = NULL; static void *g_wasi_context_key; #endif /* WASM_ENABLE_LIBC_WASI */ -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 static void *g_wasi_nn_context_key; #endif @@ -477,7 +477,7 @@ wasi_context_dtor(WASMModuleInstanceCommon *inst, void *ctx) } #endif /* end of WASM_ENABLE_LIBC_WASI */ -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 WASINNGlobalContext * wasm_runtime_get_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm) { diff --git a/core/iwasm/common/wasm_runtime_common.c b/core/iwasm/common/wasm_runtime_common.c index 51ecf23f95..88c5b18e0b 100644 --- a/core/iwasm/common/wasm_runtime_common.c +++ b/core/iwasm/common/wasm_runtime_common.c @@ -1696,20 +1696,19 @@ wasm_runtime_instantiation_args_destroy(struct InstantiationArgs2 *p) wasm_runtime_free(p); } -#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0) -struct wasi_nn_graph_registry; +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +WASINNArguments; void -wasm_runtime_wasi_nn_graph_registry_args_set_defaults( - struct wasi_nn_graph_registry *args) +wasm_runtime_wasi_nn_graph_registry_args_set_defaults(WASINNArguments *args) { memset(args, 0, sizeof(*args)); } bool -wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, - const char *encoding, const char *target, - uint32_t n_graphs, const char **graph_paths) +wasi_nn_graph_registry_set_args(WASINNArguments *registry, const char *encoding, + const char *target, uint32_t n_graphs, + const char **graph_paths) { if (!registry || !encoding || !target || !graph_paths) { return false; @@ -1725,10 +1724,10 @@ wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, return true; } -static int -wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp) +int +wasi_nn_graph_registry_create(WASINNArguments **registryp) { - struct wasi_nn_graph_registry *args = wasm_runtime_malloc(sizeof(*args)); + WASINNArguments *args = wasm_runtime_malloc(sizeof(*args)); if (args == NULL) { return -1; } @@ -1738,7 +1737,7 @@ wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp) } void -wasi_nn_graph_registry_destroy(struct wasi_nn_graph_registry *registry) +wasi_nn_graph_registry_destroy(WASINNArguments *registry) { if (registry) { for (uint32_t i = 0; i < registry->n_graphs; i++) @@ -1854,10 +1853,10 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool( wasi_args->set_by_user = true; } #endif /* WASM_ENABLE_LIBC_WASI != 0 */ -#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0) +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 void wasm_runtime_instantiation_args_set_wasi_nn_graph_registry( - struct InstantiationArgs2 *p, struct wasi_nn_graph_registry *registry) + struct InstantiationArgs2 *p, WASINNArguments *registry) { p->nn_registry = *registry; } @@ -8149,7 +8148,7 @@ wasm_runtime_check_and_update_last_used_shared_heap( } #endif -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 bool wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, const char *encoding, const char *target, diff --git a/core/iwasm/common/wasm_runtime_common.h b/core/iwasm/common/wasm_runtime_common.h index 98ea3b68e3..f06cca5da6 100644 --- a/core/iwasm/common/wasm_runtime_common.h +++ b/core/iwasm/common/wasm_runtime_common.h @@ -545,7 +545,7 @@ typedef struct WASMModuleInstMemConsumption { uint32 exports_size; } WASMModuleInstMemConsumption; -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 typedef struct WASINNGlobalContext { char *encoding; char *target; @@ -623,20 +623,20 @@ WASMExecEnv * wasm_runtime_get_exec_env_tls(void); #endif -#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0) -struct wasi_nn_graph_registry { +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +typedef struct WASINNArguments { char *encoding; char *target; char **graph_paths; uint32_t n_graphs; -}; +} WASINNArguments; WASM_RUNTIME_API_EXTERN int -wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp); +wasi_nn_graph_registry_create(WASINNArguments **registryp); WASM_RUNTIME_API_EXTERN void -wasi_nn_graph_registry_destroy(struct wasi_nn_graph_registry *registry); +wasi_nn_graph_registry_destroy(WASINNArguments *registry); #endif struct InstantiationArgs2 { @@ -644,8 +644,8 @@ struct InstantiationArgs2 { #if WASM_ENABLE_LIBC_WASI != 0 WASIArguments wasi; #endif -#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0) - struct wasi_nn_graph_registry nn_registry; +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + WASINNArguments nn_registry; #endif }; @@ -805,15 +805,15 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool( struct InstantiationArgs2 *p, const char *ns_lookup_pool[], uint32 ns_lookup_pool_size); -#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0) +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 WASM_RUNTIME_API_EXTERN void wasm_runtime_instantiation_args_set_wasi_nn_graph_registry( - struct InstantiationArgs2 *p, struct wasi_nn_graph_registry *registry); + struct InstantiationArgs2 *p, WASINNArguments *registry); WASM_RUNTIME_API_EXTERN bool -wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, - const char *encoding, const char *target, - uint32_t n_graphs, const char **graph_paths); +wasi_nn_graph_registry_set_args(WASINNArguments *registry, const char *encoding, + const char *target, uint32_t n_graphs, + const char **graph_paths); #endif /* See wasm_export.h for description */ @@ -1468,7 +1468,7 @@ wasm_runtime_check_and_update_last_used_shared_heap( uint8 **shared_heap_base_addr_adj_p, bool is_memory64); #endif -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 WASM_RUNTIME_API_EXTERN bool wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, const char *encoding, const char *target, diff --git a/core/iwasm/include/wasm_export.h b/core/iwasm/include/wasm_export.h index 16a9ad54bc..17a15688cf 100644 --- a/core/iwasm/include/wasm_export.h +++ b/core/iwasm/include/wasm_export.h @@ -798,22 +798,6 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool( struct InstantiationArgs2 *p, const char *ns_lookup_pool[], uint32_t ns_lookup_pool_size); -// WASM_RUNTIME_API_EXTERN int -// wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp); - -// WASM_RUNTIME_API_EXTERN void -// wasi_nn_graph_registry_destroy(struct wasi_nn_graph_registry *registry); - -// WASM_RUNTIME_API_EXTERN void -// wasm_runtime_instantiation_args_set_wasi_nn_graph_registry( -// struct InstantiationArgs2 *p, struct wasi_nn_graph_registry *registry); - -// WASM_RUNTIME_API_EXTERN bool -// wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, -// const char* encoding, -// const char* target, uint32_t n_graphs, -// const char** graph_paths); - WASM_RUNTIME_API_EXTERN bool wasm_runtime_init_wasi_nn_global_ctx(wasm_module_inst_t module_inst, const char *encoding, const char *target, diff --git a/core/iwasm/interpreter/wasm_runtime.c b/core/iwasm/interpreter/wasm_runtime.c index 6c8f92975c..4dc6bd2537 100644 --- a/core/iwasm/interpreter/wasm_runtime.c +++ b/core/iwasm/interpreter/wasm_runtime.c @@ -3300,11 +3300,10 @@ wasm_instantiate(WASMModule *module, WASMModuleInstance *parent, } #endif -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 /* Store graphs' path into ctx. Graphs will be loaded until user app calls * load_by_name */ - // Do not consider load() for now - struct wasi_nn_graph_registry *nn_registry = &args->nn_registry; + WASINNArguments *nn_registry = &args->nn_registry; if (!wasm_runtime_init_wasi_nn_global_ctx( (WASMModuleInstanceCommon *)module_inst, nn_registry->encoding, nn_registry->target, nn_registry->n_graphs, diff --git a/core/iwasm/libraries/wasi-nn/include/wasi_ephemeral_nn.h b/core/iwasm/libraries/wasi-nn/include/wasi_ephemeral_nn.h index 83beba98f5..86afc42674 100644 --- a/core/iwasm/libraries/wasi-nn/include/wasi_ephemeral_nn.h +++ b/core/iwasm/libraries/wasi-nn/include/wasi_ephemeral_nn.h @@ -7,6 +7,3 @@ #define WASI_NN_NAME(name) wasi_ephemeral_nn_##name #include "wasi_nn.h" - -// #undef WASM_ENABLE_WASI_EPHEMERAL_NN -// #undef WASI_NN_NAME diff --git a/core/iwasm/libraries/wasi-nn/test/build.sh b/core/iwasm/libraries/wasi-nn/test/build.sh index 79d65d730c..5c99706d04 100755 --- a/core/iwasm/libraries/wasi-nn/test/build.sh +++ b/core/iwasm/libraries/wasi-nn/test/build.sh @@ -22,6 +22,8 @@ CURR_PATH=$(cd $(dirname $0) && pwd -P) /opt/wasi-sdk/bin/clang \ --target=wasm32-wasi \ + -DWASM_ENABLE_WASI_NN=1 \ + -DWASM_ENABLE_WASI_EPHEMERAL_NN=1 \ -DNN_LOG_LEVEL=1 \ -Wl,--allow-undefined \ -I../include -I../src/utils \ diff --git a/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c b/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c index d5147b2fbd..d276dd0ac8 100644 --- a/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c +++ b/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c @@ -23,7 +23,11 @@ test_sum() run_inference(input.input_tensor, input.dim, &output_size, "sum", 1); if (output) { +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 assert((output_size / sizeof(float)) == 1); +#elif WASM_ENABLE_WASI_NN != 0 + assert(output_size == 1); +#endif assert(fabs(output[0] - 300.0) < EPSILON); } @@ -43,7 +47,11 @@ test_max() run_inference(input.input_tensor, input.dim, &output_size, "max", 1); if (output) { +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 assert((output_size / sizeof(float)) == 1); +#elif WASM_ENABLE_WASI_NN != 0 + assert(output_size == 1); +#endif assert(fabs(output[0] - 24.0) < EPSILON); NN_INFO_PRINTF("Result: max is %f", output[0]); } @@ -64,7 +72,11 @@ test_average() "average", 1); if (output) { +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 assert((output_size / sizeof(float)) == 1); +#elif WASM_ENABLE_WASI_NN != 0 + assert(output_size == 1); +#endif assert(fabs(output[0] - 12.0) < EPSILON); NN_INFO_PRINTF("Result: average is %f", output[0]); } @@ -85,7 +97,11 @@ test_mult_dimensions() "mult_dim", 1); if (output) { +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 assert((output_size / sizeof(float)) == 9); +#elif WASM_ENABLE_WASI_NN != 0 + assert(output_size == 9); +#endif for (int i = 0; i < 9; i++) assert(fabs(output[i] - i) < EPSILON); } @@ -106,7 +122,11 @@ test_mult_outputs() "mult_out", 2); if (output) { +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 assert((output_size / sizeof(float)) == 8); +#elif WASM_ENABLE_WASI_NN != 0 + assert(output_size == 8); +#endif // first tensor check for (int i = 0; i < 4; i++) assert(fabs(output[i] - (i * 4 + 24)) < EPSILON); diff --git a/core/iwasm/libraries/wasi-nn/test/utils.h b/core/iwasm/libraries/wasi-nn/test/utils.h index 5a5c03c3d7..8d2683fff4 100644 --- a/core/iwasm/libraries/wasi-nn/test/utils.h +++ b/core/iwasm/libraries/wasi-nn/test/utils.h @@ -8,7 +8,11 @@ #include +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 #include "wasi_ephemeral_nn.h" +#elif WASM_ENABLE_WASI_NN != 0 +#include "wasi_nn.h" +#endif #include "wasi_nn_types.h" #define MAX_MODEL_SIZE 85000000 diff --git a/product-mini/platforms/posix/main.c b/product-mini/platforms/posix/main.c index 7226f5b507..7765bbef54 100644 --- a/product-mini/platforms/posix/main.c +++ b/product-mini/platforms/posix/main.c @@ -20,6 +20,8 @@ #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 #include "wasi_ephemeral_nn.h" +#elif WASM_ENABLE_WASI_NN != 0 +#include "wasi_nn.h" #endif #include "../common/wasm_proposal.c" @@ -120,7 +122,7 @@ print_help(void) #if WASM_ENABLE_STATIC_PGO != 0 printf(" --gen-prof-file= Generate LLVM PGO (Profile-Guided Optimization) profile file\n"); #endif -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 printf(" --wasi-nn-graph=encoding:target:::...:\n"); printf(" Set encoding, target and model_paths for wasi-nn. target can be\n"); printf(" cpu|gpu|tpu, encoding can be tensorflowlite|openvino|llama|onnx|\n"); @@ -645,8 +647,8 @@ main(int argc, char *argv[]) int timeout_ms = -1; #endif -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - struct wasi_nn_graph_registry *nn_registry; +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + struct WASINNArguments *nn_registry; char *encoding, *target; uint32_t n_models = 0; char **model_paths; @@ -842,7 +844,7 @@ main(int argc, char *argv[]) wasm_proposal_print_status(); return 0; } -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 else if (!strncmp(argv[0], "--wasi-nn-graph=", 16)) { char *token; char *saveptr = NULL; @@ -1022,7 +1024,7 @@ main(int argc, char *argv[]) libc_wasi_set_init_args(inst_args, argc, argv, &wasi_parse_ctx); #endif -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 wasi_nn_graph_registry_create(&nn_registry); wasi_nn_graph_registry_set_args(nn_registry, encoding, target, n_models, model_paths); @@ -1148,7 +1150,7 @@ main(int argc, char *argv[]) #if WASM_ENABLE_DEBUG_INTERP != 0 fail4: #endif -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 wasi_nn_graph_registry_destroy(nn_registry); for (uint32_t i = 0; i < n_models; i++) if (model_paths[i]) From c73e4aa43259a8db0438852ce2756a58ad78a26c Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Fri, 19 Dec 2025 14:55:42 +0800 Subject: [PATCH 08/28] Revert tensorflow version --- core/iwasm/libraries/wasi-nn/test/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/iwasm/libraries/wasi-nn/test/requirements.txt b/core/iwasm/libraries/wasi-nn/test/requirements.txt index 2145e3736a..1643b91b00 100644 --- a/core/iwasm/libraries/wasi-nn/test/requirements.txt +++ b/core/iwasm/libraries/wasi-nn/test/requirements.txt @@ -1,2 +1,2 @@ -tensorflow==2.12.0 +tensorflow==2.12.1 numpy==1.24.4 From 0234fe08b9842b5e27123e4ad75a65d8ad6a4c42 Mon Sep 17 00:00:00 2001 From: qinzh Date: Tue, 6 Jan 2026 10:02:52 +0800 Subject: [PATCH 09/28] CICD: retrigger checks From 4b93110b560889d796d06e859debc04abcee2e4c Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Fri, 23 Jan 2026 13:54:01 +0800 Subject: [PATCH 10/28] Remove internal function declarations from wasm_export.h --- core/iwasm/common/wasm_runtime_common.c | 2 +- core/iwasm/common/wasm_runtime_common.h | 7 +++++ core/iwasm/include/wasm_export.h | 30 +++++-------------- .../libraries/wasi-nn/include/wasi_nn_types.h | 1 - core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 4 +-- core/iwasm/libraries/wasi-nn/test/utils.c | 2 +- 6 files changed, 19 insertions(+), 27 deletions(-) diff --git a/core/iwasm/common/wasm_runtime_common.c b/core/iwasm/common/wasm_runtime_common.c index 88c5b18e0b..bebbaf3bc4 100644 --- a/core/iwasm/common/wasm_runtime_common.c +++ b/core/iwasm/common/wasm_runtime_common.c @@ -1697,7 +1697,7 @@ wasm_runtime_instantiation_args_destroy(struct InstantiationArgs2 *p) } #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 -WASINNArguments; +typedef struct WASINNArguments WASINNArguments; void wasm_runtime_wasi_nn_graph_registry_args_set_defaults(WASINNArguments *args) diff --git a/core/iwasm/common/wasm_runtime_common.h b/core/iwasm/common/wasm_runtime_common.h index f06cca5da6..e15c813ae4 100644 --- a/core/iwasm/common/wasm_runtime_common.h +++ b/core/iwasm/common/wasm_runtime_common.h @@ -1483,6 +1483,13 @@ WASM_RUNTIME_API_EXTERN void wasm_runtime_set_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, WASINNGlobalContext *wasi_ctx); +WASM_RUNTIME_API_EXTERN WASINNGlobalContext * +wasm_runtime_get_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm); + +WASM_RUNTIME_API_EXTERN void +wasm_runtime_set_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm, + WASINNGlobalContext *wasi_nn_ctx); + WASM_RUNTIME_API_EXTERN uint32_t wasm_runtime_get_wasi_nn_global_ctx_ngraphs( WASINNGlobalContext *wasi_nn_global_ctx); diff --git a/core/iwasm/include/wasm_export.h b/core/iwasm/include/wasm_export.h index 17a15688cf..e7220f870b 100644 --- a/core/iwasm/include/wasm_export.h +++ b/core/iwasm/include/wasm_export.h @@ -291,7 +291,7 @@ typedef struct InstantiationArgs { struct InstantiationArgs2; struct WASINNGlobalContext; -typedef struct WASINNGlobalContext *wasi_nn_global_context; +typedef struct WASINNGlobalContext WASINNGlobalContext; #ifndef WASM_VALKIND_T_DEFINED #define WASM_VALKIND_T_DEFINED @@ -798,46 +798,32 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool( struct InstantiationArgs2 *p, const char *ns_lookup_pool[], uint32_t ns_lookup_pool_size); -WASM_RUNTIME_API_EXTERN bool -wasm_runtime_init_wasi_nn_global_ctx(wasm_module_inst_t module_inst, - const char *encoding, const char *target, - const uint32_t n_graphs, - char *graph_paths[], char *error_buf, - uint32_t error_buf_size); - -WASM_RUNTIME_API_EXTERN void -wasm_runtime_destroy_wasi_nn_global_ctx(wasm_module_inst_t module_inst); - -WASM_RUNTIME_API_EXTERN void -wasm_runtime_set_wasi_nn_global_ctx(wasm_module_inst_t module_inst, - wasi_nn_global_context wasi_ctx); - -WASM_RUNTIME_API_EXTERN wasi_nn_global_context +WASM_RUNTIME_API_EXTERN WASINNGlobalContext * wasm_runtime_get_wasi_nn_global_ctx(const wasm_module_inst_t module_inst); WASM_RUNTIME_API_EXTERN uint32_t wasm_runtime_get_wasi_nn_global_ctx_ngraphs( - wasi_nn_global_context wasi_nn_global_ctx); + WASINNGlobalContext * wasi_nn_global_ctx); WASM_RUNTIME_API_EXTERN char * wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i( - wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx); + WASINNGlobalContext * wasi_nn_global_ctx, uint32_t idx); WASM_RUNTIME_API_EXTERN uint32_t wasm_runtime_get_wasi_nn_global_ctx_loaded_i( - wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx); + WASINNGlobalContext * wasi_nn_global_ctx, uint32_t idx); WASM_RUNTIME_API_EXTERN uint32_t wasm_runtime_set_wasi_nn_global_ctx_loaded_i( - wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx, uint32_t value); + WASINNGlobalContext * wasi_nn_global_ctx, uint32_t idx, uint32_t value); WASM_RUNTIME_API_EXTERN char * wasm_runtime_get_wasi_nn_global_ctx_encoding( - wasi_nn_global_context wasi_nn_global_ctx); + WASINNGlobalContext * wasi_nn_global_ctx); WASM_RUNTIME_API_EXTERN char * wasm_runtime_get_wasi_nn_global_ctx_target( - wasi_nn_global_context wasi_nn_global_ctx); + WASINNGlobalContext * wasi_nn_global_ctx); /** * Instantiate a WASM module, with specified instantiation arguments diff --git a/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h b/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h index aea6554b8d..d77fe9a6cb 100644 --- a/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h +++ b/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h @@ -48,7 +48,6 @@ typedef enum { WASI_NN_ERROR_NAME(unsupported_operation), WASI_NN_ERROR_NAME(too_large), WASI_NN_ERROR_NAME(not_found), - WASI_NN_ERROR_NAME(not_loaded), // for WasmEdge-wasi-nn WASI_NN_ERROR_NAME(end_of_sequence) = 100, // End of Sequence Found. diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 892ca5dd79..01f0b6bab3 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -607,7 +607,7 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, goto fail; } - wasi_nn_global_context wasi_nn_global_ctx = + WASINNGlobalContext *wasi_nn_global_ctx = wasm_runtime_get_wasi_nn_global_ctx(instance); if (!wasi_nn_global_ctx) { NN_ERR_PRINTF("global context is invalid"); @@ -691,7 +691,7 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, NN_ERR_PRINTF("Model %s is not loaded, you should pass its path " "through --wasi-nn-graph", nul_terminated_name); - res = not_loaded; + res = not_found; } goto fail; } diff --git a/core/iwasm/libraries/wasi-nn/test/utils.c b/core/iwasm/libraries/wasi-nn/test/utils.c index f6f9bd0961..b54d2f7479 100644 --- a/core/iwasm/libraries/wasi-nn/test/utils.c +++ b/core/iwasm/libraries/wasi-nn/test/utils.c @@ -146,7 +146,7 @@ run_inference(float *input, uint32_t *input_size, uint32_t *output_size, WASI_NN_ERROR_TYPE res = wasm_load_by_name(model_name, &graph); - if (res == WASI_NN_ERROR_NAME(not_loaded)) { + if (res == WASI_NN_ERROR_NAME(not_found)) { NN_INFO_PRINTF("Model %s is not loaded, you should pass its path " "through --wasi-nn-graph", model_name); From 6bc75126bed8e4612d6aa0a14371d9b3f7588c11 Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Fri, 23 Jan 2026 15:32:41 +0800 Subject: [PATCH 11/28] Move wasi_nn parameters parse logic into libc_wasi.c --- core/iwasm/common/wasm_runtime_common.c | 117 +++++++++++----------- product-mini/platforms/common/libc_wasi.c | 62 ++++++++++++ product-mini/platforms/posix/main.c | 58 +++-------- 3 files changed, 134 insertions(+), 103 deletions(-) diff --git a/core/iwasm/common/wasm_runtime_common.c b/core/iwasm/common/wasm_runtime_common.c index bebbaf3bc4..17c8e6fedf 100644 --- a/core/iwasm/common/wasm_runtime_common.c +++ b/core/iwasm/common/wasm_runtime_common.c @@ -1696,65 +1696,6 @@ wasm_runtime_instantiation_args_destroy(struct InstantiationArgs2 *p) wasm_runtime_free(p); } -#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 -typedef struct WASINNArguments WASINNArguments; - -void -wasm_runtime_wasi_nn_graph_registry_args_set_defaults(WASINNArguments *args) -{ - memset(args, 0, sizeof(*args)); -} - -bool -wasi_nn_graph_registry_set_args(WASINNArguments *registry, const char *encoding, - const char *target, uint32_t n_graphs, - const char **graph_paths) -{ - if (!registry || !encoding || !target || !graph_paths) { - return false; - } - registry->encoding = strdup(encoding); - registry->target = strdup(target); - registry->n_graphs = n_graphs; - registry->graph_paths = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); - memset(registry->graph_paths, 0, sizeof(uint32_t *) * n_graphs); - for (uint32_t i = 0; i < registry->n_graphs; i++) - registry->graph_paths[i] = strdup(graph_paths[i]); - - return true; -} - -int -wasi_nn_graph_registry_create(WASINNArguments **registryp) -{ - WASINNArguments *args = wasm_runtime_malloc(sizeof(*args)); - if (args == NULL) { - return -1; - } - wasm_runtime_wasi_nn_graph_registry_args_set_defaults(args); - *registryp = args; - return 0; -} - -void -wasi_nn_graph_registry_destroy(WASINNArguments *registry) -{ - if (registry) { - for (uint32_t i = 0; i < registry->n_graphs; i++) - if (registry->graph_paths[i]) { - // wasi_nn_graph_registry_unregister_graph(registry, - // registry->name[i]); - free(registry->graph_paths[i]); - } - if (registry->encoding) - free(registry->encoding); - if (registry->target) - free(registry->target); - free(registry); - } -} -#endif - void wasm_runtime_instantiation_args_set_default_stack_size( struct InstantiationArgs2 *p, uint32 v) @@ -1853,7 +1794,65 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool( wasi_args->set_by_user = true; } #endif /* WASM_ENABLE_LIBC_WASI != 0 */ + #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +typedef struct WASINNArguments WASINNArguments; + +void +wasm_runtime_wasi_nn_graph_registry_args_set_defaults(WASINNArguments *args) +{ + memset(args, 0, sizeof(*args)); +} + +bool +wasi_nn_graph_registry_set_args(WASINNArguments *registry, const char *encoding, + const char *target, uint32_t n_graphs, + const char **graph_paths) +{ + if (!registry || !encoding || !target || !graph_paths) { + return false; + } + registry->encoding = strdup(encoding); + registry->target = strdup(target); + registry->n_graphs = n_graphs; + registry->graph_paths = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); + memset(registry->graph_paths, 0, sizeof(uint32_t *) * n_graphs); + for (uint32_t i = 0; i < registry->n_graphs; i++) + registry->graph_paths[i] = strdup(graph_paths[i]); + + return true; +} + +int +wasi_nn_graph_registry_create(WASINNArguments **registryp) +{ + WASINNArguments *args = wasm_runtime_malloc(sizeof(*args)); + if (args == NULL) { + return -1; + } + wasm_runtime_wasi_nn_graph_registry_args_set_defaults(args); + *registryp = args; + return 0; +} + +void +wasi_nn_graph_registry_destroy(WASINNArguments *registry) +{ + if (registry) { + for (uint32_t i = 0; i < registry->n_graphs; i++) + if (registry->graph_paths[i]) { + // wasi_nn_graph_registry_unregister_graph(registry, + // registry->name[i]); + free(registry->graph_paths[i]); + } + if (registry->encoding) + free(registry->encoding); + if (registry->target) + free(registry->target); + free(registry); + } +} + void wasm_runtime_instantiation_args_set_wasi_nn_graph_registry( struct InstantiationArgs2 *p, WASINNArguments *registry) diff --git a/product-mini/platforms/common/libc_wasi.c b/product-mini/platforms/common/libc_wasi.c index 8e45d7329c..118fce7664 100644 --- a/product-mini/platforms/common/libc_wasi.c +++ b/product-mini/platforms/common/libc_wasi.c @@ -21,6 +21,13 @@ typedef struct { uint32 ns_lookup_pool_size; } libc_wasi_parse_context_t; +typedef struct { + const char *encoding; + const char *target; + const char *graph_paths[10]; + uint32 n_graphs; +} wasi_nn_parse_context_t; + typedef enum { LIBC_WASI_PARSE_RESULT_OK = 0, LIBC_WASI_PARSE_RESULT_NEED_HELP, @@ -177,3 +184,58 @@ libc_wasi_set_init_args(struct InstantiationArgs2 *args, int argc, char **argv, wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool( args, ctx->ns_lookup_pool, ctx->ns_lookup_pool_size); } + +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +libc_wasi_parse_result_t +wasi_nn_parse(char **argv, wasi_nn_parse_context_t *ctx) +{ + char *token; + char *saveptr = NULL; + int token_count = 0, ret = 0; + char *tokens[12] = { 0 }; + + // encoding:tensorflowlite|openvino|llama target:cpu|gpu|tpu + // --wasi-nn-graph=encoding:target:model_file_path1:model_file_path2:model_file_path3:...... + token = strtok_r(argv[0] + 16, ":", &saveptr); + while (token) { + tokens[token_count] = token; + token_count++; + token = strtok_r(NULL, ":", &saveptr); + } + + if (token_count < 2) { + ret = LIBC_WASI_PARSE_RESULT_NEED_HELP; + goto fail; + } + + ctx->n_graphs = token_count - 2; + ctx->encoding = strdup(tokens[0]); + ctx->target = strdup(tokens[1]); + for (int i = 0; i < ctx->n_graphs; i++) { + ctx->graph_paths[i] = strdup(tokens[i + 2]); + } + +fail: + if (token) + free(token); + + + return ret; +} + +static void +wasi_nn_set_init_args(struct InstantiationArgs2 *args, struct WASINNArguments *nn_registry, wasi_nn_parse_context_t *ctx) +{ + wasi_nn_graph_registry_create(&nn_registry); + wasi_nn_graph_registry_set_args(nn_registry, ctx->encoding, ctx->target, ctx->n_graphs, + ctx->graph_paths); + wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(args, + nn_registry); + + for (uint32_t i = 0; i < ctx->n_graphs; i++) + if (ctx->graph_paths[i]) + free(ctx->graph_paths[i]); + free(ctx->encoding); + free(ctx->target); +} +#endif \ No newline at end of file diff --git a/product-mini/platforms/posix/main.c b/product-mini/platforms/posix/main.c index 7765bbef54..81e5e49468 100644 --- a/product-mini/platforms/posix/main.c +++ b/product-mini/platforms/posix/main.c @@ -648,10 +648,10 @@ main(int argc, char *argv[]) #endif #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + wasi_nn_parse_context_t wasi_nn_parse_ctx; struct WASINNArguments *nn_registry; - char *encoding, *target; - uint32_t n_models = 0; - char **model_paths; + + memset(&wasi_nn_parse_ctx, 0, sizeof(wasi_nn_parse_ctx)); #endif #if WASM_ENABLE_LIBC_WASI != 0 @@ -844,35 +844,18 @@ main(int argc, char *argv[]) wasm_proposal_print_status(); return 0; } -#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +#if WASM_ENABLE_LIBC_WASI != 0 && (WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0) else if (!strncmp(argv[0], "--wasi-nn-graph=", 16)) { - char *token; - char *saveptr = NULL; - int token_count = 0; - char *tokens[12] = { 0 }; - - // encoding:tensorflowlite|openvino|llama target:cpu|gpu|tpu - // --wasi-nn-graph=encoding:target:model_file_path1:model_file_path2:model_file_path3:...... - token = strtok_r(argv[0] + 16, ":", &saveptr); - while (token) { - tokens[token_count] = token; - token_count++; - token = strtok_r(NULL, ":", &saveptr); - } - - if (token_count < 2) { - return print_help(); - } - - n_models = token_count - 2; - encoding = strdup(tokens[0]); - target = strdup(tokens[1]); - model_paths = malloc(n_models * sizeof(void *)); - for (int i = 0; i < n_models; i++) { - model_paths[i] = strdup(tokens[i + 2]); + libc_wasi_parse_result_t result = + wasi_nn_parse(argv, &wasi_nn_parse_ctx); + switch (result) { + case LIBC_WASI_PARSE_RESULT_OK: + continue; + case LIBC_WASI_PARSE_RESULT_NEED_HELP: + return print_help(); + case LIBC_WASI_PARSE_RESULT_BAD_PARAM: + return 1; } - if (token) - free(token); } #endif else { @@ -1025,11 +1008,7 @@ main(int argc, char *argv[]) #endif #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - wasi_nn_graph_registry_create(&nn_registry); - wasi_nn_graph_registry_set_args(nn_registry, encoding, target, n_models, - model_paths); - wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(inst_args, - nn_registry); + wasi_nn_set_init_args(inst_args, nn_registry, &wasi_nn_parse_ctx); #endif /* instantiate the module */ wasm_module_inst = wasm_runtime_instantiate_ex2( @@ -1149,15 +1128,6 @@ main(int argc, char *argv[]) #endif #if WASM_ENABLE_DEBUG_INTERP != 0 fail4: -#endif -#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - wasi_nn_graph_registry_destroy(nn_registry); - for (uint32_t i = 0; i < n_models; i++) - if (model_paths[i]) - free(model_paths[i]); - free(model_paths); - free(encoding); - free(target); #endif /* destroy the module instance */ wasm_runtime_deinstantiate(wasm_module_inst); From 12d8a18813374889981a944a9aa2341387a730e0 Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Mon, 26 Jan 2026 11:17:39 +0800 Subject: [PATCH 12/28] Allow multiple --wasi-nn-graphs --- core/iwasm/common/wasm_runtime_common.c | 59 ++++++++----- core/iwasm/common/wasm_runtime_common.h | 22 ++--- core/iwasm/include/wasm_export.h | 8 +- core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 84 +++++++++---------- .../libraries/wasi-nn/test/test_tensorflow.c | 5 +- .../wasi-nn/test/test_tensorflow_quantized.c | 5 +- product-mini/platforms/common/libc_wasi.c | 37 ++++---- product-mini/platforms/posix/main.c | 7 +- 8 files changed, 129 insertions(+), 98 deletions(-) diff --git a/core/iwasm/common/wasm_runtime_common.c b/core/iwasm/common/wasm_runtime_common.c index 17c8e6fedf..4410b66ac5 100644 --- a/core/iwasm/common/wasm_runtime_common.c +++ b/core/iwasm/common/wasm_runtime_common.c @@ -1805,20 +1805,28 @@ wasm_runtime_wasi_nn_graph_registry_args_set_defaults(WASINNArguments *args) } bool -wasi_nn_graph_registry_set_args(WASINNArguments *registry, const char *encoding, - const char *target, uint32_t n_graphs, +wasi_nn_graph_registry_set_args(WASINNArguments *registry, const char **encoding, + const char **target, uint32_t n_graphs, const char **graph_paths) { if (!registry || !encoding || !target || !graph_paths) { return false; } - registry->encoding = strdup(encoding); - registry->target = strdup(target); + registry->n_graphs = n_graphs; + registry->target = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); + registry->encoding = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); registry->graph_paths = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); + memset(registry->target, 0, sizeof(uint32_t *) * n_graphs); + memset(registry->encoding, 0, sizeof(uint32_t *) * n_graphs); memset(registry->graph_paths, 0, sizeof(uint32_t *) * n_graphs); + for (uint32_t i = 0; i < registry->n_graphs; i++) + { registry->graph_paths[i] = strdup(graph_paths[i]); + registry->encoding[i] = strdup(encoding[i]); + registry->target[i] = strdup(target[i]); + } return true; } @@ -1841,14 +1849,12 @@ wasi_nn_graph_registry_destroy(WASINNArguments *registry) if (registry) { for (uint32_t i = 0; i < registry->n_graphs; i++) if (registry->graph_paths[i]) { - // wasi_nn_graph_registry_unregister_graph(registry, - // registry->name[i]); free(registry->graph_paths[i]); + if (registry->encoding[i]) + free(registry->encoding[i]); + if (registry->target[i]) + free(registry->target[i]); } - if (registry->encoding) - free(registry->encoding); - if (registry->target) - free(registry->target); free(registry); } } @@ -8150,7 +8156,7 @@ wasm_runtime_check_and_update_last_used_shared_heap( #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 bool wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, - const char *encoding, const char *target, + const char **encoding, const char **target, const uint32_t n_graphs, char *graph_paths[], char *error_buf, uint32_t error_buf_size) @@ -8162,16 +8168,21 @@ wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, if (!ctx) return false; - ctx->encoding = strdup(encoding); - ctx->target = strdup(target); ctx->n_graphs = n_graphs; + + ctx->encoding = (uint32_t *)malloc(sizeof(uint32_t) * n_graphs); + memset(ctx->encoding, 0, sizeof(uint32_t) * n_graphs); + ctx->target = (uint32_t *)malloc(sizeof(uint32_t) * n_graphs); + memset(ctx->target, 0, sizeof(uint32_t) * n_graphs); ctx->loaded = (uint32_t *)malloc(sizeof(uint32_t) * n_graphs); memset(ctx->loaded, 0, sizeof(uint32_t) * n_graphs); - ctx->graph_paths = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); memset(ctx->graph_paths, 0, sizeof(uint32_t *) * n_graphs); + for (uint32_t i = 0; i < n_graphs; i++) { ctx->graph_paths[i] = strdup(graph_paths[i]); + ctx->target[i] = strdup(target[i]); + ctx->encoding[i] = strdup(encoding[i]); } wasm_runtime_set_wasi_nn_global_ctx(module_inst, ctx); @@ -8191,6 +8202,10 @@ wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst) // All graphs will be unregistered in deinit() if (wasi_nn_global_ctx->graph_paths[i]) free(wasi_nn_global_ctx->graph_paths[i]); + if (wasi_nn_global_ctx->encoding[i]) + free(wasi_nn_global_ctx->encoding[i]); + if (wasi_nn_global_ctx->encoding[i]) + free(wasi_nn_global_ctx->target[i]); } free(wasi_nn_global_ctx->encoding); free(wasi_nn_global_ctx->target); @@ -8243,21 +8258,21 @@ wasm_runtime_set_wasi_nn_global_ctx_loaded_i( } char * -wasm_runtime_get_wasi_nn_global_ctx_encoding( - WASINNGlobalContext *wasi_nn_global_ctx) +wasm_runtime_get_wasi_nn_global_ctx_encoding_i( + WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx) { - if (wasi_nn_global_ctx) - return wasi_nn_global_ctx->encoding; + if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) + return wasi_nn_global_ctx->encoding[idx]; return NULL; } char * -wasm_runtime_get_wasi_nn_global_ctx_target( - WASINNGlobalContext *wasi_nn_global_ctx) +wasm_runtime_get_wasi_nn_global_ctx_target_i( + WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx) { - if (wasi_nn_global_ctx) - return wasi_nn_global_ctx->target; + if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) + return wasi_nn_global_ctx->target[idx]; return NULL; } diff --git a/core/iwasm/common/wasm_runtime_common.h b/core/iwasm/common/wasm_runtime_common.h index e15c813ae4..1bf6701ea2 100644 --- a/core/iwasm/common/wasm_runtime_common.h +++ b/core/iwasm/common/wasm_runtime_common.h @@ -547,8 +547,8 @@ typedef struct WASMModuleInstMemConsumption { #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 typedef struct WASINNGlobalContext { - char *encoding; - char *target; + char **encoding; + char **target; uint32_t n_graphs; uint32_t *loaded; @@ -625,8 +625,8 @@ wasm_runtime_get_exec_env_tls(void); #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 typedef struct WASINNArguments { - char *encoding; - char *target; + char **encoding; + char **target; char **graph_paths; uint32_t n_graphs; @@ -811,8 +811,8 @@ wasm_runtime_instantiation_args_set_wasi_nn_graph_registry( struct InstantiationArgs2 *p, WASINNArguments *registry); WASM_RUNTIME_API_EXTERN bool -wasi_nn_graph_registry_set_args(WASINNArguments *registry, const char *encoding, - const char *target, uint32_t n_graphs, +wasi_nn_graph_registry_set_args(WASINNArguments *registry, const char **encoding, + const char **target, uint32_t n_graphs, const char **graph_paths); #endif @@ -1471,7 +1471,7 @@ wasm_runtime_check_and_update_last_used_shared_heap( #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 WASM_RUNTIME_API_EXTERN bool wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, - const char *encoding, const char *target, + const char **encoding, const char **target, const uint32_t n_graphs, char *graph_paths[], char *error_buf, uint32_t error_buf_size); @@ -1507,12 +1507,12 @@ wasm_runtime_set_wasi_nn_global_ctx_loaded_i( WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value); WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_global_ctx_encoding( - WASINNGlobalContext *wasi_nn_global_ctx); +wasm_runtime_get_wasi_nn_global_ctx_encoding_i( + WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_global_ctx_target( - WASINNGlobalContext *wasi_nn_global_ctx); +wasm_runtime_get_wasi_nn_global_ctx_target_i( + WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); #endif #ifdef __cplusplus diff --git a/core/iwasm/include/wasm_export.h b/core/iwasm/include/wasm_export.h index e7220f870b..f6e6bc533b 100644 --- a/core/iwasm/include/wasm_export.h +++ b/core/iwasm/include/wasm_export.h @@ -818,12 +818,12 @@ wasm_runtime_set_wasi_nn_global_ctx_loaded_i( WASINNGlobalContext * wasi_nn_global_ctx, uint32_t idx, uint32_t value); WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_global_ctx_encoding( - WASINNGlobalContext * wasi_nn_global_ctx); +wasm_runtime_get_wasi_nn_global_ctx_encoding_i( + WASINNGlobalContext * wasi_nn_global_ctx, uint32_t idx); WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_global_ctx_target( - WASINNGlobalContext * wasi_nn_global_ctx); +wasm_runtime_get_wasi_nn_global_ctx_target_i( + WASINNGlobalContext * wasi_nn_global_ctx, uint32_t idx); /** * Instantiate a WASM module, with specified instantiation arguments diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 01f0b6bab3..931e93eca7 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -614,24 +614,15 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, res = not_found; goto fail; } - graph_encoding encoding = str2encoding( - wasm_runtime_get_wasi_nn_global_ctx_encoding(wasi_nn_global_ctx)); - execution_target target = str2target( - wasm_runtime_get_wasi_nn_global_ctx_target(wasi_nn_global_ctx)); - - // res = ensure_backend(instance, autodetect, wasi_nn_ctx); - res = ensure_backend(instance, encoding, wasi_nn_ctx); - if (res != success) - goto fail; bool is_loaded = false; uint32 model_idx = 0; char *global_model_path_i; uint32_t global_n_graphs = wasm_runtime_get_wasi_nn_global_ctx_ngraphs(wasi_nn_global_ctx); - // Assume filename got from user wasm app : max; sum; average; ... - // Assume file path got from user cmd opt: /your/path1/max.tflite; - // /your/path2/sum.tflite; ...... + // Model got from user wasm app : modelA; modelB... + // Filelist got from user cmd opt: /path1/modelA.tflite; + // /path/modelB.tflite; ...... for (model_idx = 0; model_idx < global_n_graphs; model_idx++) { // Extract filename from file path global_model_path_i = wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i( @@ -655,45 +646,54 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, model_name[model_name_len] = '\0'; } - if (model_name && strcmp(nul_terminated_name, model_name) == 0) { - is_loaded = wasm_runtime_get_wasi_nn_global_ctx_loaded_i( - wasi_nn_global_ctx, model_idx); + if (model_name && strcmp(nul_terminated_name, model_name) != 0) { free(model_name); - break; + continue; } + is_loaded = wasm_runtime_get_wasi_nn_global_ctx_loaded_i( + wasi_nn_global_ctx, model_idx); free(model_name); - } - if (!is_loaded && (model_idx < MAX_GLOBAL_GRAPHS_PER_INST) - && (model_idx < global_n_graphs)) { - NN_DBG_PRINTF("Model is not yet loaded, will add to global context"); - call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res, - wasi_nn_ctx->backend_ctx, global_model_path_i, - strlen(global_model_path_i), encoding, target, g); + graph_encoding encoding = str2encoding( + wasm_runtime_get_wasi_nn_global_ctx_encoding_i(wasi_nn_global_ctx, model_idx)); + execution_target target = str2target( + wasm_runtime_get_wasi_nn_global_ctx_target_i(wasi_nn_global_ctx, model_idx)); + + // res = ensure_backend(instance, autodetect, wasi_nn_ctx); + res = ensure_backend(instance, encoding, wasi_nn_ctx); if (res != success) goto fail; - wasm_runtime_set_wasi_nn_global_ctx_loaded_i(wasi_nn_global_ctx, - model_idx, 1); - res = success; - } - else { - if (is_loaded) { - NN_DBG_PRINTF("Model is already loaded"); + if (!is_loaded && (model_idx < MAX_GLOBAL_GRAPHS_PER_INST) + && (model_idx < global_n_graphs)) { + NN_DBG_PRINTF("Model is not yet loaded, will add to global context"); + call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res, + wasi_nn_ctx->backend_ctx, global_model_path_i, + strlen(global_model_path_i), encoding, target, g); + if (res != success) + goto fail; + + wasm_runtime_set_wasi_nn_global_ctx_loaded_i(wasi_nn_global_ctx, + model_idx, 1); res = success; + break; } - else if (model_idx >= MAX_GLOBAL_GRAPHS_PER_INST) { - // No enlarge for now - NN_ERR_PRINTF("No enough space for new model"); - res = too_large; - } - else if (model_idx >= global_n_graphs) { - NN_ERR_PRINTF("Model %s is not loaded, you should pass its path " - "through --wasi-nn-graph", - nul_terminated_name); - res = not_found; - } - goto fail; + } + + if (is_loaded) { + NN_DBG_PRINTF("Model is already loaded"); + res = success; + } + else if (model_idx >= MAX_GLOBAL_GRAPHS_PER_INST) { + // No enlarge for now + NN_ERR_PRINTF("No enough space for new model"); + res = too_large; + } + else if (model_idx >= global_n_graphs) { + NN_ERR_PRINTF("Model %s is not loaded, you should pass its path " + "through --wasi-nn-graph", + nul_terminated_name); + res = not_found; } fail: if (nul_terminated_name != NULL) { diff --git a/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c b/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c index d276dd0ac8..47504c9dea 100644 --- a/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c +++ b/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c @@ -144,8 +144,9 @@ int main() { NN_INFO_PRINTF("Usage:\niwasm --native-lib=./libwasi_nn_tflite.so " - "--wasi-nn-graph=encoding:target:model_path1:model_path2:..." - ":model_pathN test_tensorflow.wasm\""); + "--wasi-nn-graph=encodingA:targetA: " + "--wasi-nn-graph=encodingB:targetB:..." + " test_tensorflow.wasm"); NN_INFO_PRINTF("################### Testing sum..."); test_sum(); diff --git a/core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c b/core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c index a55dc9bb89..2f63e9b0d4 100644 --- a/core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c +++ b/core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c @@ -39,8 +39,9 @@ int main() { NN_INFO_PRINTF("Usage:\niwasm --native-lib=./libwasi_nn_tflite.so " - "--wasi-nn-graph=encoding:target:model_path1:model_path2:..." - ":model_pathN test_tensorflow.wasm\""); + "--wasi-nn-graph=encodingA:targetA: " + "--wasi-nn-graph=encodingB:targetB:..." + " test_tensorflow_quantized.wasm"); NN_INFO_PRINTF("################### Testing quantized model..."); test_average_quantized(); diff --git a/product-mini/platforms/common/libc_wasi.c b/product-mini/platforms/common/libc_wasi.c index 118fce7664..44e8b2e0c4 100644 --- a/product-mini/platforms/common/libc_wasi.c +++ b/product-mini/platforms/common/libc_wasi.c @@ -22,8 +22,8 @@ typedef struct { } libc_wasi_parse_context_t; typedef struct { - const char *encoding; - const char *target; + const char *encoding[10]; + const char *target[10]; const char *graph_paths[10]; uint32 n_graphs; } wasi_nn_parse_context_t; @@ -189,13 +189,23 @@ libc_wasi_set_init_args(struct InstantiationArgs2 *args, int argc, char **argv, libc_wasi_parse_result_t wasi_nn_parse(char **argv, wasi_nn_parse_context_t *ctx) { + if ('\0' == argv[16]) + return LIBC_WASI_PARSE_RESULT_NEED_HELP; + + if (ctx->n_graphs >= sizeof(ctx->graph_paths) / sizeof(char *)) { + printf("Only allow max graph number %d\n", + (int)(sizeof(ctx->graph_paths) / sizeof(char *))); + return LIBC_WASI_PARSE_RESULT_BAD_PARAM; + } + char *token; char *saveptr = NULL; int token_count = 0, ret = 0; char *tokens[12] = { 0 }; // encoding:tensorflowlite|openvino|llama target:cpu|gpu|tpu - // --wasi-nn-graph=encoding:target:model_file_path1:model_file_path2:model_file_path3:...... + // --wasi-nn-graph=encoding1:target1:model_file_path1 + // --wasi-nn-graph=encoding2:target2:model_file_path2 ... token = strtok_r(argv[0] + 16, ":", &saveptr); while (token) { tokens[token_count] = token; @@ -203,22 +213,18 @@ wasi_nn_parse(char **argv, wasi_nn_parse_context_t *ctx) token = strtok_r(NULL, ":", &saveptr); } - if (token_count < 2) { + if (token_count != 3) { ret = LIBC_WASI_PARSE_RESULT_NEED_HELP; goto fail; } - ctx->n_graphs = token_count - 2; - ctx->encoding = strdup(tokens[0]); - ctx->target = strdup(tokens[1]); - for (int i = 0; i < ctx->n_graphs; i++) { - ctx->graph_paths[i] = strdup(tokens[i + 2]); - } + ctx->encoding[ctx->n_graphs] = strdup(tokens[0]); + ctx->target[ctx->n_graphs] = strdup(tokens[1]); + ctx->graph_paths[ctx->n_graphs++] = strdup(tokens[2]); fail: if (token) free(token); - return ret; } @@ -226,16 +232,19 @@ wasi_nn_parse(char **argv, wasi_nn_parse_context_t *ctx) static void wasi_nn_set_init_args(struct InstantiationArgs2 *args, struct WASINNArguments *nn_registry, wasi_nn_parse_context_t *ctx) { - wasi_nn_graph_registry_create(&nn_registry); wasi_nn_graph_registry_set_args(nn_registry, ctx->encoding, ctx->target, ctx->n_graphs, ctx->graph_paths); wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(args, nn_registry); for (uint32_t i = 0; i < ctx->n_graphs; i++) + { if (ctx->graph_paths[i]) free(ctx->graph_paths[i]); - free(ctx->encoding); - free(ctx->target); + if (ctx->encoding[i]) + free(ctx->encoding[i]); + if (ctx->target[i]) + free(ctx->target[i]); + } } #endif \ No newline at end of file diff --git a/product-mini/platforms/posix/main.c b/product-mini/platforms/posix/main.c index 81e5e49468..6c81a365cf 100644 --- a/product-mini/platforms/posix/main.c +++ b/product-mini/platforms/posix/main.c @@ -123,7 +123,8 @@ print_help(void) printf(" --gen-prof-file= Generate LLVM PGO (Profile-Guided Optimization) profile file\n"); #endif #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - printf(" --wasi-nn-graph=encoding:target:::...:\n"); + printf(" --wasi-nn-graph=encodingA:targetB:\n"); + printf(" --wasi-nn-graph=encodingA:targetB:...\n"); printf(" Set encoding, target and model_paths for wasi-nn. target can be\n"); printf(" cpu|gpu|tpu, encoding can be tensorflowlite|openvino|llama|onnx|\n"); printf(" tensorflow|pytorch|ggml|autodetect\n"); @@ -1008,6 +1009,7 @@ main(int argc, char *argv[]) #endif #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + wasi_nn_graph_registry_create(&nn_registry); wasi_nn_set_init_args(inst_args, nn_registry, &wasi_nn_parse_ctx); #endif /* instantiate the module */ @@ -1128,6 +1130,9 @@ main(int argc, char *argv[]) #endif #if WASM_ENABLE_DEBUG_INTERP != 0 fail4: +#endif +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + wasi_nn_graph_registry_destroy(nn_registry); #endif /* destroy the module instance */ wasm_runtime_deinstantiate(wasm_module_inst); From 8dbba4644a5304fc2dd98b7458fa557430e7a85f Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Mon, 26 Jan 2026 11:23:19 +0800 Subject: [PATCH 13/28] Use clang-format-18 to format files --- core/iwasm/common/wasm_runtime_common.c | 19 ++++++++------- core/iwasm/common/wasm_runtime_common.h | 6 ++--- core/iwasm/include/wasm_export.h | 12 +++++----- core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 27 ++++++++++++---------- product-mini/platforms/common/libc_wasi.c | 15 ++++++------ product-mini/platforms/posix/main.c | 3 ++- 6 files changed, 43 insertions(+), 39 deletions(-) diff --git a/core/iwasm/common/wasm_runtime_common.c b/core/iwasm/common/wasm_runtime_common.c index 4410b66ac5..0e9621b21c 100644 --- a/core/iwasm/common/wasm_runtime_common.c +++ b/core/iwasm/common/wasm_runtime_common.c @@ -1805,14 +1805,14 @@ wasm_runtime_wasi_nn_graph_registry_args_set_defaults(WASINNArguments *args) } bool -wasi_nn_graph_registry_set_args(WASINNArguments *registry, const char **encoding, - const char **target, uint32_t n_graphs, - const char **graph_paths) +wasi_nn_graph_registry_set_args(WASINNArguments *registry, + const char **encoding, const char **target, + uint32_t n_graphs, const char **graph_paths) { if (!registry || !encoding || !target || !graph_paths) { return false; } - + registry->n_graphs = n_graphs; registry->target = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); registry->encoding = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); @@ -1821,8 +1821,7 @@ wasi_nn_graph_registry_set_args(WASINNArguments *registry, const char **encoding memset(registry->encoding, 0, sizeof(uint32_t *) * n_graphs); memset(registry->graph_paths, 0, sizeof(uint32_t *) * n_graphs); - for (uint32_t i = 0; i < registry->n_graphs; i++) - { + for (uint32_t i = 0; i < registry->n_graphs; i++) { registry->graph_paths[i] = strdup(graph_paths[i]); registry->encoding[i] = strdup(encoding[i]); registry->target[i] = strdup(target[i]); @@ -1850,10 +1849,10 @@ wasi_nn_graph_registry_destroy(WASINNArguments *registry) for (uint32_t i = 0; i < registry->n_graphs; i++) if (registry->graph_paths[i]) { free(registry->graph_paths[i]); - if (registry->encoding[i]) - free(registry->encoding[i]); - if (registry->target[i]) - free(registry->target[i]); + if (registry->encoding[i]) + free(registry->encoding[i]); + if (registry->target[i]) + free(registry->target[i]); } free(registry); } diff --git a/core/iwasm/common/wasm_runtime_common.h b/core/iwasm/common/wasm_runtime_common.h index 1bf6701ea2..16bc1616ae 100644 --- a/core/iwasm/common/wasm_runtime_common.h +++ b/core/iwasm/common/wasm_runtime_common.h @@ -811,9 +811,9 @@ wasm_runtime_instantiation_args_set_wasi_nn_graph_registry( struct InstantiationArgs2 *p, WASINNArguments *registry); WASM_RUNTIME_API_EXTERN bool -wasi_nn_graph_registry_set_args(WASINNArguments *registry, const char **encoding, - const char **target, uint32_t n_graphs, - const char **graph_paths); +wasi_nn_graph_registry_set_args(WASINNArguments *registry, + const char **encoding, const char **target, + uint32_t n_graphs, const char **graph_paths); #endif /* See wasm_export.h for description */ diff --git a/core/iwasm/include/wasm_export.h b/core/iwasm/include/wasm_export.h index f6e6bc533b..d558ce8451 100644 --- a/core/iwasm/include/wasm_export.h +++ b/core/iwasm/include/wasm_export.h @@ -803,27 +803,27 @@ wasm_runtime_get_wasi_nn_global_ctx(const wasm_module_inst_t module_inst); WASM_RUNTIME_API_EXTERN uint32_t wasm_runtime_get_wasi_nn_global_ctx_ngraphs( - WASINNGlobalContext * wasi_nn_global_ctx); + WASINNGlobalContext *wasi_nn_global_ctx); WASM_RUNTIME_API_EXTERN char * wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i( - WASINNGlobalContext * wasi_nn_global_ctx, uint32_t idx); + WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); WASM_RUNTIME_API_EXTERN uint32_t wasm_runtime_get_wasi_nn_global_ctx_loaded_i( - WASINNGlobalContext * wasi_nn_global_ctx, uint32_t idx); + WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); WASM_RUNTIME_API_EXTERN uint32_t wasm_runtime_set_wasi_nn_global_ctx_loaded_i( - WASINNGlobalContext * wasi_nn_global_ctx, uint32_t idx, uint32_t value); + WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value); WASM_RUNTIME_API_EXTERN char * wasm_runtime_get_wasi_nn_global_ctx_encoding_i( - WASINNGlobalContext * wasi_nn_global_ctx, uint32_t idx); + WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); WASM_RUNTIME_API_EXTERN char * wasm_runtime_get_wasi_nn_global_ctx_target_i( - WASINNGlobalContext * wasi_nn_global_ctx, uint32_t idx); + WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); /** * Instantiate a WASM module, with specified instantiation arguments diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 931e93eca7..cf83d6d6e8 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -651,13 +651,15 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, continue; } is_loaded = wasm_runtime_get_wasi_nn_global_ctx_loaded_i( - wasi_nn_global_ctx, model_idx); + wasi_nn_global_ctx, model_idx); free(model_name); - graph_encoding encoding = str2encoding( - wasm_runtime_get_wasi_nn_global_ctx_encoding_i(wasi_nn_global_ctx, model_idx)); - execution_target target = str2target( - wasm_runtime_get_wasi_nn_global_ctx_target_i(wasi_nn_global_ctx, model_idx)); + graph_encoding encoding = + str2encoding(wasm_runtime_get_wasi_nn_global_ctx_encoding_i( + wasi_nn_global_ctx, model_idx)); + execution_target target = + str2target(wasm_runtime_get_wasi_nn_global_ctx_target_i( + wasi_nn_global_ctx, model_idx)); // res = ensure_backend(instance, autodetect, wasi_nn_ctx); res = ensure_backend(instance, encoding, wasi_nn_ctx); @@ -665,16 +667,17 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, goto fail; if (!is_loaded && (model_idx < MAX_GLOBAL_GRAPHS_PER_INST) - && (model_idx < global_n_graphs)) { - NN_DBG_PRINTF("Model is not yet loaded, will add to global context"); + && (model_idx < global_n_graphs)) { + NN_DBG_PRINTF( + "Model is not yet loaded, will add to global context"); call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res, - wasi_nn_ctx->backend_ctx, global_model_path_i, - strlen(global_model_path_i), encoding, target, g); + wasi_nn_ctx->backend_ctx, global_model_path_i, + strlen(global_model_path_i), encoding, target, g); if (res != success) goto fail; wasm_runtime_set_wasi_nn_global_ctx_loaded_i(wasi_nn_global_ctx, - model_idx, 1); + model_idx, 1); res = success; break; } @@ -691,8 +694,8 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, } else if (model_idx >= global_n_graphs) { NN_ERR_PRINTF("Model %s is not loaded, you should pass its path " - "through --wasi-nn-graph", - nul_terminated_name); + "through --wasi-nn-graph", + nul_terminated_name); res = not_found; } fail: diff --git a/product-mini/platforms/common/libc_wasi.c b/product-mini/platforms/common/libc_wasi.c index 44e8b2e0c4..84ac337dfa 100644 --- a/product-mini/platforms/common/libc_wasi.c +++ b/product-mini/platforms/common/libc_wasi.c @@ -194,7 +194,7 @@ wasi_nn_parse(char **argv, wasi_nn_parse_context_t *ctx) if (ctx->n_graphs >= sizeof(ctx->graph_paths) / sizeof(char *)) { printf("Only allow max graph number %d\n", - (int)(sizeof(ctx->graph_paths) / sizeof(char *))); + (int)(sizeof(ctx->graph_paths) / sizeof(char *))); return LIBC_WASI_PARSE_RESULT_BAD_PARAM; } @@ -221,7 +221,7 @@ wasi_nn_parse(char **argv, wasi_nn_parse_context_t *ctx) ctx->encoding[ctx->n_graphs] = strdup(tokens[0]); ctx->target[ctx->n_graphs] = strdup(tokens[1]); ctx->graph_paths[ctx->n_graphs++] = strdup(tokens[2]); - + fail: if (token) free(token); @@ -230,15 +230,16 @@ wasi_nn_parse(char **argv, wasi_nn_parse_context_t *ctx) } static void -wasi_nn_set_init_args(struct InstantiationArgs2 *args, struct WASINNArguments *nn_registry, wasi_nn_parse_context_t *ctx) +wasi_nn_set_init_args(struct InstantiationArgs2 *args, + struct WASINNArguments *nn_registry, + wasi_nn_parse_context_t *ctx) { - wasi_nn_graph_registry_set_args(nn_registry, ctx->encoding, ctx->target, ctx->n_graphs, - ctx->graph_paths); + wasi_nn_graph_registry_set_args(nn_registry, ctx->encoding, ctx->target, + ctx->n_graphs, ctx->graph_paths); wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(args, nn_registry); - for (uint32_t i = 0; i < ctx->n_graphs; i++) - { + for (uint32_t i = 0; i < ctx->n_graphs; i++) { if (ctx->graph_paths[i]) free(ctx->graph_paths[i]); if (ctx->encoding[i]) diff --git a/product-mini/platforms/posix/main.c b/product-mini/platforms/posix/main.c index 6c81a365cf..a39bd05976 100644 --- a/product-mini/platforms/posix/main.c +++ b/product-mini/platforms/posix/main.c @@ -845,7 +845,8 @@ main(int argc, char *argv[]) wasm_proposal_print_status(); return 0; } -#if WASM_ENABLE_LIBC_WASI != 0 && (WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0) +#if WASM_ENABLE_LIBC_WASI != 0 \ + && (WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0) else if (!strncmp(argv[0], "--wasi-nn-graph=", 16)) { libc_wasi_parse_result_t result = wasi_nn_parse(argv, &wasi_nn_parse_ctx); From c024c9415088540e0b97ff8a0fd16483c8a1bfe1 Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Mon, 26 Jan 2026 13:35:48 +0800 Subject: [PATCH 14/28] CICD: retrigger checks From 736f357a587b9230e2c158227b9328af0de54383 Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Tue, 27 Jan 2026 15:42:28 +0800 Subject: [PATCH 15/28] Add model_name option for --wasi-nn-graphs to make it more flexible and simpler --- core/iwasm/common/wasm_runtime_common.c | 31 +++++++++++++++++++--- core/iwasm/common/wasm_runtime_common.h | 12 +++++++-- core/iwasm/include/wasm_export.h | 4 +++ core/iwasm/interpreter/wasm_runtime.c | 4 +-- core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 31 ++++------------------ product-mini/platforms/common/libc_wasi.c | 26 +++++++++++------- product-mini/platforms/posix/main.c | 4 +-- 7 files changed, 67 insertions(+), 45 deletions(-) diff --git a/core/iwasm/common/wasm_runtime_common.c b/core/iwasm/common/wasm_runtime_common.c index 0e9621b21c..536261355f 100644 --- a/core/iwasm/common/wasm_runtime_common.c +++ b/core/iwasm/common/wasm_runtime_common.c @@ -1806,23 +1806,27 @@ wasm_runtime_wasi_nn_graph_registry_args_set_defaults(WASINNArguments *args) bool wasi_nn_graph_registry_set_args(WASINNArguments *registry, - const char **encoding, const char **target, - uint32_t n_graphs, const char **graph_paths) + const char **model_names, const char **encoding, + const char **target, uint32_t n_graphs, + const char **graph_paths) { - if (!registry || !encoding || !target || !graph_paths) { + if (!registry || !model_names || !encoding || !target || !graph_paths) { return false; } registry->n_graphs = n_graphs; registry->target = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); registry->encoding = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); + registry->model_names = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); registry->graph_paths = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); memset(registry->target, 0, sizeof(uint32_t *) * n_graphs); memset(registry->encoding, 0, sizeof(uint32_t *) * n_graphs); + memset(registry->model_names, 0, sizeof(uint32_t *) * n_graphs); memset(registry->graph_paths, 0, sizeof(uint32_t *) * n_graphs); for (uint32_t i = 0; i < registry->n_graphs; i++) { registry->graph_paths[i] = strdup(graph_paths[i]); + registry->model_names[i] = strdup(model_names[i]); registry->encoding[i] = strdup(encoding[i]); registry->target[i] = strdup(target[i]); } @@ -1849,6 +1853,8 @@ wasi_nn_graph_registry_destroy(WASINNArguments *registry) for (uint32_t i = 0; i < registry->n_graphs; i++) if (registry->graph_paths[i]) { free(registry->graph_paths[i]); + if (registry->model_names[i]) + free(registry->model_names[i]); if (registry->encoding[i]) free(registry->encoding[i]); if (registry->target[i]) @@ -8155,6 +8161,7 @@ wasm_runtime_check_and_update_last_used_shared_heap( #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 bool wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, + const char **model_names, const char **encoding, const char **target, const uint32_t n_graphs, char *graph_paths[], char *error_buf, @@ -8175,11 +8182,14 @@ wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, memset(ctx->target, 0, sizeof(uint32_t) * n_graphs); ctx->loaded = (uint32_t *)malloc(sizeof(uint32_t) * n_graphs); memset(ctx->loaded, 0, sizeof(uint32_t) * n_graphs); + ctx->model_names = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); + memset(ctx->model_names, 0, sizeof(uint32_t *) * n_graphs); ctx->graph_paths = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); memset(ctx->graph_paths, 0, sizeof(uint32_t *) * n_graphs); for (uint32_t i = 0; i < n_graphs; i++) { ctx->graph_paths[i] = strdup(graph_paths[i]); + ctx->model_names[i] = strdup(model_names[i]); ctx->target[i] = strdup(target[i]); ctx->encoding[i] = strdup(encoding[i]); } @@ -8201,14 +8211,17 @@ wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst) // All graphs will be unregistered in deinit() if (wasi_nn_global_ctx->graph_paths[i]) free(wasi_nn_global_ctx->graph_paths[i]); + if (wasi_nn_global_ctx->model_names[i]) + free(wasi_nn_global_ctx->model_names[i]); if (wasi_nn_global_ctx->encoding[i]) free(wasi_nn_global_ctx->encoding[i]); - if (wasi_nn_global_ctx->encoding[i]) + if (wasi_nn_global_ctx->target[i]) free(wasi_nn_global_ctx->target[i]); } free(wasi_nn_global_ctx->encoding); free(wasi_nn_global_ctx->target); free(wasi_nn_global_ctx->loaded); + free(wasi_nn_global_ctx->model_names); free(wasi_nn_global_ctx->graph_paths); if (wasi_nn_global_ctx) { @@ -8226,6 +8239,16 @@ wasm_runtime_get_wasi_nn_global_ctx_ngraphs( return -1; } +char * +wasm_runtime_get_wasi_nn_global_ctx_model_names_i( + WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx) +{ + if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) + return wasi_nn_global_ctx->model_names[idx]; + + return NULL; +} + char * wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i( WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx) diff --git a/core/iwasm/common/wasm_runtime_common.h b/core/iwasm/common/wasm_runtime_common.h index 16bc1616ae..869ac1eeb3 100644 --- a/core/iwasm/common/wasm_runtime_common.h +++ b/core/iwasm/common/wasm_runtime_common.h @@ -547,6 +547,7 @@ typedef struct WASMModuleInstMemConsumption { #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 typedef struct WASINNGlobalContext { + char **model_names; char **encoding; char **target; @@ -625,6 +626,7 @@ wasm_runtime_get_exec_env_tls(void); #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 typedef struct WASINNArguments { + char **model_names; char **encoding; char **target; @@ -812,8 +814,9 @@ wasm_runtime_instantiation_args_set_wasi_nn_graph_registry( WASM_RUNTIME_API_EXTERN bool wasi_nn_graph_registry_set_args(WASINNArguments *registry, - const char **encoding, const char **target, - uint32_t n_graphs, const char **graph_paths); + const char **model_names, const char **encoding, + const char **target, uint32_t n_graphs, + const char **graph_paths); #endif /* See wasm_export.h for description */ @@ -1471,6 +1474,7 @@ wasm_runtime_check_and_update_last_used_shared_heap( #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 WASM_RUNTIME_API_EXTERN bool wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, + const char **model_names, const char **encoding, const char **target, const uint32_t n_graphs, char *graph_paths[], char *error_buf, @@ -1494,6 +1498,10 @@ WASM_RUNTIME_API_EXTERN uint32_t wasm_runtime_get_wasi_nn_global_ctx_ngraphs( WASINNGlobalContext *wasi_nn_global_ctx); +WASM_RUNTIME_API_EXTERN char * +wasm_runtime_get_wasi_nn_global_ctx_model_names_i( + WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); + WASM_RUNTIME_API_EXTERN char * wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i( WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); diff --git a/core/iwasm/include/wasm_export.h b/core/iwasm/include/wasm_export.h index d558ce8451..0c36595285 100644 --- a/core/iwasm/include/wasm_export.h +++ b/core/iwasm/include/wasm_export.h @@ -805,6 +805,10 @@ WASM_RUNTIME_API_EXTERN uint32_t wasm_runtime_get_wasi_nn_global_ctx_ngraphs( WASINNGlobalContext *wasi_nn_global_ctx); +WASM_RUNTIME_API_EXTERN char * +wasm_runtime_get_wasi_nn_global_ctx_model_names_i( + WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); + WASM_RUNTIME_API_EXTERN char * wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i( WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); diff --git a/core/iwasm/interpreter/wasm_runtime.c b/core/iwasm/interpreter/wasm_runtime.c index 4dc6bd2537..cfe3840d9c 100644 --- a/core/iwasm/interpreter/wasm_runtime.c +++ b/core/iwasm/interpreter/wasm_runtime.c @@ -3305,8 +3305,8 @@ wasm_instantiate(WASMModule *module, WASMModuleInstance *parent, * load_by_name */ WASINNArguments *nn_registry = &args->nn_registry; if (!wasm_runtime_init_wasi_nn_global_ctx( - (WASMModuleInstanceCommon *)module_inst, nn_registry->encoding, - nn_registry->target, nn_registry->n_graphs, + (WASMModuleInstanceCommon *)module_inst, nn_registry->model_names, + nn_registry->encoding, nn_registry->target, nn_registry->n_graphs, nn_registry->graph_paths, error_buf, error_buf_size)) { goto fail; } diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index cf83d6d6e8..c4e43ede89 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -617,42 +617,21 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, bool is_loaded = false; uint32 model_idx = 0; - char *global_model_path_i; uint32_t global_n_graphs = wasm_runtime_get_wasi_nn_global_ctx_ngraphs(wasi_nn_global_ctx); - // Model got from user wasm app : modelA; modelB... - // Filelist got from user cmd opt: /path1/modelA.tflite; - // /path/modelB.tflite; ...... for (model_idx = 0; model_idx < global_n_graphs; model_idx++) { - // Extract filename from file path - global_model_path_i = wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i( + char *model_name = wasm_runtime_get_wasi_nn_global_ctx_model_names_i( wasi_nn_global_ctx, model_idx); - char *model_file_name; - const char *slash = strrchr(global_model_path_i, '/'); - if (slash != NULL) { - model_file_name = (char *)(slash + 1); - } - else - model_file_name = global_model_path_i; - - // Extract modelname from filename - char *model_name = NULL; - size_t model_name_len = 0; - char *dot = strrchr(model_file_name, '.'); - if (dot) { - model_name_len = dot - model_file_name; - model_name = malloc(model_name_len + 1); - strncpy(model_name, model_file_name, model_name_len); - model_name[model_name_len] = '\0'; - } if (model_name && strcmp(nul_terminated_name, model_name) != 0) { - free(model_name); continue; } + is_loaded = wasm_runtime_get_wasi_nn_global_ctx_loaded_i( wasi_nn_global_ctx, model_idx); - free(model_name); + char *global_model_path_i = + wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i( + wasi_nn_global_ctx, model_idx); graph_encoding encoding = str2encoding(wasm_runtime_get_wasi_nn_global_ctx_encoding_i( diff --git a/product-mini/platforms/common/libc_wasi.c b/product-mini/platforms/common/libc_wasi.c index 84ac337dfa..ab22c47e8d 100644 --- a/product-mini/platforms/common/libc_wasi.c +++ b/product-mini/platforms/common/libc_wasi.c @@ -22,6 +22,7 @@ typedef struct { } libc_wasi_parse_context_t; typedef struct { + const char *model_names[10]; const char *encoding[10]; const char *target[10]; const char *graph_paths[10]; @@ -208,19 +209,23 @@ wasi_nn_parse(char **argv, wasi_nn_parse_context_t *ctx) // --wasi-nn-graph=encoding2:target2:model_file_path2 ... token = strtok_r(argv[0] + 16, ":", &saveptr); while (token) { - tokens[token_count] = token; - token_count++; - token = strtok_r(NULL, ":", &saveptr); + if (strlen(token) > 0) { + tokens[token_count] = token; + token_count++; + token = strtok_r(NULL, ":", &saveptr); + } } - if (token_count != 3) { + if (token_count != 4) { ret = LIBC_WASI_PARSE_RESULT_NEED_HELP; + printf("4 arguments are needed for wasi-nn.\n"); goto fail; } - ctx->encoding[ctx->n_graphs] = strdup(tokens[0]); - ctx->target[ctx->n_graphs] = strdup(tokens[1]); - ctx->graph_paths[ctx->n_graphs++] = strdup(tokens[2]); + ctx->model_names[ctx->n_graphs] = strdup(tokens[0]); + ctx->encoding[ctx->n_graphs] = strdup(tokens[1]); + ctx->target[ctx->n_graphs] = strdup(tokens[2]); + ctx->graph_paths[ctx->n_graphs++] = strdup(tokens[3]); fail: if (token) @@ -234,12 +239,15 @@ wasi_nn_set_init_args(struct InstantiationArgs2 *args, struct WASINNArguments *nn_registry, wasi_nn_parse_context_t *ctx) { - wasi_nn_graph_registry_set_args(nn_registry, ctx->encoding, ctx->target, - ctx->n_graphs, ctx->graph_paths); + wasi_nn_graph_registry_set_args(nn_registry, ctx->model_names, + ctx->encoding, ctx->target, ctx->n_graphs, + ctx->graph_paths); wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(args, nn_registry); for (uint32_t i = 0; i < ctx->n_graphs; i++) { + if (ctx->model_names[i]) + free(ctx->model_names[i]); if (ctx->graph_paths[i]) free(ctx->graph_paths[i]); if (ctx->encoding[i]) diff --git a/product-mini/platforms/posix/main.c b/product-mini/platforms/posix/main.c index a39bd05976..d26565f689 100644 --- a/product-mini/platforms/posix/main.c +++ b/product-mini/platforms/posix/main.c @@ -123,8 +123,8 @@ print_help(void) printf(" --gen-prof-file= Generate LLVM PGO (Profile-Guided Optimization) profile file\n"); #endif #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - printf(" --wasi-nn-graph=encodingA:targetB:\n"); - printf(" --wasi-nn-graph=encodingA:targetB:...\n"); + printf(" --wasi-nn-graph=modelA_name:encodingA:targetA:\n"); + printf(" --wasi-nn-graph=modelB_name:encodingB:targetB:...\n"); printf(" Set encoding, target and model_paths for wasi-nn. target can be\n"); printf(" cpu|gpu|tpu, encoding can be tensorflowlite|openvino|llama|onnx|\n"); printf(" tensorflow|pytorch|ggml|autodetect\n"); From 878c8ec3a35c9ac38156b89dd29fb67ffd380e1c Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Tue, 27 Jan 2026 16:15:56 +0800 Subject: [PATCH 16/28] Add help info for wasi-nn test wasm --- core/iwasm/libraries/wasi-nn/test/test_tensorflow.c | 4 ++-- core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c b/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c index 47504c9dea..e665b4aa12 100644 --- a/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c +++ b/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c @@ -144,8 +144,8 @@ int main() { NN_INFO_PRINTF("Usage:\niwasm --native-lib=./libwasi_nn_tflite.so " - "--wasi-nn-graph=encodingA:targetA: " - "--wasi-nn-graph=encodingB:targetB:..." + "--wasi-nn-graph=modelA:encodingA:targetA: " + "--wasi-nn-graph=modelB:encodingB:targetB:..." " test_tensorflow.wasm"); NN_INFO_PRINTF("################### Testing sum..."); diff --git a/core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c b/core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c index 2f63e9b0d4..dff9dff0aa 100644 --- a/core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c +++ b/core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c @@ -39,8 +39,8 @@ int main() { NN_INFO_PRINTF("Usage:\niwasm --native-lib=./libwasi_nn_tflite.so " - "--wasi-nn-graph=encodingA:targetA: " - "--wasi-nn-graph=encodingB:targetB:..." + "--wasi-nn-graph=modelA:encodingA:targetA: " + "--wasi-nn-graph=modelB:encodingB:targetB:..." " test_tensorflow_quantized.wasm"); NN_INFO_PRINTF("################### Testing quantized model..."); From c43fbeb02100daad73ae4937249c656ae34ab66a Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Tue, 27 Jan 2026 16:58:29 +0800 Subject: [PATCH 17/28] CICD: retrigger checks From 6501e4b8f1b03ece9348363efa4027dbbfb0a769 Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Thu, 5 Feb 2026 14:29:24 +0800 Subject: [PATCH 18/28] Use a specific error type wasi_nn_error_t instead of macro "WASI_NN_ERROR_TYPE" --- core/iwasm/libraries/wasi-nn/test/utils.c | 30 ++++++++++++-------- core/iwasm/libraries/wasi-nn/test/utils.h | 34 +++++++++++++---------- 2 files changed, 37 insertions(+), 27 deletions(-) diff --git a/core/iwasm/libraries/wasi-nn/test/utils.c b/core/iwasm/libraries/wasi-nn/test/utils.c index b54d2f7479..64247e8d37 100644 --- a/core/iwasm/libraries/wasi-nn/test/utils.c +++ b/core/iwasm/libraries/wasi-nn/test/utils.c @@ -8,7 +8,13 @@ #include #include -WASI_NN_ERROR_TYPE +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +typedef wasi_ephemeral_nn_error wasi_nn_error_t; +#else +typedef wasi_nn_error wasi_nn_error_t; +#endif + +wasi_nn_error_t wasm_load(char *model_name, WASI_NN_NAME(graph) * g, WASI_NN_NAME(execution_target) target) { @@ -39,7 +45,7 @@ wasm_load(char *model_name, WASI_NN_NAME(graph) * g, arr.buf = buffer; arr.size = result; - WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load)( + wasi_nn_error_t res = WASI_NN_NAME(load)( &arr, result, WASI_NN_ENCODING_NAME(tensorflowlite), target, g); #else WASI_NN_NAME(graph_builder_array) arr; @@ -56,7 +62,7 @@ wasm_load(char *model_name, WASI_NN_NAME(graph) * g, arr.buf[0].size = result; arr.buf[0].buf = buffer; - WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load)( + wasi_nn_error_t res = WASI_NN_NAME(load)( &arr, WASI_NN_ENCODING_NAME(tensorflowlite), target, g); #endif @@ -66,22 +72,22 @@ wasm_load(char *model_name, WASI_NN_NAME(graph) * g, return res; } -WASI_NN_ERROR_TYPE +wasi_nn_error_t wasm_load_by_name(const char *model_name, WASI_NN_NAME(graph) * g) { - WASI_NN_ERROR_TYPE res = + wasi_nn_error_t res = WASI_NN_NAME(load_by_name)(model_name, strlen(model_name), g); return res; } -WASI_NN_ERROR_TYPE +wasi_nn_error_t wasm_init_execution_context(WASI_NN_NAME(graph) g, WASI_NN_NAME(graph_execution_context) * ctx) { return WASI_NN_NAME(init_execution_context)(g, ctx); } -WASI_NN_ERROR_TYPE +wasi_nn_error_t wasm_set_input(WASI_NN_NAME(graph_execution_context) ctx, float *input_tensor, uint32_t *dim) { @@ -113,19 +119,19 @@ wasm_set_input(WASI_NN_NAME(graph_execution_context) ctx, float *input_tensor, tensor.data = (uint8_t *)input_tensor; #endif - WASI_NN_ERROR_TYPE err = WASI_NN_NAME(set_input)(ctx, 0, &tensor); + wasi_nn_error_t err = WASI_NN_NAME(set_input)(ctx, 0, &tensor); free(dims.buf); return err; } -WASI_NN_ERROR_TYPE +wasi_nn_error_t wasm_compute(WASI_NN_NAME(graph_execution_context) ctx) { return WASI_NN_NAME(compute)(ctx); } -WASI_NN_ERROR_TYPE +wasi_nn_error_t wasm_get_output(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index, float *out_tensor, uint32_t *out_size) { @@ -144,8 +150,8 @@ run_inference(float *input, uint32_t *input_size, uint32_t *output_size, { WASI_NN_NAME(graph) graph; - WASI_NN_ERROR_TYPE res = wasm_load_by_name(model_name, &graph); - + wasi_nn_error_t res = wasm_load_by_name(model_name, &graph); + if (res == WASI_NN_ERROR_NAME(not_found)) { NN_INFO_PRINTF("Model %s is not loaded, you should pass its path " "through --wasi-nn-graph", diff --git a/core/iwasm/libraries/wasi-nn/test/utils.h b/core/iwasm/libraries/wasi-nn/test/utils.h index 8d2683fff4..ac3acd3478 100644 --- a/core/iwasm/libraries/wasi-nn/test/utils.h +++ b/core/iwasm/libraries/wasi-nn/test/utils.h @@ -26,32 +26,36 @@ typedef struct { uint32_t elements; } input_info; +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +typedef wasi_ephemeral_nn_error wasi_nn_error_t; +#else +typedef wasi_nn_error wasi_nn_error_t; +#endif + /* wasi-nn wrappers */ -WASI_NN_ERROR_TYPE -wasm_load(char *model_name, WASI_NN_NAME(graph) * g, - WASI_NN_NAME(execution_target) target); +wasi_nn_error_t +wasm_load(char *model_name, WASI_NN_NAME(graph) *g, WASI_NN_NAME(execution_target) target); -WASI_NN_ERROR_TYPE -wasm_init_execution_context(WASI_NN_NAME(graph) g, - WASI_NN_NAME(graph_execution_context) * ctx); +wasi_nn_error_t +wasm_init_execution_context(WASI_NN_NAME(graph) g, WASI_NN_NAME(graph_execution_context) *ctx); -WASI_NN_ERROR_TYPE -wasm_set_input(WASI_NN_NAME(graph_execution_context) ctx, float *input_tensor, - uint32_t *dim); +wasi_nn_error_t +wasm_set_input(WASI_NN_NAME(graph_execution_context) ctx, float *input_tensor, uint32_t *dim); -WASI_NN_ERROR_TYPE +wasi_nn_error_t wasm_compute(WASI_NN_NAME(graph_execution_context) ctx); -WASI_NN_ERROR_TYPE -wasm_get_output(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index, - float *out_tensor, uint32_t *out_size); +wasi_nn_error_t +wasm_get_output(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index, float *out_tensor, + uint32_t *out_size); /* Utils */ float * -run_inference(float *input, uint32_t *input_size, uint32_t *output_size, - char *model_name, uint32_t num_output_tensors); +run_inference(float *input, uint32_t *input_size, + uint32_t *output_size, char *model_name, + uint32_t num_output_tensors); input_info create_input(int *dims); From b12245198e3e9ce1a1c74f36a1fa59f03500c7cb Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Thu, 5 Feb 2026 17:15:00 +0800 Subject: [PATCH 19/28] Unify WASINNGlobalContext and WASINNArguments into a single structure named WASINNRegistry. Replace malloc/free to wasm_runtime_malloc/_free, replace strdup to bh_strdup. --- core/iwasm/common/wasm_native.c | 26 +-- core/iwasm/common/wasm_runtime_common.c | 200 ++++++++------------- core/iwasm/common/wasm_runtime_common.h | 76 +++----- core/iwasm/include/wasm_export.h | 35 ++-- core/iwasm/interpreter/wasm_runtime.c | 12 -- core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 30 ++-- product-mini/platforms/common/libc_wasi.c | 7 +- product-mini/platforms/posix/main.c | 17 +- 8 files changed, 158 insertions(+), 245 deletions(-) diff --git a/core/iwasm/common/wasm_native.c b/core/iwasm/common/wasm_native.c index b8430520af..2ba4a5778d 100644 --- a/core/iwasm/common/wasm_native.c +++ b/core/iwasm/common/wasm_native.c @@ -26,7 +26,7 @@ static void *g_wasi_context_key; #endif /* WASM_ENABLE_LIBC_WASI */ #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 -static void *g_wasi_nn_context_key; +static void *g_wasi_nn_registry_key; #endif uint32 @@ -478,17 +478,17 @@ wasi_context_dtor(WASMModuleInstanceCommon *inst, void *ctx) #endif /* end of WASM_ENABLE_LIBC_WASI */ #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 -WASINNGlobalContext * -wasm_runtime_get_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm) +WASINNRegistry * +wasm_runtime_get_wasi_nn_registry(WASMModuleInstanceCommon *module_inst_comm) { - return wasm_native_get_context(module_inst_comm, g_wasi_nn_context_key); + return wasm_native_get_context(module_inst_comm, g_wasi_nn_registry_key); } void -wasm_runtime_set_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm, - WASINNGlobalContext *wasi_nn_ctx) +wasm_runtime_set_wasi_nn_registry(WASMModuleInstanceCommon *module_inst_comm, + WASINNRegistry *wasi_nn_ctx) { - wasm_native_set_context(module_inst_comm, g_wasi_nn_context_key, + wasm_native_set_context(module_inst_comm, g_wasi_nn_registry_key, wasi_nn_ctx); } @@ -499,7 +499,7 @@ wasi_nn_context_dtor(WASMModuleInstanceCommon *inst, void *ctx) return; } - wasm_runtime_destroy_wasi_nn_global_ctx(inst); + wasm_runtime_wasi_nn_registry_destroy(ctx); } #endif @@ -612,9 +612,9 @@ wasm_native_init() #endif /* WASM_ENABLE_LIB_RATS */ #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - g_wasi_nn_context_key = + g_wasi_nn_registry_key = wasm_native_create_context_key(wasi_nn_context_dtor); - if (g_wasi_nn_context_key == NULL) { + if (g_wasi_nn_registry_key == NULL) { goto fail; } @@ -684,9 +684,9 @@ wasm_native_destroy() #endif #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - if (g_wasi_nn_context_key != NULL) { - wasm_native_destroy_context_key(g_wasi_nn_context_key); - g_wasi_nn_context_key = NULL; + if (g_wasi_nn_registry_key != NULL) { + wasm_native_destroy_context_key(g_wasi_nn_registry_key); + g_wasi_nn_registry_key = NULL; } wasi_nn_destroy(); #endif diff --git a/core/iwasm/common/wasm_runtime_common.c b/core/iwasm/common/wasm_runtime_common.c index 536261355f..2d011d79b2 100644 --- a/core/iwasm/common/wasm_runtime_common.c +++ b/core/iwasm/common/wasm_runtime_common.c @@ -1796,48 +1796,48 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool( #endif /* WASM_ENABLE_LIBC_WASI != 0 */ #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 -typedef struct WASINNArguments WASINNArguments; - void -wasm_runtime_wasi_nn_graph_registry_args_set_defaults(WASINNArguments *args) +wasm_runtime_wasi_nn_graph_registry_args_set_defaults(WASINNRegistry *args) { memset(args, 0, sizeof(*args)); } bool -wasi_nn_graph_registry_set_args(WASINNArguments *registry, - const char **model_names, const char **encoding, - const char **target, uint32_t n_graphs, - const char **graph_paths) +wasm_runtime_wasi_nn_registry_set_args(WASINNRegistry *registry, + const char **model_names, const char **encoding, + const char **target, uint32_t n_graphs, + const char **graph_paths) { if (!registry || !model_names || !encoding || !target || !graph_paths) { return false; } registry->n_graphs = n_graphs; - registry->target = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); - registry->encoding = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); - registry->model_names = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); - registry->graph_paths = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); + registry->target = (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); + registry->encoding = (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); + registry->loaded = (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); + registry->model_names = (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); + registry->graph_paths = (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); memset(registry->target, 0, sizeof(uint32_t *) * n_graphs); memset(registry->encoding, 0, sizeof(uint32_t *) * n_graphs); + memset(registry->loaded, 0, sizeof(uint32_t *) * n_graphs); memset(registry->model_names, 0, sizeof(uint32_t *) * n_graphs); memset(registry->graph_paths, 0, sizeof(uint32_t *) * n_graphs); for (uint32_t i = 0; i < registry->n_graphs; i++) { - registry->graph_paths[i] = strdup(graph_paths[i]); - registry->model_names[i] = strdup(model_names[i]); - registry->encoding[i] = strdup(encoding[i]); - registry->target[i] = strdup(target[i]); + registry->graph_paths[i] = bh_strdup(graph_paths[i]); + registry->model_names[i] = bh_strdup(model_names[i]); + registry->encoding[i] = bh_strdup(encoding[i]); + registry->target[i] = bh_strdup(target[i]); } return true; } int -wasi_nn_graph_registry_create(WASINNArguments **registryp) +wasm_runtime_wasi_nn_registry_create(WASINNRegistry **registryp) { - WASINNArguments *args = wasm_runtime_malloc(sizeof(*args)); + WASINNRegistry *args = wasm_runtime_malloc(sizeof(*args)); if (args == NULL) { return -1; } @@ -1847,28 +1847,45 @@ wasi_nn_graph_registry_create(WASINNArguments **registryp) } void -wasi_nn_graph_registry_destroy(WASINNArguments *registry) +wasm_runtime_wasi_nn_registry_destroy(WASINNRegistry *registry) { if (registry) { for (uint32_t i = 0; i < registry->n_graphs; i++) if (registry->graph_paths[i]) { - free(registry->graph_paths[i]); - if (registry->model_names[i]) - free(registry->model_names[i]); - if (registry->encoding[i]) - free(registry->encoding[i]); - if (registry->target[i]) - free(registry->target[i]); + wasm_runtime_free(registry->graph_paths[i]); + if (registry->model_names[i]) + wasm_runtime_free(registry->model_names[i]); + if (registry->encoding[i]) + wasm_runtime_free(registry->encoding[i]); + if (registry->target[i]) + wasm_runtime_free(registry->target[i]); } - free(registry); + if (registry->loaded) + wasm_runtime_free(registry->loaded); + wasm_runtime_free(registry); } } void -wasm_runtime_instantiation_args_set_wasi_nn_graph_registry( - struct InstantiationArgs2 *p, WASINNArguments *registry) +wasm_runtime_instantiation_args_set_wasi_nn_registry( + struct InstantiationArgs2 *p, WASINNRegistry *registry) { - p->nn_registry = *registry; + if (!registry) + return; + WASINNRegistry *wasi_nn_registry = &p->nn_registry; + + wasi_nn_registry->n_graphs = registry->n_graphs; + + if (registry->model_names) + wasi_nn_registry->model_names = bh_strdup(registry->model_names); + if (registry->encoding) + wasi_nn_registry->encoding = bh_strdup(registry->encoding); + if (registry->target) + wasi_nn_registry->target = bh_strdup(registry->target); + if (registry->loaded) + wasi_nn_registry->loaded = bh_strdup(registry->loaded); + if (registry->graph_paths) + wasi_nn_registry->graph_paths = bh_strdup(registry->graph_paths); } #endif @@ -8159,142 +8176,73 @@ wasm_runtime_check_and_update_last_used_shared_heap( #endif #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 -bool -wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, - const char **model_names, - const char **encoding, const char **target, - const uint32_t n_graphs, - char *graph_paths[], char *error_buf, - uint32_t error_buf_size) -{ - WASINNGlobalContext *ctx; - bool ret = false; - - ctx = runtime_malloc(sizeof(*ctx), module_inst, error_buf, error_buf_size); - if (!ctx) - return false; - - ctx->n_graphs = n_graphs; - - ctx->encoding = (uint32_t *)malloc(sizeof(uint32_t) * n_graphs); - memset(ctx->encoding, 0, sizeof(uint32_t) * n_graphs); - ctx->target = (uint32_t *)malloc(sizeof(uint32_t) * n_graphs); - memset(ctx->target, 0, sizeof(uint32_t) * n_graphs); - ctx->loaded = (uint32_t *)malloc(sizeof(uint32_t) * n_graphs); - memset(ctx->loaded, 0, sizeof(uint32_t) * n_graphs); - ctx->model_names = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); - memset(ctx->model_names, 0, sizeof(uint32_t *) * n_graphs); - ctx->graph_paths = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); - memset(ctx->graph_paths, 0, sizeof(uint32_t *) * n_graphs); - - for (uint32_t i = 0; i < n_graphs; i++) { - ctx->graph_paths[i] = strdup(graph_paths[i]); - ctx->model_names[i] = strdup(model_names[i]); - ctx->target[i] = strdup(target[i]); - ctx->encoding[i] = strdup(encoding[i]); - } - - wasm_runtime_set_wasi_nn_global_ctx(module_inst, ctx); - - ret = true; - - return ret; -} - -void -wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst) -{ - WASINNGlobalContext *wasi_nn_global_ctx = - wasm_runtime_get_wasi_nn_global_ctx(module_inst); - - for (uint32 i = 0; i < wasi_nn_global_ctx->n_graphs; i++) { - // All graphs will be unregistered in deinit() - if (wasi_nn_global_ctx->graph_paths[i]) - free(wasi_nn_global_ctx->graph_paths[i]); - if (wasi_nn_global_ctx->model_names[i]) - free(wasi_nn_global_ctx->model_names[i]); - if (wasi_nn_global_ctx->encoding[i]) - free(wasi_nn_global_ctx->encoding[i]); - if (wasi_nn_global_ctx->target[i]) - free(wasi_nn_global_ctx->target[i]); - } - free(wasi_nn_global_ctx->encoding); - free(wasi_nn_global_ctx->target); - free(wasi_nn_global_ctx->loaded); - free(wasi_nn_global_ctx->model_names); - free(wasi_nn_global_ctx->graph_paths); - - if (wasi_nn_global_ctx) { - wasm_runtime_free(wasi_nn_global_ctx); - } -} uint32_t -wasm_runtime_get_wasi_nn_global_ctx_ngraphs( - WASINNGlobalContext *wasi_nn_global_ctx) +wasm_runtime_get_wasi_nn_registry_ngraphs( + WASINNRegistry *wasi_nn_registry) { - if (wasi_nn_global_ctx) - return wasi_nn_global_ctx->n_graphs; + if (wasi_nn_registry) + return wasi_nn_registry->n_graphs; return -1; } char * -wasm_runtime_get_wasi_nn_global_ctx_model_names_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx) +wasm_runtime_get_wasi_nn_registry_model_names_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx) { - if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) - return wasi_nn_global_ctx->model_names[idx]; + if (wasi_nn_registry && (idx < wasi_nn_registry->n_graphs)) + return wasi_nn_registry->model_names[idx]; return NULL; } char * -wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx) +wasm_runtime_get_wasi_nn_registry_graph_paths_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx) { - if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) - return wasi_nn_global_ctx->graph_paths[idx]; + if (wasi_nn_registry && (idx < wasi_nn_registry->n_graphs)) + return wasi_nn_registry->graph_paths[idx]; return NULL; } uint32_t -wasm_runtime_get_wasi_nn_global_ctx_loaded_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx) +wasm_runtime_get_wasi_nn_registry_loaded_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx) { - if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) - return wasi_nn_global_ctx->loaded[idx]; + if (wasi_nn_registry && (idx < wasi_nn_registry->n_graphs)) + return wasi_nn_registry->loaded[idx]; return -1; } uint32_t -wasm_runtime_set_wasi_nn_global_ctx_loaded_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value) +wasm_runtime_set_wasi_nn_registry_loaded_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx, uint32_t value) { - if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) - wasi_nn_global_ctx->loaded[idx] = value; + if (wasi_nn_registry && (idx < wasi_nn_registry->n_graphs)) + wasi_nn_registry->loaded[idx] = value; return 0; } char * -wasm_runtime_get_wasi_nn_global_ctx_encoding_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx) +wasm_runtime_get_wasi_nn_registry_encoding_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx) { - if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) - return wasi_nn_global_ctx->encoding[idx]; + if (wasi_nn_registry && (idx < wasi_nn_registry->n_graphs)) + return wasi_nn_registry->encoding[idx]; return NULL; } char * -wasm_runtime_get_wasi_nn_global_ctx_target_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx) +wasm_runtime_get_wasi_nn_registry_target_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx) { - if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) - return wasi_nn_global_ctx->target[idx]; + if (wasi_nn_registry && (idx < wasi_nn_registry->n_graphs)) + return wasi_nn_registry->target[idx]; return NULL; } diff --git a/core/iwasm/common/wasm_runtime_common.h b/core/iwasm/common/wasm_runtime_common.h index 869ac1eeb3..0f7b1fbce3 100644 --- a/core/iwasm/common/wasm_runtime_common.h +++ b/core/iwasm/common/wasm_runtime_common.h @@ -546,7 +546,7 @@ typedef struct WASMModuleInstMemConsumption { } WASMModuleInstMemConsumption; #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 -typedef struct WASINNGlobalContext { +typedef struct WASINNRegistry { char **model_names; char **encoding; char **target; @@ -554,7 +554,7 @@ typedef struct WASINNGlobalContext { uint32_t n_graphs; uint32_t *loaded; char **graph_paths; -} WASINNGlobalContext; +} WASINNRegistry; #endif #if WASM_ENABLE_LIBC_WASI != 0 @@ -625,20 +625,11 @@ wasm_runtime_get_exec_env_tls(void); #endif #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 -typedef struct WASINNArguments { - char **model_names; - char **encoding; - char **target; - - char **graph_paths; - uint32_t n_graphs; -} WASINNArguments; - WASM_RUNTIME_API_EXTERN int -wasi_nn_graph_registry_create(WASINNArguments **registryp); +wasm_runtime_wasi_nn_registry_create(WASINNRegistry **registryp); WASM_RUNTIME_API_EXTERN void -wasi_nn_graph_registry_destroy(WASINNArguments *registry); +wasm_runtime_wasi_nn_registry_destroy(WASINNRegistry *registry); #endif struct InstantiationArgs2 { @@ -647,7 +638,7 @@ struct InstantiationArgs2 { WASIArguments wasi; #endif #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - WASINNArguments nn_registry; + WASINNRegistry nn_registry; #endif }; @@ -809,11 +800,11 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool( #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 WASM_RUNTIME_API_EXTERN void -wasm_runtime_instantiation_args_set_wasi_nn_graph_registry( - struct InstantiationArgs2 *p, WASINNArguments *registry); +wasm_runtime_instantiation_args_set_wasi_nn_registry( + struct InstantiationArgs2 *p, WASINNRegistry *registry); WASM_RUNTIME_API_EXTERN bool -wasi_nn_graph_registry_set_args(WASINNArguments *registry, +wasm_runtime_wasi_nn_registry_set_args(WASINNRegistry *registry, const char **model_names, const char **encoding, const char **target, uint32_t n_graphs, const char **graph_paths); @@ -1472,55 +1463,44 @@ wasm_runtime_check_and_update_last_used_shared_heap( #endif #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 -WASM_RUNTIME_API_EXTERN bool -wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, - const char **model_names, - const char **encoding, const char **target, - const uint32_t n_graphs, - char *graph_paths[], char *error_buf, - uint32_t error_buf_size); - -WASM_RUNTIME_API_EXTERN void -wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst); - WASM_RUNTIME_API_EXTERN void -wasm_runtime_set_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, - WASINNGlobalContext *wasi_ctx); +wasm_runtime_set_wasi_nn_registry(WASMModuleInstanceCommon *module_inst, + WASINNRegistry *wasi_ctx); -WASM_RUNTIME_API_EXTERN WASINNGlobalContext * -wasm_runtime_get_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm); +WASM_RUNTIME_API_EXTERN WASINNRegistry * +wasm_runtime_get_wasi_nn_registry(WASMModuleInstanceCommon *module_inst_comm); WASM_RUNTIME_API_EXTERN void -wasm_runtime_set_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm, - WASINNGlobalContext *wasi_nn_ctx); +wasm_runtime_set_wasi_nn_registry(WASMModuleInstanceCommon *module_inst_comm, + WASINNRegistry *wasi_nn_ctx); WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_get_wasi_nn_global_ctx_ngraphs( - WASINNGlobalContext *wasi_nn_global_ctx); +wasm_runtime_get_wasi_nn_registry_ngraphs( + WASINNRegistry *wasi_nn_registry); WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_global_ctx_model_names_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); +wasm_runtime_get_wasi_nn_registry_model_names_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx); WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); +wasm_runtime_get_wasi_nn_registry_graph_paths_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx); WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_get_wasi_nn_global_ctx_loaded_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); +wasm_runtime_get_wasi_nn_registry_loaded_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx); WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_set_wasi_nn_global_ctx_loaded_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value); +wasm_runtime_set_wasi_nn_registry_loaded_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx, uint32_t value); WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_global_ctx_encoding_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); +wasm_runtime_get_wasi_nn_registry_encoding_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx); WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_global_ctx_target_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); +wasm_runtime_get_wasi_nn_registry_target_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx); #endif #ifdef __cplusplus diff --git a/core/iwasm/include/wasm_export.h b/core/iwasm/include/wasm_export.h index 0c36595285..37cceaef1d 100644 --- a/core/iwasm/include/wasm_export.h +++ b/core/iwasm/include/wasm_export.h @@ -290,8 +290,7 @@ typedef struct InstantiationArgs { #endif /* INSTANTIATION_ARGS_OPTION_DEFINED */ struct InstantiationArgs2; -struct WASINNGlobalContext; -typedef struct WASINNGlobalContext WASINNGlobalContext; +typedef struct WASINNRegistry WASINNRegistry; #ifndef WASM_VALKIND_T_DEFINED #define WASM_VALKIND_T_DEFINED @@ -798,36 +797,36 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool( struct InstantiationArgs2 *p, const char *ns_lookup_pool[], uint32_t ns_lookup_pool_size); -WASM_RUNTIME_API_EXTERN WASINNGlobalContext * -wasm_runtime_get_wasi_nn_global_ctx(const wasm_module_inst_t module_inst); +WASM_RUNTIME_API_EXTERN WASINNRegistry * +wasm_runtime_get_wasi_nn_registry(const wasm_module_inst_t module_inst); WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_get_wasi_nn_global_ctx_ngraphs( - WASINNGlobalContext *wasi_nn_global_ctx); +wasm_runtime_get_wasi_nn_registry_ngraphs( + WASINNRegistry *wasi_nn_registry); WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_global_ctx_model_names_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); +wasm_runtime_get_wasi_nn_registry_model_names_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx); WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); +wasm_runtime_get_wasi_nn_registry_graph_paths_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx); WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_get_wasi_nn_global_ctx_loaded_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); +wasm_runtime_get_wasi_nn_registry_loaded_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx); WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_set_wasi_nn_global_ctx_loaded_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value); +wasm_runtime_set_wasi_nn_registry_loaded_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx, uint32_t value); WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_global_ctx_encoding_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); +wasm_runtime_get_wasi_nn_registry_encoding_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx); WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_global_ctx_target_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); +wasm_runtime_get_wasi_nn_registry_target_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx); /** * Instantiate a WASM module, with specified instantiation arguments diff --git a/core/iwasm/interpreter/wasm_runtime.c b/core/iwasm/interpreter/wasm_runtime.c index cfe3840d9c..a59bc9257b 100644 --- a/core/iwasm/interpreter/wasm_runtime.c +++ b/core/iwasm/interpreter/wasm_runtime.c @@ -3300,18 +3300,6 @@ wasm_instantiate(WASMModule *module, WASMModuleInstance *parent, } #endif -#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - /* Store graphs' path into ctx. Graphs will be loaded until user app calls - * load_by_name */ - WASINNArguments *nn_registry = &args->nn_registry; - if (!wasm_runtime_init_wasi_nn_global_ctx( - (WASMModuleInstanceCommon *)module_inst, nn_registry->model_names, - nn_registry->encoding, nn_registry->target, nn_registry->n_graphs, - nn_registry->graph_paths, error_buf, error_buf_size)) { - goto fail; - } -#endif - #if WASM_ENABLE_DEBUG_INTERP != 0 if (!is_sub_inst) { /* Add module instance into module's instance list */ diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index c4e43ede89..33ef0e090e 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -607,9 +607,9 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, goto fail; } - WASINNGlobalContext *wasi_nn_global_ctx = - wasm_runtime_get_wasi_nn_global_ctx(instance); - if (!wasi_nn_global_ctx) { + WASINNRegistry *wasi_nn_registry = + wasm_runtime_get_wasi_nn_registry(instance); + if (!wasi_nn_registry) { NN_ERR_PRINTF("global context is invalid"); res = not_found; goto fail; @@ -618,27 +618,27 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, bool is_loaded = false; uint32 model_idx = 0; uint32_t global_n_graphs = - wasm_runtime_get_wasi_nn_global_ctx_ngraphs(wasi_nn_global_ctx); + wasm_runtime_get_wasi_nn_registry_ngraphs(wasi_nn_registry); for (model_idx = 0; model_idx < global_n_graphs; model_idx++) { - char *model_name = wasm_runtime_get_wasi_nn_global_ctx_model_names_i( - wasi_nn_global_ctx, model_idx); + char *model_name = wasm_runtime_get_wasi_nn_registry_model_names_i( + wasi_nn_registry, model_idx); if (model_name && strcmp(nul_terminated_name, model_name) != 0) { continue; } - is_loaded = wasm_runtime_get_wasi_nn_global_ctx_loaded_i( - wasi_nn_global_ctx, model_idx); + is_loaded = wasm_runtime_get_wasi_nn_registry_loaded_i( + wasi_nn_registry, model_idx); char *global_model_path_i = - wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i( - wasi_nn_global_ctx, model_idx); + wasm_runtime_get_wasi_nn_registry_graph_paths_i( + wasi_nn_registry, model_idx); graph_encoding encoding = - str2encoding(wasm_runtime_get_wasi_nn_global_ctx_encoding_i( - wasi_nn_global_ctx, model_idx)); + str2encoding(wasm_runtime_get_wasi_nn_registry_encoding_i( + wasi_nn_registry, model_idx)); execution_target target = - str2target(wasm_runtime_get_wasi_nn_global_ctx_target_i( - wasi_nn_global_ctx, model_idx)); + str2target(wasm_runtime_get_wasi_nn_registry_target_i( + wasi_nn_registry, model_idx)); // res = ensure_backend(instance, autodetect, wasi_nn_ctx); res = ensure_backend(instance, encoding, wasi_nn_ctx); @@ -655,7 +655,7 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, if (res != success) goto fail; - wasm_runtime_set_wasi_nn_global_ctx_loaded_i(wasi_nn_global_ctx, + wasm_runtime_set_wasi_nn_registry_loaded_i(wasi_nn_registry, model_idx, 1); res = success; break; diff --git a/product-mini/platforms/common/libc_wasi.c b/product-mini/platforms/common/libc_wasi.c index ab22c47e8d..bbe475119d 100644 --- a/product-mini/platforms/common/libc_wasi.c +++ b/product-mini/platforms/common/libc_wasi.c @@ -236,14 +236,13 @@ wasi_nn_parse(char **argv, wasi_nn_parse_context_t *ctx) static void wasi_nn_set_init_args(struct InstantiationArgs2 *args, - struct WASINNArguments *nn_registry, + struct WASINNRegistry *nn_registry, wasi_nn_parse_context_t *ctx) { - wasi_nn_graph_registry_set_args(nn_registry, ctx->model_names, + wasm_runtime_wasi_nn_registry_set_args(nn_registry, ctx->model_names, ctx->encoding, ctx->target, ctx->n_graphs, ctx->graph_paths); - wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(args, - nn_registry); + wasm_runtime_instantiation_args_set_wasi_nn_registry(args, nn_registry); for (uint32_t i = 0; i < ctx->n_graphs; i++) { if (ctx->model_names[i]) diff --git a/product-mini/platforms/posix/main.c b/product-mini/platforms/posix/main.c index d26565f689..58ec567c2b 100644 --- a/product-mini/platforms/posix/main.c +++ b/product-mini/platforms/posix/main.c @@ -650,7 +650,7 @@ main(int argc, char *argv[]) #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 wasi_nn_parse_context_t wasi_nn_parse_ctx; - struct WASINNArguments *nn_registry; + struct WASINNRegistry *nn_registry; memset(&wasi_nn_parse_ctx, 0, sizeof(wasi_nn_parse_ctx)); #endif @@ -1009,19 +1009,21 @@ main(int argc, char *argv[]) libc_wasi_set_init_args(inst_args, argc, argv, &wasi_parse_ctx); #endif -#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - wasi_nn_graph_registry_create(&nn_registry); - wasi_nn_set_init_args(inst_args, nn_registry, &wasi_nn_parse_ctx); -#endif /* instantiate the module */ wasm_module_inst = wasm_runtime_instantiate_ex2( wasm_module, inst_args, error_buf, sizeof(error_buf)); - wasm_runtime_instantiation_args_destroy(inst_args); if (!wasm_module_inst) { printf("%s\n", error_buf); goto fail3; } +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + wasm_runtime_wasi_nn_registry_create(&nn_registry); + wasi_nn_set_init_args(inst_args, nn_registry, &wasi_nn_parse_ctx); + wasm_runtime_set_wasi_nn_registry(wasm_module_inst, nn_registry); +#endif + wasm_runtime_instantiation_args_destroy(inst_args); + #if WASM_CONFIGURABLE_BOUNDS_CHECKS != 0 if (disable_bounds_checks) { wasm_runtime_set_bounds_checks(wasm_module_inst, false); @@ -1131,9 +1133,6 @@ main(int argc, char *argv[]) #endif #if WASM_ENABLE_DEBUG_INTERP != 0 fail4: -#endif -#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - wasi_nn_graph_registry_destroy(nn_registry); #endif /* destroy the module instance */ wasm_runtime_deinstantiate(wasm_module_inst); From a85079a2aa19126b608939085b49273c87b7fd75 Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Thu, 5 Feb 2026 17:18:40 +0800 Subject: [PATCH 20/28] Check integer overflow for wasm_runtime_wasi_nn_registry_set_args --- core/iwasm/common/wasm_runtime_common.c | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/core/iwasm/common/wasm_runtime_common.c b/core/iwasm/common/wasm_runtime_common.c index 2d011d79b2..fcfd2c522b 100644 --- a/core/iwasm/common/wasm_runtime_common.c +++ b/core/iwasm/common/wasm_runtime_common.c @@ -1812,6 +1812,11 @@ wasm_runtime_wasi_nn_registry_set_args(WASINNRegistry *registry, return false; } + if ((sizeof(uint32_t *) * n_graphs) >= UINT32_MAX) { + LOG_ERROR("Invalid size for wasm_runtime_malloc."); + return NULL; + } + registry->n_graphs = n_graphs; registry->target = (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); registry->encoding = (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); From 7b62ec8252e9d6a5a19dc27dfe264a3d115b5aa2 Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Thu, 5 Feb 2026 17:27:50 +0800 Subject: [PATCH 21/28] Remove strdup in wasi_nn_parse --- product-mini/platforms/common/libc_wasi.c | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/product-mini/platforms/common/libc_wasi.c b/product-mini/platforms/common/libc_wasi.c index bbe475119d..3559bd1f63 100644 --- a/product-mini/platforms/common/libc_wasi.c +++ b/product-mini/platforms/common/libc_wasi.c @@ -222,14 +222,12 @@ wasi_nn_parse(char **argv, wasi_nn_parse_context_t *ctx) goto fail; } - ctx->model_names[ctx->n_graphs] = strdup(tokens[0]); - ctx->encoding[ctx->n_graphs] = strdup(tokens[1]); - ctx->target[ctx->n_graphs] = strdup(tokens[2]); - ctx->graph_paths[ctx->n_graphs++] = strdup(tokens[3]); + ctx->model_names[ctx->n_graphs] = tokens[0]; + ctx->encoding[ctx->n_graphs] = tokens[1]; + ctx->target[ctx->n_graphs] = tokens[2]; + ctx->graph_paths[ctx->n_graphs++] = tokens[3]; fail: - if (token) - free(token); return ret; } @@ -243,16 +241,5 @@ wasi_nn_set_init_args(struct InstantiationArgs2 *args, ctx->encoding, ctx->target, ctx->n_graphs, ctx->graph_paths); wasm_runtime_instantiation_args_set_wasi_nn_registry(args, nn_registry); - - for (uint32_t i = 0; i < ctx->n_graphs; i++) { - if (ctx->model_names[i]) - free(ctx->model_names[i]); - if (ctx->graph_paths[i]) - free(ctx->graph_paths[i]); - if (ctx->encoding[i]) - free(ctx->encoding[i]); - if (ctx->target[i]) - free(ctx->target[i]); - } } #endif \ No newline at end of file From 01641bfc70a0e9753aa94170c18e82ff4def5ad5 Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Thu, 5 Feb 2026 17:37:04 +0800 Subject: [PATCH 22/28] Remove "encoding" for load_by_name --- core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 2 +- core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h | 2 +- core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c | 2 +- core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp | 2 +- core/iwasm/libraries/wasi-nn/src/wasi_nn_openvino.c | 2 +- core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h | 3 +-- core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp | 2 +- 7 files changed, 7 insertions(+), 8 deletions(-) diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 33ef0e090e..b2cc4d1119 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -651,7 +651,7 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, "Model is not yet loaded, will add to global context"); call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res, wasi_nn_ctx->backend_ctx, global_model_path_i, - strlen(global_model_path_i), encoding, target, g); + strlen(global_model_path_i), target, g); if (res != success) goto fail; diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h b/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h index 344e66550b..6a75c33090 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h @@ -18,7 +18,7 @@ load(void *ctx, graph_builder_array *builder, graph_encoding encoding, __attribute__((visibility("default"))) wasi_nn_error load_by_name(void *tflite_ctx, const char *name, uint32_t namelen, - graph_encoding encoding, execution_target target, graph *g); + execution_target target, graph *g); __attribute__((visibility("default"))) wasi_nn_error load_by_name_with_config(void *ctx, const char *name, uint32_t namelen, diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c index 7042affa70..cef1479af6 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c @@ -339,7 +339,7 @@ __load_by_name_with_configuration(void *ctx, const char *filename, graph *g) __attribute__((visibility("default"))) wasi_nn_error load_by_name(void *ctx, const char *filename, uint32_t filename_len, - graph_encoding encoding, execution_target target, graph *g) + execution_target target, graph *g) { struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx; diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp index 947fa558e3..352294d6dd 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp @@ -335,7 +335,7 @@ load(void *onnx_ctx, graph_builder_array *builder, graph_encoding encoding, __attribute__((visibility("default"))) wasi_nn_error load_by_name(void *onnx_ctx, const char *name, uint32_t filename_len, - graph_encoding encoding, execution_target target, graph *g) + execution_target target, graph *g) { if (!onnx_ctx) { return runtime_error; diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_openvino.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn_openvino.c index 4739953605..084a97b0da 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_openvino.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_openvino.c @@ -307,7 +307,7 @@ load(void *ctx, graph_builder_array *builder, graph_encoding encoding, __attribute__((visibility("default"))) wasi_nn_error load_by_name(void *ctx, const char *filename, uint32_t filename_len, - graph_encoding encoding, execution_target target, graph *g) + execution_target target, graph *g) { OpenVINOContext *ov_ctx = (OpenVINOContext *)ctx; struct OpenVINOGraph *graph; diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h b/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h index 7ea76eddb1..8561d2f067 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h @@ -22,8 +22,7 @@ typedef struct { typedef wasi_nn_error (*LOAD)(void *, graph_builder_array *, graph_encoding, execution_target, graph *); typedef wasi_nn_error (*LOAD_BY_NAME)(void *, const char *, uint32_t, - graph_encoding, execution_target, - graph *); + execution_target, graph *); typedef wasi_nn_error (*LOAD_BY_NAME_WITH_CONFIG)(void *, const char *, uint32_t, void *, uint32_t, graph *); diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp index 2b4832dc41..e413a7e4f7 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp @@ -165,7 +165,7 @@ load(void *tflite_ctx, graph_builder_array *builder, graph_encoding encoding, __attribute__((visibility("default"))) wasi_nn_error load_by_name(void *tflite_ctx, const char *filename, uint32_t filename_len, - graph_encoding encoding, execution_target target, graph *g) + execution_target target, graph *g) { TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx; From eb55c2814ea295087a38331dd2f730c9c4d83f43 Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Fri, 6 Feb 2026 10:16:29 +0800 Subject: [PATCH 23/28] Remove some internal functions used by wasi-nn from wasm_export.h --- core/iwasm/common/wasm_runtime_common.c | 74 ---------------------- core/iwasm/common/wasm_runtime_common.h | 28 -------- core/iwasm/include/wasm_export.h | 34 +++------- core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 22 +++---- 4 files changed, 16 insertions(+), 142 deletions(-) diff --git a/core/iwasm/common/wasm_runtime_common.c b/core/iwasm/common/wasm_runtime_common.c index fcfd2c522b..d99467e7f3 100644 --- a/core/iwasm/common/wasm_runtime_common.c +++ b/core/iwasm/common/wasm_runtime_common.c @@ -8179,77 +8179,3 @@ wasm_runtime_check_and_update_last_used_shared_heap( return false; } #endif - -#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - -uint32_t -wasm_runtime_get_wasi_nn_registry_ngraphs( - WASINNRegistry *wasi_nn_registry) -{ - if (wasi_nn_registry) - return wasi_nn_registry->n_graphs; - - return -1; -} - -char * -wasm_runtime_get_wasi_nn_registry_model_names_i( - WASINNRegistry *wasi_nn_registry, uint32_t idx) -{ - if (wasi_nn_registry && (idx < wasi_nn_registry->n_graphs)) - return wasi_nn_registry->model_names[idx]; - - return NULL; -} - -char * -wasm_runtime_get_wasi_nn_registry_graph_paths_i( - WASINNRegistry *wasi_nn_registry, uint32_t idx) -{ - if (wasi_nn_registry && (idx < wasi_nn_registry->n_graphs)) - return wasi_nn_registry->graph_paths[idx]; - - return NULL; -} - -uint32_t -wasm_runtime_get_wasi_nn_registry_loaded_i( - WASINNRegistry *wasi_nn_registry, uint32_t idx) -{ - if (wasi_nn_registry && (idx < wasi_nn_registry->n_graphs)) - return wasi_nn_registry->loaded[idx]; - - return -1; -} - -uint32_t -wasm_runtime_set_wasi_nn_registry_loaded_i( - WASINNRegistry *wasi_nn_registry, uint32_t idx, uint32_t value) -{ - if (wasi_nn_registry && (idx < wasi_nn_registry->n_graphs)) - wasi_nn_registry->loaded[idx] = value; - - return 0; -} - -char * -wasm_runtime_get_wasi_nn_registry_encoding_i( - WASINNRegistry *wasi_nn_registry, uint32_t idx) -{ - if (wasi_nn_registry && (idx < wasi_nn_registry->n_graphs)) - return wasi_nn_registry->encoding[idx]; - - return NULL; -} - -char * -wasm_runtime_get_wasi_nn_registry_target_i( - WASINNRegistry *wasi_nn_registry, uint32_t idx) -{ - if (wasi_nn_registry && (idx < wasi_nn_registry->n_graphs)) - return wasi_nn_registry->target[idx]; - - return NULL; -} - -#endif diff --git a/core/iwasm/common/wasm_runtime_common.h b/core/iwasm/common/wasm_runtime_common.h index 0f7b1fbce3..80b0ea05fe 100644 --- a/core/iwasm/common/wasm_runtime_common.h +++ b/core/iwasm/common/wasm_runtime_common.h @@ -1473,34 +1473,6 @@ wasm_runtime_get_wasi_nn_registry(WASMModuleInstanceCommon *module_inst_comm); WASM_RUNTIME_API_EXTERN void wasm_runtime_set_wasi_nn_registry(WASMModuleInstanceCommon *module_inst_comm, WASINNRegistry *wasi_nn_ctx); - -WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_get_wasi_nn_registry_ngraphs( - WASINNRegistry *wasi_nn_registry); - -WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_registry_model_names_i( - WASINNRegistry *wasi_nn_registry, uint32_t idx); - -WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_registry_graph_paths_i( - WASINNRegistry *wasi_nn_registry, uint32_t idx); - -WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_get_wasi_nn_registry_loaded_i( - WASINNRegistry *wasi_nn_registry, uint32_t idx); - -WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_set_wasi_nn_registry_loaded_i( - WASINNRegistry *wasi_nn_registry, uint32_t idx, uint32_t value); - -WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_registry_encoding_i( - WASINNRegistry *wasi_nn_registry, uint32_t idx); - -WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_registry_target_i( - WASINNRegistry *wasi_nn_registry, uint32_t idx); #endif #ifdef __cplusplus diff --git a/core/iwasm/include/wasm_export.h b/core/iwasm/include/wasm_export.h index 37cceaef1d..8f4a61cdcb 100644 --- a/core/iwasm/include/wasm_export.h +++ b/core/iwasm/include/wasm_export.h @@ -797,36 +797,18 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool( struct InstantiationArgs2 *p, const char *ns_lookup_pool[], uint32_t ns_lookup_pool_size); -WASM_RUNTIME_API_EXTERN WASINNRegistry * +WASM_RUNTIME_API_EXTERN struct WASINNRegistry * wasm_runtime_get_wasi_nn_registry(const wasm_module_inst_t module_inst); -WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_get_wasi_nn_registry_ngraphs( - WASINNRegistry *wasi_nn_registry); - -WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_registry_model_names_i( - WASINNRegistry *wasi_nn_registry, uint32_t idx); - -WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_registry_graph_paths_i( - WASINNRegistry *wasi_nn_registry, uint32_t idx); - -WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_get_wasi_nn_registry_loaded_i( - WASINNRegistry *wasi_nn_registry, uint32_t idx); - -WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_set_wasi_nn_registry_loaded_i( - WASINNRegistry *wasi_nn_registry, uint32_t idx, uint32_t value); +WASM_RUNTIME_API_EXTERN void +wasm_runtime_set_wasi_nn_registry(wasm_module_inst_t module_inst, + struct WASINNRegistry *wasi_ctx); -WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_registry_encoding_i( - WASINNRegistry *wasi_nn_registry, uint32_t idx); +WASM_RUNTIME_API_EXTERN int +wasm_runtime_wasi_nn_registry_create(struct WASINNRegistry **registryp); -WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_registry_target_i( - WASINNRegistry *wasi_nn_registry, uint32_t idx); +WASM_RUNTIME_API_EXTERN void +wasm_runtime_wasi_nn_registry_destroy(struct WASINNRegistry *registry); /** * Instantiate a WASM module, with specified instantiation arguments diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index b2cc4d1119..1e63813cf4 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -19,6 +19,7 @@ #include "bh_platform.h" #include "wasi_nn_types.h" #include "wasm_export.h" +#include "wasm_runtime_common.h" #if WASM_ENABLE_WASI_EPHEMERAL_NN == 0 #warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It is deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.) @@ -618,27 +619,21 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, bool is_loaded = false; uint32 model_idx = 0; uint32_t global_n_graphs = - wasm_runtime_get_wasi_nn_registry_ngraphs(wasi_nn_registry); + wasi_nn_registry->n_graphs; for (model_idx = 0; model_idx < global_n_graphs; model_idx++) { - char *model_name = wasm_runtime_get_wasi_nn_registry_model_names_i( - wasi_nn_registry, model_idx); + char *model_name = wasi_nn_registry->model_names[model_idx]; if (model_name && strcmp(nul_terminated_name, model_name) != 0) { continue; } - is_loaded = wasm_runtime_get_wasi_nn_registry_loaded_i( - wasi_nn_registry, model_idx); - char *global_model_path_i = - wasm_runtime_get_wasi_nn_registry_graph_paths_i( - wasi_nn_registry, model_idx); + is_loaded = wasi_nn_registry->loaded[model_idx]; + char *global_model_path_i = wasi_nn_registry->graph_paths[model_idx]; graph_encoding encoding = - str2encoding(wasm_runtime_get_wasi_nn_registry_encoding_i( - wasi_nn_registry, model_idx)); + str2encoding(wasi_nn_registry->encoding[model_idx]); execution_target target = - str2target(wasm_runtime_get_wasi_nn_registry_target_i( - wasi_nn_registry, model_idx)); + str2target(wasi_nn_registry->target[model_idx]); // res = ensure_backend(instance, autodetect, wasi_nn_ctx); res = ensure_backend(instance, encoding, wasi_nn_ctx); @@ -655,8 +650,7 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, if (res != success) goto fail; - wasm_runtime_set_wasi_nn_registry_loaded_i(wasi_nn_registry, - model_idx, 1); + wasi_nn_registry->loaded[model_idx] = 1; res = success; break; } From ccee1941c29d811f04ae4b29308a71da2aefce80 Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Fri, 6 Feb 2026 10:26:01 +0800 Subject: [PATCH 24/28] Removed unnecessary conditional checks. --- core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 1e63813cf4..79465a8daa 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -623,7 +623,7 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, for (model_idx = 0; model_idx < global_n_graphs; model_idx++) { char *model_name = wasi_nn_registry->model_names[model_idx]; - if (model_name && strcmp(nul_terminated_name, model_name) != 0) { + if (strcmp(nul_terminated_name, model_name) != 0) { continue; } From 5357fb5f21dc957f43c4f44aef48b96a2a37a5f9 Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Fri, 6 Feb 2026 11:13:12 +0800 Subject: [PATCH 25/28] Move the error checks to an earlier stage. --- core/iwasm/common/wasm_runtime_common.c | 37 +++++----- core/iwasm/common/wasm_runtime_common.h | 10 +-- core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 46 +----------- product-mini/platforms/common/libc_wasi.c | 85 +++++++++++++++++++--- 4 files changed, 100 insertions(+), 78 deletions(-) diff --git a/core/iwasm/common/wasm_runtime_common.c b/core/iwasm/common/wasm_runtime_common.c index d99467e7f3..ccd569cdc3 100644 --- a/core/iwasm/common/wasm_runtime_common.c +++ b/core/iwasm/common/wasm_runtime_common.c @@ -1804,8 +1804,8 @@ wasm_runtime_wasi_nn_graph_registry_args_set_defaults(WASINNRegistry *args) bool wasm_runtime_wasi_nn_registry_set_args(WASINNRegistry *registry, - const char **model_names, const char **encoding, - const char **target, uint32_t n_graphs, + const char **model_names, const uint32_t **encoding, + const uint32_t **target, uint32_t n_graphs, const char **graph_paths) { if (!registry || !model_names || !encoding || !target || !graph_paths) { @@ -1832,8 +1832,8 @@ wasm_runtime_wasi_nn_registry_set_args(WASINNRegistry *registry, for (uint32_t i = 0; i < registry->n_graphs; i++) { registry->graph_paths[i] = bh_strdup(graph_paths[i]); registry->model_names[i] = bh_strdup(model_names[i]); - registry->encoding[i] = bh_strdup(encoding[i]); - registry->target[i] = bh_strdup(target[i]); + registry->encoding[i] = encoding[i]; + registry->target[i] = target[i]; } return true; @@ -1860,13 +1860,13 @@ wasm_runtime_wasi_nn_registry_destroy(WASINNRegistry *registry) wasm_runtime_free(registry->graph_paths[i]); if (registry->model_names[i]) wasm_runtime_free(registry->model_names[i]); - if (registry->encoding[i]) - wasm_runtime_free(registry->encoding[i]); - if (registry->target[i]) - wasm_runtime_free(registry->target[i]); } - if (registry->loaded) - wasm_runtime_free(registry->loaded); + if (registry->encoding) + wasm_runtime_free(registry->encoding); + if (registry->target) + wasm_runtime_free(registry->target); + if (registry->loaded) + wasm_runtime_free(registry->loaded); wasm_runtime_free(registry); } } @@ -1881,16 +1881,13 @@ wasm_runtime_instantiation_args_set_wasi_nn_registry( wasi_nn_registry->n_graphs = registry->n_graphs; - if (registry->model_names) - wasi_nn_registry->model_names = bh_strdup(registry->model_names); - if (registry->encoding) - wasi_nn_registry->encoding = bh_strdup(registry->encoding); - if (registry->target) - wasi_nn_registry->target = bh_strdup(registry->target); - if (registry->loaded) - wasi_nn_registry->loaded = bh_strdup(registry->loaded); - if (registry->graph_paths) - wasi_nn_registry->graph_paths = bh_strdup(registry->graph_paths); + for (uint32_t i = 0; i < registry->n_graphs; i++) { + registry->graph_paths[i] = bh_strdup(registry->graph_paths[i]); + registry->model_names[i] = bh_strdup(registry->model_names[i]); + registry->encoding[i] = registry->encoding[i]; + registry->target[i] = registry->target[i]; + wasi_nn_registry->loaded = registry->loaded; + } } #endif diff --git a/core/iwasm/common/wasm_runtime_common.h b/core/iwasm/common/wasm_runtime_common.h index 80b0ea05fe..23aa451266 100644 --- a/core/iwasm/common/wasm_runtime_common.h +++ b/core/iwasm/common/wasm_runtime_common.h @@ -548,11 +548,11 @@ typedef struct WASMModuleInstMemConsumption { #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 typedef struct WASINNRegistry { char **model_names; - char **encoding; - char **target; + uint32_t **encoding; + uint32_t **target; uint32_t n_graphs; - uint32_t *loaded; + uint32_t **loaded; char **graph_paths; } WASINNRegistry; #endif @@ -805,8 +805,8 @@ wasm_runtime_instantiation_args_set_wasi_nn_registry( WASM_RUNTIME_API_EXTERN bool wasm_runtime_wasi_nn_registry_set_args(WASINNRegistry *registry, - const char **model_names, const char **encoding, - const char **target, uint32_t n_graphs, + const char **model_names, const uint32_t **encoding, + const uint32_t **target, uint32_t n_graphs, const char **graph_paths); #endif diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 79465a8daa..d0bfe0f2e4 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -211,46 +211,6 @@ wasi_nn_destroy() * - model file format * - on device ML framework */ -static graph_encoding -str2encoding(char *str_encoding) -{ - if (!str_encoding) { - NN_ERR_PRINTF("Got empty string encoding"); - return -1; - } - - if (!strcmp(str_encoding, "openvino")) - return openvino; - else if (!strcmp(str_encoding, "tensorflowlite")) - return tensorflowlite; - else if (!strcmp(str_encoding, "ggml")) - return ggml; - else if (!strcmp(str_encoding, "onnx")) - return onnx; - else - return unknown_backend; - // return autodetect; -} - -static execution_target -str2target(char *str_target) -{ - if (!str_target) { - NN_ERR_PRINTF("Got empty string target"); - return -1; - } - - if (!strcmp(str_target, "cpu")) - return cpu; - else if (!strcmp(str_target, "gpu")) - return gpu; - else if (!strcmp(str_target, "tpu")) - return tpu; - else - return unsupported_target; - // return autodetect; -} - static graph_encoding choose_a_backend() { @@ -630,10 +590,8 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, is_loaded = wasi_nn_registry->loaded[model_idx]; char *global_model_path_i = wasi_nn_registry->graph_paths[model_idx]; - graph_encoding encoding = - str2encoding(wasi_nn_registry->encoding[model_idx]); - execution_target target = - str2target(wasi_nn_registry->target[model_idx]); + graph_encoding encoding = (graph_encoding)(wasi_nn_registry->encoding[model_idx]); + execution_target target = (execution_target)(wasi_nn_registry->target[model_idx]); // res = ensure_backend(instance, autodetect, wasi_nn_ctx); res = ensure_backend(instance, encoding, wasi_nn_ctx); diff --git a/product-mini/platforms/common/libc_wasi.c b/product-mini/platforms/common/libc_wasi.c index 3559bd1f63..b7fa53bbc7 100644 --- a/product-mini/platforms/common/libc_wasi.c +++ b/product-mini/platforms/common/libc_wasi.c @@ -21,19 +21,77 @@ typedef struct { uint32 ns_lookup_pool_size; } libc_wasi_parse_context_t; +typedef enum { + LIBC_WASI_PARSE_RESULT_OK = 0, + LIBC_WASI_PARSE_RESULT_NEED_HELP, + LIBC_WASI_PARSE_RESULT_BAD_PARAM +} libc_wasi_parse_result_t; + typedef struct { const char *model_names[10]; - const char *encoding[10]; - const char *target[10]; + const uint32_t *encoding[10]; + const uint32_t *target[10]; const char *graph_paths[10]; uint32 n_graphs; } wasi_nn_parse_context_t; typedef enum { - LIBC_WASI_PARSE_RESULT_OK = 0, - LIBC_WASI_PARSE_RESULT_NEED_HELP, - LIBC_WASI_PARSE_RESULT_BAD_PARAM -} libc_wasi_parse_result_t; + wasi_nn_openvino = 0, + wasi_nn_onnx, + wasi_nn_tensorflow, + wasi_nn_pytorch, + wasi_nn_tensorflowlite, + wasi_nn_ggml, + wasi_nn_autodetect, + wasi_nn_unknown_backend, +} wasi_nn_encoding; + +typedef enum wasi_nn_target { + wasi_nn_cpu = 0, + wasi_nn_gpu, + wasi_nn_tpu, + wasi_nn_unsupported_target, +} wasi_nn_target; + +static wasi_nn_encoding +str2encoding(char *str_encoding) +{ + if (!str_encoding) { + printf("Got empty string encoding"); + return -1; + } + + if (!strcmp(str_encoding, "openvino")) + return wasi_nn_openvino; + else if (!strcmp(str_encoding, "tensorflowlite")) + return wasi_nn_tensorflowlite; + else if (!strcmp(str_encoding, "ggml")) + return wasi_nn_ggml; + else if (!strcmp(str_encoding, "onnx")) + return wasi_nn_onnx; + else + return wasi_nn_unknown_backend; + // return autodetect; +} + +static wasi_nn_target +str2target(char *str_target) +{ + if (!str_target) { + printf("Got empty string target"); + return -1; + } + + if (!strcmp(str_target, "cpu")) + return wasi_nn_cpu; + else if (!strcmp(str_target, "gpu")) + return wasi_nn_gpu; + else if (!strcmp(str_target, "tpu")) + return wasi_nn_tpu; + else + return wasi_nn_unsupported_target; + // return autodetect; +} static void libc_wasi_print_help(void) @@ -223,10 +281,19 @@ wasi_nn_parse(char **argv, wasi_nn_parse_context_t *ctx) } ctx->model_names[ctx->n_graphs] = tokens[0]; - ctx->encoding[ctx->n_graphs] = tokens[1]; - ctx->target[ctx->n_graphs] = tokens[2]; - ctx->graph_paths[ctx->n_graphs++] = tokens[3]; + ctx->encoding[ctx->n_graphs] = (uint32_t)str2encoding(tokens[1]); + ctx->target[ctx->n_graphs] = (uint32_t)str2target(tokens[2]); + ctx->graph_paths[ctx->n_graphs] = tokens[3]; + + if ((!ctx->model_names[ctx->n_graphs]) || + (ctx->encoding[ctx->n_graphs] == wasi_nn_unknown_backend) || + (ctx->target[ctx->n_graphs] == wasi_nn_unsupported_target)) { + ret = LIBC_WASI_PARSE_RESULT_NEED_HELP; + printf("Invalid arguments for wasi-nn.\n"); + goto fail; + } + ctx->n_graphs++; fail: return ret; From 4747d61912f83200585fa5b029a82312b0d3dfb2 Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Fri, 6 Feb 2026 11:19:29 +0800 Subject: [PATCH 26/28] Use clang-format-18 to format source file --- core/iwasm/common/wasm_native.c | 2 +- core/iwasm/common/wasm_runtime_common.c | 39 +++++++++++-------- core/iwasm/common/wasm_runtime_common.h | 12 +++--- core/iwasm/include/wasm_export.h | 2 +- core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 9 +++-- .../libraries/wasi-nn/src/wasi_nn_backend.h | 2 +- core/iwasm/libraries/wasi-nn/test/utils.c | 2 +- core/iwasm/libraries/wasi-nn/test/utils.h | 22 +++++------ product-mini/platforms/common/libc_wasi.c | 18 ++++----- 9 files changed, 59 insertions(+), 49 deletions(-) diff --git a/core/iwasm/common/wasm_native.c b/core/iwasm/common/wasm_native.c index 2ba4a5778d..2b49c052ec 100644 --- a/core/iwasm/common/wasm_native.c +++ b/core/iwasm/common/wasm_native.c @@ -486,7 +486,7 @@ wasm_runtime_get_wasi_nn_registry(WASMModuleInstanceCommon *module_inst_comm) void wasm_runtime_set_wasi_nn_registry(WASMModuleInstanceCommon *module_inst_comm, - WASINNRegistry *wasi_nn_ctx) + WASINNRegistry *wasi_nn_ctx) { wasm_native_set_context(module_inst_comm, g_wasi_nn_registry_key, wasi_nn_ctx); diff --git a/core/iwasm/common/wasm_runtime_common.c b/core/iwasm/common/wasm_runtime_common.c index ccd569cdc3..61b8e77961 100644 --- a/core/iwasm/common/wasm_runtime_common.c +++ b/core/iwasm/common/wasm_runtime_common.c @@ -1804,9 +1804,11 @@ wasm_runtime_wasi_nn_graph_registry_args_set_defaults(WASINNRegistry *args) bool wasm_runtime_wasi_nn_registry_set_args(WASINNRegistry *registry, - const char **model_names, const uint32_t **encoding, - const uint32_t **target, uint32_t n_graphs, - const char **graph_paths) + const char **model_names, + const uint32_t **encoding, + const uint32_t **target, + uint32_t n_graphs, + const char **graph_paths) { if (!registry || !model_names || !encoding || !target || !graph_paths) { return false; @@ -1818,11 +1820,16 @@ wasm_runtime_wasi_nn_registry_set_args(WASINNRegistry *registry, } registry->n_graphs = n_graphs; - registry->target = (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); - registry->encoding = (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); - registry->loaded = (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); - registry->model_names = (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); - registry->graph_paths = (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); + registry->target = + (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); + registry->encoding = + (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); + registry->loaded = + (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); + registry->model_names = + (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); + registry->graph_paths = + (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); memset(registry->target, 0, sizeof(uint32_t *) * n_graphs); memset(registry->encoding, 0, sizeof(uint32_t *) * n_graphs); memset(registry->loaded, 0, sizeof(uint32_t *) * n_graphs); @@ -1858,15 +1865,15 @@ wasm_runtime_wasi_nn_registry_destroy(WASINNRegistry *registry) for (uint32_t i = 0; i < registry->n_graphs; i++) if (registry->graph_paths[i]) { wasm_runtime_free(registry->graph_paths[i]); - if (registry->model_names[i]) - wasm_runtime_free(registry->model_names[i]); + if (registry->model_names[i]) + wasm_runtime_free(registry->model_names[i]); } - if (registry->encoding) - wasm_runtime_free(registry->encoding); - if (registry->target) - wasm_runtime_free(registry->target); - if (registry->loaded) - wasm_runtime_free(registry->loaded); + if (registry->encoding) + wasm_runtime_free(registry->encoding); + if (registry->target) + wasm_runtime_free(registry->target); + if (registry->loaded) + wasm_runtime_free(registry->loaded); wasm_runtime_free(registry); } } diff --git a/core/iwasm/common/wasm_runtime_common.h b/core/iwasm/common/wasm_runtime_common.h index 23aa451266..bc6cbbbf15 100644 --- a/core/iwasm/common/wasm_runtime_common.h +++ b/core/iwasm/common/wasm_runtime_common.h @@ -805,9 +805,11 @@ wasm_runtime_instantiation_args_set_wasi_nn_registry( WASM_RUNTIME_API_EXTERN bool wasm_runtime_wasi_nn_registry_set_args(WASINNRegistry *registry, - const char **model_names, const uint32_t **encoding, - const uint32_t **target, uint32_t n_graphs, - const char **graph_paths); + const char **model_names, + const uint32_t **encoding, + const uint32_t **target, + uint32_t n_graphs, + const char **graph_paths); #endif /* See wasm_export.h for description */ @@ -1465,14 +1467,14 @@ wasm_runtime_check_and_update_last_used_shared_heap( #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 WASM_RUNTIME_API_EXTERN void wasm_runtime_set_wasi_nn_registry(WASMModuleInstanceCommon *module_inst, - WASINNRegistry *wasi_ctx); + WASINNRegistry *wasi_ctx); WASM_RUNTIME_API_EXTERN WASINNRegistry * wasm_runtime_get_wasi_nn_registry(WASMModuleInstanceCommon *module_inst_comm); WASM_RUNTIME_API_EXTERN void wasm_runtime_set_wasi_nn_registry(WASMModuleInstanceCommon *module_inst_comm, - WASINNRegistry *wasi_nn_ctx); + WASINNRegistry *wasi_nn_ctx); #endif #ifdef __cplusplus diff --git a/core/iwasm/include/wasm_export.h b/core/iwasm/include/wasm_export.h index 8f4a61cdcb..885e160c0f 100644 --- a/core/iwasm/include/wasm_export.h +++ b/core/iwasm/include/wasm_export.h @@ -802,7 +802,7 @@ wasm_runtime_get_wasi_nn_registry(const wasm_module_inst_t module_inst); WASM_RUNTIME_API_EXTERN void wasm_runtime_set_wasi_nn_registry(wasm_module_inst_t module_inst, - struct WASINNRegistry *wasi_ctx); + struct WASINNRegistry *wasi_ctx); WASM_RUNTIME_API_EXTERN int wasm_runtime_wasi_nn_registry_create(struct WASINNRegistry **registryp); diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index d0bfe0f2e4..8effa8fd39 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -578,8 +578,7 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, bool is_loaded = false; uint32 model_idx = 0; - uint32_t global_n_graphs = - wasi_nn_registry->n_graphs; + uint32_t global_n_graphs = wasi_nn_registry->n_graphs; for (model_idx = 0; model_idx < global_n_graphs; model_idx++) { char *model_name = wasi_nn_registry->model_names[model_idx]; @@ -590,8 +589,10 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, is_loaded = wasi_nn_registry->loaded[model_idx]; char *global_model_path_i = wasi_nn_registry->graph_paths[model_idx]; - graph_encoding encoding = (graph_encoding)(wasi_nn_registry->encoding[model_idx]); - execution_target target = (execution_target)(wasi_nn_registry->target[model_idx]); + graph_encoding encoding = + (graph_encoding)(wasi_nn_registry->encoding[model_idx]); + execution_target target = + (execution_target)(wasi_nn_registry->target[model_idx]); // res = ensure_backend(instance, autodetect, wasi_nn_ctx); res = ensure_backend(instance, encoding, wasi_nn_ctx); diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h b/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h index 6a75c33090..6e2e5a4a9d 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h @@ -18,7 +18,7 @@ load(void *ctx, graph_builder_array *builder, graph_encoding encoding, __attribute__((visibility("default"))) wasi_nn_error load_by_name(void *tflite_ctx, const char *name, uint32_t namelen, - execution_target target, graph *g); + execution_target target, graph *g); __attribute__((visibility("default"))) wasi_nn_error load_by_name_with_config(void *ctx, const char *name, uint32_t namelen, diff --git a/core/iwasm/libraries/wasi-nn/test/utils.c b/core/iwasm/libraries/wasi-nn/test/utils.c index 64247e8d37..5bff60d6e5 100644 --- a/core/iwasm/libraries/wasi-nn/test/utils.c +++ b/core/iwasm/libraries/wasi-nn/test/utils.c @@ -151,7 +151,7 @@ run_inference(float *input, uint32_t *input_size, uint32_t *output_size, WASI_NN_NAME(graph) graph; wasi_nn_error_t res = wasm_load_by_name(model_name, &graph); - + if (res == WASI_NN_ERROR_NAME(not_found)) { NN_INFO_PRINTF("Model %s is not loaded, you should pass its path " "through --wasi-nn-graph", diff --git a/core/iwasm/libraries/wasi-nn/test/utils.h b/core/iwasm/libraries/wasi-nn/test/utils.h index ac3acd3478..ff14b209fe 100644 --- a/core/iwasm/libraries/wasi-nn/test/utils.h +++ b/core/iwasm/libraries/wasi-nn/test/utils.h @@ -35,27 +35,27 @@ typedef wasi_nn_error wasi_nn_error_t; /* wasi-nn wrappers */ wasi_nn_error_t -wasm_load(char *model_name, WASI_NN_NAME(graph) *g, WASI_NN_NAME(execution_target) target); +wasm_load(char *model_name, WASI_NN_NAME(graph) * g, + WASI_NN_NAME(execution_target) target); -wasi_nn_error_t -wasm_init_execution_context(WASI_NN_NAME(graph) g, WASI_NN_NAME(graph_execution_context) *ctx); +wasi_nn_error_t wasm_init_execution_context( + WASI_NN_NAME(graph) g, WASI_NN_NAME(graph_execution_context) * ctx); wasi_nn_error_t -wasm_set_input(WASI_NN_NAME(graph_execution_context) ctx, float *input_tensor, uint32_t *dim); +wasm_set_input(WASI_NN_NAME(graph_execution_context) ctx, float *input_tensor, + uint32_t *dim); -wasi_nn_error_t -wasm_compute(WASI_NN_NAME(graph_execution_context) ctx); +wasi_nn_error_t wasm_compute(WASI_NN_NAME(graph_execution_context) ctx); wasi_nn_error_t -wasm_get_output(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index, float *out_tensor, - uint32_t *out_size); +wasm_get_output(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index, + float *out_tensor, uint32_t *out_size); /* Utils */ float * -run_inference(float *input, uint32_t *input_size, - uint32_t *output_size, char *model_name, - uint32_t num_output_tensors); +run_inference(float *input, uint32_t *input_size, uint32_t *output_size, + char *model_name, uint32_t num_output_tensors); input_info create_input(int *dims); diff --git a/product-mini/platforms/common/libc_wasi.c b/product-mini/platforms/common/libc_wasi.c index b7fa53bbc7..137521fbd4 100644 --- a/product-mini/platforms/common/libc_wasi.c +++ b/product-mini/platforms/common/libc_wasi.c @@ -285,13 +285,13 @@ wasi_nn_parse(char **argv, wasi_nn_parse_context_t *ctx) ctx->target[ctx->n_graphs] = (uint32_t)str2target(tokens[2]); ctx->graph_paths[ctx->n_graphs] = tokens[3]; - if ((!ctx->model_names[ctx->n_graphs]) || - (ctx->encoding[ctx->n_graphs] == wasi_nn_unknown_backend) || - (ctx->target[ctx->n_graphs] == wasi_nn_unsupported_target)) { - ret = LIBC_WASI_PARSE_RESULT_NEED_HELP; - printf("Invalid arguments for wasi-nn.\n"); - goto fail; - } + if ((!ctx->model_names[ctx->n_graphs]) + || (ctx->encoding[ctx->n_graphs] == wasi_nn_unknown_backend) + || (ctx->target[ctx->n_graphs] == wasi_nn_unsupported_target)) { + ret = LIBC_WASI_PARSE_RESULT_NEED_HELP; + printf("Invalid arguments for wasi-nn.\n"); + goto fail; + } ctx->n_graphs++; fail: @@ -305,8 +305,8 @@ wasi_nn_set_init_args(struct InstantiationArgs2 *args, wasi_nn_parse_context_t *ctx) { wasm_runtime_wasi_nn_registry_set_args(nn_registry, ctx->model_names, - ctx->encoding, ctx->target, ctx->n_graphs, - ctx->graph_paths); + ctx->encoding, ctx->target, + ctx->n_graphs, ctx->graph_paths); wasm_runtime_instantiation_args_set_wasi_nn_registry(args, nn_registry); } #endif \ No newline at end of file From 9e89828a3437227db317960ee31ce34cffa58037 Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Fri, 6 Feb 2026 14:48:36 +0800 Subject: [PATCH 27/28] Make wasi_nn_load_by_name and wasi_nn_load_by_name_with_config share a common logic. --- core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 106 +++++++++++++-------- 1 file changed, 65 insertions(+), 41 deletions(-) diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 8effa8fd39..4feff102b6 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -535,38 +535,13 @@ copyin_and_nul_terminate(wasm_module_inst_t inst, char *name, uint32_t name_len, return success; } -wasi_nn_error -wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, - graph *g) +static wasi_nn_error +load_by_name_with_optional_config(WASINNContext *wasi_nn_ctx, + wasm_module_inst_t instance, bool use_config, + graph *g, const char *model_name, + const char *config, int32_t config_len) { - WASINNContext *wasi_nn_ctx = NULL; - char *nul_terminated_name = NULL; - wasi_nn_error res; - - wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env); - if (!instance) { - return runtime_error; - } - - if (!wasm_runtime_validate_native_addr(instance, g, - (uint64)sizeof(graph))) { - NN_ERR_PRINTF("graph is invalid"); - return invalid_argument; - } - - res = copyin_and_nul_terminate(instance, name, name_len, - &nul_terminated_name); - if (res != success) { - goto fail; - } - - NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME %s...", nul_terminated_name); - - wasi_nn_ctx = lock_ctx(instance); - if (wasi_nn_ctx == NULL) { - res = busy; - goto fail; - } + wasi_nn_error res = success; WASINNRegistry *wasi_nn_registry = wasm_runtime_get_wasi_nn_registry(instance); @@ -580,9 +555,9 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, uint32 model_idx = 0; uint32_t global_n_graphs = wasi_nn_registry->n_graphs; for (model_idx = 0; model_idx < global_n_graphs; model_idx++) { - char *model_name = wasi_nn_registry->model_names[model_idx]; + char *model_name_i = wasi_nn_registry->model_names[model_idx]; - if (strcmp(nul_terminated_name, model_name) != 0) { + if (strcmp(model_name, model_name_i) != 0) { continue; } @@ -594,7 +569,6 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, execution_target target = (execution_target)(wasi_nn_registry->target[model_idx]); - // res = ensure_backend(instance, autodetect, wasi_nn_ctx); res = ensure_backend(instance, encoding, wasi_nn_ctx); if (res != success) goto fail; @@ -603,9 +577,17 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, && (model_idx < global_n_graphs)) { NN_DBG_PRINTF( "Model is not yet loaded, will add to global context"); - call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res, - wasi_nn_ctx->backend_ctx, global_model_path_i, - strlen(global_model_path_i), target, g); + if (use_config && config && config_len > 0) { + call_wasi_nn_func( + wasi_nn_ctx->backend, load_by_name_with_config, res, + wasi_nn_ctx->backend_ctx, global_model_path_i, + strlen(global_model_path_i), config, config_len, g); + } + else { + call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res, + wasi_nn_ctx->backend_ctx, global_model_path_i, + strlen(global_model_path_i), target, g); + } if (res != success) goto fail; @@ -627,9 +609,51 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, else if (model_idx >= global_n_graphs) { NN_ERR_PRINTF("Model %s is not loaded, you should pass its path " "through --wasi-nn-graph", - nul_terminated_name); + model_name); res = not_found; } + +fail: + + return res; +} + +wasi_nn_error +wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, + graph *g) +{ + WASINNContext *wasi_nn_ctx = NULL; + char *nul_terminated_name = NULL; + wasi_nn_error res; + + wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env); + if (!instance) { + return runtime_error; + } + + if (!wasm_runtime_validate_native_addr(instance, g, + (uint64)sizeof(graph))) { + NN_ERR_PRINTF("graph is invalid"); + return invalid_argument; + } + + res = copyin_and_nul_terminate(instance, name, name_len, + &nul_terminated_name); + if (res != success) { + goto fail; + } + + NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME %s...", nul_terminated_name); + + wasi_nn_ctx = lock_ctx(instance); + if (wasi_nn_ctx == NULL) { + res = busy; + goto fail; + } + + res = load_by_name_with_optional_config(wasi_nn_ctx, instance, false, g, + nul_terminated_name, NULL, 0); + fail: if (nul_terminated_name != NULL) { wasm_runtime_free(nul_terminated_name); @@ -686,9 +710,9 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name, goto fail; ; - call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name_with_config, res, - wasi_nn_ctx->backend_ctx, nul_terminated_name, name_len, - nul_terminated_config, config_len, g); + res = load_by_name_with_optional_config(wasi_nn_ctx, instance, true, g, + nul_terminated_name, + nul_terminated_config, config_len); if (res != success) goto fail; From 7822544dbeb58731abf5ae341d567f08cff58ff3 Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Fri, 6 Feb 2026 15:01:00 +0800 Subject: [PATCH 28/28] Fix compilation errors for nuttx platform. --- product-mini/platforms/common/libc_wasi.c | 80 +++++++++++------------ 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/product-mini/platforms/common/libc_wasi.c b/product-mini/platforms/common/libc_wasi.c index 137521fbd4..a7d08b62f5 100644 --- a/product-mini/platforms/common/libc_wasi.c +++ b/product-mini/platforms/common/libc_wasi.c @@ -53,46 +53,6 @@ typedef enum wasi_nn_target { wasi_nn_unsupported_target, } wasi_nn_target; -static wasi_nn_encoding -str2encoding(char *str_encoding) -{ - if (!str_encoding) { - printf("Got empty string encoding"); - return -1; - } - - if (!strcmp(str_encoding, "openvino")) - return wasi_nn_openvino; - else if (!strcmp(str_encoding, "tensorflowlite")) - return wasi_nn_tensorflowlite; - else if (!strcmp(str_encoding, "ggml")) - return wasi_nn_ggml; - else if (!strcmp(str_encoding, "onnx")) - return wasi_nn_onnx; - else - return wasi_nn_unknown_backend; - // return autodetect; -} - -static wasi_nn_target -str2target(char *str_target) -{ - if (!str_target) { - printf("Got empty string target"); - return -1; - } - - if (!strcmp(str_target, "cpu")) - return wasi_nn_cpu; - else if (!strcmp(str_target, "gpu")) - return wasi_nn_gpu; - else if (!strcmp(str_target, "tpu")) - return wasi_nn_tpu; - else - return wasi_nn_unsupported_target; - // return autodetect; -} - static void libc_wasi_print_help(void) { @@ -245,6 +205,46 @@ libc_wasi_set_init_args(struct InstantiationArgs2 *args, int argc, char **argv, } #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +static wasi_nn_encoding +str2encoding(char *str_encoding) +{ + if (!str_encoding) { + printf("Got empty string encoding"); + return -1; + } + + if (!strcmp(str_encoding, "openvino")) + return wasi_nn_openvino; + else if (!strcmp(str_encoding, "tensorflowlite")) + return wasi_nn_tensorflowlite; + else if (!strcmp(str_encoding, "ggml")) + return wasi_nn_ggml; + else if (!strcmp(str_encoding, "onnx")) + return wasi_nn_onnx; + else + return wasi_nn_unknown_backend; + // return autodetect; +} + +static wasi_nn_target +str2target(char *str_target) +{ + if (!str_target) { + printf("Got empty string target"); + return -1; + } + + if (!strcmp(str_target, "cpu")) + return wasi_nn_cpu; + else if (!strcmp(str_target, "gpu")) + return wasi_nn_gpu; + else if (!strcmp(str_target, "tpu")) + return wasi_nn_tpu; + else + return wasi_nn_unsupported_target; + // return autodetect; +} + libc_wasi_parse_result_t wasi_nn_parse(char **argv, wasi_nn_parse_context_t *ctx) {