Skip to content

Commit a51f137

Browse files
committed
Split quoted text into character-level tokens
1 parent bf82af9 commit a51f137

File tree

1 file changed

+71
-24
lines changed

1 file changed

+71
-24
lines changed

conditioner.hpp

Lines changed: 71 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1678,46 +1678,91 @@ struct LLMEmbedder : public Conditioner {
16781678
}
16791679
}
16801680

1681-
std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text,
1682-
std::pair<int, int> attn_range,
1683-
size_t max_length = 0,
1684-
bool padding = false) {
1681+
std::tuple<std::vector<int>, std::vector<float>> tokenize(
1682+
std::string text,
1683+
std::pair<int, int> attn_range,
1684+
size_t max_length = 0,
1685+
bool padding = false,
1686+
bool spell_quotes = false) {
16851687
std::vector<std::pair<std::string, float>> parsed_attention;
16861688
parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f);
1689+
16871690
if (attn_range.second - attn_range.first > 0) {
1688-
auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first));
1689-
parsed_attention.insert(parsed_attention.end(),
1690-
new_parsed_attention.begin(),
1691-
new_parsed_attention.end());
1691+
auto new_parsed_attention = parse_prompt_attention(
1692+
text.substr(attn_range.first, attn_range.second - attn_range.first));
1693+
parsed_attention.insert(
1694+
parsed_attention.end(),
1695+
new_parsed_attention.begin(),
1696+
new_parsed_attention.end());
16921697
}
16931698
parsed_attention.emplace_back(text.substr(attn_range.second), 1.f);
1699+
16941700
{
16951701
std::stringstream ss;
1696-
ss << "[";
1702+
ss << '[';
16971703
for (const auto& item : parsed_attention) {
16981704
ss << "['" << item.first << "', " << item.second << "], ";
16991705
}
1700-
ss << "]";
1706+
ss << ']';
17011707
LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
17021708
}
17031709

17041710
std::vector<int> tokens;
17051711
std::vector<float> weights;
1712+
17061713
for (const auto& item : parsed_attention) {
17071714
const std::string& curr_text = item.first;
17081715
float curr_weight = item.second;
1709-
std::vector<int> curr_tokens = tokenizer->tokenize(curr_text, nullptr);
1710-
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
1711-
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
1712-
}
17131716

1714-
tokenizer->pad_tokens(tokens, weights, max_length, padding);
1717+
if (spell_quotes) {
1718+
std::vector<std::string> parts;
1719+
bool in_quote = false;
1720+
std::string current_part;
17151721

1716-
// for (int i = 0; i < tokens.size(); i++) {
1717-
// std::cout << tokens[i] << ":" << weights[i] << ", " << i << std::endl;
1718-
// }
1719-
// std::cout << std::endl;
1722+
for (char c : curr_text) {
1723+
if (c == '\'') {
1724+
if (!current_part.empty()) {
1725+
parts.push_back(current_part);
1726+
current_part.clear();
1727+
}
1728+
in_quote = !in_quote;
1729+
} else {
1730+
current_part += c;
1731+
if (in_quote && current_part.size() == 1) {
1732+
parts.push_back(current_part);
1733+
current_part.clear();
1734+
}
1735+
}
1736+
}
1737+
if (!current_part.empty()) {
1738+
parts.push_back(current_part);
1739+
}
17201740

1741+
for (const auto& part : parts) {
1742+
if (part.empty())
1743+
continue;
1744+
if (part[0] == '\'' && part.back() == '\'') {
1745+
std::string quoted_content = part.substr(1, part.size() - 2);
1746+
for (char ch : quoted_content) {
1747+
std::string char_str(1, ch);
1748+
std::vector<int> char_tokens = tokenizer->tokenize(char_str, nullptr);
1749+
tokens.insert(tokens.end(), char_tokens.begin(), char_tokens.end());
1750+
weights.insert(weights.end(), char_tokens.size(), curr_weight);
1751+
}
1752+
} else {
1753+
std::vector<int> part_tokens = tokenizer->tokenize(part, nullptr);
1754+
tokens.insert(tokens.end(), part_tokens.begin(), part_tokens.end());
1755+
weights.insert(weights.end(), part_tokens.size(), curr_weight);
1756+
}
1757+
}
1758+
} else {
1759+
std::vector<int> curr_tokens = tokenizer->tokenize(curr_text, nullptr);
1760+
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
1761+
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
1762+
}
1763+
}
1764+
1765+
tokenizer->pad_tokens(tokens, weights, max_length, padding);
17211766
return {tokens, weights};
17221767
}
17231768

@@ -1728,8 +1773,9 @@ struct LLMEmbedder : public Conditioner {
17281773
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
17291774
std::pair<int, int> prompt_attn_range;
17301775
int prompt_template_encode_start_idx = 34;
1731-
int max_length = 0;
1732-
bool pad = false;
1776+
int max_length = 0;
1777+
bool pad = false;
1778+
bool spell_quotes = false;
17331779
std::set<int> out_layers;
17341780
if (llm->enable_vision && conditioner_params.ref_images.size() > 0) {
17351781
LOG_INFO("QwenImageEditPlusPipeline");
@@ -1830,8 +1876,9 @@ struct LLMEmbedder : public Conditioner {
18301876
} else if (sd_version_is_longcat(version)) {
18311877
prompt_template_encode_start_idx = 36;
18321878
// prompt_template_encode_end_idx = 5;
1833-
max_length = 512;
1834-
pad = true;
1879+
max_length = 512;
1880+
pad = true;
1881+
spell_quotes = true;
18351882

18361883
prompt = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n";
18371884

@@ -1852,7 +1899,7 @@ struct LLMEmbedder : public Conditioner {
18521899
prompt += "<|im_end|>\n<|im_start|>assistant\n";
18531900
}
18541901

1855-
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, pad);
1902+
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, pad, spell_quotes);
18561903
auto& tokens = std::get<0>(tokens_and_weights);
18571904
auto& weights = std::get<1>(tokens_and_weights);
18581905

0 commit comments

Comments
 (0)