|
1 | 1 | from starlette.applications import Starlette |
2 | 2 | from starlette.responses import JSONResponse |
3 | 3 | from starlette.routing import Route |
4 | | -from optimum.onnxruntime import ORTModelForFeatureExtraction |
| 4 | +import onnxruntime as ort |
5 | 5 | from transformers import AutoTokenizer |
6 | 6 | import numpy as np |
| 7 | +import os |
7 | 8 |
|
8 | | -# Load ONNX model and tokenizer once at startup |
9 | | -model = ORTModelForFeatureExtraction.from_pretrained("/model") |
10 | | -tokenizer = AutoTokenizer.from_pretrained("/model") |
| 9 | +# Load ONNX model directly |
| 10 | +MODEL_DIR = "/model" |
| 11 | +tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) |
11 | 12 |
|
| 13 | +# Create inference session |
| 14 | +model_path = os.path.join(MODEL_DIR, "model.onnx") |
| 15 | +session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) |
12 | 16 |
|
13 | 17 | def mean_pooling(token_embeddings, attention_mask): |
14 | | - """Mean pooling to get sentence embeddings""" |
15 | 18 | input_mask_expanded = np.expand_dims(attention_mask, -1) |
16 | | - input_mask_expanded = np.broadcast_to( |
17 | | - input_mask_expanded, token_embeddings.shape |
18 | | - ).astype(float) |
19 | | - |
| 19 | + input_mask_expanded = np.broadcast_to(input_mask_expanded, token_embeddings.shape).astype(float) |
20 | 20 | sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1) |
21 | 21 | sum_mask = np.clip(np.sum(input_mask_expanded, axis=1), a_min=1e-9, a_max=None) |
22 | | - |
23 | 22 | return sum_embeddings / sum_mask |
24 | 23 |
|
25 | | - |
26 | 24 | async def embed(request): |
27 | 25 | try: |
28 | 26 | data = await request.json() |
29 | | - except: |
| 27 | + except Exception: |
30 | 28 | return JSONResponse({"error": "Invalid JSON"}, status_code=400) |
31 | 29 |
|
32 | 30 | if not data or "texts" not in data: |
33 | 31 | return JSONResponse({"error": "Missing texts field"}, status_code=400) |
34 | 32 |
|
35 | 33 | texts = data["texts"] |
36 | | - |
37 | 34 | if isinstance(texts, str): |
38 | 35 | texts = [texts] |
39 | 36 |
|
40 | 37 | encoded = tokenizer(texts, padding=True, truncation=True, return_tensors="np") |
41 | | - outputs = model(**encoded) |
42 | | - embeddings = mean_pooling(outputs.last_hidden_state, encoded["attention_mask"]) |
| 38 | + ort_inputs = {k: v for k, v in encoded.items()} |
43 | 39 |
|
44 | | - # Normalize embeddings |
| 40 | + # Run model |
| 41 | + ort_outputs = session.run(None, ort_inputs) |
| 42 | + token_embeddings = ort_outputs[0] |
| 43 | + |
| 44 | + embeddings = mean_pooling(token_embeddings, encoded["attention_mask"]) |
45 | 45 | embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) |
46 | 46 |
|
47 | 47 | return JSONResponse({"embeddings": embeddings.tolist()}) |
48 | 48 |
|
49 | | - |
50 | 49 | async def health(request): |
51 | 50 | return JSONResponse({"status": "healthy"}) |
52 | 51 |
|
53 | | - |
54 | 52 | app = Starlette( |
55 | 53 | routes=[ |
56 | 54 | Route("/embed", embed, methods=["POST"]), |
|
0 commit comments