Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 62 additions & 1 deletion convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4676,14 +4676,27 @@ def set_gguf_parameters(self):
self.gguf_writer.add_uint32(gguf.Keys.LLM.SAMPLING_RATE.format(arch=arch), sampling_rate)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.startswith("language_model."):
# Strip prefixes — Local variant nests under model.language_model / model.embedding_list
if name.startswith("model.language_model."):
name = name.replace("model.language_model.", "", 1)
elif name.startswith("language_model."):
name = name.replace("language_model.", "", 1)

# Local variant: embedding_list.0 = text (skip), 1-32 = audio codebooks 0-31
if (match := re.fullmatch(r"model\.embedding_list\.(\d+)\.weight", name)) is not None:
idx = int(match.group(1))
if idx == 0:
return # text embedding — already covered by embed_tokens
yield (f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.TOKEN_EMBD_AUDIO]}.{idx - 1}.weight", data_torch)
return

# 8B variant: emb_ext.N = audio codebook N (no offset)
if (match := re.fullmatch(r"emb_ext\.(\d+)\.weight", name)) is not None:
vq_idx = int(match.group(1))
yield (f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.TOKEN_EMBD_AUDIO]}.{vq_idx}.weight", data_torch)
return

# Audio LM heads: 0 = text output, 1-32 = audio output 0-31
if (match := re.fullmatch(r"lm_heads\.(\d+)\.weight", name)) is not None:
head_idx = int(match.group(1))
if head_idx == 0:
Expand All @@ -4692,6 +4705,54 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
yield (f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.OUTPUT_AUDIO]}.{head_idx - 1}.weight", data_torch)
return

# Layer norms before LM heads (Local variant)
if (match := re.fullmatch(r"layer_norm_before_lm_heads\.(\d+)\.weight", name)) is not None:
idx = int(match.group(1))
yield (f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.AUDIO_LN]}.{idx}.weight", data_torch)
return

# Local transformer (4-layer mini-transformer)
_local_map = {
"self_attn.q_proj": gguf.MODEL_TENSOR.LOCAL_ATTN_Q,
"self_attn.k_proj": gguf.MODEL_TENSOR.LOCAL_ATTN_K,
"self_attn.v_proj": gguf.MODEL_TENSOR.LOCAL_ATTN_V,
"self_attn.o_proj": gguf.MODEL_TENSOR.LOCAL_ATTN_OUT,
"self_attn.q_norm": gguf.MODEL_TENSOR.LOCAL_ATTN_Q_NORM,
"self_attn.k_norm": gguf.MODEL_TENSOR.LOCAL_ATTN_K_NORM,
"input_layernorm": gguf.MODEL_TENSOR.LOCAL_ATTN_NORM,
"post_attention_layernorm": gguf.MODEL_TENSOR.LOCAL_FFN_NORM,
"mlp.gate_proj": gguf.MODEL_TENSOR.LOCAL_FFN_GATE,
"mlp.down_proj": gguf.MODEL_TENSOR.LOCAL_FFN_DOWN,
"mlp.up_proj": gguf.MODEL_TENSOR.LOCAL_FFN_UP,
}
if (match := re.fullmatch(r"local_transformer\.layers\.(\d+)\.(.+?)\.weight", name)) is not None:
layer_id = int(match.group(1))
suffix = match.group(2)
if suffix in _local_map:
gguf_name = gguf.TENSOR_NAMES[_local_map[suffix]].format(bid=layer_id)
yield (f"{gguf_name}.weight", data_torch)
return
if name == "local_transformer.norm.weight":
yield (f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.LOCAL_OUTPUT_NORM]}.weight", data_torch)
return

# Local-to-speech bridge MLPs (33 indexed)
_bridge_map = {"gate_proj": gguf.MODEL_TENSOR.LOCAL_TO_SPEECH_GATE, "down_proj": gguf.MODEL_TENSOR.LOCAL_TO_SPEECH_DOWN, "up_proj": gguf.MODEL_TENSOR.LOCAL_TO_SPEECH_UP}
if (match := re.fullmatch(r"local_to_speech_embedding_mlps\.(\d+)\.(.+?)\.weight", name)) is not None:
idx = int(match.group(1))
proj = match.group(2)
if proj in _bridge_map:
yield (f"{gguf.TENSOR_NAMES[_bridge_map[proj]]}.{idx}.weight", data_torch)
return

