Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,50 @@ internal SentencePieceBaseModel(ModelProto modelProto, bool addBos = false, bool
specialTokens);
}

internal SentencePieceBaseModel(
bool addBos, bool addEos,
string bosToken, int bosId,
string eosToken, int eosId,
string unkToken, int unkId,
bool addDummyPrefix, bool escapeWhiteSpaces,
bool treatWhitespaceAsSuffix, bool byteFallback,
ReadOnlySpan<byte> precompiledCharsmap, bool removeExtraWhitespaces,
IReadOnlyDictionary<string, int>? specialTokens)
{
AddBeginningOfSentence = addBos;
AddEndOfSentence = addEos;
BeginningOfSentenceToken = bosToken;
BeginningOfSentenceId = Math.Max(0, bosId);
EndOfSentenceToken = eosToken;
EndOfSentenceId = Math.Max(0, eosId);
UnknownToken = unkToken;
UnknownId = Math.Max(0, unkId);
AddDummyPrefix = addDummyPrefix;
Comment on lines +72 to +80
EscapeWhiteSpaces = escapeWhiteSpaces;
TreatWhitespaceAsSuffix = treatWhitespaceAsSuffix;
ByteFallback = byteFallback;
SpecialTokens = specialTokens;

if (specialTokens is not null && specialTokens.Count > 0)
{
InternalSpecialTokens = new Dictionary<StringSpanOrdinalKey, int>();
SpecialTokensReverse = new Dictionary<int, string>();

foreach (var item in specialTokens)
{
InternalSpecialTokens.Add(new StringSpanOrdinalKey(item.Key), item.Value);
SpecialTokensReverse.Add(item.Value, item.Key);
}

SpecialTokensRegex = new Regex(string.Join("|", specialTokens.Keys.Select(s => Regex.Escape(s))), RegexOptions.Compiled);
}

Normalizer = new SentencePieceNormalizer(
precompiledCharsmap, removeExtraWhitespaces,
addDummyPrefix, escapeWhiteSpaces,
treatWhitespaceAsSuffix, specialTokens);
}

internal Regex? SpecialTokensRegex { get; }

internal Dictionary<StringSpanOrdinalKey, int>? InternalSpecialTokens { get; }
Expand Down
253 changes: 253 additions & 0 deletions src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.Text.Json;

namespace Microsoft.ML.Tokenizers
{
Expand All @@ -30,6 +31,11 @@ internal SentencePieceTokenizer(ModelProto modelProto, bool addBos, bool addEos,
};
}

private SentencePieceTokenizer(SentencePieceBaseModel model)
{
_model = model;
}

/// <summary>
/// The special tokens.
/// </summary>
Expand Down Expand Up @@ -457,5 +463,252 @@ public static SentencePieceTokenizer Create(

return new SentencePieceTokenizer(modelProto, addBeginningOfSentence, addEndOfSentence, specialTokens);
}

/// <summary>
/// Creates a Unigram <see cref="SentencePieceTokenizer"/> from an in-memory vocabulary of (piece, score) pairs.
/// </summary>
/// <param name="vocab">
/// The vocabulary as an ordered sequence of (piece, score) pairs. The position of each pair
/// in the sequence determines its token ID.
/// </param>
/// <param name="unkId">The index (token ID) of the unknown token in <paramref name="vocab"/>.</param>
/// <param name="addBeginningOfSentence">Whether to emit the beginning-of-sentence token during encoding.</param>
/// <param name="addEndOfSentence">Whether to emit the end-of-sentence token during encoding.</param>
/// <param name="precompiledCharsMap">
/// Optional precompiled character normalization map (as found in the SentencePiece <c>normalizer_spec.precompiled_charsmap</c>
/// field or in the Hugging Face <c>tokenizer.json</c> <c>normalizer.precompiled_charsmap</c> property).
/// Pass <see langword="default"/> to skip precompiled normalization.
/// </param>
/// <param name="addDummyPrefix">Whether to prepend the dummy whitespace prefix character (U+2581) at the start of the input.</param>
/// <param name="escapeWhiteSpaces">Whether to replace spaces with the dummy whitespace character (U+2581) during normalization.</param>
/// <param name="treatWhitespaceAsSuffix">Whether to emit the U+2581 character at the end of the last token rather than the beginning of the first token.</param>
/// <param name="specialTokens">Additional special tokens to recognize, supplied as a mapping of token string to token ID.</param>
/// <returns>A new <see cref="SentencePieceTokenizer"/> instance.</returns>
/// <remarks>
/// The beginning-of-sentence and end-of-sentence token IDs are auto-detected by looking for pieces
/// named <c>&lt;s&gt;</c> and <c>&lt;/s&gt;</c> in <paramref name="vocab"/>. If not found, positions 1 and 2
/// are used as fallbacks (the SentencePiece convention). Similarly, a <c>&lt;pad&gt;</c> piece is
/// detected automatically if present.
Comment on lines +488 to +491
/// <para>
/// When creating the tokenizer, ensure that the vocabulary is sourced from a trusted provider.
/// </para>
/// </remarks>
public static SentencePieceTokenizer Create(
IEnumerable<(string Piece, float Score)> vocab,
int unkId,
bool addBeginningOfSentence = true,
bool addEndOfSentence = false,
ReadOnlySpan<byte> precompiledCharsMap = default,
bool addDummyPrefix = true,
bool escapeWhiteSpaces = true,
bool treatWhitespaceAsSuffix = false,
IReadOnlyDictionary<string, int>? specialTokens = null)
{
if (vocab is null)
{
throw new ArgumentNullException(nameof(vocab));
}

IReadOnlyList<(string Piece, float Score)> pieces = vocab as IReadOnlyList<(string Piece, float Score)>
?? new List<(string Piece, float Score)>(vocab);

SentencePieceUnigramModel model = new SentencePieceUnigramModel(
pieces, unkId, addBeginningOfSentence, addEndOfSentence,
precompiledCharsMap, addDummyPrefix, escapeWhiteSpaces,
treatWhitespaceAsSuffix, removeExtraWhitespaces: true, specialTokens);

return new SentencePieceTokenizer(model);
}

