Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions core/llm/countTokens.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import {
import { renderChatMessage } from "../util/messageContent.js";
import { AsyncEncoder, LlamaAsyncEncoder } from "./asyncEncoder.js";
import { DEFAULT_PRUNING_LENGTH } from "./constants.js";
import { getAdjustedTokenCountFromModel } from "./getAdjustedTokenCount.js";
import llamaTokenizer from "./llamaTokenizer.js";
interface Encoding {
encode: Tiktoken["encode"];
Expand Down Expand Up @@ -114,8 +115,9 @@ function countTokens(
modelName = "llama2",
): number {
const encoding = encodingForModel(modelName);
let baseTokens = 0;
if (Array.isArray(content)) {
return content.reduce((acc, part) => {
baseTokens = content.reduce((acc, part) => {
return (
acc +
(part.type === "text"
Expand All @@ -124,8 +126,9 @@ function countTokens(
);
}, 0);
} else {
return encoding.encode(content ?? "", "all", []).length;
baseTokens = encoding.encode(content ?? "", "all", []).length;
}
return getAdjustedTokenCountFromModel(baseTokens, modelName);
Comment thread
RomneyDa marked this conversation as resolved.
}

// https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/10
Expand Down
48 changes: 48 additions & 0 deletions core/llm/getAdjustedTokenCount.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import { getAdjustedTokenCountFromModel } from "./getAdjustedTokenCount";

describe("getAdjustedTokenCountFromModel", () => {
it("should return base tokens for non-special models", () => {
expect(getAdjustedTokenCountFromModel(100, "gpt-4")).toBe(100);
expect(getAdjustedTokenCountFromModel(100, "llama2")).toBe(100);
expect(getAdjustedTokenCountFromModel(100, "random-model")).toBe(100);
});

it("should apply multiplier for Claude models", () => {
expect(getAdjustedTokenCountFromModel(100, "claude-3-opus")).toBe(123);
expect(getAdjustedTokenCountFromModel(100, "claude-3.5-sonnet")).toBe(123);
expect(getAdjustedTokenCountFromModel(100, "CLAUDE-2")).toBe(123);
expect(getAdjustedTokenCountFromModel(50, "claude")).toBe(62); // 50 * 1.23 = 61.5, ceiled to 62
});

it("should apply multiplier for Gemini models", () => {
expect(getAdjustedTokenCountFromModel(100, "gemini-pro")).toBe(118);
expect(getAdjustedTokenCountFromModel(100, "gemini-1.5-pro")).toBe(118);
expect(getAdjustedTokenCountFromModel(100, "GEMINI-flash")).toBe(118);
expect(getAdjustedTokenCountFromModel(50, "gemini")).toBe(59); // 50 * 1.18 = 59
});

it("should apply multiplier for Mistral family models", () => {
expect(getAdjustedTokenCountFromModel(100, "mistral-large")).toBe(126);
expect(getAdjustedTokenCountFromModel(100, "mixtral-8x7b")).toBe(126);
expect(getAdjustedTokenCountFromModel(100, "devstral")).toBe(126);
expect(getAdjustedTokenCountFromModel(100, "CODESTRAL")).toBe(126);
expect(getAdjustedTokenCountFromModel(50, "mistral")).toBe(63); // 50 * 1.26 = 63
});

it("should handle edge cases", () => {
expect(getAdjustedTokenCountFromModel(0, "claude")).toBe(0);
expect(getAdjustedTokenCountFromModel(1, "gemini")).toBe(2); // 1 * 1.18 = 1.18, ceiled to 2
expect(getAdjustedTokenCountFromModel(1000, "mixtral")).toBe(1260);
});

it("should handle empty or undefined model names", () => {
expect(getAdjustedTokenCountFromModel(100, "")).toBe(100);
expect(getAdjustedTokenCountFromModel(100, undefined as any)).toBe(100);
});

it("should be case-insensitive", () => {
expect(getAdjustedTokenCountFromModel(100, "ClAuDe-3-OpUs")).toBe(123);
expect(getAdjustedTokenCountFromModel(100, "GeMiNi-PrO")).toBe(118);
expect(getAdjustedTokenCountFromModel(100, "MiXtRaL")).toBe(126);
});
});
38 changes: 38 additions & 0 deletions core/llm/getAdjustedTokenCount.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Importing a bunch of tokenizers can be very resource intensive (MB-scale per tokenizer)
// Using token counting APIs (e.g. for anthropic) can be complicated and unreliable in many environments
// So for now we will just use super fast gpt-tokenizer and apply safety buffers
// I'm using rough estimates from this article to apply safety buffers to common tokenizers
// which will have HIGHER token counts than gpt. Roughly using token ratio from article + 10%
// https://medium.com/@disparate-ai/not-all-tokens-are-created-equal-7347d549af4d
const ANTHROPIC_TOKEN_MULTIPLIER = 1.23;
const GEMINI_TOKEN_MULTIPLIER = 1.18;
const MISTRAL_TOKEN_MULTIPLIER = 1.26;

/**
* Adjusts token count based on model-specific tokenizer differences.
* Since we use llama tokenizer (~= gpt tokenizer) for all models, we apply
* multipliers for models known to have higher token counts.
*
* @param baseTokens - Token count from llama/gpt tokenizer
* @param modelName - Name of the model
* @returns Adjusted token count with safety buffer
*/
export function getAdjustedTokenCountFromModel(
baseTokens: number,
modelName: string,
): number {
let multiplier = 1;
const lowerModelName = modelName?.toLowerCase() ?? "";
if (lowerModelName.includes("claude")) {
multiplier = ANTHROPIC_TOKEN_MULTIPLIER;
} else if (lowerModelName.includes("gemini")) {
multiplier = GEMINI_TOKEN_MULTIPLIER;
} else if (
lowerModelName.includes("stral") ||
lowerModelName.includes("mixtral")
) {
// Mistral family models: mistral, mixtral, codestral, devstral, etc
multiplier = MISTRAL_TOKEN_MULTIPLIER;
}
return Math.ceil(baseTokens * multiplier);
}
32 changes: 3 additions & 29 deletions extensions/cli/src/util/tokenizer.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { ModelConfig } from "@continuedev/config-yaml";
import type { ChatHistoryItem } from "core/index.js";
import { getAdjustedTokenCountFromModel } from "core/llm/getAdjustedTokenCount.js";
import { encode } from "gpt-tokenizer";
import type { ChatCompletionTool } from "openai/resources/chat/completions.mjs";

Expand Down Expand Up @@ -32,33 +33,6 @@ export function getModelMaxTokens(model: ModelConfig): number {
: maxTokens;
}

// Importing a bunch of tokenizers can be very resource intensive (MB-scale per tokenizer)
// Using token counting APIs (e.g. for anthropic) can be complicated and unreliable in many environments
// So for now we will just use super fast gpt-tokenizer and apply safety buffers
// I'm using rough estimates from this article to apply safety buffers to common tokenizers
// which will have HIGHER token counts than gpt. Roughly using token ratio from article + 10%
// https://medium.com/@disparate-ai/not-all-tokens-are-created-equal-7347d549af4d
const ANTHROPIC_TOKEN_MULTIPLIER = 1.23;
const GEMINI_TOKEN_MULTIPLIER = 1.18;
const MISTRAL_TOKEN_MULTIPLIER = 1.26;

function getAdjustedTokenCountFromModel(
baseTokens: number,
model: ModelConfig,
) {
let multiplier = 1;
const modelName = model.model?.toLowerCase() ?? "";
if (modelName.includes("claude")) {
multiplier = ANTHROPIC_TOKEN_MULTIPLIER;
} else if (modelName.includes("gemini")) {
multiplier = GEMINI_TOKEN_MULTIPLIER;
} else if (modelName.includes("stral")) {
// devstral, mixtral, mistral, etc
multiplier = MISTRAL_TOKEN_MULTIPLIER;
}
return Math.ceil(baseTokens * multiplier);
}

/**
* Count tokens in message content (string or multimodal array)
*/
Expand All @@ -68,7 +42,7 @@ function countContentTokens(
): number {
if (typeof content === "string") {
const count = encode(content).length;
return getAdjustedTokenCountFromModel(count, model);
return getAdjustedTokenCountFromModel(count, model.model ?? "");
}

if (Array.isArray(content)) {
Expand All @@ -81,7 +55,7 @@ function countContentTokens(
tokenCount += 1024; // Rough estimate for image tokens
}
}
return getAdjustedTokenCountFromModel(tokenCount, model);
return getAdjustedTokenCountFromModel(tokenCount, model.model ?? "");
}

return 0;
Expand Down
Loading