Skip to content

Commit 02d9d24

Browse files
committed
Merge branch 'minor-embeddings-changes' into 'develop'
Published embeddings image See merge request baserow/baserow!3802
2 parents b0fb7c2 + a7c28da commit 02d9d24

File tree

4 files changed

+68
-44
lines changed

4 files changed

+68
-44
lines changed

embeddings/Dockerfile

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,30 @@
1-
FROM python:3.13-slim AS builder
1+
FROM python:3.11-slim-bookworm AS builder
22

33
RUN apt-get update && apt-get install -y --no-install-recommends \
4-
git \
5-
&& rm -rf /var/lib/apt/lists/*
4+
git && rm -rf /var/lib/apt/lists/*
5+
6+
WORKDIR /build
7+
COPY _download_model.py .
68

79
RUN pip install --no-cache-dir \
8-
optimum[onnxruntime]==1.27.0 \
9-
transformers==4.53.0
10+
huggingface_hub==0.31.0 \
11+
transformers==4.53.0 && \
12+
python _download_model.py && \
13+
rm _download_model.py
1014

11-
COPY _download_model.py /tmp/download_model.py
12-
RUN python /tmp/download_model.py && rm /tmp/download_model.py
15+
FROM python:3.11-slim-bookworm
1316

14-
FROM python:3.13-slim
17+
COPY --from=builder /model /model
1518

1619
RUN pip install --no-cache-dir \
17-
optimum[onnxruntime]==1.27.0 \
20+
onnxruntime==1.20.1 \
1821
transformers==4.53.0 \
1922
starlette==0.48.0 \
20-
uvicorn==0.37.0 \
21-
&& rm -rf /root/.cache
22-
23-
COPY --from=builder /model /model
24-
25-
COPY app.py /app/app.py
23+
uvicorn==0.37.0 && \
24+
rm -rf /root/.cache
2625

2726
WORKDIR /app
27+
COPY app.py .
2828

2929
EXPOSE 80
30-
31-
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "80"]
30+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "80"]

embeddings/README.md

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,29 @@ The service uses a multi-stage Docker build:
3636
- **Batch support**: Process multiple texts in a single request
3737
- **Health checks**: Built-in health endpoint for monitoring
3838

39+
## Docker
40+
41+
### Run locally
42+
43+
```
44+
docker run -p 8080:80 baserow/embeddings:1.0.0
45+
```
46+
47+
### Build for publish
48+
49+
```
50+
docker buildx build \
51+
--platform linux/amd64,linux/arm64 \
52+
-t baserow/embeddings:1.0.0 \
53+
-t baserow/embeddings:latest \
54+
--push .
55+
```
56+
3957
## API
4058

4159
### Endpoints
4260

43-
#### `POST /embed`
61+
#### `POST http://localhost:8080/embed`
4462

4563
Generate embeddings for one or more texts.
4664

embeddings/_download_model.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,27 @@
11
#!/usr/bin/env python3
22
"""Download and convert the embedding model to ONNX format."""
33

4-
from optimum.onnxruntime import ORTModelForFeatureExtraction
4+
from huggingface_hub import snapshot_download
55
from transformers import AutoTokenizer
6+
import os
7+
import shutil
68

7-
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
9+
MODEL_ID = "optimum/all-MiniLM-L6-v2"
810
OUTPUT_DIR = "/model"
911

10-
print(f"Downloading {MODEL_NAME} (ONNX format)...")
11-
model = ORTModelForFeatureExtraction.from_pretrained(MODEL_NAME)
12-
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12+
print(f"Downloading {MODEL_ID} from Hugging Face...")
13+
local_dir = snapshot_download(repo_id=MODEL_ID, allow_patterns=["*.onnx", "*.json", "*.txt"])
1314

14-
print(f"Saving to {OUTPUT_DIR}...")
15-
model.save_pretrained(OUTPUT_DIR)
15+
print(f"Saving tokenizer...")
16+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
17+
os.makedirs(OUTPUT_DIR, exist_ok=True)
1618
tokenizer.save_pretrained(OUTPUT_DIR)
1719

18-
print("Done!")
20+
print("Copying ONNX model files...")
21+
for item in os.listdir(local_dir):
22+
src = os.path.join(local_dir, item)
23+
dst = os.path.join(OUTPUT_DIR, item)
24+
if os.path.isfile(src):
25+
shutil.copy(src, dst)
26+
27+
print("Done! ONNX model is ready at", OUTPUT_DIR)

embeddings/app.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,54 @@
11
from starlette.applications import Starlette
22
from starlette.responses import JSONResponse
33
from starlette.routing import Route
4-
from optimum.onnxruntime import ORTModelForFeatureExtraction
4+
import onnxruntime as ort
55
from transformers import AutoTokenizer
66
import numpy as np
7+
import os
78

8-
# Load ONNX model and tokenizer once at startup
9-
model = ORTModelForFeatureExtraction.from_pretrained("/model")
10-
tokenizer = AutoTokenizer.from_pretrained("/model")
9+
# Load ONNX model directly
10+
MODEL_DIR = "/model"
11+
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
1112

13+
# Create inference session
14+
model_path = os.path.join(MODEL_DIR, "model.onnx")
15+
session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
1216

1317
def mean_pooling(token_embeddings, attention_mask):
14-
"""Mean pooling to get sentence embeddings"""
1518
input_mask_expanded = np.expand_dims(attention_mask, -1)
16-
input_mask_expanded = np.broadcast_to(
17-
input_mask_expanded, token_embeddings.shape
18-
).astype(float)
19-
19+
input_mask_expanded = np.broadcast_to(input_mask_expanded, token_embeddings.shape).astype(float)
2020
sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1)
2121
sum_mask = np.clip(np.sum(input_mask_expanded, axis=1), a_min=1e-9, a_max=None)
22-
2322
return sum_embeddings / sum_mask
2423

25-
2624
async def embed(request):
2725
try:
2826
data = await request.json()
29-
except:
27+
except Exception:
3028
return JSONResponse({"error": "Invalid JSON"}, status_code=400)
3129

3230
if not data or "texts" not in data:
3331
return JSONResponse({"error": "Missing texts field"}, status_code=400)
3432

3533
texts = data["texts"]
36-
3734
if isinstance(texts, str):
3835
texts = [texts]
3936

4037
encoded = tokenizer(texts, padding=True, truncation=True, return_tensors="np")
41-
outputs = model(**encoded)
42-
embeddings = mean_pooling(outputs.last_hidden_state, encoded["attention_mask"])
38+
ort_inputs = {k: v for k, v in encoded.items()}
4339

44-
# Normalize embeddings
40+
# Run model
41+
ort_outputs = session.run(None, ort_inputs)
42+
token_embeddings = ort_outputs[0]
43+
44+
embeddings = mean_pooling(token_embeddings, encoded["attention_mask"])
4545
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
4646

4747
return JSONResponse({"embeddings": embeddings.tolist()})
4848

49-
5049
async def health(request):
5150
return JSONResponse({"status": "healthy"})
5251

53-
5452
app = Starlette(
5553
routes=[
5654
Route("/embed", embed, methods=["POST"]),

0 commit comments

Comments
 (0)