|
6 | 6 | * Usage: bun scripts/generate-query-embeddings.ts |
7 | 7 | */ |
8 | 8 |
|
9 | | -import { writeFileSync } from 'node:fs' |
10 | | -import { join } from 'node:path' |
11 | | -import { gzipSync } from 'node:zlib' |
| 9 | +import { writeFileSync } from "node:fs"; |
| 10 | +import { join } from "node:path"; |
| 11 | +import { gzipSync } from "node:zlib"; |
12 | 12 |
|
13 | 13 | // Load queries |
14 | | -import activityTypes from '../src/extraction/embeddings/queries/activity-types.json' |
15 | | -import agreementQueries from '../src/extraction/embeddings/queries/agreement.json' |
16 | | -import suggestionQueries from '../src/extraction/embeddings/queries/suggestions.json' |
| 14 | +import activityTypes from "../src/extraction/embeddings/queries/activity-types.json"; |
| 15 | +import agreementQueries from "../src/extraction/embeddings/queries/agreement.json"; |
| 16 | +import suggestionQueries from "../src/extraction/embeddings/queries/suggestions.json"; |
17 | 17 |
|
18 | | -const OPENAI_API_KEY = process.env.OPENAI_API_KEY |
| 18 | +const OPENAI_API_KEY = process.env.OPENAI_API_KEY; |
19 | 19 | if (!OPENAI_API_KEY) { |
20 | | - console.error('Error: OPENAI_API_KEY environment variable required') |
21 | | - console.error('Set it in .env or export it') |
22 | | - process.exit(1) |
| 20 | + console.error("Error: OPENAI_API_KEY environment variable required"); |
| 21 | + console.error("Set it in .env or export it"); |
| 22 | + process.exit(1); |
23 | 23 | } |
24 | 24 |
|
25 | | -const MODEL = 'text-embedding-3-large' |
| 25 | +const MODEL = "text-embedding-3-large"; |
26 | 26 |
|
27 | 27 | interface OpenAIEmbeddingResponse { |
28 | | - data: Array<{ embedding: number[]; index: number }> |
29 | | - model: string |
30 | | - usage: { prompt_tokens: number; total_tokens: number } |
| 28 | + data: Array<{ embedding: number[]; index: number }>; |
| 29 | + model: string; |
| 30 | + usage: { prompt_tokens: number; total_tokens: number }; |
31 | 31 | } |
32 | 32 |
|
33 | 33 | async function embedBatch(texts: string[]): Promise<number[][]> { |
34 | | - const response = await fetch('https://api.openai.com/v1/embeddings', { |
35 | | - method: 'POST', |
36 | | - headers: { |
37 | | - 'Content-Type': 'application/json', |
38 | | - Authorization: `Bearer ${OPENAI_API_KEY}` |
39 | | - }, |
40 | | - body: JSON.stringify({ model: MODEL, input: texts }) |
41 | | - }) |
42 | | - |
43 | | - if (!response.ok) { |
44 | | - const error = await response.text() |
45 | | - throw new Error(`OpenAI API error: ${response.status} ${error}`) |
46 | | - } |
47 | | - |
48 | | - const data = (await response.json()) as OpenAIEmbeddingResponse |
49 | | - |
50 | | - // Sort by index and return embeddings |
51 | | - const embeddings: number[][] = new Array(texts.length) |
52 | | - for (const item of data.data) { |
53 | | - embeddings[item.index] = item.embedding |
54 | | - } |
55 | | - |
56 | | - return embeddings |
| 34 | + const response = await fetch("https://api.openai.com/v1/embeddings", { |
| 35 | + method: "POST", |
| 36 | + headers: { |
| 37 | + "Content-Type": "application/json", |
| 38 | + Authorization: `Bearer ${OPENAI_API_KEY}`, |
| 39 | + }, |
| 40 | + body: JSON.stringify({ model: MODEL, input: texts }), |
| 41 | + }); |
| 42 | + |
| 43 | + if (!response.ok) { |
| 44 | + const error = await response.text(); |
| 45 | + throw new Error(`OpenAI API error: ${response.status} ${error}`); |
| 46 | + } |
| 47 | + |
| 48 | + const data = (await response.json()) as OpenAIEmbeddingResponse; |
| 49 | + |
| 50 | + // Sort by index and return embeddings |
| 51 | + const embeddings: number[][] = new Array(texts.length); |
| 52 | + for (const item of data.data) { |
| 53 | + embeddings[item.index] = item.embedding; |
| 54 | + } |
| 55 | + |
| 56 | + return embeddings; |
57 | 57 | } |
58 | 58 |
|
59 | 59 | async function main() { |
60 | | - console.log('Generating query embeddings...\n') |
61 | | - |
62 | | - // Flatten all queries |
63 | | - const allActivityTypes = Object.values(activityTypes).flat() |
64 | | - const allQueries = [...suggestionQueries, ...agreementQueries, ...allActivityTypes] |
65 | | - |
66 | | - console.log(`Suggestion queries: ${suggestionQueries.length}`) |
67 | | - console.log(`Agreement queries: ${agreementQueries.length}`) |
68 | | - console.log(`Activity types: ${allActivityTypes.length}`) |
69 | | - console.log(`Total queries: ${allQueries.length}\n`) |
70 | | - |
71 | | - // Embed in batches of 100 |
72 | | - const BATCH_SIZE = 100 |
73 | | - const allEmbeddings: number[][] = [] |
74 | | - |
75 | | - for (let i = 0; i < allQueries.length; i += BATCH_SIZE) { |
76 | | - const batch = allQueries.slice(i, i + BATCH_SIZE) |
77 | | - console.log(`Embedding batch ${Math.floor(i / BATCH_SIZE) + 1}/${Math.ceil(allQueries.length / BATCH_SIZE)}...`) |
78 | | - |
79 | | - const embeddings = await embedBatch(batch) |
80 | | - allEmbeddings.push(...embeddings) |
81 | | - } |
82 | | - |
83 | | - // Build output structure |
84 | | - const output = { |
85 | | - model: MODEL, |
86 | | - generatedAt: new Date().toISOString(), |
87 | | - queryCount: allQueries.length, |
88 | | - dimensions: allEmbeddings[0]?.length ?? 0, |
89 | | - queries: allQueries.map((query, i) => ({ |
90 | | - text: query, |
91 | | - embedding: allEmbeddings[i] |
92 | | - })) |
93 | | - } |
94 | | - |
95 | | - // Write compressed file |
96 | | - const outputPath = join(import.meta.dir, '../src/extraction/embeddings/queries/query-embeddings.json.gz') |
97 | | - const jsonData = JSON.stringify(output) |
98 | | - const compressed = gzipSync(jsonData) |
99 | | - writeFileSync(outputPath, compressed) |
100 | | - |
101 | | - const sizeMB = (compressed.length / 1024 / 1024).toFixed(1) |
102 | | - console.log(`\nWritten ${allQueries.length} embeddings to:`) |
103 | | - console.log(outputPath) |
104 | | - console.log(`\nDimensions: ${output.dimensions}`) |
105 | | - console.log(`Compressed size: ${sizeMB}MB`) |
| 60 | + console.log("Generating query embeddings...\n"); |
| 61 | + |
| 62 | + // Flatten all queries |
| 63 | + const allActivityTypes = Object.values(activityTypes).flat(); |
| 64 | + const allQueries = [ |
| 65 | + ...suggestionQueries, |
| 66 | + ...agreementQueries, |
| 67 | + ...allActivityTypes, |
| 68 | + ]; |
| 69 | + |
| 70 | + console.log(`Suggestion queries: ${suggestionQueries.length}`); |
| 71 | + console.log(`Agreement queries: ${agreementQueries.length}`); |
| 72 | + console.log(`Activity types: ${allActivityTypes.length}`); |
| 73 | + console.log(`Total queries: ${allQueries.length}\n`); |
| 74 | + |
| 75 | + // Embed in batches of 100 |
| 76 | + const BATCH_SIZE = 100; |
| 77 | + const allEmbeddings: number[][] = []; |
| 78 | + |
| 79 | + for (let i = 0; i < allQueries.length; i += BATCH_SIZE) { |
| 80 | + const batch = allQueries.slice(i, i + BATCH_SIZE); |
| 81 | + console.log( |
| 82 | + `Embedding batch ${Math.floor(i / BATCH_SIZE) + 1}/${Math.ceil(allQueries.length / BATCH_SIZE)}...`, |
| 83 | + ); |
| 84 | + |
| 85 | + const embeddings = await embedBatch(batch); |
| 86 | + allEmbeddings.push(...embeddings); |
| 87 | + } |
| 88 | + |
| 89 | + // Build output structure |
| 90 | + const output = { |
| 91 | + model: MODEL, |
| 92 | + generatedAt: new Date().toISOString(), |
| 93 | + queryCount: allQueries.length, |
| 94 | + dimensions: allEmbeddings[0]?.length ?? 0, |
| 95 | + queries: allQueries.map((query, i) => ({ |
| 96 | + text: query, |
| 97 | + embedding: allEmbeddings[i], |
| 98 | + })), |
| 99 | + }; |
| 100 | + |
| 101 | + // Write compressed file |
| 102 | + const outputPath = join( |
| 103 | + import.meta.dir, |
| 104 | + "../src/extraction/embeddings/queries/query-embeddings.json.gz", |
| 105 | + ); |
| 106 | + const jsonData = JSON.stringify(output); |
| 107 | + const compressed = gzipSync(jsonData); |
| 108 | + writeFileSync(outputPath, compressed); |
| 109 | + |
| 110 | + const sizeMB = (compressed.length / 1024 / 1024).toFixed(1); |
| 111 | + console.log(`\nWritten ${allQueries.length} embeddings to:`); |
| 112 | + console.log(outputPath); |
| 113 | + console.log(`\nDimensions: ${output.dimensions}`); |
| 114 | + console.log(`Compressed size: ${sizeMB}MB`); |
106 | 115 | } |
107 | 116 |
|
108 | | -main().catch(console.error) |
| 117 | +main().catch(console.error); |
0 commit comments