Skip to content

Commit d14a24f

Browse files
author
Ralf Waldukat
committed
Add 6 new API functions from llama.cpp 2026-01-01
Implemented bindings for: - llama_attach_threadpool / llama_detach_threadpool - llama_params_fit (check if model will fit in memory) - llama_state_seq_get_size_ext - llama_state_seq_get_data_ext - llama_state_seq_set_data_ext Added type definitions: - ggml_threadpool_t - llama_state_seq_flags - LLAMA_STATE_SEQ_FLAGS_SWA_ONLY - LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY All 218 C API functions now have Python bindings.
1 parent 1f0241e commit d14a24f

File tree

1 file changed

+141
-2
lines changed

1 file changed

+141
-2
lines changed

llama_cpp/llama_cpp.py

Lines changed: 141 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@
126126
None, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p
127127
)
128128

129+
# typedef struct ggml_threadpool * ggml_threadpool_t;
130+
ggml_threadpool_t = ctypes.c_void_p
131+
129132
# llama.h bindings
130133

131134
_lib.llama_max_devices.argtypes = []
@@ -186,6 +189,13 @@
186189
# typedef int32_t llama_seq_id;
187190
llama_seq_id = ctypes.c_int32
188191

192+
# typedef uint32_t llama_state_seq_flags;
193+
llama_state_seq_flags = ctypes.c_uint32
194+
195+
# State sequence flags
196+
LLAMA_STATE_SEQ_FLAGS_SWA_ONLY = 1
197+
LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY = 1
198+
189199

190200
# enum llama_vocab_type {
191201
# LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab
@@ -1197,11 +1207,26 @@ def llama_numa_init(numa: int, /): ...
11971207
# struct llama_context * ctx,
11981208
# ggml_threadpool_t threadpool,
11991209
# ggml_threadpool_t threadpool_batch);
1200-
# TODO: Add llama_attach_threadpool
1210+
@ctypes_function(
1211+
"llama_attach_threadpool",
1212+
[llama_context_p_ctypes, ggml_threadpool_t, ggml_threadpool_t],
1213+
None,
1214+
)
1215+
def llama_attach_threadpool(
1216+
ctx: llama_context_p,
1217+
threadpool: int,
1218+
threadpool_batch: int,
1219+
/,
1220+
):
1221+
"""Attach threadpools to context"""
1222+
...
12011223

12021224

12031225
# LLAMA_API void llama_detach_threadpool(struct llama_context * ctx);
1204-
# TODO: Add llama_detach_threadpool
1226+
@ctypes_function("llama_detach_threadpool", [llama_context_p_ctypes], None)
1227+
def llama_detach_threadpool(ctx: llama_context_p, /):
1228+
"""Detach threadpool from context"""
1229+
...
12051230

12061231

12071232
# DEPRECATED(LLAMA_API struct llama_model * llama_load_model_from_file(
@@ -1375,6 +1400,43 @@ def llama_supports_rpc() -> bool: ...
13751400
def llama_max_tensor_buft_overrides() -> int:
13761401
"""Get maximum number of tensor buffer type overrides"""
13771402
...
1403+
1404+
1405+
# LLAMA_API enum llama_params_fit_status llama_params_fit(
1406+
# const char * path_model,
1407+
# struct llama_model_params * mparams,
1408+
# struct llama_context_params * cparams,
1409+
# float * tensor_split,
1410+
# struct llama_model_tensor_buft_override * tensor_buft_overrides,
1411+
# size_t n_buft_overrides);
1412+
@ctypes_function(
1413+
"llama_params_fit",
1414+
[
1415+
ctypes.c_char_p,
1416+
ctypes.POINTER(llama_model_params),
1417+
ctypes.POINTER(llama_context_params),
1418+
ctypes.POINTER(ctypes.c_float),
1419+
ctypes.c_void_p, # tensor_buft_overrides - not fully bound
1420+
ctypes.c_size_t,
1421+
],
1422+
ctypes.c_int,
1423+
)
1424+
def llama_params_fit(
1425+
path_model: bytes,
1426+
mparams: CtypesPointerOrRef[llama_model_params],
1427+
cparams: CtypesPointerOrRef[llama_context_params],
1428+
tensor_split: CtypesArray[ctypes.c_float],
1429+
tensor_buft_overrides: int,
1430+
n_buft_overrides: Union[ctypes.c_size_t, int],
1431+
/,
1432+
) -> int:
1433+
"""Check if model parameters will fit in memory
1434+
1435+
Returns:
1436+
LLAMA_PARAMS_FIT_STATUS_SUCCESS (0) - found allocations that are projected to fit
1437+
LLAMA_PARAMS_FIT_STATUS_FAILURE (1) - could not find allocations that are projected to fit
1438+
LLAMA_PARAMS_FIT_STATUS_ERROR (2) - a hard error occurred
1439+
"""
13781440
...
13791441

