Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions justfile
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,27 @@ it filter="" enterprise="true" extra_options="":
# such as `just it-v2 wcc`
it-v2 filter="" extra_options="":
pytest tests/integrationV2 --include-integration-v2 --basetemp=tmp/ {{extra_options}} {{ if filter != "" { "-k '" + filter + "'" } else { "" } }}


# runs the
session-v1-it-tests:
#!/usr/bin/env bash
set -e
ENV_DIR="scripts/test_envs/gds_session"
trap "cd $ENV_DIR && docker compose down" EXIT
cd $ENV_DIR && docker compose up -d
cd -
NEO4J_URI=bolt://localhost:7688 \
NEO4J_USER=neo4j \
NEO4J_PASSWORD=password \
NEO4J_DB=neo4j \
NEO4J_AURA_DB_URI=bolt://localhost:7687 \
pytest tests --include-cloud-architecture





update-session:
docker pull europe-west1-docker.pkg.dev/gds-aura-artefacts/gds/gds-session:latest

35 changes: 35 additions & 0 deletions scripts/test_envs/gds_session/compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
services:
gds-session:
image: europe-west1-docker.pkg.dev/gds-aura-artefacts/gds/gds-session:latest # build locally if you dont want to use a released one
volumes:
- ./password:/passwords/password
environment:
- SESSION_ID=42
- MODEL_STORAGE_BASE_LOCATION=/models
- ALLOW_LIST=DEFAULT
- DNS_NAME=gds-session
- PAGE_CACHE_SIZE=100M
ports:
- "7688:7687"
- "8080:8080"
- "8491:8491"


neo4j:
image: neo4j:enterprise
volumes:
- ${HOME}/.gds_license:/licenses/.gds_license
environment:
- NEO4J_AUTH=none # for testing

- NEO4J_ACCEPT_LICENSE_AGREEMENT=yes

- NEO4J_server_metrics_prometheus_enabled=true
- NEO4J_server_metrics_prometheus_endpoint=0.0.0.0:2004
- NEO4J_server_metrics_filter=*
- NEO4J_server_metrics_enabled=true
ports:
- "7474:7474"
- "7687:7687"
- "2004:2004"
restart: always
1 change: 1 addition & 0 deletions scripts/test_envs/gds_session/password
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
password
8 changes: 3 additions & 5 deletions src/graphdatascience/arrow_client/v2/job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@

from pandas import ArrowDtype, DataFrame
from pyarrow._flight import Ticket
from tenacity import Retrying, retry_if_result, wait_exponential
from tenacity import Retrying, retry_if_result

from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
from graphdatascience.arrow_client.v2.api_types import JobIdConfig, JobStatus
from graphdatascience.arrow_client.v2.data_mapper_utils import deserialize_single
from graphdatascience.query_runner.progress.progress_bar import TqdmProgressBar
from graphdatascience.query_runner.termination_flag import TerminationFlag
from graphdatascience.retry_utils.retry_utils import job_wait_strategy

JOB_STATUS_ENDPOINT = "v2/jobs.status"
RESULTS_SUMMARY_ENDPOINT = "v2/results.summary"
Expand Down Expand Up @@ -50,12 +51,9 @@ def check_expected_status(status: JobStatus) -> bool:
if termination_flag is None:
termination_flag = TerminationFlag.create()

for attempt in Retrying(
retry=retry_if_result(lambda _: True), wait=wait_exponential(min=0.1, max=5), reraise=True
):
for attempt in Retrying(retry=retry_if_result(lambda _: True), wait=job_wait_strategy(), reraise=True):
with attempt:
termination_flag.assert_running()

job_status = self.get_job_status(client, job_id)

if check_expected_status(job_status) or job_status.aborted():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from typing import Any

from pandas import DataFrame
from tenacity import retry, retry_if_result, wait_incrementing
from tenacity import retry, retry_if_result

from graphdatascience.call_parameters import CallParameters
from graphdatascience.query_runner.protocol.status import Status
from graphdatascience.query_runner.query_runner import QueryRunner
from graphdatascience.query_runner.termination_flag import TerminationFlag
from graphdatascience.retry_utils.retry_utils import before_log
from graphdatascience.retry_utils.retry_utils import before_log, job_wait_strategy
from graphdatascience.session.dbms.protocol_version import ProtocolVersion


Expand Down Expand Up @@ -159,7 +159,7 @@ def is_not_done(result: DataFrame) -> bool:
reraise=True,
before=before_log(f"Projection (graph: `{params['graph_name']}`)", logger, DEBUG),
retry=retry_if_result(is_not_done),
wait=wait_incrementing(start=0.2, increment=0.2, max=2),
wait=job_wait_strategy(),
)
def project_fn() -> DataFrame:
termination_flag.assert_running()
Expand Down
6 changes: 3 additions & 3 deletions src/graphdatascience/query_runner/protocol/write_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from typing import Any

from pandas import DataFrame
from tenacity import retry, retry_if_result, wait_incrementing
from tenacity import retry, retry_if_result

