|
126 | 126 | None, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p |
127 | 127 | ) |
128 | 128 |
|
| 129 | +# typedef struct ggml_threadpool * ggml_threadpool_t; |
| 130 | +ggml_threadpool_t = ctypes.c_void_p |
| 131 | + |
129 | 132 | # llama.h bindings |
130 | 133 |
|
131 | 134 | _lib.llama_max_devices.argtypes = [] |
|
186 | 189 | # typedef int32_t llama_seq_id; |
187 | 190 | llama_seq_id = ctypes.c_int32 |
188 | 191 |
|
| 192 | +# typedef uint32_t llama_state_seq_flags; |
| 193 | +llama_state_seq_flags = ctypes.c_uint32 |
| 194 | + |
| 195 | +# State sequence flags |
| 196 | +LLAMA_STATE_SEQ_FLAGS_SWA_ONLY = 1 |
| 197 | +LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY = 1 |
| 198 | + |
189 | 199 |
|
190 | 200 | # enum llama_vocab_type { |
191 | 201 | # LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab |
@@ -1197,11 +1207,26 @@ def llama_numa_init(numa: int, /): ... |
1197 | 1207 | # struct llama_context * ctx, |
1198 | 1208 | # ggml_threadpool_t threadpool, |
1199 | 1209 | # ggml_threadpool_t threadpool_batch); |
1200 | | -# TODO: Add llama_attach_threadpool |
| 1210 | +@ctypes_function( |
| 1211 | + "llama_attach_threadpool", |
| 1212 | + [llama_context_p_ctypes, ggml_threadpool_t, ggml_threadpool_t], |
| 1213 | + None, |
| 1214 | +) |
| 1215 | +def llama_attach_threadpool( |
| 1216 | + ctx: llama_context_p, |
| 1217 | + threadpool: int, |
| 1218 | + threadpool_batch: int, |
| 1219 | + /, |
| 1220 | +): |
| 1221 | + """Attach threadpools to context""" |
| 1222 | + ... |
1201 | 1223 |
|
1202 | 1224 |
|
1203 | 1225 | # LLAMA_API void llama_detach_threadpool(struct llama_context * ctx); |
1204 | | -# TODO: Add llama_detach_threadpool |
| 1226 | +@ctypes_function("llama_detach_threadpool", [llama_context_p_ctypes], None) |
| 1227 | +def llama_detach_threadpool(ctx: llama_context_p, /): |
| 1228 | + """Detach threadpool from context""" |
| 1229 | + ... |
1205 | 1230 |
|
1206 | 1231 |
|
1207 | 1232 | # DEPRECATED(LLAMA_API struct llama_model * llama_load_model_from_file( |
@@ -1375,6 +1400,43 @@ def llama_supports_rpc() -> bool: ... |
1375 | 1400 | def llama_max_tensor_buft_overrides() -> int: |
1376 | 1401 | """Get maximum number of tensor buffer type overrides""" |
1377 | 1402 | ... |
| 1403 | + |
| 1404 | + |
| 1405 | +# LLAMA_API enum llama_params_fit_status llama_params_fit( |
| 1406 | +# const char * path_model, |
| 1407 | +# struct llama_model_params * mparams, |
| 1408 | +# struct llama_context_params * cparams, |
| 1409 | +# float * tensor_split, |
| 1410 | +# struct llama_model_tensor_buft_override * tensor_buft_overrides, |
| 1411 | +# size_t n_buft_overrides); |
| 1412 | +@ctypes_function( |
| 1413 | + "llama_params_fit", |
| 1414 | + [ |
| 1415 | + ctypes.c_char_p, |
| 1416 | + ctypes.POINTER(llama_model_params), |
| 1417 | + ctypes.POINTER(llama_context_params), |
| 1418 | + ctypes.POINTER(ctypes.c_float), |
| 1419 | + ctypes.c_void_p, # tensor_buft_overrides - not fully bound |
| 1420 | + ctypes.c_size_t, |
| 1421 | + ], |
| 1422 | + ctypes.c_int, |
| 1423 | +) |
| 1424 | +def llama_params_fit( |
| 1425 | + path_model: bytes, |
| 1426 | + mparams: CtypesPointerOrRef[llama_model_params], |
| 1427 | + cparams: CtypesPointerOrRef[llama_context_params], |
| 1428 | + tensor_split: CtypesArray[ctypes.c_float], |
| 1429 | + tensor_buft_overrides: int, |
| 1430 | + n_buft_overrides: Union[ctypes.c_size_t, int], |
| 1431 | + /, |
| 1432 | +) -> int: |
| 1433 | + """Check if model parameters will fit in memory |
| 1434 | +
|
| 1435 | + Returns: |
| 1436 | + LLAMA_PARAMS_FIT_STATUS_SUCCESS (0) - found allocations that are projected to fit |
| 1437 | + LLAMA_PARAMS_FIT_STATUS_FAILURE (1) - could not find allocations that are projected to fit |
| 1438 | + LLAMA_PARAMS_FIT_STATUS_ERROR (2) - a hard error occurred |
| 1439 | + """ |
1378 | 1440 | ... |
1379 | 1441 |
|
1380 | 1442 |
|
@@ -2515,6 +2577,83 @@ def llama_state_seq_load_file( |
2515 | 2577 | ) -> int: ... |
2516 | 2578 |
|
2517 | 2579 |
|
| 2580 | +# LLAMA_API size_t llama_state_seq_get_size_ext( |
| 2581 | +# struct llama_context * ctx, |
| 2582 | +# llama_seq_id seq_id, |
| 2583 | +# llama_state_seq_flags flags); |
| 2584 | +@ctypes_function( |
| 2585 | + "llama_state_seq_get_size_ext", |
| 2586 | + [llama_context_p_ctypes, llama_seq_id, llama_state_seq_flags], |
| 2587 | + ctypes.c_size_t, |
| 2588 | +) |
| 2589 | +def llama_state_seq_get_size_ext( |
| 2590 | + ctx: llama_context_p, |
| 2591 | + seq_id: Union[llama_seq_id, int], |
| 2592 | + flags: Union[llama_state_seq_flags, int], |
| 2593 | + /, |
| 2594 | +) -> int: |
| 2595 | + """Get size needed to copy sequence state with flags""" |
| 2596 | + ... |
| 2597 | + |
| 2598 | + |
| 2599 | +# LLAMA_API size_t llama_state_seq_get_data_ext( |
| 2600 | +# struct llama_context * ctx, |
| 2601 | +# uint8_t * dst, |
| 2602 | +# size_t size, |
| 2603 | +# llama_seq_id seq_id, |
| 2604 | +# llama_state_seq_flags flags); |
| 2605 | +@ctypes_function( |
| 2606 | + "llama_state_seq_get_data_ext", |
| 2607 | + [ |
| 2608 | + llama_context_p_ctypes, |
| 2609 | + ctypes.POINTER(ctypes.c_uint8), |
| 2610 | + ctypes.c_size_t, |
| 2611 | + llama_seq_id, |
| 2612 | + llama_state_seq_flags, |
| 2613 | + ], |
| 2614 | + ctypes.c_size_t, |
| 2615 | +) |
| 2616 | +def llama_state_seq_get_data_ext( |
| 2617 | + ctx: llama_context_p, |
| 2618 | + dst: CtypesArray[ctypes.c_uint8], |
| 2619 | + size: Union[ctypes.c_size_t, int], |
| 2620 | + seq_id: Union[llama_seq_id, int], |
| 2621 | + flags: Union[llama_state_seq_flags, int], |
| 2622 | + /, |
| 2623 | +) -> int: |
| 2624 | + """Copy sequence state to buffer with flags""" |
| 2625 | + ... |
| 2626 | + |
| 2627 | + |
| 2628 | +# LLAMA_API size_t llama_state_seq_set_data_ext( |
| 2629 | +# struct llama_context * ctx, |
| 2630 | +# const uint8_t * src, |
| 2631 | +# size_t size, |
| 2632 | +# llama_seq_id dest_seq_id, |
| 2633 | +# llama_state_seq_flags flags); |
| 2634 | +@ctypes_function( |
| 2635 | + "llama_state_seq_set_data_ext", |
| 2636 | + [ |
| 2637 | + llama_context_p_ctypes, |
| 2638 | + ctypes.POINTER(ctypes.c_uint8), |
| 2639 | + ctypes.c_size_t, |
| 2640 | + llama_seq_id, |
| 2641 | + llama_state_seq_flags, |
| 2642 | + ], |
| 2643 | + ctypes.c_size_t, |
| 2644 | +) |
| 2645 | +def llama_state_seq_set_data_ext( |
| 2646 | + ctx: llama_context_p, |
| 2647 | + src: CtypesArray[ctypes.c_uint8], |
| 2648 | + size: Union[ctypes.c_size_t, int], |
| 2649 | + dest_seq_id: Union[llama_seq_id, int], |
| 2650 | + flags: Union[llama_state_seq_flags, int], |
| 2651 | + /, |
| 2652 | +) -> int: |
| 2653 | + """Restore sequence state from buffer with flags""" |
| 2654 | + ... |
| 2655 | + |
| 2656 | + |
2518 | 2657 | # // |
2519 | 2658 | # // Decoding |
2520 | 2659 | # // |
|
0 commit comments