From 2dc495c158b21385c6bef3cfebaa9298f4614582 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Fri, 12 Dec 2025 11:13:29 +0100 Subject: [PATCH 1/5] Improve test runtimes by decreasing initial wait times for jobs and projections. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Florentin Dörre --- src/graphdatascience/arrow_client/v2/job_client.py | 8 +++----- .../query_runner/protocol/project_protocols.py | 6 +++--- .../query_runner/protocol/write_protocols.py | 6 +++--- src/graphdatascience/retry_utils/retry_utils.py | 14 +++++++++++++- 4 files changed, 22 insertions(+), 12 deletions(-) diff --git a/src/graphdatascience/arrow_client/v2/job_client.py b/src/graphdatascience/arrow_client/v2/job_client.py index 13ffd738e..c19ae5142 100644 --- a/src/graphdatascience/arrow_client/v2/job_client.py +++ b/src/graphdatascience/arrow_client/v2/job_client.py @@ -3,13 +3,14 @@ from pandas import ArrowDtype, DataFrame from pyarrow._flight import Ticket -from tenacity import Retrying, retry_if_result, wait_exponential +from tenacity import Retrying, retry_if_result from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient from graphdatascience.arrow_client.v2.api_types import JobIdConfig, JobStatus from graphdatascience.arrow_client.v2.data_mapper_utils import deserialize_single from graphdatascience.query_runner.progress.progress_bar import TqdmProgressBar from graphdatascience.query_runner.termination_flag import TerminationFlag +from graphdatascience.retry_utils.retry_utils import job_wait_strategy JOB_STATUS_ENDPOINT = "v2/jobs.status" RESULTS_SUMMARY_ENDPOINT = "v2/results.summary" @@ -50,12 +51,9 @@ def check_expected_status(status: JobStatus) -> bool: if termination_flag is None: termination_flag = TerminationFlag.create() - for attempt in Retrying( - retry=retry_if_result(lambda _: True), wait=wait_exponential(min=0.1, max=5), reraise=True - ): + for attempt in Retrying(retry=retry_if_result(lambda _: True), wait=job_wait_strategy(), reraise=True): with attempt: termination_flag.assert_running() - job_status = self.get_job_status(client, job_id) if check_expected_status(job_status) or job_status.aborted(): diff --git a/src/graphdatascience/query_runner/protocol/project_protocols.py b/src/graphdatascience/query_runner/protocol/project_protocols.py index 84a91e2a9..af5c2f81c 100644 --- a/src/graphdatascience/query_runner/protocol/project_protocols.py +++ b/src/graphdatascience/query_runner/protocol/project_protocols.py @@ -3,13 +3,13 @@ from typing import Any from pandas import DataFrame -from tenacity import retry, retry_if_result, wait_incrementing +from tenacity import retry, retry_if_result from graphdatascience.call_parameters import CallParameters from graphdatascience.query_runner.protocol.status import Status from graphdatascience.query_runner.query_runner import QueryRunner from graphdatascience.query_runner.termination_flag import TerminationFlag -from graphdatascience.retry_utils.retry_utils import before_log +from graphdatascience.retry_utils.retry_utils import before_log, job_wait_strategy from graphdatascience.session.dbms.protocol_version import ProtocolVersion @@ -159,7 +159,7 @@ def is_not_done(result: DataFrame) -> bool: reraise=True, before=before_log(f"Projection (graph: `{params['graph_name']}`)", logger, DEBUG), retry=retry_if_result(is_not_done), - wait=wait_incrementing(start=0.2, increment=0.2, max=2), + wait=job_wait_strategy(), ) def project_fn() -> DataFrame: termination_flag.assert_running() diff --git a/src/graphdatascience/query_runner/protocol/write_protocols.py b/src/graphdatascience/query_runner/protocol/write_protocols.py index a1fd6e534..61c1d406f 100644 --- a/src/graphdatascience/query_runner/protocol/write_protocols.py +++ b/src/graphdatascience/query_runner/protocol/write_protocols.py @@ -3,7 +3,7 @@ from typing import Any from pandas import DataFrame -from tenacity import retry, retry_if_result, wait_incrementing +from tenacity import retry, retry_if_result from graphdatascience.call_parameters import CallParameters from graphdatascience.query_runner.progress.progress_bar import TqdmProgressBar @@ -11,7 +11,7 @@ from graphdatascience.query_runner.query_mode import QueryMode from graphdatascience.query_runner.query_runner import QueryRunner from graphdatascience.query_runner.termination_flag import TerminationFlag -from graphdatascience.retry_utils.retry_utils import before_log +from graphdatascience.retry_utils.retry_utils import before_log, job_wait_strategy from graphdatascience.session.dbms.protocol_version import ProtocolVersion @@ -157,7 +157,7 @@ def is_not_completed(result: DataFrame) -> bool: @retry( reraise=True, retry=retry_if_result(is_not_completed), - wait=wait_incrementing(start=0.2, increment=0.2, max=2), + wait=job_wait_strategy(), before=before_log( f"Write-Back (graph: `{parameters['graphName']}`, jobId: `{parameters['jobId']}`)", logger, diff --git a/src/graphdatascience/retry_utils/retry_utils.py b/src/graphdatascience/retry_utils/retry_utils.py index e05d407cf..f47084a3d 100644 --- a/src/graphdatascience/retry_utils/retry_utils.py +++ b/src/graphdatascience/retry_utils/retry_utils.py @@ -1,7 +1,8 @@ import logging import typing -from tenacity import RetryCallState +import tenacity.wait +from tenacity import RetryCallState, wait_chain, wait_fixed def before_log( @@ -18,3 +19,14 @@ def log_it(retry_state: RetryCallState) -> None: ) return log_it + + +def job_wait_strategy() -> tenacity.wait.wait_base: + # Wait for 0.02 s in the very beginning (to speed up tests) + # Wait for 0.1 s in the first 10 seconds + # Then increase exponentially to a max of 5 seconds + return wait_chain( + *[wait_fixed(0.02)] + + [wait_fixed(0.1) for j in range(100)] + + [wait_fixed(1), wait_fixed(2), wait_fixed(4), wait_fixed(5)] + ) From 884ad7ad9db534f51df0487202e05ebed81723b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florentin=20D=C3=B6rre?= Date: Fri, 12 Dec 2025 12:07:24 +0100 Subject: [PATCH 2/5] Fix unit test --- tests/integrationV2/procedure_surface/conftest.py | 2 +- tests/unit/arrow_client/V2/test_job_client.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/integrationV2/procedure_surface/conftest.py b/tests/integrationV2/procedure_surface/conftest.py index b8abee42e..e0418e589 100644 --- a/tests/integrationV2/procedure_surface/conftest.py +++ b/tests/integrationV2/procedure_surface/conftest.py @@ -81,7 +81,7 @@ def start_database(logs_dir: Path, network: Network) -> Generator[DbmsConnection if neo4j_image is None: raise ValueError("NEO4J_DATABASE_IMAGE environment variable is not set") db_logs_dir = logs_dir / "arrow_surface" / "db_logs" - db_logs_dir.mkdir(parents=True) + db_logs_dir.mkdir(parents=True, exist_ok=True) db_logs_dir.chmod(0o777) db_container = ( DockerContainer(image=neo4j_image) diff --git a/tests/unit/arrow_client/V2/test_job_client.py b/tests/unit/arrow_client/V2/test_job_client.py index b5c9ace35..57affa19e 100644 --- a/tests/unit/arrow_client/V2/test_job_client.py +++ b/tests/unit/arrow_client/V2/test_job_client.py @@ -220,8 +220,8 @@ def test_wait_for_job_progress_bar_qualitative(mocker: MockerFixture) -> None: progress_output = pbarOutputStream.getvalue().split("\r") assert "Algo [elapsed: 00:00 ]" in progress_output - assert "Algo [elapsed: 00:01 , status: RUNNING, task: Halfway there]" in progress_output - assert any("Algo [elapsed: 00:03 , status: FINISHED]" in line for line in progress_output) + assert "Algo [elapsed: 00:00 , status: RUNNING, task: Halfway there]" in progress_output + assert any("Algo [elapsed: 00:00 , status: FINISHED]" in line for line in progress_output) def test_get_summary(mocker: MockerFixture) -> None: From b72c02f878d93ab9871330149c4ceebdcd3f6384 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florentin=20D=C3=B6rre?= Date: Fri, 12 Dec 2025 13:28:40 +0100 Subject: [PATCH 3/5] Speedup unittests avoid time.sleep and use no wait in unit_test --- .../arrow_client/V1/test_gds_arrow_client.py | 6 +++- tests/unit/arrow_client/conftest.py | 17 +++++++++++ .../test_authenticated_flight_client.py | 24 ++++----------- .../progress/test_query_progress_logger.py | 17 +++++++++-- tests/unit/test_gds_arrow_client.py | 30 ++++++++++++++++--- 5 files changed, 68 insertions(+), 26 deletions(-) create mode 100644 tests/unit/arrow_client/conftest.py diff --git a/tests/unit/arrow_client/V1/test_gds_arrow_client.py b/tests/unit/arrow_client/V1/test_gds_arrow_client.py index ddab3819c..8d074bf93 100644 --- a/tests/unit/arrow_client/V1/test_gds_arrow_client.py +++ b/tests/unit/arrow_client/V1/test_gds_arrow_client.py @@ -19,6 +19,7 @@ from graphdatascience.procedure_surface.arrow.error_handler import handle_flight_error from graphdatascience.query_runner.arrow_authentication import UsernamePasswordAuthentication from graphdatascience.query_runner.arrow_info import ArrowInfo +from graphdatascience.retry_utils.retry_config import RetryConfigV2 ActionParam: TypeAlias = str | tuple[str, Any] | Action @@ -132,10 +133,13 @@ def gds_client(flight_server: FlightServer) -> Generator[GdsArrowClient, None, N @pytest.fixture() -def flaky_gds_client(flaky_flight_server: FlakyFlightServer) -> Generator[GdsArrowClient, None, None]: +def flaky_gds_client( + flaky_flight_server: FlakyFlightServer, retry_config_v2: RetryConfigV2 +) -> Generator[GdsArrowClient, None, None]: with AuthenticatedArrowClient.create( ArrowInfo(f"localhost:{flaky_flight_server.port}", True, True, ["v1"]), UsernamePasswordAuthentication("user", "password"), + retry_config=retry_config_v2, ) as arrow_client: yield GdsArrowClient(arrow_client) diff --git a/tests/unit/arrow_client/conftest.py b/tests/unit/arrow_client/conftest.py new file mode 100644 index 000000000..084335ccc --- /dev/null +++ b/tests/unit/arrow_client/conftest.py @@ -0,0 +1,17 @@ +import pytest +from pyarrow._flight import FlightInternalError, FlightTimedOutError, FlightUnavailableError + +from graphdatascience.retry_utils.retry_config import RetryConfigV2, StopConfig + + +@pytest.fixture +def retry_config_v2() -> RetryConfigV2: + return RetryConfigV2( + retryable_exceptions=[ + FlightTimedOutError, + FlightUnavailableError, + FlightInternalError, + ], + stop_config=StopConfig(after_delay=10, after_attempt=5), + wait_config=None, # No wait for tests. Makes them faster + ) diff --git a/tests/unit/arrow_client/test_authenticated_flight_client.py b/tests/unit/arrow_client/test_authenticated_flight_client.py index 6d51be6e4..6628fc8f8 100644 --- a/tests/unit/arrow_client/test_authenticated_flight_client.py +++ b/tests/unit/arrow_client/test_authenticated_flight_client.py @@ -1,10 +1,9 @@ import pytest -from pyarrow._flight import FlightInternalError, FlightTimedOutError, FlightUnavailableError from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication from graphdatascience.arrow_client.arrow_info import ArrowInfo from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient, ConnectionInfo -from graphdatascience.retry_utils.retry_config import ExponentialWaitConfig, RetryConfigV2, StopConfig +from graphdatascience.retry_utils.retry_config import ExponentialWaitConfig, RetryConfigV2 @pytest.fixture @@ -12,19 +11,6 @@ def arrow_info() -> ArrowInfo: return ArrowInfo(listenAddress="localhost:8491", enabled=True, running=True, versions=["1.0.0"]) -@pytest.fixture -def retry_config() -> RetryConfigV2: - return RetryConfigV2( - retryable_exceptions=[ - FlightTimedOutError, - FlightUnavailableError, - FlightInternalError, - ], - stop_config=StopConfig(after_delay=10, after_attempt=5), - wait_config=ExponentialWaitConfig(multiplier=1, min=1, max=10), - ) - - @pytest.fixture def mock_auth() -> ArrowAuthentication: class MockAuthentication(ArrowAuthentication): @@ -42,14 +28,14 @@ def test_create_authenticated_arrow_client(arrow_info: ArrowInfo, mock_auth: Arr assert client.connection_info() == ConnectionInfo("localhost", 8491, encrypted=True) -def test_connection_info(arrow_info: ArrowInfo, retry_config: RetryConfigV2) -> None: - client = AuthenticatedArrowClient(host="localhost", port=8491, retry_config=retry_config) +def test_connection_info(arrow_info: ArrowInfo, retry_config_v2: RetryConfigV2) -> None: + client = AuthenticatedArrowClient(host="localhost", port=8491, retry_config=retry_config_v2) connection_info = client.connection_info() assert connection_info == ConnectionInfo("localhost", 8491, encrypted=False) -def test_pickle_roundtrip(arrow_info: ArrowInfo, retry_config: RetryConfigV2) -> None: - client = AuthenticatedArrowClient(host="localhost", port=8491, retry_config=retry_config) +def test_pickle_roundtrip(arrow_info: ArrowInfo, retry_config_v2: RetryConfigV2) -> None: + client = AuthenticatedArrowClient(host="localhost", port=8491, retry_config=retry_config_v2) import pickle pickled_client = pickle.dumps(client) diff --git a/tests/unit/query_runner/progress/test_query_progress_logger.py b/tests/unit/query_runner/progress/test_query_progress_logger.py index 7961ae80d..924459849 100644 --- a/tests/unit/query_runner/progress/test_query_progress_logger.py +++ b/tests/unit/query_runner/progress/test_query_progress_logger.py @@ -1,4 +1,5 @@ import re +import threading import time from concurrent.futures import ThreadPoolExecutor from io import StringIO @@ -14,19 +15,27 @@ def test_call_through_functions() -> None: + progress_fetched_event = threading.Event() + progress_called = [] + def fake_run_cypher(query: str, database: str | None = None) -> DataFrame: + progress_called.append(time.time()) + assert "CALL gds.listProgress('foo')" in query assert database == "database" + progress_fetched_event.set() + return DataFrame([{"progress": "n/a", "taskName": "Test task", "status": "RUNNING"}]) def fake_query() -> DataFrame: - time.sleep(1) + progress_fetched_event.wait(5) return DataFrame([{"result": 42}]) qpl = QueryProgressLogger(fake_run_cypher, lambda: ServerVersion(3, 0, 0)) df = qpl.run_with_progress_logging(fake_query, "foo", "database") + assert len(progress_called) > 0 assert df["result"][0] == 42 @@ -45,14 +54,18 @@ def fake_query() -> DataFrame: def test_uses_beta_endpoint() -> None: + progress_fetched_event = threading.Event() + def fake_run_cypher(query: str, database: str | None = None) -> DataFrame: assert "CALL gds.beta.listProgress('foo')" in query assert database == "database" + progress_fetched_event.set() + return DataFrame([{"progress": "n/a", "taskName": "Test task", "status": "RUNNING"}]) def fake_query() -> DataFrame: - time.sleep(1) + progress_fetched_event.wait(5) return DataFrame([{"result": 42}]) qpl = QueryProgressLogger(fake_run_cypher, lambda: ServerVersion(2, 4, 0)) diff --git a/tests/unit/test_gds_arrow_client.py b/tests/unit/test_gds_arrow_client.py index 2807f2728..56f7a565d 100644 --- a/tests/unit/test_gds_arrow_client.py +++ b/tests/unit/test_gds_arrow_client.py @@ -7,16 +7,19 @@ from pyarrow._flight import GeneratorStream from pyarrow.flight import ( Action, + FlightInternalError, FlightServerBase, FlightServerError, FlightTimedOutError, FlightUnavailableError, Ticket, ) +from tenacity import retry_any, retry_if_exception_type, stop_after_attempt, wait_none from graphdatascience.query_runner.arrow_authentication import UsernamePasswordAuthentication from graphdatascience.query_runner.arrow_info import ArrowInfo from graphdatascience.query_runner.gds_arrow_client import AuthMiddleware, GdsArrowClient +from graphdatascience.retry_utils.retry_config import RetryConfig ActionParam: TypeAlias = str | tuple[str, Any] | Action @@ -121,14 +124,33 @@ def flaky_flight_server() -> Generator[None, FlakyFlightServer, None]: @pytest.fixture() -def flight_client(flight_server: FlightServer) -> Generator[GdsArrowClient, None, None]: - with GdsArrowClient.create(ArrowInfo(f"localhost:{flight_server.port}", True, True, ["v1"])) as client: +def retry_config() -> RetryConfig: + return RetryConfig( + retry=retry_any( + retry_if_exception_type(FlightTimedOutError), + retry_if_exception_type(FlightUnavailableError), + retry_if_exception_type(FlightInternalError), + ), + stop=stop_after_attempt(5), + wait=wait_none(), # make test go fast + ) + + +@pytest.fixture() +def flight_client(flight_server: FlightServer, retry_config: RetryConfig) -> Generator[GdsArrowClient, None, None]: + with GdsArrowClient.create( + ArrowInfo(f"localhost:{flight_server.port}", True, True, ["v1"]), retry_config=retry_config + ) as client: yield client @pytest.fixture() -def flaky_flight_client(flaky_flight_server: FlakyFlightServer) -> Generator[GdsArrowClient, None, None]: - with GdsArrowClient.create(ArrowInfo(f"localhost:{flaky_flight_server.port}", True, True, ["v1"])) as client: +def flaky_flight_client( + flaky_flight_server: FlakyFlightServer, retry_config: RetryConfig +) -> Generator[GdsArrowClient, None, None]: + with GdsArrowClient.create( + ArrowInfo(f"localhost:{flaky_flight_server.port}", True, True, ["v1"]), retry_config=retry_config + ) as client: yield client From 51a7d9d7eeef44b15337775c20b2ff52a7f1bd48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florentin=20D=C3=B6rre?= Date: Fri, 12 Dec 2025 13:44:44 +0100 Subject: [PATCH 4/5] Update session startup timeout it just barely didnt make it in CI --- justfile | 4 ++++ tests/integrationV2/conftest.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/justfile b/justfile index 1b0b705c6..43fb63f9e 100644 --- a/justfile +++ b/justfile @@ -27,3 +27,7 @@ it filter="" enterprise="true" extra_options="": # such as `just it-v2 wcc` it-v2 filter="" extra_options="": pytest tests/integrationV2 --include-integration-v2 --basetemp=tmp/ {{extra_options}} {{ if filter != "" { "-k '" + filter + "'" } else { "" } }} + + +update-session: + docker pull europe-west1-docker.pkg.dev/gds-aura-artefacts/gds/gds-session:latest diff --git a/tests/integrationV2/conftest.py b/tests/integrationV2/conftest.py index 096a50896..55253c869 100644 --- a/tests/integrationV2/conftest.py +++ b/tests/integrationV2/conftest.py @@ -99,7 +99,7 @@ def start_session( session_container = session_container.with_network(network).with_network_aliases("gds-session") with session_container as session_container: try: - wait_for_logs(session_container, "Running GDS tasks: 0", timeout=20) + wait_for_logs(session_container, "Running GDS tasks: 0", timeout=30) yield GdsSessionConnectionInfo( host=session_container.get_container_host_ip(), arrow_port=session_container.get_exposed_port(8491), From 89f86ac9c20d5e68c80fac02aaaf38d14d8ac1b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florentin=20D=C3=B6rre?= Date: Mon, 15 Dec 2025 11:54:55 +0100 Subject: [PATCH 5/5] Fix cloud-test conftest setup recent change in auth tuple construction let this fail --- justfile | 20 +++++++++++++ scripts/test_envs/gds_session/compose.yml | 35 +++++++++++++++++++++++ scripts/test_envs/gds_session/password | 1 + tests/integration/conftest.py | 2 +- 4 files changed, 57 insertions(+), 1 deletion(-) create mode 100755 scripts/test_envs/gds_session/compose.yml create mode 100644 scripts/test_envs/gds_session/password diff --git a/justfile b/justfile index 43fb63f9e..6b2b0d83d 100644 --- a/justfile +++ b/justfile @@ -29,5 +29,25 @@ it-v2 filter="" extra_options="": pytest tests/integrationV2 --include-integration-v2 --basetemp=tmp/ {{extra_options}} {{ if filter != "" { "-k '" + filter + "'" } else { "" } }} +# runs the +session-v1-it-tests: + #!/usr/bin/env bash + set -e + ENV_DIR="scripts/test_envs/gds_session" + trap "cd $ENV_DIR && docker compose down" EXIT + cd $ENV_DIR && docker compose up -d + cd - + NEO4J_URI=bolt://localhost:7688 \ + NEO4J_USER=neo4j \ + NEO4J_PASSWORD=password \ + NEO4J_DB=neo4j \ + NEO4J_AURA_DB_URI=bolt://localhost:7687 \ + pytest tests --include-cloud-architecture + + + + + update-session: docker pull europe-west1-docker.pkg.dev/gds-aura-artefacts/gds/gds-session:latest + diff --git a/scripts/test_envs/gds_session/compose.yml b/scripts/test_envs/gds_session/compose.yml new file mode 100755 index 000000000..42d522ea2 --- /dev/null +++ b/scripts/test_envs/gds_session/compose.yml @@ -0,0 +1,35 @@ +services: + gds-session: + image: europe-west1-docker.pkg.dev/gds-aura-artefacts/gds/gds-session:latest # build locally if you dont want to use a released one + volumes: + - ./password:/passwords/password + environment: + - SESSION_ID=42 + - MODEL_STORAGE_BASE_LOCATION=/models + - ALLOW_LIST=DEFAULT + - DNS_NAME=gds-session + - PAGE_CACHE_SIZE=100M + ports: + - "7688:7687" + - "8080:8080" + - "8491:8491" + + + neo4j: + image: neo4j:enterprise + volumes: + - ${HOME}/.gds_license:/licenses/.gds_license + environment: + - NEO4J_AUTH=none # for testing + + - NEO4J_ACCEPT_LICENSE_AGREEMENT=yes + + - NEO4J_server_metrics_prometheus_enabled=true + - NEO4J_server_metrics_prometheus_endpoint=0.0.0.0:2004 + - NEO4J_server_metrics_filter=* + - NEO4J_server_metrics_enabled=true + ports: + - "7474:7474" + - "7687:7687" + - "2004:2004" + restart: always diff --git a/scripts/test_envs/gds_session/password b/scripts/test_envs/gds_session/password new file mode 100644 index 000000000..7aa311adf --- /dev/null +++ b/scripts/test_envs/gds_session/password @@ -0,0 +1 @@ +password \ No newline at end of file diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 143d457fb..ebccfadf1 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -20,7 +20,7 @@ if neo4j_user := os.environ.get("NEO4J_USER", os.environ.get("NEO4J_USERNAME", "neo4j")): AUTH = ( neo4j_user, - os.environ.get("NEO4J_PASSWORD", "neo4j"), + os.environ.get("NEO4J_PASSWORD", "password"), ) DB = os.environ.get("NEO4J_DB", "neo4j")