Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ async def put_results_into_database(self, tdos):

async def add_huggingface_relevancy(self, tdos: list[URLRelevanceHuggingfaceTDO]):
urls_with_html = [tdo.url_with_html for tdo in tdos]
results = self.huggingface_interface.get_url_relevancy(urls_with_html)
results = await self.huggingface_interface.get_url_relevancy_async(urls_with_html)
for tdo, result in zip(tdos, results):
tdo.relevant = result

Expand Down
48 changes: 24 additions & 24 deletions hugging_face/HuggingFaceInterface.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,34 @@
from transformers import pipeline
import asyncio

Check warning on line 1 in hugging_face/HuggingFaceInterface.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] hugging_face/HuggingFaceInterface.py#L1 <100>

Missing docstring in public module
Raw output
./hugging_face/HuggingFaceInterface.py:1:1: D100 Missing docstring in public module
import json
import sys
from typing import List

from collector_db.DTOs.URLWithHTML import URLWithHTML
import gc

class HuggingFaceInterface:

@staticmethod
def load_relevancy_model() -> pipeline:
return pipeline("text-classification", model="PDAP/url-relevance")

def get_url_relevancy(
self,
urls_with_html: list[URLWithHTML],
threshold: float = 0.5
) -> list[bool]:
urls = [url_with_html.url for url_with_html in urls_with_html]
relevance_pipe = self.load_relevancy_model()
results: list[dict] = relevance_pipe(urls)

bool_results = []
for result in results:
score = result["score"]
if score >= threshold:
bool_results.append(True)
else:
bool_results.append(False)
del relevance_pipe
gc.collect()
return bool_results
async def get_url_relevancy_async(urls_with_html: List[URLWithHTML]) -> List[bool]:

Check warning on line 11 in hugging_face/HuggingFaceInterface.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] hugging_face/HuggingFaceInterface.py#L11 <102>

Missing docstring in public method
Raw output
./hugging_face/HuggingFaceInterface.py:11:1: D102 Missing docstring in public method
urls = [u.url for u in urls_with_html]
input_data = json.dumps(urls)

proc = await asyncio.create_subprocess_exec(
sys.executable, "hugging_face/relevancy_worker.py",
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)

stdout, stderr = await proc.communicate(input=input_data.encode("utf-8"))

raw_output = stdout.decode("utf-8").strip()

# Try to extract the actual JSON line
for line in raw_output.splitlines():
try:
return json.loads(line)
except json.JSONDecodeError:
continue

raise RuntimeError(f"Could not parse JSON from subprocess: {raw_output}")

15 changes: 15 additions & 0 deletions hugging_face/relevancy_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import sys

Check warning on line 1 in hugging_face/relevancy_worker.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] hugging_face/relevancy_worker.py#L1 <100>

Missing docstring in public module
Raw output
./hugging_face/relevancy_worker.py:1:1: D100 Missing docstring in public module
import json
from transformers import pipeline

def main():

Check warning on line 5 in hugging_face/relevancy_worker.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] hugging_face/relevancy_worker.py#L5 <103>

Missing docstring in public function
Raw output
./hugging_face/relevancy_worker.py:5:1: D103 Missing docstring in public function
urls = json.loads(sys.stdin.read())

pipe = pipeline("text-classification", model="PDAP/url-relevance")
results = pipe(urls)
bools = [r["score"] >= 0.5 for r in results]

print(json.dumps(bools))

if __name__ == "__main__":

Check failure on line 14 in hugging_face/relevancy_worker.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] hugging_face/relevancy_worker.py#L14 <305>

expected 2 blank lines after class or function definition, found 1
Raw output
./hugging_face/relevancy_worker.py:14:1: E305 expected 2 blank lines after class or function definition, found 1
main()
8 changes: 5 additions & 3 deletions tests/manual/huggingface/test_hugging_face_interface.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import pytest

Check warning on line 1 in tests/manual/huggingface/test_hugging_face_interface.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] tests/manual/huggingface/test_hugging_face_interface.py#L1 <100>

Missing docstring in public module
Raw output
./tests/manual/huggingface/test_hugging_face_interface.py:1:1: D100 Missing docstring in public module

from collector_db.DTOs.URLWithHTML import URLWithHTML
from hugging_face.HuggingFaceInterface import HuggingFaceInterface


def test_get_url_relevancy():
@pytest.mark.asyncio
async def test_get_url_relevancy():

Check warning on line 7 in tests/manual/huggingface/test_hugging_face_interface.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] tests/manual/huggingface/test_hugging_face_interface.py#L7 <103>

Missing docstring in public function
Raw output
./tests/manual/huggingface/test_hugging_face_interface.py:7:1: D103 Missing docstring in public function
hfi = HuggingFaceInterface()
def package_url(url: str) -> URLWithHTML:
return URLWithHTML(url=url, url_id=1, html_infos=[])

result = hfi.get_url_relevancy([
result = await hfi.get_url_relevancy_async([
package_url("https://coloradosprings.gov/police-department/article/news/i-25-traffic-safety-deployment-after-stop"),
package_url("https://example.com"),
package_url("https://police.com")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def num_to_bool(num: int) -> bool:
else:
return False

def mock_get_url_relevancy(
async def mock_get_url_relevancy(
urls_with_html: list[URLWithHTML],
threshold: float = 0.8
) -> list[bool]:
Expand All @@ -33,7 +33,7 @@ def mock_get_url_relevancy(
return results

mock_hf_interface = MagicMock(spec=HuggingFaceInterface)
mock_hf_interface.get_url_relevancy = mock_get_url_relevancy
mock_hf_interface.get_url_relevancy_async = mock_get_url_relevancy

task_operator = URLRelevanceHuggingfaceTaskOperator(
adb_client=AsyncDatabaseClient(),
Expand All @@ -50,7 +50,7 @@ def mock_get_url_relevancy(
await db_data_creator.html_data(url_ids)

run_info: TaskOperatorRunInfo = await task_operator.run_task(1)
assert run_info.outcome == TaskOperatorOutcome.SUCCESS
assert run_info.outcome == TaskOperatorOutcome.SUCCESS, run_info.message


results = await db_data_creator.adb_client.get_all(AutoRelevantSuggestion)
Expand Down