From 06406bdaded60121a9b0df113f515f0a6fae2c96 Mon Sep 17 00:00:00 2001 From: maxachis Date: Fri, 18 Apr 2025 10:34:53 -0400 Subject: [PATCH 1/2] refactor(app): refactor to reduce memory strain from huggingface task --- .../URLRelevanceHuggingfaceTaskOperator.py | 2 +- hugging_face/HuggingFaceInterface.py | 48 +++++++++---------- hugging_face/relevancy_worker.py | 15 ++++++ .../test_hugging_face_interface.py | 8 ++-- 4 files changed, 45 insertions(+), 28 deletions(-) create mode 100644 hugging_face/relevancy_worker.py diff --git a/core/classes/task_operators/URLRelevanceHuggingfaceTaskOperator.py b/core/classes/task_operators/URLRelevanceHuggingfaceTaskOperator.py index 4871a9f0..49aa7aa0 100644 --- a/core/classes/task_operators/URLRelevanceHuggingfaceTaskOperator.py +++ b/core/classes/task_operators/URLRelevanceHuggingfaceTaskOperator.py @@ -46,7 +46,7 @@ async def put_results_into_database(self, tdos): async def add_huggingface_relevancy(self, tdos: list[URLRelevanceHuggingfaceTDO]): urls_with_html = [tdo.url_with_html for tdo in tdos] - results = self.huggingface_interface.get_url_relevancy(urls_with_html) + results = await self.huggingface_interface.get_url_relevancy_async(urls_with_html) for tdo, result in zip(tdos, results): tdo.relevant = result diff --git a/hugging_face/HuggingFaceInterface.py b/hugging_face/HuggingFaceInterface.py index 87d88caf..9ad11d0b 100644 --- a/hugging_face/HuggingFaceInterface.py +++ b/hugging_face/HuggingFaceInterface.py @@ -1,34 +1,34 @@ -from transformers import pipeline +import asyncio +import json +import sys +from typing import List from collector_db.DTOs.URLWithHTML import URLWithHTML -import gc class HuggingFaceInterface: @staticmethod - def load_relevancy_model() -> pipeline: - return pipeline("text-classification", model="PDAP/url-relevance") - - def get_url_relevancy( - self, - urls_with_html: list[URLWithHTML], - threshold: float = 0.5 - ) -> list[bool]: - urls = [url_with_html.url for url_with_html in urls_with_html] - relevance_pipe = self.load_relevancy_model() - results: list[dict] = relevance_pipe(urls) - - bool_results = [] - for result in results: - score = result["score"] - if score >= threshold: - bool_results.append(True) - else: - bool_results.append(False) - del relevance_pipe - gc.collect() - return bool_results + async def get_url_relevancy_async(urls_with_html: List[URLWithHTML]) -> List[bool]: + urls = [u.url for u in urls_with_html] + input_data = json.dumps(urls) + proc = await asyncio.create_subprocess_exec( + sys.executable, "hugging_face/relevancy_worker.py", + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await proc.communicate(input=input_data.encode("utf-8")) + raw_output = stdout.decode("utf-8").strip() + + # Try to extract the actual JSON line + for line in raw_output.splitlines(): + try: + return json.loads(line) + except json.JSONDecodeError: + continue + + raise RuntimeError(f"Could not parse JSON from subprocess: {raw_output}") diff --git a/hugging_face/relevancy_worker.py b/hugging_face/relevancy_worker.py new file mode 100644 index 00000000..5d07d10f --- /dev/null +++ b/hugging_face/relevancy_worker.py @@ -0,0 +1,15 @@ +import sys +import json +from transformers import pipeline + +def main(): + urls = json.loads(sys.stdin.read()) + + pipe = pipeline("text-classification", model="PDAP/url-relevance") + results = pipe(urls) + bools = [r["score"] >= 0.5 for r in results] + + print(json.dumps(bools)) + +if __name__ == "__main__": + main() diff --git a/tests/manual/huggingface/test_hugging_face_interface.py b/tests/manual/huggingface/test_hugging_face_interface.py index b1b86350..08ce8ccd 100644 --- a/tests/manual/huggingface/test_hugging_face_interface.py +++ b/tests/manual/huggingface/test_hugging_face_interface.py @@ -1,13 +1,15 @@ +import pytest + from collector_db.DTOs.URLWithHTML import URLWithHTML from hugging_face.HuggingFaceInterface import HuggingFaceInterface - -def test_get_url_relevancy(): +@pytest.mark.asyncio +async def test_get_url_relevancy(): hfi = HuggingFaceInterface() def package_url(url: str) -> URLWithHTML: return URLWithHTML(url=url, url_id=1, html_infos=[]) - result = hfi.get_url_relevancy([ + result = await hfi.get_url_relevancy_async([ package_url("https://coloradosprings.gov/police-department/article/news/i-25-traffic-safety-deployment-after-stop"), package_url("https://example.com"), package_url("https://police.com") From 0c902809910408c7835e4d602668bbade7331b41 Mon Sep 17 00:00:00 2001 From: maxachis Date: Fri, 18 Apr 2025 10:41:38 -0400 Subject: [PATCH 2/2] fix(tests): fix broken test --- .../tasks/test_url_relevancy_huggingface_task.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_automated/integration/tasks/test_url_relevancy_huggingface_task.py b/tests/test_automated/integration/tasks/test_url_relevancy_huggingface_task.py index abe15965..95fb5fc7 100644 --- a/tests/test_automated/integration/tasks/test_url_relevancy_huggingface_task.py +++ b/tests/test_automated/integration/tasks/test_url_relevancy_huggingface_task.py @@ -21,7 +21,7 @@ def num_to_bool(num: int) -> bool: else: return False - def mock_get_url_relevancy( + async def mock_get_url_relevancy( urls_with_html: list[URLWithHTML], threshold: float = 0.8 ) -> list[bool]: @@ -33,7 +33,7 @@ def mock_get_url_relevancy( return results mock_hf_interface = MagicMock(spec=HuggingFaceInterface) - mock_hf_interface.get_url_relevancy = mock_get_url_relevancy + mock_hf_interface.get_url_relevancy_async = mock_get_url_relevancy task_operator = URLRelevanceHuggingfaceTaskOperator( adb_client=AsyncDatabaseClient(), @@ -50,7 +50,7 @@ def mock_get_url_relevancy( await db_data_creator.html_data(url_ids) run_info: TaskOperatorRunInfo = await task_operator.run_task(1) - assert run_info.outcome == TaskOperatorOutcome.SUCCESS + assert run_info.outcome == TaskOperatorOutcome.SUCCESS, run_info.message results = await db_data_creator.adb_client.get_all(AutoRelevantSuggestion)