Skip to content

Commit aae2011

Browse files
committed
Merge remote-tracking branch 'avion23/update-llama-cpp-2026-01' into granite-docling
2 parents c91f276 + b844550 commit aae2011

File tree

5 files changed

+272
-103
lines changed

5 files changed

+272
-103
lines changed

llama_cpp/llama.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,9 @@ def __init__(
9191
logits_all: bool = False,
9292
embedding: bool = False,
9393
offload_kqv: bool = True,
94-
flash_attn: bool = False,
9594
op_offload: Optional[bool] = None,
9695
swa_full: Optional[bool] = None,
96+
flash_attn: Optional[bool] = None,
9797
# Sampling Params
9898
no_perf: bool = False,
9999
last_n_tokens_size: int = 64,
@@ -173,7 +173,7 @@ def __init__(
173173
logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs.
174174
embedding: Embedding mode only.
175175
offload_kqv: Offload K, Q, V to GPU.
176-
flash_attn: Use flash attention.
176+
flash_attn: Use flash attention. None = auto, True = enabled, False = disabled.
177177
op_offload: offload host tensor operations to device
178178
swa_full: use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
179179
no_perf: Measure performance timings.
@@ -341,11 +341,16 @@ def __init__(
341341
self._logits_all = logits_all if draft_model is None else True
342342
self.context_params.embeddings = embedding # TODO: Rename to embeddings
343343
self.context_params.offload_kqv = offload_kqv
344-
self.context_params.flash_attn_type = (
345-
llama_cpp.LLAMA_FLASH_ATTN_TYPE_ENABLED
346-
if flash_attn
347-
else llama_cpp.LLAMA_FLASH_ATTN_TYPE_DISABLED
348-
)
344+
if flash_attn is None:
345+
self.context_params.flash_attn_type = llama_cpp.LLAMA_FLASH_ATTN_TYPE_AUTO
346+
elif flash_attn:
347+
self.context_params.flash_attn_type = (
348+
llama_cpp.LLAMA_FLASH_ATTN_TYPE_ENABLED
349+
)
350+
else:
351+
self.context_params.flash_attn_type = (
352+
llama_cpp.LLAMA_FLASH_ATTN_TYPE_DISABLED
353+
)
349354

350355
if op_offload is not None:
351356
self.context_params.op_offload = op_offload
@@ -1167,9 +1172,9 @@ def _create_completion(
11671172
bos_token_id: int = self.token_bos()
11681173
cls_token_id: int = self._model.token_cls()
11691174
sep_token_id: int = self._model.token_sep()
1170-
prefix_token_id: int = 0 # self._model.token_prefix() # TODO: Fix
1171-
middle_token_id: int = 0 # self._model.token_middle() # TODO: Fix
1172-
suffix_token_id: int = 0 # self._model.token_suffix() # TODO: Fix
1175+
prefix_token_id: int = self._model.token_prefix()
1176+
middle_token_id: int = self._model.token_middle()
1177+
suffix_token_id: int = self._model.token_suffix()
11731178
add_space_prefix: bool = (
11741179
self.metadata.get("tokenizer.ggml.add_space_prefix", "true") == "true"
11751180
)
@@ -2143,10 +2148,7 @@ def __getstate__(self):
21432148
logits_all=self._logits_all,
21442149
embedding=self.context_params.embeddings,
21452150
offload_kqv=self.context_params.offload_kqv,
2146-
flash_attn=(
2147-
self.context_params.flash_attn_type
2148-
== llama_cpp.LLAMA_FLASH_ATTN_TYPE_ENABLED
2149-
),
2151+
flash_attn=self.context_params.flash_attn_type,
21502152
op_offload=self.context_params.op_offload,
21512153
swa_full=self.context_params.swa_full,
21522154
# Sampling Params
@@ -2366,23 +2368,23 @@ def from_pretrained(
23662368
)
23672369

23682370
if additional_files:
2369-
for additonal_file_name in additional_files:
2371+
for additional_file_name in additional_files:
23702372
# find the additional shard file:
23712373
matching_additional_files = [
23722374
file
23732375
for file in file_list
2374-
if fnmatch.fnmatch(file, additonal_file_name)
2376+
n if fnmatch.fnmatch(file, additional_file_name)
23752377
]
23762378

23772379
if len(matching_additional_files) == 0:
23782380
raise ValueError(
2379-
f"No file found in {repo_id} that match {additonal_file_name}\n\n"
2381+
f"No file found in {repo_id} that match {additional_file_name}\n\n"
23802382
f"Available Files:\n{json.dumps(file_list)}"
23812383
)
23822384

23832385
if len(matching_additional_files) > 1:
23842386
raise ValueError(
2385-
f"Multiple files found in {repo_id} matching {additonal_file_name}\n\n"
2387+
f"Multiple files found in {repo_id} matching {additional_file_name}\n\n"
23862388
f"Available Files:\n{json.dumps(files)}"
23872389
)
23882390

0 commit comments

Comments
 (0)