diff --git a/src/vecenv.h b/src/vecenv.h index 6464de89f..cc0fc8c3f 100644 --- a/src/vecenv.h +++ b/src/vecenv.h @@ -193,6 +193,8 @@ void my_init(Env* env, Dict* kwargs); void my_log(Log* log, Dict* out); +typedef struct StaticOMPArg StaticOMPArg; + struct StaticThreading { atomic_int* buffer_states; atomic_int shutdown; @@ -200,16 +202,17 @@ struct StaticThreading { int num_buffers; pthread_t* threads; float* accum; // [num_buffers * NUM_EVAL_PROF] per-buffer timing in ms + StaticOMPArg* thread_args; }; -typedef struct StaticOMPArg { +struct StaticOMPArg { StaticVec* vec; int buf; int horizon; void* ctx; net_callback_fn net_callback; thread_init_fn thread_init; -} StaticOMPArg; +}; // OMP thread manager static void* static_omp_threadmanager(void* arg) { @@ -468,7 +471,8 @@ void create_static_threads(StaticVec* vec, int num_threads, int horizon, // Streams are now created by pufferlib.cu (PyTorch-managed streams) // Do NOT create streams here - they've already been set up - StaticOMPArg* args = (StaticOMPArg*)calloc(vec->buffers, sizeof(StaticOMPArg)); + vec->threading->thread_args = (StaticOMPArg*)calloc(vec->buffers, sizeof(StaticOMPArg)); + StaticOMPArg* args = vec->threading->thread_args; for (int i = 0; i < vec->buffers; i++) { args[i].vec = vec; args[i].buf = i; @@ -501,6 +505,7 @@ void static_vec_close(StaticVec* vec) { free(vec->threading->buffer_states); free(vec->threading->threads); free(vec->threading->accum); + free(vec->threading->thread_args); free(vec->threading); } free(vec->buffer_env_starts);