@@ -91,9 +91,9 @@ def __init__(
9191 logits_all : bool = False ,
9292 embedding : bool = False ,
9393 offload_kqv : bool = True ,
94- flash_attn : bool = False ,
9594 op_offload : Optional [bool ] = None ,
9695 swa_full : Optional [bool ] = None ,
96+ flash_attn : Optional [bool ] = None ,
9797 # Sampling Params
9898 no_perf : bool = False ,
9999 last_n_tokens_size : int = 64 ,
@@ -173,7 +173,7 @@ def __init__(
173173 logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs.
174174 embedding: Embedding mode only.
175175 offload_kqv: Offload K, Q, V to GPU.
176- flash_attn: Use flash attention.
176+ flash_attn: Use flash attention. None = auto, True = enabled, False = disabled.
177177 op_offload: offload host tensor operations to device
178178 swa_full: use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
179179 no_perf: Measure performance timings.
@@ -341,11 +341,16 @@ def __init__(
341341 self ._logits_all = logits_all if draft_model is None else True
342342 self .context_params .embeddings = embedding # TODO: Rename to embeddings
343343 self .context_params .offload_kqv = offload_kqv
344- self .context_params .flash_attn_type = (
345- llama_cpp .LLAMA_FLASH_ATTN_TYPE_ENABLED
346- if flash_attn
347- else llama_cpp .LLAMA_FLASH_ATTN_TYPE_DISABLED
348- )
344+ if flash_attn is None :
345+ self .context_params .flash_attn_type = llama_cpp .LLAMA_FLASH_ATTN_TYPE_AUTO
346+ elif flash_attn :
347+ self .context_params .flash_attn_type = (
348+ llama_cpp .LLAMA_FLASH_ATTN_TYPE_ENABLED
349+ )
350+ else :
351+ self .context_params .flash_attn_type = (
352+ llama_cpp .LLAMA_FLASH_ATTN_TYPE_DISABLED
353+ )
349354
350355 if op_offload is not None :
351356 self .context_params .op_offload = op_offload
@@ -1167,9 +1172,9 @@ def _create_completion(
11671172 bos_token_id : int = self .token_bos ()
11681173 cls_token_id : int = self ._model .token_cls ()
11691174 sep_token_id : int = self ._model .token_sep ()
1170- prefix_token_id : int = 0 # self._model.token_prefix() # TODO: Fix
1171- middle_token_id : int = 0 # self._model.token_middle() # TODO: Fix
1172- suffix_token_id : int = 0 # self._model.token_suffix() # TODO: Fix
1175+ prefix_token_id : int = self ._model .token_prefix ()
1176+ middle_token_id : int = self ._model .token_middle ()
1177+ suffix_token_id : int = self ._model .token_suffix ()
11731178 add_space_prefix : bool = (
11741179 self .metadata .get ("tokenizer.ggml.add_space_prefix" , "true" ) == "true"
11751180 )
@@ -2143,10 +2148,7 @@ def __getstate__(self):
21432148 logits_all = self ._logits_all ,
21442149 embedding = self .context_params .embeddings ,
21452150 offload_kqv = self .context_params .offload_kqv ,
2146- flash_attn = (
2147- self .context_params .flash_attn_type
2148- == llama_cpp .LLAMA_FLASH_ATTN_TYPE_ENABLED
2149- ),
2151+ flash_attn = self .context_params .flash_attn_type ,
21502152 op_offload = self .context_params .op_offload ,
21512153 swa_full = self .context_params .swa_full ,
21522154 # Sampling Params
@@ -2366,23 +2368,23 @@ def from_pretrained(
23662368 )
23672369
23682370 if additional_files :
2369- for additonal_file_name in additional_files :
2371+ for additional_file_name in additional_files :
23702372 # find the additional shard file:
23712373 matching_additional_files = [
23722374 file
23732375 for file in file_list
2374- if fnmatch .fnmatch (file , additonal_file_name )
2376+ n if fnmatch .fnmatch (file , additional_file_name )
23752377 ]
23762378
23772379 if len (matching_additional_files ) == 0 :
23782380 raise ValueError (
2379- f"No file found in { repo_id } that match { additonal_file_name } \n \n "
2381+ f"No file found in { repo_id } that match { additional_file_name } \n \n "
23802382 f"Available Files:\n { json .dumps (file_list )} "
23812383 )
23822384
23832385 if len (matching_additional_files ) > 1 :
23842386 raise ValueError (
2385- f"Multiple files found in { repo_id } matching { additonal_file_name } \n \n "
2387+ f"Multiple files found in { repo_id } matching { additional_file_name } \n \n "
23862388 f"Available Files:\n { json .dumps (files )} "
23872389 )
23882390
0 commit comments