Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ repos:
nc_py_api/|
benchmarks/|
examples/|
tests/
tests/|
tests_unit/
)

- repo: https://github.com/psf/black
Expand All @@ -32,7 +33,8 @@ repos:
nc_py_api/|
benchmarks/|
examples/|
tests/
tests/|
tests_unit/
)

- repo: https://github.com/tox-dev/pyproject-fmt
Expand Down
90 changes: 53 additions & 37 deletions nc_py_api/ex_app/integration_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
status,
)
from fastapi.responses import JSONResponse, PlainTextResponse
from filelock import FileLock
from filelock import Timeout as FileLockTimeout
from starlette.requests import HTTPConnection, Request
from starlette.types import ASGIApp, Receive, Scope, Send

Expand Down Expand Up @@ -158,7 +160,7 @@ def fetch_models_task(nc: NextcloudApp, models: dict[str, dict], progress_init_s
"""Use for cases when you want to define custom `/init` but still need to easy download models.

:param nc: NextcloudApp instance.
:param models_to_fetch: Dictionary describing which models should be downloaded of the form:
:param models: Dictionary describing which models should be downloaded of the form:
.. code-block:: python
{
"model_url_1": {
Expand Down Expand Up @@ -205,42 +207,56 @@ def __fetch_model_as_file(
current_progress: int, progress_for_task: int, nc: NextcloudApp, model_path: str, download_options: dict
) -> str:
result_path = download_options.pop("save_path", urlparse(model_path).path.split("/")[-1])
with niquests.get(model_path, stream=True) as response:
if not response.ok:
raise ModelFetchError(
f"Downloading of '{model_path}' failed, returned ({response.status_code}) {response.text}"
)
downloaded_size = 0
linked_etag = ""
for each_history in response.history:
linked_etag = each_history.headers.get("X-Linked-ETag", "")
if linked_etag:
break
if not linked_etag:
linked_etag = response.headers.get("X-Linked-ETag", response.headers.get("ETag", ""))
total_size = int(response.headers.get("Content-Length"))
try:
existing_size = os.path.getsize(result_path)
except OSError:
existing_size = 0
if linked_etag and total_size == existing_size:
with builtins.open(result_path, "rb") as file:
sha256_hash = hashlib.sha256()
for byte_block in iter(lambda: file.read(4096), b""):
sha256_hash.update(byte_block)
if f'"{sha256_hash.hexdigest()}"' == linked_etag:
nc.set_init_status(min(current_progress + progress_for_task, 99))
return result_path

with builtins.open(result_path, "wb") as file:
last_progress = current_progress
for chunk in response.iter_raw(-1):
downloaded_size += file.write(chunk)
if total_size:
new_progress = min(current_progress + int(progress_for_task * downloaded_size / total_size), 99)
if new_progress != last_progress:
nc.set_init_status(new_progress)
last_progress = new_progress
tmp_path = result_path + ".tmp"
try:
with FileLock(result_path + ".lock", timeout=7200), niquests.get(model_path, stream=True) as response:
if not response.ok:
raise ModelFetchError(
f"Downloading of '{model_path}' failed, returned ({response.status_code}) {response.text}"
)
downloaded_size = 0
linked_etag = ""
for redirect_resp in response.history:
linked_etag = redirect_resp.headers.get("X-Linked-ETag", "")
if linked_etag:
break
if not linked_etag:
linked_etag = response.headers.get("X-Linked-ETag", response.headers.get("ETag", ""))
total_size = int(response.headers.get("Content-Length", 0))
try:
existing_size = os.path.getsize(result_path)
except OSError:
existing_size = 0
if linked_etag and total_size and total_size == existing_size:
with builtins.open(result_path, "rb") as file:
sha256_hash = hashlib.sha256()
for byte_block in iter(lambda: file.read(4096), b""):
sha256_hash.update(byte_block)
if f'"{sha256_hash.hexdigest()}"' == linked_etag:
nc.set_init_status(min(current_progress + progress_for_task, 99))
return result_path

try:
with builtins.open(tmp_path, "wb") as file:
last_progress = current_progress
for chunk in response.iter_raw(-1):
downloaded_size += file.write(chunk)
if total_size:
new_progress = min(
current_progress + int(progress_for_task * downloaded_size / total_size), 99
)
if new_progress != last_progress:
nc.set_init_status(new_progress)
last_progress = new_progress
os.replace(tmp_path, result_path)
except BaseException:
if os.path.exists(tmp_path):
os.remove(tmp_path)
raise
except FileLockTimeout as exc:
raise ModelFetchError(
f"Timed out waiting for lock on '{result_path}' after 7200s — another process may be stuck downloading"
) from exc

return result_path

Expand Down
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ dynamic = [
]
dependencies = [
"fastapi>=0.109.2",
"filelock>=3.20.3,<4",
"niquests>=3,<4",
"pydantic>=2.1.1",
"python-dotenv>=1",
Expand Down Expand Up @@ -148,6 +149,12 @@ lint.extend-per-file-ignores."tests/**/*.py" = [
"S",
"UP",
]
lint.extend-per-file-ignores."tests_unit/**/*.py" = [
"D",
"E402",
"S",
"UP",
]
lint.mccabe.max-complexity = 16

[tool.isort]
Expand Down Expand Up @@ -198,6 +205,7 @@ messages_control.disable = [
minversion = "6.0"
testpaths = [
"tests",
"tests_unit",
]
filterwarnings = [
"ignore::DeprecationWarning",
Expand Down
214 changes: 214 additions & 0 deletions tests_unit/test_fetch_model_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
"""Tests for model file download with FileLock and atomic rename."""

import hashlib
import os
from threading import Thread
from unittest import mock

import pytest
from filelock import FileLock
from filelock import Timeout as FileLockTimeout

from nc_py_api._exceptions import ModelFetchError
from nc_py_api.ex_app.integration_fastapi import fetch_models_task


class FakeResponse:
"""Mock HTTP response for niquests.get() with streaming support."""

def __init__(self, content: bytes, etag: str = "", status_code: int = 200, ok: bool = True):
self.content = content
self.status_code = status_code
self.ok = ok
self.text = "" if ok else "Not Found"
self.history = []
sha = hashlib.sha256(content).hexdigest()
self.headers = {
"Content-Length": str(len(content)),
"ETag": etag or f'"{sha}"',
}

def iter_raw(self, _chunk_size):
yield self.content

def __enter__(self):
return self

def __exit__(self, *args):
pass


def _mock_nc():
nc = mock.MagicMock()
nc.set_init_status = mock.MagicMock()
return nc


class TestFetchModelAsFile:
"""Tests for __fetch_model_as_file via fetch_models_task."""

def test_download_creates_file(self, tmp_path):
content = b"model-data-abc"
save_path = str(tmp_path / "model.bin")

with mock.patch("nc_py_api.ex_app.integration_fastapi.niquests.get", return_value=FakeResponse(content)):
fetch_models_task(_mock_nc(), {"https://example.com/model.bin": {"save_path": save_path}}, 0)

assert os.path.isfile(save_path)
with open(save_path, "rb") as f:
assert f.read() == content

def test_no_tmp_file_remains_after_success(self, tmp_path):
save_path = str(tmp_path / "model.bin")

with mock.patch("nc_py_api.ex_app.integration_fastapi.niquests.get", return_value=FakeResponse(b"data")):
fetch_models_task(_mock_nc(), {"https://example.com/m.bin": {"save_path": save_path}}, 0)

assert not os.path.exists(save_path + ".tmp")

def test_lock_file_released_after_download(self, tmp_path):
save_path = str(tmp_path / "model.bin")

with mock.patch("nc_py_api.ex_app.integration_fastapi.niquests.get", return_value=FakeResponse(b"data")):
fetch_models_task(_mock_nc(), {"https://example.com/m.bin": {"save_path": save_path}}, 0)

lock_path = save_path + ".lock"
# Lock file may or may not exist (FileLock implementation detail),
# but it must not be held — acquiring it should succeed immediately.
lock = FileLock(lock_path, timeout=1)
lock.acquire()
lock.release()

def test_skips_download_when_file_matches_etag(self, tmp_path):
content = b"existing-model-data"
sha = hashlib.sha256(content).hexdigest()
etag = f'"{sha}"'
save_path = str(tmp_path / "model.bin")
with open(save_path, "wb") as f:
f.write(content)

call_count = {"iter_raw": 0}
original_iter_raw = FakeResponse.iter_raw

def tracking_iter_raw(self, chunk_size):
call_count["iter_raw"] += 1
yield from original_iter_raw(self, chunk_size)

resp = FakeResponse(content, etag=etag)
resp.iter_raw = lambda cs: tracking_iter_raw(resp, cs)

with mock.patch("nc_py_api.ex_app.integration_fastapi.niquests.get", return_value=resp):
fetch_models_task(_mock_nc(), {"https://example.com/model.bin": {"save_path": save_path}}, 0)

assert call_count["iter_raw"] == 0
with open(save_path, "rb") as f:
assert f.read() == content

def test_tmp_file_cleaned_up_on_download_error(self, tmp_path):
save_path = str(tmp_path / "model.bin")

def failing_iter_raw(_chunk_size):
yield b"partial"
raise ConnectionError("network down")

resp = FakeResponse(b"full-content")
resp.iter_raw = failing_iter_raw

with (
mock.patch("nc_py_api.ex_app.integration_fastapi.niquests.get", return_value=resp),
pytest.raises(ModelFetchError),
):
fetch_models_task(_mock_nc(), {"https://example.com/m.bin": {"save_path": save_path}}, 0)

assert not os.path.exists(save_path + ".tmp")
assert not os.path.exists(save_path)

def test_original_file_untouched_on_download_error(self, tmp_path):
save_path = str(tmp_path / "model.bin")
with open(save_path, "wb") as f:
f.write(b"original-good-data")

def failing_iter_raw(_chunk_size):
yield b"partial"
raise ConnectionError("network down")

resp = FakeResponse(b"new-content", etag='"different-etag"')
resp.iter_raw = failing_iter_raw

with (
mock.patch("nc_py_api.ex_app.integration_fastapi.niquests.get", return_value=resp),
pytest.raises(ModelFetchError),
):
fetch_models_task(_mock_nc(), {"https://example.com/m.bin": {"save_path": save_path}}, 0)

with open(save_path, "rb") as f:
assert f.read() == b"original-good-data"

def test_http_error_raises_model_fetch_error(self, tmp_path):
save_path = str(tmp_path / "model.bin")
resp = FakeResponse(b"", status_code=404, ok=False)

with (
mock.patch("nc_py_api.ex_app.integration_fastapi.niquests.get", return_value=resp),
pytest.raises(ModelFetchError),
):
fetch_models_task(_mock_nc(), {"https://example.com/m.bin": {"save_path": save_path}}, 0)

def test_concurrent_downloads_do_not_corrupt(self, tmp_path):
save_path = str(tmp_path / "model.bin")
errors = []

def download():
try:
fetch_models_task(_mock_nc(), {"https://example.com/m.bin": {"save_path": save_path}}, 0)
except Exception as e: # noqa pylint: disable=broad-exception-caught
errors.append(e)

# Patch once around both threads to avoid mock.patch context manager
# race: independent per-thread patches can restore the original
# function while the other thread still needs the mock.
responses = iter([FakeResponse(b"A" * 10000), FakeResponse(b"B" * 10000)])

def mock_get(_url, **_kwargs):
return next(responses)

with mock.patch("nc_py_api.ex_app.integration_fastapi.niquests.get", side_effect=mock_get):
t1 = Thread(target=download)
t2 = Thread(target=download)
t1.start()
t2.start()
t1.join(timeout=60)
t2.join(timeout=60)

assert not errors, f"Threads raised errors: {errors}"
assert os.path.isfile(save_path)
with open(save_path, "rb") as f:
data = f.read()
# File must be entirely one content or the other — never mixed
assert data in (b"A" * 10000, b"B" * 10000)

def test_filelock_timeout_raises_model_fetch_error(self, tmp_path):
save_path = str(tmp_path / "model.bin")
lock = FileLock(save_path + ".lock")
nc = _mock_nc()

with (
mock.patch("nc_py_api.ex_app.integration_fastapi.FileLock", side_effect=FileLockTimeout(lock)),
pytest.raises(ModelFetchError),
):
fetch_models_task(nc, {"https://example.com/m.bin": {"save_path": save_path}}, 0)

status_msg = nc.set_init_status.call_args_list[-1][0][1]
assert "Timed out waiting for lock" in status_msg

def test_progress_updates_sent(self, tmp_path):
save_path = str(tmp_path / "model.bin")
nc = _mock_nc()

with mock.patch("nc_py_api.ex_app.integration_fastapi.niquests.get", return_value=FakeResponse(b"data")):
fetch_models_task(nc, {"https://example.com/m.bin": {"save_path": save_path}}, 0)

# set_init_status should be called at least for completion (100)
assert nc.set_init_status.called
# Last call should be 100 (completion)
assert nc.set_init_status.call_args_list[-1] == mock.call(100)