Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,8 @@ jobs:
run: pipx install ruff
- name: Lint code with Ruff
run: ruff check --output-format=github --target-version=py39
- name: Install test dependencies
run: pip install -r server/requirements.txt
# Discover and run all files matching test_*.py or *_test.py under server/
- name: Run tests
run: pytest server/ -v
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ df = pd.read_sql(query, engine)

#### Django REST
- The email and password are set in `server/api/management/commands/createsu.py`
- Backend tests can be run using `pytest` by running the below command inside the running backend container:

```
docker compose exec backend pytest api/ -v
```

## Local Kubernetes Deployment

Expand Down
161 changes: 112 additions & 49 deletions server/api/services/embedding_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,17 @@

logger = logging.getLogger(__name__)

def get_closest_embeddings(
user, message_data, document_name=None, guid=None, num_results=10
):

def build_query(user, embedding_vector, document_name=None, guid=None, num_results=10):
"""
Find the closest embeddings to a given message for a specific user.
Build an unevaluated QuerySet for the closest embeddings.

Parameters
----------
user : User
The user whose uploaded documents will be searched
message_data : str
The input message to find similar embeddings for
embedding_vector : array-like
Pre-computed embedding vector to compare against
document_name : str, optional
Filter results to a specific document name
guid : str, optional
Expand All @@ -32,59 +31,52 @@ def get_closest_embeddings(

Returns
-------
list[dict]
List of dictionaries containing embedding results with keys:
- name: document name
- text: embedded text content
- page_number: page number in source document
- chunk_number: chunk number within the document
- distance: L2 distance from query embedding
- file_id: GUID of the source file
QuerySet
Unevaluated Django QuerySet ordered by L2 distance, sliced to num_results
"""

encoding_start = time.time()
transformerModel = TransformerModel.get_instance().model
embedding_message = transformerModel.encode(message_data)
encoding_time = time.time() - encoding_start

db_query_start = time.time()