/// <summary>
/// Creates a Unigram <see cref="SentencePieceTokenizer"/> by parsing a Hugging Face <c>tokenizer.json</c>
/// that contains a Unigram model (<c>model.type == "Unigram"</c>).
/// </summary>
/// <param name="tokenizerJsonStream">A stream containing the UTF-8-encoded <c>tokenizer.json</c> content.</param>
/// <param name="addBeginningOfSentence">Whether to emit the beginning-of-sentence token during encoding.</param>
/// <param name="addEndOfSentence">Whether to emit the end-of-sentence token during encoding.</param>
/// <param name="specialTokens">Additional special tokens to recognize, supplied as a mapping of token string to token ID.</param>
/// <returns>A new <see cref="SentencePieceTokenizer"/> instance.</returns>
/// <remarks>
/// The following fields are read from the JSON:
/// <list type="bullet">
/// <item><description><c>model.vocab</c> — array of <c>[piece, score]</c> pairs (required).</description></item>
/// <item><description><c>model.unk_id</c> — index of the unknown token (required).</description></item>
/// <item><description><c>normalizer.precompiled_charsmap</c> (base64) — normalization map; also searched inside a <c>Sequence</c> normalizer.</description></item>
/// <item><description><c>pre_tokenizer</c> of type <c>Metaspace</c> — <c>add_prefix_space</c> and <c>replacement</c>; also searched inside a <c>Sequence</c> pre-tokenizer.</description></item>
/// </list>
/// <para>
/// When creating the tokenizer, ensure that the JSON stream is sourced from a trusted provider.
/// </para>
/// </remarks>
public static SentencePieceTokenizer CreateFromTokenizerJson(
Stream tokenizerJsonStream,
bool addBeginningOfSentence = true,
bool addEndOfSentence = false,
IReadOnlyDictionary<string, int>? specialTokens = null)
{
if (tokenizerJsonStream is null)
{
throw new ArgumentNullException(nameof(tokenizerJsonStream));
}

using JsonDocument doc = JsonDocument.Parse(tokenizerJsonStream);
JsonElement root = doc.RootElement;

// Validate model type
if (!root.TryGetProperty("model", out JsonElement modelElement))
{
throw new InvalidDataException("The tokenizer.json does not contain a 'model' property.");
}

if (modelElement.TryGetProperty("type", out JsonElement modelTypeElement) &&
!string.Equals(modelTypeElement.GetString(), "Unigram", StringComparison.OrdinalIgnoreCase))
{
throw new InvalidDataException($"Expected model type 'Unigram' but found '{modelTypeElement.GetString()}'.");
}
Comment on lines +558 to +568

if (!modelElement.TryGetProperty("unk_id", out JsonElement unkIdElement))
{
throw new InvalidDataException("The tokenizer.json model does not contain an 'unk_id' property.");
}

int unkId = unkIdElement.GetInt32();

if (!modelElement.TryGetProperty("vocab", out JsonElement vocabElement) ||
vocabElement.ValueKind != JsonValueKind.Array)
{
throw new InvalidDataException("The tokenizer.json model does not contain a valid 'vocab' array.");
}

List<(string Piece, float Score)> vocab = new List<(string Piece, float Score)>(vocabElement.GetArrayLength());
foreach (JsonElement entry in vocabElement.EnumerateArray())
{
if (entry.ValueKind != JsonValueKind.Array || entry.GetArrayLength() < 2)
{
throw new InvalidDataException("Each entry in 'model.vocab' must be a [piece, score] array.");
}

string? piece = entry[0].GetString();
if (piece is null)
{
throw new InvalidDataException("A piece string in 'model.vocab' is null.");
}

vocab.Add((piece, entry[1].GetSingle()));
}

// Extract normalizer settings
byte[]? precompiledCharsMap = null;
bool addDummyPrefix = true;
bool removeExtraWhitespaces = true;
if (root.TryGetProperty("normalizer", out JsonElement normalizerElement) &&
normalizerElement.ValueKind == JsonValueKind.Object)
{
precompiledCharsMap = ExtractPrecompiledCharsMap(normalizerElement);
Comment thread
ericstj marked this conversation as resolved.
}

// Extract pre_tokenizer settings
bool escapeWhiteSpaces = true;
bool treatWhitespaceAsSuffix = false;
if (root.TryGetProperty("pre_tokenizer", out JsonElement preTokenizerElement))
{
ExtractMetaspaceSettings(preTokenizerElement, ref addDummyPrefix, ref escapeWhiteSpaces, ref treatWhitespaceAsSuffix);
}
Comment on lines +610 to +616

SentencePieceUnigramModel model = new SentencePieceUnigramModel(
vocab, unkId, addBeginningOfSentence, addEndOfSentence,
precompiledCharsMap is not null ? precompiledCharsMap.AsSpan() : default,
addDummyPrefix, escapeWhiteSpaces, treatWhitespaceAsSuffix, removeExtraWhitespaces, specialTokens);

return new SentencePieceTokenizer(model);
}

