Skip to content

Commit 9f225e4

Browse files
committed
Split quoted text into character-level tokens
remove debug logs
1 parent fc8d85e commit 9f225e4

File tree

1 file changed

+76
-29
lines changed

1 file changed

+76
-29
lines changed

conditioner.hpp

Lines changed: 76 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1648,46 +1648,91 @@ struct LLMEmbedder : public Conditioner {
16481648
}
16491649
}
16501650

1651-
std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text,
1652-
std::pair<int, int> attn_range,
1653-
size_t max_length = 0,
1654-
bool padding = false) {
1651+
std::tuple<std::vector<int>, std::vector<float>> tokenize(
1652+
std::string text,
1653+
std::pair<int, int> attn_range,
1654+
size_t max_length = 0,
1655+
bool padding = false,
1656+
bool spell_quotes = false) {
16551657
std::vector<std::pair<std::string, float>> parsed_attention;
16561658
parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f);
1659+
16571660
if (attn_range.second - attn_range.first > 0) {
1658-
auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first));
1659-
parsed_attention.insert(parsed_attention.end(),
1660-
new_parsed_attention.begin(),
1661-
new_parsed_attention.end());
1661+
auto new_parsed_attention = parse_prompt_attention(
1662+
text.substr(attn_range.first, attn_range.second - attn_range.first));
1663+
parsed_attention.insert(
1664+
parsed_attention.end(),
1665+
new_parsed_attention.begin(),
1666+
new_parsed_attention.end());
16621667
}
16631668
parsed_attention.emplace_back(text.substr(attn_range.second), 1.f);
1664-
{
1665-
std::stringstream ss;
1666-
ss << "[";
1667-
for (const auto& item : parsed_attention) {
1668-
ss << "['" << item.first << "', " << item.second << "], ";
1669-
}
1670-
ss << "]";
1671-
LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
1672-
}
1669+
1670+
// {
1671+
// std::stringstream ss;
1672+
// ss << '[';
1673+
// for (const auto& item : parsed_attention) {
1674+
// ss << "['" << item.first << "', " << item.second << "], ";
1675+
// }
1676+
// ss << ']';
1677+
// LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
1678+
// }
16731679

16741680
std::vector<int> tokens;
16751681
std::vector<float> weights;
1682+
16761683
for (const auto& item : parsed_attention) {
16771684
const std::string& curr_text = item.first;
16781685
float curr_weight = item.second;
1679-
std::vector<int> curr_tokens = tokenizer->tokenize(curr_text, nullptr);
1680-
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
1681-
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
1682-
}
16831686

1684-
tokenizer->pad_tokens(tokens, weights, max_length, padding);
1687+
if (spell_quotes) {
1688+
std::vector<std::string> parts;
1689+
bool in_quote = false;
1690+
std::string current_part;
16851691

1686-
// for (int i = 0; i < tokens.size(); i++) {
1687-
// std::cout << tokens[i] << ":" << weights[i] << ", " << i << std::endl;
1688-
// }
1689-
// std::cout << std::endl;
1692+
for (char c : curr_text) {
1693+
if (c == '\'') {
1694+
if (!current_part.empty()) {
1695+
parts.push_back(current_part);
1696+
current_part.clear();
1697+
}
1698+
in_quote = !in_quote;
1699+
} else {
1700+
current_part += c;
1701+
if (in_quote && current_part.size() == 1) {
1702+
parts.push_back(current_part);
1703+
current_part.clear();
1704+
}
1705+
}
1706+
}
1707+
if (!current_part.empty()) {
1708+
parts.push_back(current_part);
1709+
}
16901710

1711+
for (const auto& part : parts) {
1712+
if (part.empty())
1713+
continue;
1714+
if (part[0] == '\'' && part.back() == '\'') {
1715+
std::string quoted_content = part.substr(1, part.size() - 2);
1716+
for (char ch : quoted_content) {
1717+
std::string char_str(1, ch);
1718+
std::vector<int> char_tokens = tokenizer->tokenize(char_str, nullptr);
1719+
tokens.insert(tokens.end(), char_tokens.begin(), char_tokens.end());
1720+
weights.insert(weights.end(), char_tokens.size(), curr_weight);
1721+
}
1722+
} else {
1723+
std::vector<int> part_tokens = tokenizer->tokenize(part, nullptr);
1724+
tokens.insert(tokens.end(), part_tokens.begin(), part_tokens.end());
1725+
weights.insert(weights.end(), part_tokens.size(), curr_weight);
1726+
}
1727+
}
1728+
} else {
1729+
std::vector<int> curr_tokens = tokenizer->tokenize(curr_text, nullptr);
1730+
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
1731+
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
1732+
}
1733+
}
1734+
1735+
tokenizer->pad_tokens(tokens, weights, max_length, padding);
16911736
return {tokens, weights};
16921737
}
16931738

@@ -1698,7 +1743,8 @@ struct LLMEmbedder : public Conditioner {
16981743
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
16991744
std::pair<int, int> prompt_attn_range;
17001745
int prompt_template_encode_start_idx = 34;
1701-
int max_length = 0;
1746+
int max_length = 0;
1747+
bool spell_quotes = false;
17021748
std::set<int> out_layers;
17031749
if (llm->enable_vision && conditioner_params.ref_images.size() > 0) {
17041750
LOG_INFO("QwenImageEditPlusPipeline");
@@ -1810,7 +1856,8 @@ struct LLMEmbedder : public Conditioner {
18101856
} else if (sd_version_is_longcat(version)) {
18111857
prompt_template_encode_start_idx = 36;
18121858
// prompt_template_encode_end_idx = 5;
1813-
max_length = 512;
1859+
max_length = 512;
1860+
spell_quotes = true;
18141861

18151862
prompt = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n";
18161863

@@ -1831,7 +1878,7 @@ struct LLMEmbedder : public Conditioner {
18311878
prompt += "<|im_end|>\n<|im_start|>assistant\n";
18321879
}
18331880

1834-
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0);
1881+
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0, spell_quotes);
18351882
auto& tokens = std::get<0>(tokens_and_weights);
18361883
auto& weights = std::get<1>(tokens_and_weights);
18371884

0 commit comments

Comments
 (0)