# Django QuerySets are lazily evaluated
if user.is_authenticated:
# User sees their own files + files uploaded by superusers
closest_embeddings_query = (
Embeddings.objects.filter(
Q(upload_file__uploaded_by=user) | Q(upload_file__uploaded_by__is_superuser=True)
)
.annotate(
distance=L2Distance("embedding_sentence_transformers", embedding_message)
)
.order_by("distance")
queryset = Embeddings.objects.filter(
Q(upload_file__uploaded_by=user) | Q(upload_file__uploaded_by__is_superuser=True)
)
else:
# Unauthenticated users only see superuser-uploaded files
closest_embeddings_query = (
Embeddings.objects.filter(upload_file__uploaded_by__is_superuser=True)
.annotate(
distance=L2Distance("embedding_sentence_transformers", embedding_message)
)
.order_by("distance")
)
queryset = Embeddings.objects.filter(upload_file__uploaded_by__is_superuser=True)

queryset = (
queryset
.annotate(distance=L2Distance("embedding_sentence_transformers", embedding_vector))
.order_by("distance")
)

# Filtering to a document GUID takes precedence over a document name
if guid:
closest_embeddings_query = closest_embeddings_query.filter(
upload_file__guid=guid
)
queryset = queryset.filter(upload_file__guid=guid)
elif document_name:
closest_embeddings_query = closest_embeddings_query.filter(name=document_name)
queryset = queryset.filter(name=document_name)

# Slicing is equivalent to SQL's LIMIT clause
closest_embeddings_query = closest_embeddings_query[:num_results]
return queryset[:num_results]


def evaluate_query(queryset):
"""
Evaluate a QuerySet and return a list of result dicts.

Parameters
----------
queryset : iterable
Iterable of Embeddings objects (or any objects with the expected attributes)

Returns
-------
list[dict]
List of dicts with keys: name, text, page_number, chunk_number, distance, file_id
"""
# Iterating evaluates the QuerySet and hits the database
# TODO: Research improving the query evaluation performance
results = [
return [
{
"name": obj.name,
"text": obj.text,
Expand All @@ -93,13 +85,36 @@ def get_closest_embeddings(
"distance": obj.distance,
"file_id": obj.upload_file.guid if obj.upload_file else None,
}
for obj in closest_embeddings_query
for obj in queryset
]

db_query_time = time.time() - db_query_start

def log_usage(
results, message_data, user, guid, document_name, num_results, encoding_time, db_query_time
):
"""
Create a SemanticSearchUsage record. Swallows exceptions so search isn't interrupted.

Parameters
----------
results : list[dict]
The search results, each containing a "distance" key
message_data : str
The original search query text
user : User
The user who performed the search
guid : str or None
Document GUID filter used in the search
document_name : str or None
Document name filter used in the search
num_results : int
Number of results requested
encoding_time : float
Time in seconds to encode the query
db_query_time : float
Time in seconds for the database query
"""
try:
# Handle user having no uploaded docs or doc filtering returning no matches
if results:
distances = [r["distance"] for r in results]
SemanticSearchUsage.objects.create(
Expand All @@ -113,11 +128,10 @@ def get_closest_embeddings(
num_results_returned=len(results),
max_distance=max(distances),
median_distance=median(distances),
min_distance=min(distances)
min_distance=min(distances),
)
else:
logger.warning("Semantic search returned no results")

SemanticSearchUsage.objects.create(
query_text=message_data,
user=user if (user and user.is_authenticated) else None,
Expand All @@ -129,9 +143,58 @@ def get_closest_embeddings(
num_results_returned=0,
max_distance=None,
median_distance=None,
min_distance=None
min_distance=None,
)
except Exception as e:
logger.error(f"Failed to create semantic search usage database record: {e}")


def get_closest_embeddings(
user, message_data, document_name=None, guid=None, num_results=10
):
"""
Find the closest embeddings to a given message for a specific user.

Parameters
----------
user : User
The user whose uploaded documents will be searched
message_data : str
The input message to find similar embeddings for
document_name : str, optional
Filter results to a specific document name
guid : str, optional
Filter results to a specific document GUID (takes precedence over document_name)
num_results : int, default 10
Maximum number of results to return

Returns
-------
list[dict]
List of dictionaries containing embedding results with keys:
- name: document name
- text: embedded text content
- page_number: page number in source document
- chunk_number: chunk number within the document
- distance: L2 distance from query embedding
- file_id: GUID of the source file

Notes
-----
Creates a SemanticSearchUsage record. Swallows exceptions so search isn't interrupted.
"""
encoding_start = time.time()
model = TransformerModel.get_instance().model
embedding_vector = model.encode(message_data)
encoding_time = time.time() - encoding_start

db_query_start = time.time()
queryset = build_query(user, embedding_vector, document_name, guid, num_results)
results = evaluate_query(queryset)
db_query_time = time.time() - db_query_start

log_usage(
results, message_data, user, guid, document_name, num_results, encoding_time, db_query_time
)

return results
85 changes: 85 additions & 0 deletions server/api/services/test_embedding_services.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from unittest.mock import MagicMock, patch

from api.services.embedding_services import evaluate_query, log_usage


def test_evaluate_query_maps_fields():
obj = MagicMock()
obj.name = "doc.pdf"
obj.text = "some text"
obj.page_num = 3
obj.chunk_number = 1
obj.distance = 0.42
obj.upload_file.guid = "abc-123"

results = evaluate_query([obj])

assert results == [
{
"name": "doc.pdf",
"text": "some text",
"page_number": 3,
"chunk_number": 1,
"distance": 0.42,
"file_id": "abc-123",
}
]


def test_evaluate_query_none_upload_file():
obj = MagicMock()
obj.name = "doc.pdf"
obj.text = "some text"
obj.page_num = 1
obj.chunk_number = 0
obj.distance = 1.0
obj.upload_file = None

results = evaluate_query([obj])

assert results[0]["file_id"] is None


@patch("api.services.embedding_services.SemanticSearchUsage.objects.create")
def test_log_usage_computes_distance_stats(mock_create):
results = [{"distance": 1.0}, {"distance": 3.0}, {"distance": 2.0}]
user = MagicMock(is_authenticated=True)

log_usage(
results,
message_data="test query",
user=user,
guid=None,
document_name=None,
num_results=10,
encoding_time=0.1,
db_query_time=0.2,
)

mock_create.assert_called_once()
kwargs = mock_create.call_args.kwargs
assert kwargs["min_distance"] == 1.0
assert kwargs["max_distance"] == 3.0
assert kwargs["median_distance"] == 2.0
assert kwargs["num_results_returned"] == 3


@patch(
"api.services.embedding_services.SemanticSearchUsage.objects.create",
side_effect=Exception("DB error"),
)
def test_log_usage_swallows_exceptions(mock_create):
results = [{"distance": 1.0}]
user = MagicMock(is_authenticated=True)

# pytest fails the test if it catches unhandled Exception
log_usage(
results,
message_data="test query",
user=user,
guid=None,
document_name=None,
num_results=10,
encoding_time=0.1,
db_query_time=0.2,
)
3 changes: 3 additions & 0 deletions server/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[pytest]
DJANGO_SETTINGS_MODULE = balancer_backend.settings
pythonpath = .
4 changes: 3 additions & 1 deletion server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ sentence_transformers
PyMuPDF==1.24.0
Pillow
pytesseract
anthropic
anthropic
pytest
pytest-django
Loading