private static byte[]? ExtractPrecompiledCharsMap(JsonElement normalizer)
Comment thread
ericstj marked this conversation as resolved.
{
if (!normalizer.TryGetProperty("type", out JsonElement typeEl))
{
return null;
}

string? type = typeEl.GetString();
if (string.Equals(type, "Precompiled", StringComparison.OrdinalIgnoreCase))
{
if (normalizer.TryGetProperty("precompiled_charsmap", out JsonElement mapEl))
{
string? base64 = mapEl.GetString();
if (base64 is not null)
{
return Convert.FromBase64String(base64);
}
}
return null;
}
else if (string.Equals(type, "Sequence", StringComparison.OrdinalIgnoreCase) &&
normalizer.TryGetProperty("normalizers", out JsonElement normalizersEl) &&
normalizersEl.ValueKind == JsonValueKind.Array)
{
byte[]? result = null;
foreach (JsonElement inner in normalizersEl.EnumerateArray())
{
if (inner.ValueKind != JsonValueKind.Object)
{
continue;
}

byte[]? innerResult = ExtractPrecompiledCharsMap(inner);
if (innerResult is not null)
{
result = innerResult;
}
}
return result;
}
else
{
throw new NotSupportedException($"Normalizer type '{type}' is not supported. Only 'Precompiled' and 'Sequence' normalizers are supported.");
}
}

private static void ExtractMetaspaceSettings(JsonElement preTokenizer, ref bool addDummyPrefix, ref bool escapeWhiteSpaces, ref bool treatWhitespaceAsSuffix)
{
if (!preTokenizer.TryGetProperty("type", out JsonElement typeEl))
{
return;
}

string? type = typeEl.GetString();
if (string.Equals(type, "Metaspace", StringComparison.OrdinalIgnoreCase))
{
if (preTokenizer.TryGetProperty("add_prefix_space", out JsonElement addPrefixEl))
{
addDummyPrefix = addPrefixEl.GetBoolean();
}

if (preTokenizer.TryGetProperty("replacement", out JsonElement replacementEl))
{
string? replacement = replacementEl.GetString();
escapeWhiteSpaces = replacement == "\u2581"; // U+2581 LOWER ONE EIGHTH BLOCK (▁)
}

if (preTokenizer.TryGetProperty("prepend_scheme", out JsonElement prependSchemeEl))
{
string? scheme = prependSchemeEl.GetString();
// "never" suppresses the dummy prefix; "always"/"first" keep the default (true)
if (string.Equals(scheme, "never", StringComparison.OrdinalIgnoreCase))
{
addDummyPrefix = false;
}
}
}
else if (string.Equals(type, "Sequence", StringComparison.OrdinalIgnoreCase) &&
preTokenizer.TryGetProperty("pretokenizers", out JsonElement preTokenizersEl) &&
preTokenizersEl.ValueKind == JsonValueKind.Array)
{
foreach (JsonElement inner in preTokenizersEl.EnumerateArray())
{
ExtractMetaspaceSettings(inner, ref addDummyPrefix, ref escapeWhiteSpaces, ref treatWhitespaceAsSuffix);
}
}
}
}
}
Loading
Loading