from graphdatascience.call_parameters import CallParameters
from graphdatascience.query_runner.progress.progress_bar import TqdmProgressBar
from graphdatascience.query_runner.protocol.status import Status
from graphdatascience.query_runner.query_mode import QueryMode
from graphdatascience.query_runner.query_runner import QueryRunner
from graphdatascience.query_runner.termination_flag import TerminationFlag
from graphdatascience.retry_utils.retry_utils import before_log
from graphdatascience.retry_utils.retry_utils import before_log, job_wait_strategy
from graphdatascience.session.dbms.protocol_version import ProtocolVersion


Expand Down Expand Up @@ -157,7 +157,7 @@ def is_not_completed(result: DataFrame) -> bool:
@retry(
reraise=True,
retry=retry_if_result(is_not_completed),
wait=wait_incrementing(start=0.2, increment=0.2, max=2),
wait=job_wait_strategy(),
before=before_log(
f"Write-Back (graph: `{parameters['graphName']}`, jobId: `{parameters['jobId']}`)",
logger,
Expand Down
14 changes: 13 additions & 1 deletion src/graphdatascience/retry_utils/retry_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
import typing

from tenacity import RetryCallState
import tenacity.wait
from tenacity import RetryCallState, wait_chain, wait_fixed


def before_log(
Expand All @@ -18,3 +19,14 @@ def log_it(retry_state: RetryCallState) -> None:
)

return log_it


def job_wait_strategy() -> tenacity.wait.wait_base:
# Wait for 0.02 s in the very beginning (to speed up tests)
# Wait for 0.1 s in the first 10 seconds
# Then increase exponentially to a max of 5 seconds
return wait_chain(
*[wait_fixed(0.02)]
+ [wait_fixed(0.1) for j in range(100)]
+ [wait_fixed(1), wait_fixed(2), wait_fixed(4), wait_fixed(5)]
)
2 changes: 1 addition & 1 deletion tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
if neo4j_user := os.environ.get("NEO4J_USER", os.environ.get("NEO4J_USERNAME", "neo4j")):
AUTH = (
neo4j_user,
os.environ.get("NEO4J_PASSWORD", "neo4j"),
os.environ.get("NEO4J_PASSWORD", "password"),
)

DB = os.environ.get("NEO4J_DB", "neo4j")
Expand Down
2 changes: 1 addition & 1 deletion tests/integrationV2/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def start_session(
session_container = session_container.with_network(network).with_network_aliases("gds-session")
with session_container as session_container:
try:
wait_for_logs(session_container, "Running GDS tasks: 0", timeout=20)
wait_for_logs(session_container, "Running GDS tasks: 0", timeout=30)
yield GdsSessionConnectionInfo(
host=session_container.get_container_host_ip(),
arrow_port=session_container.get_exposed_port(8491),
Expand Down
2 changes: 1 addition & 1 deletion tests/integrationV2/procedure_surface/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def start_database(logs_dir: Path, network: Network) -> Generator[DbmsConnection
if neo4j_image is None:
raise ValueError("NEO4J_DATABASE_IMAGE environment variable is not set")
db_logs_dir = logs_dir / "arrow_surface" / "db_logs"
db_logs_dir.mkdir(parents=True)
db_logs_dir.mkdir(parents=True, exist_ok=True)
db_logs_dir.chmod(0o777)
db_container = (
DockerContainer(image=neo4j_image)
Expand Down
6 changes: 5 additions & 1 deletion tests/unit/arrow_client/V1/test_gds_arrow_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from graphdatascience.procedure_surface.arrow.error_handler import handle_flight_error
from graphdatascience.query_runner.arrow_authentication import UsernamePasswordAuthentication
from graphdatascience.query_runner.arrow_info import ArrowInfo
from graphdatascience.retry_utils.retry_config import RetryConfigV2

ActionParam: TypeAlias = str | tuple[str, Any] | Action

Expand Down Expand Up @@ -132,10 +133,13 @@ def gds_client(flight_server: FlightServer) -> Generator[GdsArrowClient, None, N


@pytest.fixture()
def flaky_gds_client(flaky_flight_server: FlakyFlightServer) -> Generator[GdsArrowClient, None, None]:
def flaky_gds_client(
flaky_flight_server: FlakyFlightServer, retry_config_v2: RetryConfigV2
) -> Generator[GdsArrowClient, None, None]:
with AuthenticatedArrowClient.create(
ArrowInfo(f"localhost:{flaky_flight_server.port}", True, True, ["v1"]),
UsernamePasswordAuthentication("user", "password"),
retry_config=retry_config_v2,
) as arrow_client:
yield GdsArrowClient(arrow_client)

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/arrow_client/V2/test_job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ def test_wait_for_job_progress_bar_qualitative(mocker: MockerFixture) -> None:

progress_output = pbarOutputStream.getvalue().split("\r")
assert "Algo [elapsed: 00:00 ]" in progress_output
assert "Algo [elapsed: 00:01 , status: RUNNING, task: Halfway there]" in progress_output
assert any("Algo [elapsed: 00:03 , status: FINISHED]" in line for line in progress_output)
assert "Algo [elapsed: 00:00 , status: RUNNING, task: Halfway there]" in progress_output
assert any("Algo [elapsed: 00:00 , status: FINISHED]" in line for line in progress_output)


def test_get_summary(mocker: MockerFixture) -> None:
Expand Down
17 changes: 17 additions & 0 deletions tests/unit/arrow_client/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest
from pyarrow._flight import FlightInternalError, FlightTimedOutError, FlightUnavailableError

from graphdatascience.retry_utils.retry_config import RetryConfigV2, StopConfig


@pytest.fixture
def retry_config_v2() -> RetryConfigV2:
return RetryConfigV2(
retryable_exceptions=[
FlightTimedOutError,
FlightUnavailableError,
FlightInternalError,
],
stop_config=StopConfig(after_delay=10, after_attempt=5),
wait_config=None, # No wait for tests. Makes them faster
)
24 changes: 5 additions & 19 deletions tests/unit/arrow_client/test_authenticated_flight_client.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,16 @@
import pytest
from pyarrow._flight import FlightInternalError, FlightTimedOutError, FlightUnavailableError

from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication
from graphdatascience.arrow_client.arrow_info import ArrowInfo
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient, ConnectionInfo
from graphdatascience.retry_utils.retry_config import ExponentialWaitConfig, RetryConfigV2, StopConfig
from graphdatascience.retry_utils.retry_config import ExponentialWaitConfig, RetryConfigV2


@pytest.fixture
def arrow_info() -> ArrowInfo:
return ArrowInfo(listenAddress="localhost:8491", enabled=True, running=True, versions=["1.0.0"])


@pytest.fixture
def retry_config() -> RetryConfigV2:
return RetryConfigV2(
retryable_exceptions=[
FlightTimedOutError,
FlightUnavailableError,
FlightInternalError,
],
stop_config=StopConfig(after_delay=10, after_attempt=5),
wait_config=ExponentialWaitConfig(multiplier=1, min=1, max=10),
)


@pytest.fixture
def mock_auth() -> ArrowAuthentication:
class MockAuthentication(ArrowAuthentication):
Expand All @@ -42,14 +28,14 @@ def test_create_authenticated_arrow_client(arrow_info: ArrowInfo, mock_auth: Arr
assert client.connection_info() == ConnectionInfo("localhost", 8491, encrypted=True)


def test_connection_info(arrow_info: ArrowInfo, retry_config: RetryConfigV2) -> None:
client = AuthenticatedArrowClient(host="localhost", port=8491, retry_config=retry_config)
def test_connection_info(arrow_info: ArrowInfo, retry_config_v2: RetryConfigV2) -> None:
client = AuthenticatedArrowClient(host="localhost", port=8491, retry_config=retry_config_v2)
connection_info = client.connection_info()
assert connection_info == ConnectionInfo("localhost", 8491, encrypted=False)


def test_pickle_roundtrip(arrow_info: ArrowInfo, retry_config: RetryConfigV2) -> None:
client = AuthenticatedArrowClient(host="localhost", port=8491, retry_config=retry_config)
def test_pickle_roundtrip(arrow_info: ArrowInfo, retry_config_v2: RetryConfigV2) -> None:
client = AuthenticatedArrowClient(host="localhost", port=8491, retry_config=retry_config_v2)
import pickle

pickled_client = pickle.dumps(client)
Expand Down
17 changes: 15 additions & 2 deletions tests/unit/query_runner/progress/test_query_progress_logger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from io import StringIO
Expand All @@ -14,19 +15,27 @@


def test_call_through_functions() -> None:
progress_fetched_event = threading.Event()
progress_called = []

def fake_run_cypher(query: str, database: str | None = None) -> DataFrame:
progress_called.append(time.time())

assert "CALL gds.listProgress('foo')" in query
assert database == "database"

progress_fetched_event.set()

return DataFrame([{"progress": "n/a", "taskName": "Test task", "status": "RUNNING"}])

def fake_query() -> DataFrame:
time.sleep(1)
progress_fetched_event.wait(5)
return DataFrame([{"result": 42}])

qpl = QueryProgressLogger(fake_run_cypher, lambda: ServerVersion(3, 0, 0))
df = qpl.run_with_progress_logging(fake_query, "foo", "database")

assert len(progress_called) > 0
assert df["result"][0] == 42


Expand All @@ -45,14 +54,18 @@ def fake_query() -> DataFrame:


def test_uses_beta_endpoint() -> None:
progress_fetched_event = threading.Event()

def fake_run_cypher(query: str, database: str | None = None) -> DataFrame:
assert "CALL gds.beta.listProgress('foo')" in query
assert database == "database"

progress_fetched_event.set()

return DataFrame([{"progress": "n/a", "taskName": "Test task", "status": "RUNNING"}])

def fake_query() -> DataFrame:
time.sleep(1)
progress_fetched_event.wait(5)
return DataFrame([{"result": 42}])

qpl = QueryProgressLogger(fake_run_cypher, lambda: ServerVersion(2, 4, 0))
Expand Down
Loading