# Speech-to-local bridge MLP (single)
_s2l_map = {"gate_proj": gguf.MODEL_TENSOR.SPEECH_TO_LOCAL_GATE, "down_proj": gguf.MODEL_TENSOR.SPEECH_TO_LOCAL_DOWN, "up_proj": gguf.MODEL_TENSOR.SPEECH_TO_LOCAL_UP}
if (match := re.fullmatch(r"speech_embedding_to_local_mlp\.(.+?)\.weight", name)) is not None:
proj = match.group(1)
if proj in _s2l_map:
yield (f"{gguf.TENSOR_NAMES[_s2l_map[proj]]}.weight", data_torch)
return

yield from super().modify_tensors(data_torch, name, bid)


Expand Down
57 changes: 57 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,25 @@ class MODEL_TENSOR(IntEnum):
POS_EMBD = auto()
OUTPUT = auto()
OUTPUT_AUDIO = auto() # moss-tts-delay, indexed as output_audio.{id}
AUDIO_LN = auto()
LOCAL_ATTN_NORM = auto()
LOCAL_ATTN_Q = auto()
LOCAL_ATTN_Q_NORM = auto()
LOCAL_ATTN_K = auto()
LOCAL_ATTN_K_NORM = auto()
LOCAL_ATTN_V = auto()
LOCAL_ATTN_OUT = auto()
LOCAL_FFN_NORM = auto()
LOCAL_FFN_GATE = auto()
LOCAL_FFN_DOWN = auto()
LOCAL_FFN_UP = auto()
LOCAL_OUTPUT_NORM = auto()
LOCAL_TO_SPEECH_GATE = auto()
LOCAL_TO_SPEECH_DOWN = auto()
LOCAL_TO_SPEECH_UP = auto()
SPEECH_TO_LOCAL_GATE = auto()
SPEECH_TO_LOCAL_DOWN = auto()
SPEECH_TO_LOCAL_UP = auto()
DENSE_2_OUT = auto() # embeddinggemma 2_Dense
DENSE_3_OUT = auto() # embeddinggemma 3_Dense
OUTPUT_NORM = auto()
Expand Down Expand Up @@ -964,6 +983,25 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
MODEL_TENSOR.OUTPUT: "output",
MODEL_TENSOR.OUTPUT_AUDIO: "output_audio",
MODEL_TENSOR.AUDIO_LN: "audio_ln",
MODEL_TENSOR.LOCAL_ATTN_NORM: "local.blk.{bid}.attn_norm",
MODEL_TENSOR.LOCAL_ATTN_Q: "local.blk.{bid}.attn_q",
MODEL_TENSOR.LOCAL_ATTN_Q_NORM: "local.blk.{bid}.attn_q_norm",
MODEL_TENSOR.LOCAL_ATTN_K: "local.blk.{bid}.attn_k",
MODEL_TENSOR.LOCAL_ATTN_K_NORM: "local.blk.{bid}.attn_k_norm",
MODEL_TENSOR.LOCAL_ATTN_V: "local.blk.{bid}.attn_v",
MODEL_TENSOR.LOCAL_ATTN_OUT: "local.blk.{bid}.attn_output",
MODEL_TENSOR.LOCAL_FFN_NORM: "local.blk.{bid}.ffn_norm",
MODEL_TENSOR.LOCAL_FFN_GATE: "local.blk.{bid}.ffn_gate",
MODEL_TENSOR.LOCAL_FFN_DOWN: "local.blk.{bid}.ffn_down",
MODEL_TENSOR.LOCAL_FFN_UP: "local.blk.{bid}.ffn_up",
MODEL_TENSOR.LOCAL_OUTPUT_NORM: "local.output_norm",
MODEL_TENSOR.LOCAL_TO_SPEECH_GATE: "local_to_speech.ffn_gate",
MODEL_TENSOR.LOCAL_TO_SPEECH_DOWN: "local_to_speech.ffn_down",
MODEL_TENSOR.LOCAL_TO_SPEECH_UP: "local_to_speech.ffn_up",
MODEL_TENSOR.SPEECH_TO_LOCAL_GATE: "speech_to_local.ffn_gate",
MODEL_TENSOR.SPEECH_TO_LOCAL_DOWN: "speech_to_local.ffn_down",
MODEL_TENSOR.SPEECH_TO_LOCAL_UP: "speech_to_local.ffn_up",
MODEL_TENSOR.DENSE_2_OUT: "dense_2", # embeddinggemma 2_Dense
MODEL_TENSOR.DENSE_3_OUT: "dense_3", # embeddinggemma 2_Dense
MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
Expand Down Expand Up @@ -1812,6 +1850,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.OUTPUT_AUDIO,
MODEL_TENSOR.AUDIO_LN,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
Expand All @@ -1824,6 +1863,24 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.LOCAL_ATTN_NORM,
MODEL_TENSOR.LOCAL_ATTN_Q,
MODEL_TENSOR.LOCAL_ATTN_Q_NORM,
MODEL_TENSOR.LOCAL_ATTN_K,
MODEL_TENSOR.LOCAL_ATTN_K_NORM,
MODEL_TENSOR.LOCAL_ATTN_V,
MODEL_TENSOR.LOCAL_ATTN_OUT,
MODEL_TENSOR.LOCAL_FFN_NORM,
MODEL_TENSOR.LOCAL_FFN_GATE,
MODEL_TENSOR.LOCAL_FFN_DOWN,
MODEL_TENSOR.LOCAL_FFN_UP,
MODEL_TENSOR.LOCAL_OUTPUT_NORM,
MODEL_TENSOR.LOCAL_TO_SPEECH_GATE,
MODEL_TENSOR.LOCAL_TO_SPEECH_DOWN,
MODEL_TENSOR.LOCAL_TO_SPEECH_UP,
MODEL_TENSOR.SPEECH_TO_LOCAL_GATE,
MODEL_TENSOR.SPEECH_TO_LOCAL_DOWN,
MODEL_TENSOR.SPEECH_TO_LOCAL_UP,
],
MODEL_ARCH.QWEN3MOE: [
MODEL_TENSOR.TOKEN_EMBD,
Expand Down
61 changes: 61 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,25 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
{ LLM_TENSOR_OUTPUT_NORM_LFM2, "token_embd_norm" }, // fix for wrong tensor name
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_OUTPUT_AUDIO, "output_audio.%d" },
{ LLM_TENSOR_AUDIO_LN, "audio_ln.%d" },
{ LLM_TENSOR_LOCAL_ATTN_NORM, "local.blk.%d.attn_norm" },
{ LLM_TENSOR_LOCAL_ATTN_Q, "local.blk.%d.attn_q" },
{ LLM_TENSOR_LOCAL_ATTN_Q_NORM, "local.blk.%d.attn_q_norm" },
{ LLM_TENSOR_LOCAL_ATTN_K, "local.blk.%d.attn_k" },
{ LLM_TENSOR_LOCAL_ATTN_K_NORM, "local.blk.%d.attn_k_norm" },
{ LLM_TENSOR_LOCAL_ATTN_V, "local.blk.%d.attn_v" },
{ LLM_TENSOR_LOCAL_ATTN_OUT, "local.blk.%d.attn_output" },
{ LLM_TENSOR_LOCAL_FFN_NORM, "local.blk.%d.ffn_norm" },
{ LLM_TENSOR_LOCAL_FFN_GATE, "local.blk.%d.ffn_gate" },
{ LLM_TENSOR_LOCAL_FFN_DOWN, "local.blk.%d.ffn_down" },
{ LLM_TENSOR_LOCAL_FFN_UP, "local.blk.%d.ffn_up" },
{ LLM_TENSOR_LOCAL_OUTPUT_NORM, "local.output_norm" },
{ LLM_TENSOR_LOCAL_TO_SPEECH_GATE, "local_to_speech.ffn_gate.%d" },
{ LLM_TENSOR_LOCAL_TO_SPEECH_DOWN, "local_to_speech.ffn_down.%d" },
{ LLM_TENSOR_LOCAL_TO_SPEECH_UP, "local_to_speech.ffn_up.%d" },
{ LLM_TENSOR_SPEECH_TO_LOCAL_GATE, "speech_to_local.ffn_gate" },
{ LLM_TENSOR_SPEECH_TO_LOCAL_DOWN, "speech_to_local.ffn_down" },
{ LLM_TENSOR_SPEECH_TO_LOCAL_UP, "speech_to_local.ffn_up" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
Expand Down Expand Up @@ -990,6 +1009,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_OUTPUT_AUDIO,
LLM_TENSOR_AUDIO_LN,
LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_Q,
LLM_TENSOR_ATTN_Q_NORM,
Expand All @@ -1001,6 +1021,24 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_FFN_GATE,
LLM_TENSOR_FFN_DOWN,
LLM_TENSOR_FFN_UP,
LLM_TENSOR_LOCAL_ATTN_NORM,
LLM_TENSOR_LOCAL_ATTN_Q,
LLM_TENSOR_LOCAL_ATTN_Q_NORM,
LLM_TENSOR_LOCAL_ATTN_K,
LLM_TENSOR_LOCAL_ATTN_K_NORM,
LLM_TENSOR_LOCAL_ATTN_V,
LLM_TENSOR_LOCAL_ATTN_OUT,
LLM_TENSOR_LOCAL_FFN_NORM,
LLM_TENSOR_LOCAL_FFN_GATE,
LLM_TENSOR_LOCAL_FFN_DOWN,
LLM_TENSOR_LOCAL_FFN_UP,
LLM_TENSOR_LOCAL_OUTPUT_NORM,
LLM_TENSOR_LOCAL_TO_SPEECH_GATE,
LLM_TENSOR_LOCAL_TO_SPEECH_DOWN,
LLM_TENSOR_LOCAL_TO_SPEECH_UP,
LLM_TENSOR_SPEECH_TO_LOCAL_GATE,
LLM_TENSOR_SPEECH_TO_LOCAL_DOWN,
LLM_TENSOR_SPEECH_TO_LOCAL_UP,
};
case LLM_ARCH_QWEN3MOE:
case LLM_ARCH_QWEN3VLMOE:
Expand Down Expand Up @@ -2597,6 +2635,25 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL}},
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_OUTPUT_AUDIO, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_AUDIO_LN, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_LOCAL_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_LOCAL_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_LOCAL_ATTN_Q_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_LOCAL_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_LOCAL_ATTN_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_LOCAL_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_LOCAL_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_LOCAL_FFN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_LOCAL_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_LOCAL_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_LOCAL_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_LOCAL_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_LOCAL_TO_SPEECH_GATE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_LOCAL_TO_SPEECH_DOWN, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_LOCAL_TO_SPEECH_UP, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_SPEECH_TO_LOCAL_GATE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_SPEECH_TO_LOCAL_DOWN, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_SPEECH_TO_LOCAL_UP, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
Expand Down Expand Up @@ -2827,6 +2884,10 @@ std::string LLM_TN_IMPL::str() const {
switch (tensor) {
case LLM_TENSOR_TOKEN_EMBD_AUDIO:
case LLM_TENSOR_OUTPUT_AUDIO:
case LLM_TENSOR_AUDIO_LN:
case LLM_TENSOR_LOCAL_TO_SPEECH_GATE:
case LLM_TENSOR_LOCAL_TO_SPEECH_DOWN:
case LLM_TENSOR_LOCAL_TO_SPEECH_UP:
name = ::format(LLM_TENSOR_NAMES.at(tensor), xid);
break;
default:
Expand Down
19 changes: 19 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,25 @@ enum llm_tensor {
LLM_TENSOR_DENSE_3_OUT,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_OUTPUT_AUDIO,
LLM_TENSOR_AUDIO_LN,
LLM_TENSOR_LOCAL_ATTN_NORM,
LLM_TENSOR_LOCAL_ATTN_Q,
LLM_TENSOR_LOCAL_ATTN_Q_NORM,
LLM_TENSOR_LOCAL_ATTN_K,
LLM_TENSOR_LOCAL_ATTN_K_NORM,
LLM_TENSOR_LOCAL_ATTN_V,
LLM_TENSOR_LOCAL_ATTN_OUT,
LLM_TENSOR_LOCAL_FFN_NORM,
LLM_TENSOR_LOCAL_FFN_GATE,
LLM_TENSOR_LOCAL_FFN_DOWN,
LLM_TENSOR_LOCAL_FFN_UP,
LLM_TENSOR_LOCAL_OUTPUT_NORM,
LLM_TENSOR_LOCAL_TO_SPEECH_GATE,
LLM_TENSOR_LOCAL_TO_SPEECH_DOWN,
LLM_TENSOR_LOCAL_TO_SPEECH_UP,
LLM_TENSOR_SPEECH_TO_LOCAL_GATE,
LLM_TENSOR_SPEECH_TO_LOCAL_DOWN,
LLM_TENSOR_SPEECH_TO_LOCAL_UP,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_OUTPUT_NORM_LFM2, // fix for wrong tensor name
LLM_TENSOR_ROPE_FREQS,
Expand Down
Loading