diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index dcb7a4bb..41bc74d9 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -27,3 +27,8 @@ jobs: run: pipx install ruff - name: Lint code with Ruff run: ruff check --output-format=github --target-version=py39 + - name: Install test dependencies + run: pip install -r server/requirements.txt + # Discover and run all files matching test_*.py or *_test.py under server/ + - name: Run tests + run: pytest server/ -v diff --git a/README.md b/README.md index f1cea06b..15018d37 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,11 @@ df = pd.read_sql(query, engine) #### Django REST - The email and password are set in `server/api/management/commands/createsu.py` +- Backend tests can be run using `pytest` by running the below command inside the running backend container: + +``` +docker compose exec backend pytest api/ -v +``` ## Local Kubernetes Deployment diff --git a/server/api/services/embedding_services.py b/server/api/services/embedding_services.py index e35f7965..3fa9bb68 100644 --- a/server/api/services/embedding_services.py +++ b/server/api/services/embedding_services.py @@ -11,18 +11,17 @@ logger = logging.getLogger(__name__) -def get_closest_embeddings( - user, message_data, document_name=None, guid=None, num_results=10 -): + +def build_query(user, embedding_vector, document_name=None, guid=None, num_results=10): """ - Find the closest embeddings to a given message for a specific user. + Build an unevaluated QuerySet for the closest embeddings. Parameters ---------- user : User The user whose uploaded documents will be searched - message_data : str - The input message to find similar embeddings for + embedding_vector : array-like + Pre-computed embedding vector to compare against document_name : str, optional Filter results to a specific document name guid : str, optional @@ -32,59 +31,52 @@ def get_closest_embeddings( Returns ------- - list[dict] - List of dictionaries containing embedding results with keys: - - name: document name - - text: embedded text content - - page_number: page number in source document - - chunk_number: chunk number within the document - - distance: L2 distance from query embedding - - file_id: GUID of the source file + QuerySet + Unevaluated Django QuerySet ordered by L2 distance, sliced to num_results """ - - encoding_start = time.time() - transformerModel = TransformerModel.get_instance().model - embedding_message = transformerModel.encode(message_data) - encoding_time = time.time() - encoding_start - - db_query_start = time.time() - # Django QuerySets are lazily evaluated if user.is_authenticated: # User sees their own files + files uploaded by superusers - closest_embeddings_query = ( - Embeddings.objects.filter( - Q(upload_file__uploaded_by=user) | Q(upload_file__uploaded_by__is_superuser=True) - ) - .annotate( - distance=L2Distance("embedding_sentence_transformers", embedding_message) - ) - .order_by("distance") + queryset = Embeddings.objects.filter( + Q(upload_file__uploaded_by=user) | Q(upload_file__uploaded_by__is_superuser=True) ) else: # Unauthenticated users only see superuser-uploaded files - closest_embeddings_query = ( - Embeddings.objects.filter(upload_file__uploaded_by__is_superuser=True) - .annotate( - distance=L2Distance("embedding_sentence_transformers", embedding_message) - ) - .order_by("distance") - ) + queryset = Embeddings.objects.filter(upload_file__uploaded_by__is_superuser=True) + + queryset = ( + queryset + .annotate(distance=L2Distance("embedding_sentence_transformers", embedding_vector)) + .order_by("distance") + ) # Filtering to a document GUID takes precedence over a document name if guid: - closest_embeddings_query = closest_embeddings_query.filter( - upload_file__guid=guid - ) + queryset = queryset.filter(upload_file__guid=guid) elif document_name: - closest_embeddings_query = closest_embeddings_query.filter(name=document_name) + queryset = queryset.filter(name=document_name) # Slicing is equivalent to SQL's LIMIT clause - closest_embeddings_query = closest_embeddings_query[:num_results] + return queryset[:num_results] + + +def evaluate_query(queryset): + """ + Evaluate a QuerySet and return a list of result dicts. + + Parameters + ---------- + queryset : iterable + Iterable of Embeddings objects (or any objects with the expected attributes) + Returns + ------- + list[dict] + List of dicts with keys: name, text, page_number, chunk_number, distance, file_id + """ # Iterating evaluates the QuerySet and hits the database # TODO: Research improving the query evaluation performance - results = [ + return [ { "name": obj.name, "text": obj.text, @@ -93,13 +85,36 @@ def get_closest_embeddings( "distance": obj.distance, "file_id": obj.upload_file.guid if obj.upload_file else None, } - for obj in closest_embeddings_query + for obj in queryset ] - db_query_time = time.time() - db_query_start +def log_usage( + results, message_data, user, guid, document_name, num_results, encoding_time, db_query_time +): + """ + Create a SemanticSearchUsage record. Swallows exceptions so search isn't interrupted. + + Parameters + ---------- + results : list[dict] + The search results, each containing a "distance" key + message_data : str + The original search query text + user : User + The user who performed the search + guid : str or None + Document GUID filter used in the search + document_name : str or None + Document name filter used in the search + num_results : int + Number of results requested + encoding_time : float + Time in seconds to encode the query + db_query_time : float + Time in seconds for the database query + """ try: - # Handle user having no uploaded docs or doc filtering returning no matches if results: distances = [r["distance"] for r in results] SemanticSearchUsage.objects.create( @@ -113,11 +128,10 @@ def get_closest_embeddings( num_results_returned=len(results), max_distance=max(distances), median_distance=median(distances), - min_distance=min(distances) + min_distance=min(distances), ) else: logger.warning("Semantic search returned no results") - SemanticSearchUsage.objects.create( query_text=message_data, user=user if (user and user.is_authenticated) else None, @@ -129,9 +143,58 @@ def get_closest_embeddings( num_results_returned=0, max_distance=None, median_distance=None, - min_distance=None + min_distance=None, ) except Exception as e: logger.error(f"Failed to create semantic search usage database record: {e}") + +def get_closest_embeddings( + user, message_data, document_name=None, guid=None, num_results=10 +): + """ + Find the closest embeddings to a given message for a specific user. + + Parameters + ---------- + user : User + The user whose uploaded documents will be searched + message_data : str + The input message to find similar embeddings for + document_name : str, optional + Filter results to a specific document name + guid : str, optional + Filter results to a specific document GUID (takes precedence over document_name) + num_results : int, default 10 + Maximum number of results to return + + Returns + ------- + list[dict] + List of dictionaries containing embedding results with keys: + - name: document name + - text: embedded text content + - page_number: page number in source document + - chunk_number: chunk number within the document + - distance: L2 distance from query embedding + - file_id: GUID of the source file + + Notes + ----- + Creates a SemanticSearchUsage record. Swallows exceptions so search isn't interrupted. + """ + encoding_start = time.time() + model = TransformerModel.get_instance().model + embedding_vector = model.encode(message_data) + encoding_time = time.time() - encoding_start + + db_query_start = time.time() + queryset = build_query(user, embedding_vector, document_name, guid, num_results) + results = evaluate_query(queryset) + db_query_time = time.time() - db_query_start + + log_usage( + results, message_data, user, guid, document_name, num_results, encoding_time, db_query_time + ) + return results diff --git a/server/api/services/test_embedding_services.py b/server/api/services/test_embedding_services.py new file mode 100644 index 00000000..677c1e7b --- /dev/null +++ b/server/api/services/test_embedding_services.py @@ -0,0 +1,85 @@ +from unittest.mock import MagicMock, patch + +from api.services.embedding_services import evaluate_query, log_usage + + +def test_evaluate_query_maps_fields(): + obj = MagicMock() + obj.name = "doc.pdf" + obj.text = "some text" + obj.page_num = 3 + obj.chunk_number = 1 + obj.distance = 0.42 + obj.upload_file.guid = "abc-123" + + results = evaluate_query([obj]) + + assert results == [ + { + "name": "doc.pdf", + "text": "some text", + "page_number": 3, + "chunk_number": 1, + "distance": 0.42, + "file_id": "abc-123", + } + ] + + +def test_evaluate_query_none_upload_file(): + obj = MagicMock() + obj.name = "doc.pdf" + obj.text = "some text" + obj.page_num = 1 + obj.chunk_number = 0 + obj.distance = 1.0 + obj.upload_file = None + + results = evaluate_query([obj]) + + assert results[0]["file_id"] is None + + +@patch("api.services.embedding_services.SemanticSearchUsage.objects.create") +def test_log_usage_computes_distance_stats(mock_create): + results = [{"distance": 1.0}, {"distance": 3.0}, {"distance": 2.0}] + user = MagicMock(is_authenticated=True) + + log_usage( + results, + message_data="test query", + user=user, + guid=None, + document_name=None, + num_results=10, + encoding_time=0.1, + db_query_time=0.2, + ) + + mock_create.assert_called_once() + kwargs = mock_create.call_args.kwargs + assert kwargs["min_distance"] == 1.0 + assert kwargs["max_distance"] == 3.0 + assert kwargs["median_distance"] == 2.0 + assert kwargs["num_results_returned"] == 3 + + +@patch( + "api.services.embedding_services.SemanticSearchUsage.objects.create", + side_effect=Exception("DB error"), +) +def test_log_usage_swallows_exceptions(mock_create): + results = [{"distance": 1.0}] + user = MagicMock(is_authenticated=True) + + # pytest fails the test if it catches unhandled Exception + log_usage( + results, + message_data="test query", + user=user, + guid=None, + document_name=None, + num_results=10, + encoding_time=0.1, + db_query_time=0.2, + ) diff --git a/server/pytest.ini b/server/pytest.ini new file mode 100644 index 00000000..235b9752 --- /dev/null +++ b/server/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +DJANGO_SETTINGS_MODULE = balancer_backend.settings +pythonpath = . diff --git a/server/requirements.txt b/server/requirements.txt index bbaf7bc9..001708e9 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -18,4 +18,6 @@ sentence_transformers PyMuPDF==1.24.0 Pillow pytesseract -anthropic \ No newline at end of file +anthropic +pytest +pytest-django \ No newline at end of file