13801442

@@ -2515,6 +2577,83 @@ def llama_state_seq_load_file(
25152577
) -> int: ...
25162578

25172579

2580+
# LLAMA_API size_t llama_state_seq_get_size_ext(
2581+
# struct llama_context * ctx,
2582+
# llama_seq_id seq_id,
2583+
# llama_state_seq_flags flags);
2584+
@ctypes_function(
2585+
"llama_state_seq_get_size_ext",
2586+
[llama_context_p_ctypes, llama_seq_id, llama_state_seq_flags],
2587+
ctypes.c_size_t,
2588+
)
2589+
def llama_state_seq_get_size_ext(
2590+
ctx: llama_context_p,
2591+
seq_id: Union[llama_seq_id, int],
2592+
flags: Union[llama_state_seq_flags, int],
2593+
/,
2594+
) -> int:
2595+
"""Get size needed to copy sequence state with flags"""
2596+
...
2597+
2598+
2599+
# LLAMA_API size_t llama_state_seq_get_data_ext(
2600+
# struct llama_context * ctx,
2601+
# uint8_t * dst,
2602+
# size_t size,
2603+
# llama_seq_id seq_id,
2604+
# llama_state_seq_flags flags);
2605+
@ctypes_function(
2606+
"llama_state_seq_get_data_ext",
2607+
[
2608+
llama_context_p_ctypes,
2609+
ctypes.POINTER(ctypes.c_uint8),
2610+
ctypes.c_size_t,
2611+
llama_seq_id,
2612+
llama_state_seq_flags,
2613+
],
2614+
ctypes.c_size_t,
2615+
)
2616+
def llama_state_seq_get_data_ext(
2617+
ctx: llama_context_p,
2618+
dst: CtypesArray[ctypes.c_uint8],
2619+
size: Union[ctypes.c_size_t, int],
2620+
seq_id: Union[llama_seq_id, int],
2621+
flags: Union[llama_state_seq_flags, int],
2622+
/,
2623+
) -> int:
2624+
"""Copy sequence state to buffer with flags"""
2625+
...
2626+
2627+
2628+
# LLAMA_API size_t llama_state_seq_set_data_ext(
2629+
# struct llama_context * ctx,
2630+
# const uint8_t * src,
2631+
# size_t size,
2632+
# llama_seq_id dest_seq_id,
2633+
# llama_state_seq_flags flags);
2634+
@ctypes_function(
2635+
"llama_state_seq_set_data_ext",
2636+
[
2637+
llama_context_p_ctypes,
2638+
ctypes.POINTER(ctypes.c_uint8),
2639+
ctypes.c_size_t,
2640+
llama_seq_id,
2641+
llama_state_seq_flags,
2642+
],
2643+
ctypes.c_size_t,
2644+
)
2645+
def llama_state_seq_set_data_ext(
2646+
ctx: llama_context_p,
2647+
src: CtypesArray[ctypes.c_uint8],
2648+
size: Union[ctypes.c_size_t, int],
2649+
dest_seq_id: Union[llama_seq_id, int],
2650+
flags: Union[llama_state_seq_flags, int],
2651+
/,
2652+
) -> int:
2653+
"""Restore sequence state from buffer with flags"""
2654+
...
2655+
2656+
25182657
# //
25192658
# // Decoding
25202659
# //

0 commit comments

Comments
 (0)