From 87e3382217edf155ef51c4b137b10af1be92a202 Mon Sep 17 00:00:00 2001 From: Chengbiao Jin Date: Mon, 23 Mar 2026 13:50:16 -0700 Subject: [PATCH 01/11] Add conda build --- build.sh | 98 ++++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 85 insertions(+), 13 deletions(-) diff --git a/build.sh b/build.sh index 5e173613..f304b2ab 100755 --- a/build.sh +++ b/build.sh @@ -1,26 +1,35 @@ #!/usr/bin/env bash set -euo pipefail +RECIPE_DIR="$(cd "$(dirname "$0")/pytigergraph-recipe/recipe" && pwd)" +PYPI_PACKAGE="pytigergraph" + usage() { cat <&2; usage >&2; exit 1 ;; esac shift done +# ── PyPI ──────────────────────────────────────────────────────────────────── + if $DO_BUILD; then echo "---- Removing old dist ----" rm -rf dist - echo "---- Building new package ----" + echo "---- Building PyPI package ----" python3 -m build fi @@ -55,3 +69,61 @@ if $DO_UPLOAD; then echo "---- Uploading to PyPI ----" python3 -m twine upload dist/* fi + +# ── Conda ─────────────────────────────────────────────────────────────────── + +if $DO_CONDA_BUILD; then + if ! command -v conda-build &>/dev/null; then + echo "Error: conda-build not found. Run: conda install conda-build" >&2 + exit 1 + fi + + # Verify the required version is already published on PyPI before proceeding. + RECIPE_VERSION=$(grep "^ version:" "$RECIPE_DIR/meta.yaml" | awk '{print $2}' | tr -d '"') + echo "---- Checking PyPI for $PYPI_PACKAGE==$RECIPE_VERSION ----" + PYPI_VERSIONS=$(curl -sf "https://pypi.org/pypi/$PYPI_PACKAGE/json" | python3 -c "import sys,json; print('\n'.join(json.load(sys.stdin)['releases'].keys()))" 2>/dev/null || true) + if ! echo "$PYPI_VERSIONS" | grep -qx "$RECIPE_VERSION"; then + echo " $PYPI_PACKAGE==$RECIPE_VERSION not found on PyPI. Running --all to build and publish first..." + rm -rf dist + python3 -m build + python3 -m twine upload dist/* + echo " ✓ Published $PYPI_PACKAGE==$RECIPE_VERSION to PyPI" + else + echo " ✓ Found $PYPI_PACKAGE==$RECIPE_VERSION on PyPI" + fi + + # Compute sha256 of the tarball declared in the recipe and verify it matches. + TARBALL_URL=$(grep "url:" "$RECIPE_DIR/meta.yaml" | awk '{print $2}') + echo "---- Computing sha256 for $TARBALL_URL ----" + COMPUTED_SHA=$(curl -sL "$TARBALL_URL" | sha256sum | awk '{print $1}') + RECIPE_SHA=$(grep "sha256:" "$RECIPE_DIR/meta.yaml" | awk '{print $2}' || true) + if [[ -n "$RECIPE_SHA" && "$COMPUTED_SHA" != "$RECIPE_SHA" ]]; then + echo "Error: sha256 mismatch!" >&2 + echo " recipe : $RECIPE_SHA" >&2 + echo " actual : $COMPUTED_SHA" >&2 + exit 1 + fi + if [[ -z "$RECIPE_SHA" ]]; then + echo "Warning: no sha256 in recipe. For conda-forge submission add:" + echo " sha256: $COMPUTED_SHA" + fi + + echo "---- Building conda package ----" + conda build "$RECIPE_DIR" +fi + +if $DO_CONDA_UPLOAD; then + if ! command -v anaconda &>/dev/null; then + echo "Error: anaconda-client not found. Run: conda install anaconda-client" >&2 + exit 1 + fi + + CONDA_PKG=$(conda build "$RECIPE_DIR" --output) + if [[ ! -f "$CONDA_PKG" ]]; then + echo "Error: conda package not found at $CONDA_PKG. Run --conda-build first." >&2 + exit 1 + fi + + echo "---- Uploading conda package to anaconda.org ----" + anaconda upload --user tigergraph "$CONDA_PKG" +fi From f491423f7dcf73b0db00759ac05dcda382f36cb8 Mon Sep 17 00:00:00 2001 From: Chengbiao Jin Date: Fri, 27 Mar 2026 13:21:10 -0700 Subject: [PATCH 02/11] Release 2.0.2 --- CHANGELOG.md | 27 + build.sh | 60 +- pyTigerGraph/__init__.py | 2 +- pyTigerGraph/common/base.py | 62 ++- pyTigerGraph/common/gsql.py | 58 ++ pyTigerGraph/pyTigerGraphEdge.py | 2 +- pyTigerGraph/pyTigerGraphGSQL.py | 25 +- pyTigerGraph/pyTigerGraphLoading.py | 37 +- pyTigerGraph/pyTigerGraphSchema.py | 335 +++++++++-- pyTigerGraph/pyTigerGraphVertex.py | 10 +- pyTigerGraph/pytgasync/pyTigerGraphEdge.py | 6 +- pyTigerGraph/pytgasync/pyTigerGraphGSQL.py | 27 +- pyTigerGraph/pytgasync/pyTigerGraphLoading.py | 37 +- pyTigerGraph/pytgasync/pyTigerGraphSchema.py | 337 +++++++++-- pyTigerGraph/pytgasync/pyTigerGraphUtils.py | 2 +- pyTigerGraph/pytgasync/pyTigerGraphVertex.py | 10 +- pyproject.toml | 2 +- pytigergraph-recipe/recipe/meta.yaml | 20 +- tests/test_common_base.py | 116 ++++ tests/test_common_query_helpers.py | 522 ++++++++++++++++++ 20 files changed, 1570 insertions(+), 127 deletions(-) create mode 100644 tests/test_common_base.py create mode 100644 tests/test_common_query_helpers.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 51570ad2..eada59c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,33 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [2.0.2] - 2026-03-27 + +### New Features + +- **Schema Change Job APIs** — `createSchemaChangeJob()`, `getSchemaChangeJobs()`, `runSchemaChangeJob()`, `dropSchemaChangeJobs()` for managing schema change jobs via REST. +- **`force` parameter for `runSchemaChange()`** — allows forcing schema changes even when they would cause data loss. Also accepts `dict` (JSON format) for TigerGraph >= 4.0 and supports global schema changes. +- **Graph scope control** — `useGraph(graphName)` and `useGlobal()` methods on the connection object, mirroring GSQL's `USE GRAPH` / `USE GLOBAL`. `useGlobal()` doubles as a context manager for temporary global scoping (`with conn.useGlobal(): ...`). +- **GSQL reserved keyword helpers** — `getReservedKeywords()` and `isReservedKeyword(name)` static methods to query the canonical set of GSQL reserved keywords. +- **Conda build support** — `build.sh` now supports `--conda-build`, `--conda-upload`, `--conda-all`, and `--conda-forge-test` for building and validating conda packages. + +### Fixed + +- **`_refresh_auth_headers()` init ordering** — auth header cache is now built before `_verify_jwt_token_support()` and the tgCloud ping, preventing `AttributeError` when using JWT tokens or TigerGraph Cloud hosts. +- **`gsql()` graph scope** — the `POST /gsql/v1/statements` path now falls back to `self.graphname` and prepends `USE GRAPH` automatically when the connection has a graph set. +- **`dropVertices()`** now correctly falls back to `self.graphname` when the `graph` parameter is `None`. +- **`dropAllDataSources()`** now correctly uses `self.graphname` fallback for the 4.x REST API path. +- **`getVectorIndexStatus()`** no longer produces a malformed URL when called without a graph name; now supports global scope (returns status for all graphs). +- **`previewSampleData()`** now raises `TigerGraphException` when no graph name is available, instead of sending an empty graph name to the server. +- **Docstring fixes** — corrected `timeout` parameter descriptions across vertex and edge query methods. + +### Tests + +- Added `test_common_base.py` — unit tests for auth header init ordering and credential refresh. +- Added `test_common_query_helpers.py` — unit tests for POST query parameter encoding (`_encode_str_for_post`, `_prep_query_parameters_json`) and round-trip verification. + +--- + ## [2.0.1] - 2026-03-23 ### Breaking Changes diff --git a/build.sh b/build.sh index f304b2ab..f9065979 100755 --- a/build.sh +++ b/build.sh @@ -3,6 +3,8 @@ set -euo pipefail RECIPE_DIR="$(cd "$(dirname "$0")/pytigergraph-recipe/recipe" && pwd)" PYPI_PACKAGE="pytigergraph" +CONDA_FORGE_PKG="pytigergraph" +STAGED_RECIPES_DIR="${STAGED_RECIPES_DIR:-$(cd "$(dirname "$0")/../staged-recipes" 2>/dev/null && pwd || echo "")}" usage() { cat <&2; usage >&2; exit 1 ;; @@ -109,7 +117,7 @@ if $DO_CONDA_BUILD; then fi echo "---- Building conda package ----" - conda build "$RECIPE_DIR" + conda build -c conda-forge "$RECIPE_DIR" fi if $DO_CONDA_UPLOAD; then @@ -118,7 +126,7 @@ if $DO_CONDA_UPLOAD; then exit 1 fi - CONDA_PKG=$(conda build "$RECIPE_DIR" --output) + CONDA_PKG=$(conda build -c conda-forge "$RECIPE_DIR" --output) if [[ ! -f "$CONDA_PKG" ]]; then echo "Error: conda package not found at $CONDA_PKG. Run --conda-build first." >&2 exit 1 @@ -127,3 +135,33 @@ if $DO_CONDA_UPLOAD; then echo "---- Uploading conda package to anaconda.org ----" anaconda upload --user tigergraph "$CONDA_PKG" fi + +if $DO_CONDA_FORGE_TEST; then + if [[ -z "$STAGED_RECIPES_DIR" || ! -f "$STAGED_RECIPES_DIR/build-locally.py" ]]; then + echo "Error: staged-recipes not found at '${STAGED_RECIPES_DIR}'." >&2 + echo "Clone it with: git clone https://github.com/conda-forge/staged-recipes.git ../staged-recipes" >&2 + echo "Or set: export STAGED_RECIPES_DIR=/path/to/staged-recipes" >&2 + exit 1 + fi + if [[ ! -f "$STAGED_RECIPES_DIR/recipes/$CONDA_FORGE_PKG/meta.yaml" ]]; then + echo "Error: recipe not found at $STAGED_RECIPES_DIR/recipes/$CONDA_FORGE_PKG/meta.yaml" >&2 + echo "Copy your recipe: cp $RECIPE_DIR/meta.yaml $STAGED_RECIPES_DIR/recipes/$CONDA_FORGE_PKG/meta.yaml" >&2 + exit 1 + fi + # Detect the local platform config for build-locally.py + _OS="$(uname -s)" + _ARCH="$(uname -m)" + if [[ "$_OS" == "Darwin" && "$_ARCH" == "arm64" ]]; then + _CONFIG="osx_arm64" + elif [[ "$_OS" == "Darwin" ]]; then + _CONFIG="osx64" + elif [[ "$_OS" == "Linux" && "$_ARCH" == "aarch64" ]]; then + _CONFIG="linux_aarch64" + else + _CONFIG="linux64" + fi + echo "---- Running conda-forge CI simulation (config: $_CONFIG) ----" + echo "Note: build-locally.py builds ALL recipes in staged-recipes/recipes/" + cd "$STAGED_RECIPES_DIR" + python build-locally.py "$_CONFIG" +fi diff --git a/pyTigerGraph/__init__.py b/pyTigerGraph/__init__.py index 82aadf27..18674644 100644 --- a/pyTigerGraph/__init__.py +++ b/pyTigerGraph/__init__.py @@ -7,7 +7,7 @@ try: __version__ = _pkg_version("pyTigerGraph") except PackageNotFoundError: - __version__ = "2.0.1" + __version__ = "2.0.2" __license__ = "Apache 2" diff --git a/pyTigerGraph/common/base.py b/pyTigerGraph/common/base.py index ef78b5f0..7039f214 100644 --- a/pyTigerGraph/common/base.py +++ b/pyTigerGraph/common/base.py @@ -122,6 +122,11 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "", self.base64_credential = base64.b64encode( "{0}:{1}".format(self.username, self.password).encode("utf-8")).decode("utf-8") + # Pre-build auth header dicts immediately after credentials are set so + # _prep_req can safely use _cached_token_auth/_cached_pwd_auth in any + # subsequent _get()/_req() call (e.g. tgCloud ping, JWT verification). + self._refresh_auth_headers() + # Detect auth mode automatically by checking if jwtToken or apiToken is provided self.authHeader = self._set_auth_header() self.authMode = "token" if (self.jwtToken or self.apiToken) else "pwd" @@ -226,16 +231,63 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "", self.awsIamHeaders["X-Amz-Security-Token"] = request.headers["X-Amz-Security-Token"] self.awsIamHeaders["Authorization"] = request.headers["Authorization"] + self.asynchronous = False + if self.jwtToken: self._verify_jwt_token_support() - self.asynchronous = False + logger.debug("exit: __init__") - # Pre-build per-authMode header dicts so _prep_req avoids repeating - # the isinstance/string-comparison chain on every request. - self._refresh_auth_headers() + # -- Scope helpers (mirror GSQL's USE GRAPH / USE GLOBAL) ---------- - logger.debug("exit: __init__") + class _GlobalScope: + """Context manager returned by :meth:`useGlobal` for temporary global scope.""" + + def __init__(self, conn): + self._conn = conn + self._saved = None + + def __enter__(self): + self._saved = self._conn.graphname + self._conn.graphname = "" + return self._conn + + def __exit__(self, *exc): + self._conn.graphname = self._saved + + def useGraph(self, graphName: str = ""): + """Switch this connection to a specific graph's scope. + + Mirrors GSQL's ``USE GRAPH `` command. + After this call, all operations that accept an optional graph name + will target this graph by default. + + If *graphName* is omitted or empty, behaves the same as + :meth:`useGlobal` (switches to global scope). + + Args: + graphName: + Name of the graph to use. Empty or omitted for global scope. + """ + if not graphName: + return self.useGlobal() + self.graphname = graphName + + def useGlobal(self): + """Switch this connection to global scope. + + Mirrors GSQL's ``USE GLOBAL`` command. + After this call, all operations that accept an optional graph name + will target the global scope by default. + + Can also be used as a context manager for temporary global scope:: + + with conn.useGlobal(): + conn.getSchemaChangeJobs() # global + # conn.graphname is restored here + """ + self.graphname = "" + return self._GlobalScope(self) def _set_auth_header(self): """Set the authentication header based on available tokens or credentials.""" diff --git a/pyTigerGraph/common/gsql.py b/pyTigerGraph/common/gsql.py index 917a11cf..b7351fc6 100644 --- a/pyTigerGraph/common/gsql.py +++ b/pyTigerGraph/common/gsql.py @@ -2,6 +2,10 @@ Use GSQL within pyTigerGraph. All functions in this module are called as methods on a link:https://docs.tigergraph.com/pytigergraph/current/core-functions/base[`TigerGraphConnection` object]. + +This module also defines the canonical set of GSQL reserved keywords, +serving as the single source of truth for pyTigerGraph, tigergraph-mcp, +and any downstream application. """ import logging import re @@ -157,3 +161,57 @@ def _parse_get_udf(responses, json_out): if rets: return rets[0] return "" + + +# ─── GSQL Reserved Keywords ────────────────────────────────────────────── + +_RESERVED_KEYWORDS: frozenset = frozenset({ + "ACCUM", "ADD", "ALL", "ALLOCATE", "ALTER", "AND", "ANY", "AS", "ASC", + "AVG", "BAG", "BATCH", "BETWEEN", "BIGINT", "BLOB", "BOOL", "BOOLEAN", + "BOTH", "BREAK", "BY", "CALL", "CASCADE", "CASE", "CATCH", "CHAR", + "CHARACTER", "CHECK", "CLOB", "COALESCE", "COMPRESS", "CONST", "CONSTRAINT", + "CONTINUE", "COST", "COUNT", "CREATE", "CURRENT_DATE", "CURRENT_TIME", + "CURRENT_TIMESTAMP", "CURSOR", "KAFKA", "S3", "DATETIME", "DATETIME_ADD", + "DATETIME_SUB", "DAY", "DATETIME_DIFF", "DATETIME_TO_EPOCH", + "DATETIME_FORMAT", "DECIMAL", "DECLARE", "DELETE", "DESC", "DISTRIBUTED", + "DO", "DOUBLE", "DROP", "EDGE", "ELSE", "ELSEIF", "EPOCH_TO_DATETIME", + "END", "ESCAPE", "EXCEPTION", "EXISTS", "FALSE", "FILE", "SYS.FILE_NAME", + "FILTER", "FIXED_BINARY", "FLOAT", "FOR", "FOREACH", "FROM", "GLOBAL", + "GRANTS", "GRAPH", "GROUP", "GROUPBYACCUM", "HAVING", "HOUR", "HEADER", + "HEAPACCUM", "IF", "IGNORE", "SYS.INTERNAL_ID", "IN", "INDEX", + "INPUT_LINE_FILTER", "INSERT", "INT", "INTERSECT", "INT8", "INT16", "INT32", + "INT32_T", "INT64_T", "INTEGER", "INTERPRET", "INTO", "IS", "ISEMPTY", + "JOB", "JOIN", "JSONARRAY", "JSONOBJECT", "KEY", "LEADING", "LIKE", "LIMIT", + "LIST", "LOAD", "LOADACCUM", "LOG", "LONG", "MAP", "MINUTE", "NOBODY", + "NOT", "NOW", "NULL", "OFFSET", "ON", "OPENCYPHER", "OR", "ORDER", + "PINNED", "POLICY", "POST_ACCUM", "POST-ACCUM", "PRIMARY", "PRIMARY_ID", + "PRINT", "PROXY", "QUERY", "QUIT", "RAISE", "RANGE", "REDUCE", "REPLACE", + "RESET_COLLECTION_ACCUM", "RETURN", "RETURNS", "ROW", "SAMPLE", "SECOND", + "SELECT", "SELECTVERTEX", "SET", "STATIC", "STRING", "SUM", "TARGET", + "TEMP_TABLE", "THEN", "TO", "TO_CSV", "TO_DATETIME", "TRAILING", + "TRANSLATESQL", "TRIM", "TRUE", "TRY", "TUPLE", "TYPE", "TYPEDEF", "UINT", + "UINT8", "UINT16", "UINT32", "UINT8_T", "UINT32_T", "UINT64_T", "UNION", + "UPDATE", "UPSERT", "USING", "VALUES", "VERTEX", "WHEN", "WHERE", "WHILE", + "WITH", "GSQL_SYS_TAG", "_INTERNAL_ATTR_TAG", +}) + + +def _get_reserved_keywords() -> frozenset: + """Return the full set of GSQL reserved keywords. + + Returns: + A frozenset of uppercase keyword strings. + """ + return _RESERVED_KEYWORDS + + +def _is_reserved_keyword(name: str) -> bool: + """Check whether *name* is a GSQL reserved keyword (case-insensitive). + + Args: + name: The identifier to check. + + Returns: + True if the name is reserved. + """ + return name.upper() in _RESERVED_KEYWORDS diff --git a/pyTigerGraph/pyTigerGraphEdge.py b/pyTigerGraph/pyTigerGraphEdge.py index bab09015..32cba541 100644 --- a/pyTigerGraph/pyTigerGraphEdge.py +++ b/pyTigerGraph/pyTigerGraphEdge.py @@ -722,7 +722,7 @@ def getEdgesDataFrame(self, sourceVertexType: str, sourceVertexId: str, edgeType limit: Maximum number of edge instances to be returned (after sorting). timeout: - Time allowed for successful execution (0 = no limit, default). + Time allowed for successful execution in seconds. 0 or omitted applies the system-wide endpoint timeout. Returns: The (selected) details of the (matching) edge instances (sorted, limited) as dictionary, diff --git a/pyTigerGraph/pyTigerGraphGSQL.py b/pyTigerGraph/pyTigerGraphGSQL.py index a149fe10..dc8bf831 100644 --- a/pyTigerGraph/pyTigerGraphGSQL.py +++ b/pyTigerGraph/pyTigerGraphGSQL.py @@ -13,7 +13,9 @@ from pyTigerGraph.common.gsql import ( _parse_gsql, _prep_get_udf, - _parse_get_udf + _parse_get_udf, + _get_reserved_keywords, + _is_reserved_keyword, ) from pyTigerGraph.common.exception import TigerGraphException from pyTigerGraph.pyTigerGraphBase import pyTigerGraphBase @@ -495,3 +497,24 @@ def getGSQLVersion(self, verbose: bool = False) -> dict: logger.debug("exit: getGSQLVersion") return res + + @staticmethod + def getReservedKeywords() -> frozenset: + """Return the full set of GSQL reserved keywords. + + Returns: + A frozenset of uppercase keyword strings. + """ + return _get_reserved_keywords() + + @staticmethod + def isReservedKeyword(name: str) -> bool: + """Check whether *name* is a GSQL reserved keyword (case-insensitive). + + Args: + name: The identifier to check. + + Returns: + True if the name is reserved. + """ + return _is_reserved_keyword(name) diff --git a/pyTigerGraph/pyTigerGraphLoading.py b/pyTigerGraph/pyTigerGraphLoading.py index 2ee54939..c88f5f30 100644 --- a/pyTigerGraph/pyTigerGraphLoading.py +++ b/pyTigerGraph/pyTigerGraphLoading.py @@ -28,6 +28,7 @@ _prep_sample_data_url, ) +from pyTigerGraph.common.exception import TigerGraphException from pyTigerGraph.common.gsql import _wrap_gsql_result from pyTigerGraph.pyTigerGraphBase import pyTigerGraphBase @@ -614,8 +615,9 @@ def dropAllDataSources(self, graphName: str = None) -> dict: Returns: A dict with at least a ``"message"`` key describing the outcome. """ + graph = graphName or self.graphname if self._version_greater_than_4_0(): - url = _prep_drop_all_data_sources(self.gsUrl, graphName) + url = _prep_drop_all_data_sources(self.gsUrl, graph) return self._req("DELETE", url, resKey=None) return _wrap_gsql_result(self.gsql("DROP DATA_SOURCE *")) @@ -640,31 +642,46 @@ def previewSampleData(self, dsName: str, path: str, size: int = 10, ) graph = graphName or self.graphname + if not graph: + raise TigerGraphException( + "previewSampleData requires a graph name. " + "Set graphname on the connection or pass graphName explicitly.", 0) url = _prep_sample_data_url(self.gsUrl) body = { "graphName": graph, "dataSource": dsName, "path": path, "size": size, + "parsing": { + "fileFormat": "none", + "eol": "\\n", + }, } - return self._req("POST", url, data=body, jsonData=True, resKey="results") + return self._req("POST", url, data=body, jsonData=True, + authMode="pwd", resKey="results") def getVectorIndexStatus(self, graphName: str = None, vertexType: str = None, vectorName: str = None) -> dict: """Get the rebuild status of vector indexes. - Uses REST++ endpoint ``GET /vector/status/[/[/]]``. + Uses REST++ endpoint + ``GET /vector/status[/[/[/]]]``. Args: graphName: Graph name. Defaults to the connection's current graph. - vertexType: Optionally filter by vertex type. - vectorName: Optionally filter by vector attribute name. + If the connection has no graph set (global scope), + returns status for all graphs. + vertexType: Optionally filter by vertex type (requires graphName). + vectorName: Optionally filter by vector attribute name + (requires vertexType). """ graph = graphName or self.graphname - path = f"/vector/status/{graph}" - if vertexType: - path += f"/{vertexType}" - if vectorName: - path += f"/{vectorName}" + path = "/vector/status" + if graph: + path += f"/{graph}" + if vertexType: + path += f"/{vertexType}" + if vectorName: + path += f"/{vectorName}" return self._req("GET", self.restppUrl + path) \ No newline at end of file diff --git a/pyTigerGraph/pyTigerGraphSchema.py b/pyTigerGraph/pyTigerGraphSchema.py index d34668d3..c3fb63d9 100644 --- a/pyTigerGraph/pyTigerGraphSchema.py +++ b/pyTigerGraph/pyTigerGraphSchema.py @@ -393,7 +393,9 @@ def dropVertices(self, vertex_names: Union[str, list], graph: str = None, of strings. Use ``"all"`` to drop all vertices. graph (str, optional): The graph from which vertex types should be dropped. - If not provided, drops global vertex types. + If ``None`` and the connection has a ``graphname`` set, + that graph is used. If neither is set, drops global + vertex types. ignoreErrors (bool): If ``True``, suppress exceptions (e.g. when some vertices do not exist) and return the error as a dict instead. Defaults to ``False``. @@ -425,9 +427,11 @@ def dropVertices(self, vertex_names: Union[str, list], graph: str = None, else: raise TigerGraphException("vertex_names must be a string or list of strings.", 0) + gname = graph or self.graphname + params = {"vertex": vertex_param} - if graph is not None: - params["graph"] = graph + if gname: + params["graph"] = gname if not ignoreErrors: res = self._delete(self.gsUrl + "/gsql/v1/schema/vertices", @@ -488,12 +492,20 @@ def validateGraphSchema(self) -> dict: return res - def createGraph(self, graphName: str) -> dict: - """Creates an empty graph. + def createGraph(self, graphName: str, + vertexTypes: list = None, + edgeTypes: list = None) -> dict: + """Creates a graph, optionally including existing global vertex/edge types. Args: graphName: Name of the graph to create. + vertexTypes: + Optional list of existing global vertex type names to include + in the graph. Pass ``["*"]`` to include all global types. + edgeTypes: + Optional list of existing global edge type names to include + in the graph. Returns: A dict with at least a ``"message"`` key describing the outcome. @@ -504,13 +516,23 @@ def createGraph(self, graphName: str) -> dict: """ logger.debug("entry: createGraph") + type_names = [] + if vertexTypes: + type_names.extend(vertexTypes) + if edgeTypes: + type_names.extend(edgeTypes) + type_list = ", ".join(type_names) + gsql_cmd = f"CREATE GRAPH {graphName}({type_list})" + if self._version_greater_than_4_0(): - data = {"name": graphName} + data = {"gsql": gsql_cmd} res = self._post(self.gsUrl + "/gsql/v1/schema/graphs", data=data, authMode="pwd", resKey=None, - headers={'Content-Type': 'application/json'}, jsonData=True) + params={"gsql": "true", "graphName": graphName}, + headers={"Content-Type": "application/json"}, + jsonData=True) else: - res = _wrap_gsql_result(self.gsql(f"CREATE GRAPH {graphName}()")) + res = _wrap_gsql_result(self.gsql(gsql_cmd)) if logger.level == logging.DEBUG: logger.debug("return: " + str(res)) @@ -539,7 +561,8 @@ def dropGraph(self, graphName: str) -> dict: authMode="pwd", resKey=None, headers={'Content-Type': 'application/json'}) else: - res = _wrap_gsql_result(self.gsql(f"DROP GRAPH {graphName}")) + res = _wrap_gsql_result( + self.gsql(f"DROP GRAPH {graphName}")) if logger.level == logging.DEBUG: logger.debug("return: " + str(res)) @@ -571,48 +594,83 @@ def listGraphs(self) -> list: return res - def runSchemaChange(self, gsqlStatements: Union[str, list], graphName: str = None) -> dict: - """Runs schema change statements on a graph. + def runSchemaChange(self, gsqlStatements: Union[str, list, dict], + graphName: str = None, force: bool = False) -> dict: + """Runs schema change statements directly (without creating a named job). + + Supports both local (graph-scoped) and global schema changes. Args: gsqlStatements: - GSQL schema change DDL statements (e.g. ``ADD VERTEX ...``, ``ADD EDGE ...``). - Can be a string of semicolon-separated statements or a list of statements. + Schema change specification in one of two formats: + + **dict (JSON format, TG >= 4.0 only)** — Sent directly to + ``POST /gsql/v1/schema/change`` as JSON. Supports keys such as + ``addVertexTypes``, ``dropVertexTypes``, ``alterVertexTypes``, + ``addEdgeTypes``, ``dropEdgeTypes``, ``alterEdgeTypes``. + + **str or list (GSQL DDL)** — Wrapped in a GSQL schema change + job and executed via ``gsql()``. Works on all TigerGraph + versions. + graphName: - Target graph name. Uses connection's graphname if not provided. + Target graph name for a local schema change. If ``None`` and + the connection has no ``graphname`` set, a **global** schema + change is executed instead. + + force: + If ``True``, abort any loading jobs that conflict with the + schema change. Only applies to the JSON (dict) path. Returns: A dict with at least a ``"message"`` key describing the outcome. Endpoints: - - `POST /gsql/v1/schema/change?graph={graphName}` (In TigerGraph versions >= 4.0) - - Falls back to GSQL schema change job for TigerGraph versions < 4.0 + - ``POST /gsql/v1/schema/change`` with JSON body (TG >= 4.0) + - GSQL schema change job via ``gsql()`` (all versions) """ logger.debug("entry: runSchemaChange") gname = graphName or self.graphname - if isinstance(gsqlStatements, list): - gsqlStatements = "\n".join( - s if s.rstrip().endswith(";") else s + ";" - for s in gsqlStatements - ) - - if self._version_greater_than_4_0(): - params = {"graph": gname} + if isinstance(gsqlStatements, dict): + if not self._version_greater_than_4_0(): + raise TigerGraphException( + "JSON-format schema changes require TigerGraph >= 4.0. " + "Pass GSQL DDL statements as a string instead.") + params = {} + if gname: + params["graph"] = gname + if force: + params["force"] = "true" res = self._post(self.gsUrl + "/gsql/v1/schema/change", - params=params, data=gsqlStatements, authMode="pwd", resKey=None, - headers={'Content-Type': 'text/plain'}) + params=params, data=json.dumps(gsqlStatements), + authMode="pwd", resKey=None, + headers={'Content-Type': 'application/json'}) else: + if isinstance(gsqlStatements, list): + gsqlStatements = "\n".join( + s if s.rstrip().endswith(";") else s + ";" + for s in gsqlStatements + ) job_name = f"schema_change_{uuid.uuid4().hex[:8]}" - gsql_cmd = ( - f"USE GRAPH {gname}\n" - f"CREATE SCHEMA_CHANGE JOB {job_name} FOR GRAPH {gname} {{\n" - f" {gsqlStatements}\n" - f"}}\n" - f"RUN SCHEMA_CHANGE JOB {job_name}\n" - f"DROP JOB {job_name}" - ) + if gname: + gsql_cmd = ( + f"USE GRAPH {gname}\n" + f"CREATE SCHEMA_CHANGE JOB {job_name} FOR GRAPH {gname} {{\n" + f" {gsqlStatements}\n" + f"}}\n" + f"RUN SCHEMA_CHANGE JOB {job_name}\n" + f"DROP JOB {job_name}" + ) + else: + gsql_cmd = ( + f"CREATE GLOBAL SCHEMA_CHANGE JOB {job_name} {{\n" + f" {gsqlStatements}\n" + f"}}\n" + f"RUN GLOBAL SCHEMA_CHANGE JOB {job_name}\n" + f"DROP JOB {job_name}" + ) res = _wrap_gsql_result(self.gsql(gsql_cmd)) if logger.level == logging.DEBUG: @@ -620,3 +678,212 @@ def runSchemaChange(self, gsqlStatements: Union[str, list], graphName: str = Non logger.debug("exit: runSchemaChange") return res + + def createSchemaChangeJob(self, jobName: str, statements: Union[str, list, dict], + graphName: str = None) -> dict: + """Creates a named schema change job without running it. + + Args: + jobName: + Name for the schema change job. + + statements: + Schema change specification in one of two formats: + + **dict (JSON format)** — Sent as JSON to + ``POST /gsql/v1/schema/jobs/{jobName}``. + For global jobs the dict should contain a ``"graphs"`` key; + for local jobs it should contain keys such as + ``addVertexTypes``, ``dropVertexTypes``, etc. + + **str or list (GSQL DDL)** — Individual DDL statements + (e.g. ``"ADD VERTEX Foo (...)"``). + They are wrapped in a ``CREATE [GLOBAL] SCHEMA_CHANGE JOB`` + command and sent via the ``?gsql=true`` parameter. + + graphName: + Target graph for a local schema change job. If ``None`` and + the connection has no ``graphname`` set, a **global** job is + created. + + Returns: + The server response dict. + + Endpoint: + - ``POST /gsql/v1/schema/jobs/{jobName}`` + """ + logger.debug("entry: createSchemaChangeJob") + + gname = graphName or self.graphname + + url = self.gsUrl + "/gsql/v1/schema/jobs/" + jobName + + if isinstance(statements, dict): + params = {} + if gname: + params["graph"] = gname + res = self._post(url, params=params, + data=json.dumps(statements), + authMode="pwd", resKey=None, + headers={'Content-Type': 'application/json'}) + else: + if isinstance(statements, list): + statements = "\n".join( + s if s.rstrip().endswith(";") else s + ";" + for s in statements + ) + if gname: + gsql_cmd = ( + f"CREATE SCHEMA_CHANGE JOB {jobName} FOR GRAPH {gname} {{\n" + f" {statements}\n" + f"}}" + ) + params = {"gsql": "true", "graph": gname} + else: + gsql_cmd = ( + f"CREATE GLOBAL SCHEMA_CHANGE JOB {jobName} {{\n" + f" {statements}\n" + f"}}" + ) + params = {"gsql": "true", "type": "global"} + res = self._post(url, params=params, + data=json.dumps({"gsql": gsql_cmd}), + authMode="pwd", resKey=None, + headers={'Content-Type': 'text/plain'}) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(res)) + logger.debug("exit: createSchemaChangeJob") + + return res + + def getSchemaChangeJobs(self, jobName: str = None, graphName: str = None, + jsonFormat: bool = True) -> Union[dict, list]: + """Retrieves schema change jobs. + + Args: + jobName: + Name of a specific job to retrieve. If ``None``, all schema + change jobs are returned. + + graphName: + Graph whose local jobs to retrieve. If ``None`` and the + connection has no ``graphname`` set, global jobs are returned. + + jsonFormat: + If ``True`` (default), requests JSON-formatted output from the + server. + + Returns: + A dict (single job) or list (all jobs) describing the schema + change job(s). + + Endpoints: + - ``GET /gsql/v1/schema/jobs`` (all jobs) + - ``GET /gsql/v1/schema/jobs/{jobName}`` (single job) + """ + logger.debug("entry: getSchemaChangeJobs") + + gname = graphName or self.graphname + + url = self.gsUrl + "/gsql/v1/schema/jobs" + if jobName: + url += "/" + jobName + + params = {} + if gname: + params["graph"] = gname + if jsonFormat: + params["json"] = "true" + + res = self._get(url, params=params, authMode="pwd", resKey=None) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(res)) + logger.debug("exit: getSchemaChangeJobs") + + return res + + def runSchemaChangeJob(self, jobName: str, graphName: str = None, + force: bool = False) -> dict: + """Runs an existing (already created) schema change job. + + Args: + jobName: + Name of the schema change job to run. + + graphName: + Graph on which to run the local job. If ``None`` and the + connection has no ``graphname`` set, runs a global job. + + force: + If ``True``, abort any loading jobs that conflict with the + schema change. + + Returns: + The server response dict. + + Endpoint: + - ``PUT /gsql/v1/schema/jobs/{jobName}`` + """ + logger.debug("entry: runSchemaChangeJob") + + gname = graphName or self.graphname + + query_parts = [] + if gname: + query_parts.append(f"graph={gname}") + if force: + query_parts.append("force=true") + + url = self.gsUrl + "/gsql/v1/schema/jobs/" + jobName + if query_parts: + url += "?" + "&".join(query_parts) + + res = self._put(url, authMode="pwd", resKey=None) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(res)) + logger.debug("exit: runSchemaChangeJob") + + return res + + def dropSchemaChangeJobs(self, jobNames: Union[str, list], + graphName: str = None) -> dict: + """Drops one or more schema change jobs. + + Args: + jobNames: + A single job name (str) or a list of job names to drop. + + graphName: + Graph whose local jobs to drop. If ``None`` and the + connection has no ``graphname`` set, drops global jobs. + + Returns: + The server response dict. + + Endpoint: + - ``DELETE /gsql/v1/schema/jobs`` + """ + logger.debug("entry: dropSchemaChangeJobs") + + gname = graphName or self.graphname + + if isinstance(jobNames, list): + job_param = ",".join(jobNames) + else: + job_param = jobNames + + params = {"jobName": job_param} + if gname: + params["graph"] = gname + + res = self._delete(self.gsUrl + "/gsql/v1/schema/jobs", + params=params, authMode="pwd", resKey=None) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(res)) + logger.debug("exit: dropSchemaChangeJobs") + + return res diff --git a/pyTigerGraph/pyTigerGraphVertex.py b/pyTigerGraph/pyTigerGraphVertex.py index d6b85c7c..79dadeb0 100644 --- a/pyTigerGraph/pyTigerGraphVertex.py +++ b/pyTigerGraph/pyTigerGraphVertex.py @@ -447,7 +447,7 @@ def getVertices(self, vertexType: str, select: str = "", where: str = "", withType: (When the output format is "df") should the vertex type be included in the dataframe? timeout: - Time allowed for successful execution (0 = no limit, default). + Time allowed for successful execution in seconds. 0 or omitted applies the system-wide endpoint timeout. Returns: The (selected) details of the (matching) vertex instances (sorted, limited) as @@ -510,7 +510,7 @@ def getVertexDataFrame(self, vertexType: str, select: str = "", where: str = "", Maximum number of vertex instances to be returned (after sorting). Must be used with `sort`. timeout: - Time allowed for successful execution (0 = no limit, default). + Time allowed for successful execution in seconds. 0 or omitted applies the system-wide endpoint timeout. Returns: The (selected) details of the (matching) vertex instances (sorted, limited) as pandas @@ -564,7 +564,7 @@ def getVerticesById(self, vertexType: str, vertexIds: Union[int, str, list], sel withType: (If the output format is "df") should the vertex type be included in the dataframe? timeout: - Time allowed for successful execution (0 = no limit, default). + Time allowed for successful execution in seconds. 0 or omitted applies the system-wide endpoint timeout. Returns: The (selected) details of the (matching) vertex instances as dictionary, JSON or pandas @@ -715,7 +715,7 @@ def delVertices(self, vertexType: str, where: str = "", limit: str = "", sort: s If true, the deleted vertex IDs can never be inserted back, unless the graph is dropped or the graph store is cleared. timeout: - Time allowed for successful execution (0 = no limit, default). + Time allowed for successful execution in seconds. 0 or omitted applies the system-wide endpoint timeout. Returns: A single number of vertices deleted. @@ -762,7 +762,7 @@ def delVerticesById(self, vertexType: str, vertexIds: Union[int, str, list], If true, the deleted vertex IDs can never be inserted back, unless the graph is dropped or the graph store is cleared. timeout: - Time allowed for successful execution (0 = no limit, default). + Time allowed for successful execution in seconds. 0 or omitted applies the system-wide endpoint timeout. Returns: A single number of vertices deleted. diff --git a/pyTigerGraph/pytgasync/pyTigerGraphEdge.py b/pyTigerGraph/pytgasync/pyTigerGraphEdge.py index 5d8e755b..315e8954 100644 --- a/pyTigerGraph/pytgasync/pyTigerGraphEdge.py +++ b/pyTigerGraph/pytgasync/pyTigerGraphEdge.py @@ -458,7 +458,7 @@ async def upsertEdge(self, sourceVertexType: str, sourceVertexId: str, edgeType: targetVertexType, targetVertexId, attributes) - params = {"vertex_must_exist": vertexMustExist} + params = {"vertex_must_exist": str(vertexMustExist).lower()} ret = await self._req("POST", self.restppUrl + "/graph/" + self.graphname, data=data, params=params) ret = ret[0]["accepted_edges"] @@ -537,7 +537,7 @@ async def upsertEdges(self, sourceVertexType: str, edgeType: str, targetVertexTy if atomic: headers["gsql-atomic-level"] = "atomic" - params = {"vertex_must_exist": vertexMustExist} + params = {"vertex_must_exist": str(vertexMustExist).lower()} ret = await self._req("POST", self.restppUrl + "/graph/" + self.graphname, data=data, params=params, headers=headers) ret = ret[0]["accepted_edges"] @@ -708,7 +708,7 @@ async def getEdgesDataFrame(self, sourceVertexType: str, sourceVertexId: str, ed limit: Maximum number of edge instances to be returned (after sorting). timeout: - Time allowed for successful execution (0 = no limit, default). + Time allowed for successful execution in seconds. 0 or omitted applies the system-wide endpoint timeout. Returns: The (selected) details of the (matching) edge instances (sorted, limited) as dictionary, diff --git a/pyTigerGraph/pytgasync/pyTigerGraphGSQL.py b/pyTigerGraph/pytgasync/pyTigerGraphGSQL.py index 27ac33a6..41914c87 100644 --- a/pyTigerGraph/pytgasync/pyTigerGraphGSQL.py +++ b/pyTigerGraph/pytgasync/pyTigerGraphGSQL.py @@ -14,7 +14,9 @@ from pyTigerGraph.common.gsql import ( _parse_gsql, _prep_get_udf, - _parse_get_udf + _parse_get_udf, + _get_reserved_keywords, + _is_reserved_keyword, ) from pyTigerGraph.pytgasync.pyTigerGraphBase import AsyncPyTigerGraphBase @@ -411,7 +413,7 @@ async def getGSQLVersion(self, verbose: bool = False) -> dict: params = {} if verbose: - params["verbose"] = verbose + params["verbose"] = str(verbose).lower() res = await self._req("GET", self.gsUrl+"/gsql/v1/version", params=params, authMode="pwd", resKey=None, @@ -422,3 +424,24 @@ async def getGSQLVersion(self, verbose: bool = False) -> dict: logger.debug("exit: getGSQLVersion") return res + + @staticmethod + def getReservedKeywords() -> frozenset: + """Return the full set of GSQL reserved keywords. + + Returns: + A frozenset of uppercase keyword strings. + """ + return _get_reserved_keywords() + + @staticmethod + def isReservedKeyword(name: str) -> bool: + """Check whether *name* is a GSQL reserved keyword (case-insensitive). + + Args: + name: The identifier to check. + + Returns: + True if the name is reserved. + """ + return _is_reserved_keyword(name) diff --git a/pyTigerGraph/pytgasync/pyTigerGraphLoading.py b/pyTigerGraph/pytgasync/pyTigerGraphLoading.py index 8b60392a..08937449 100644 --- a/pyTigerGraph/pytgasync/pyTigerGraphLoading.py +++ b/pyTigerGraph/pytgasync/pyTigerGraphLoading.py @@ -27,6 +27,7 @@ _prep_drop_all_data_sources, _prep_sample_data_url, ) +from pyTigerGraph.common.exception import TigerGraphException from pyTigerGraph.common.gsql import _wrap_gsql_result from pyTigerGraph.pytgasync.pyTigerGraphBase import AsyncPyTigerGraphBase @@ -619,8 +620,9 @@ async def dropAllDataSources(self, graphName: str = None) -> dict: Returns: A dict with at least a ``"message"`` key describing the outcome. """ + graph = graphName or self.graphname if await self._version_greater_than_4_0(): - url = _prep_drop_all_data_sources(self.gsUrl, graphName) + url = _prep_drop_all_data_sources(self.gsUrl, graph) return await self._req("DELETE", url, resKey=None) return _wrap_gsql_result(await self.gsql("DROP DATA_SOURCE *")) @@ -645,31 +647,46 @@ async def previewSampleData(self, dsName: str, path: str, size: int = 10, ) graph = graphName or self.graphname + if not graph: + raise TigerGraphException( + "previewSampleData requires a graph name. " + "Set graphname on the connection or pass graphName explicitly.", 0) url = _prep_sample_data_url(self.gsUrl) body = { "graphName": graph, "dataSource": dsName, "path": path, "size": size, + "parsing": { + "fileFormat": "none", + "eol": "\\n", + }, } - return await self._req("POST", url, data=body, jsonData=True, resKey="results") + return await self._req("POST", url, data=body, jsonData=True, + authMode="pwd", resKey="results") async def getVectorIndexStatus(self, graphName: str = None, vertexType: str = None, vectorName: str = None) -> dict: """Get the rebuild status of vector indexes. - Uses REST++ endpoint ``GET /vector/status/[/[/]]``. + Uses REST++ endpoint + ``GET /vector/status[/[/[/]]]``. Args: graphName: Graph name. Defaults to the connection's current graph. - vertexType: Optionally filter by vertex type. - vectorName: Optionally filter by vector attribute name. + If the connection has no graph set (global scope), + returns status for all graphs. + vertexType: Optionally filter by vertex type (requires graphName). + vectorName: Optionally filter by vector attribute name + (requires vertexType). """ graph = graphName or self.graphname - path = f"/vector/status/{graph}" - if vertexType: - path += f"/{vertexType}" - if vectorName: - path += f"/{vectorName}" + path = "/vector/status" + if graph: + path += f"/{graph}" + if vertexType: + path += f"/{vertexType}" + if vectorName: + path += f"/{vectorName}" return await self._req("GET", self.restppUrl + path) diff --git a/pyTigerGraph/pytgasync/pyTigerGraphSchema.py b/pyTigerGraph/pytgasync/pyTigerGraphSchema.py index 77fa7a69..7c4f2640 100644 --- a/pyTigerGraph/pytgasync/pyTigerGraphSchema.py +++ b/pyTigerGraph/pytgasync/pyTigerGraphSchema.py @@ -4,6 +4,7 @@ All functions in this module are called as methods on a link:https://docs.tigergraph.com/pytigergraph/current/core-functions/base[`TigerGraphConnection` object]. """ +import json import logging import re import uuid @@ -394,7 +395,9 @@ async def dropVertices(self, vertex_names: Union[str, list], graph: str = None, of strings. Use ``"all"`` to drop all vertices. graph (str, optional): The graph from which vertex types should be dropped. - If not provided, drops global vertex types. + If ``None`` and the connection has a ``graphname`` set, + that graph is used. If neither is set, drops global + vertex types. ignoreErrors (bool): If ``True``, suppress exceptions (e.g. when some vertices do not exist) and return the error as a dict instead. Defaults to ``False``. @@ -426,9 +429,11 @@ async def dropVertices(self, vertex_names: Union[str, list], graph: str = None, else: raise TigerGraphException("vertex_names must be a string or list of strings.", 0) + gname = graph or self.graphname + params = {"vertex": vertex_param} - if graph is not None: - params["graph"] = graph + if gname: + params["graph"] = gname if not ignoreErrors: res = await self._req("DELETE", self.gsUrl + "/gsql/v1/schema/vertices", @@ -488,12 +493,20 @@ async def validateGraphSchema(self) -> dict: return res - async def createGraph(self, graphName: str) -> dict: - """Creates an empty graph. + async def createGraph(self, graphName: str, + vertexTypes: list = None, + edgeTypes: list = None) -> dict: + """Creates a graph, optionally including existing global vertex/edge types. Args: graphName: Name of the graph to create. + vertexTypes: + Optional list of existing global vertex type names to include + in the graph. Pass ``["*"]`` to include all global types. + edgeTypes: + Optional list of existing global edge type names to include + in the graph. Returns: A dict with at least a ``"message"`` key describing the outcome. @@ -504,13 +517,23 @@ async def createGraph(self, graphName: str) -> dict: """ logger.debug("entry: createGraph") + type_names = [] + if vertexTypes: + type_names.extend(vertexTypes) + if edgeTypes: + type_names.extend(edgeTypes) + type_list = ", ".join(type_names) + gsql_cmd = f"CREATE GRAPH {graphName}({type_list})" + if await self._version_greater_than_4_0(): - data = {"name": graphName} + data = {"gsql": gsql_cmd} res = await self._req("POST", self.gsUrl + "/gsql/v1/schema/graphs", data=data, authMode="pwd", resKey=None, - headers={'Content-Type': 'application/json'}, jsonData=True) + params={"gsql": "true", "graphName": graphName}, + headers={"Content-Type": "application/json"}, + jsonData=True) else: - res = _wrap_gsql_result(await self.gsql(f"CREATE GRAPH {graphName}()")) + res = _wrap_gsql_result(await self.gsql(gsql_cmd)) if logger.level == logging.DEBUG: logger.debug("return: " + str(res)) @@ -539,7 +562,8 @@ async def dropGraph(self, graphName: str) -> dict: authMode="pwd", resKey=None, headers={'Content-Type': 'application/json'}) else: - res = _wrap_gsql_result(await self.gsql(f"DROP GRAPH {graphName}")) + res = _wrap_gsql_result( + await self.gsql(f"DROP GRAPH {graphName}")) if logger.level == logging.DEBUG: logger.debug("return: " + str(res)) @@ -571,48 +595,83 @@ async def listGraphs(self) -> list: return res - async def runSchemaChange(self, gsqlStatements: Union[str, list], graphName: str = None) -> dict: - """Runs schema change statements on a graph. + async def runSchemaChange(self, gsqlStatements: Union[str, list, dict], + graphName: str = None, force: bool = False) -> dict: + """Runs schema change statements directly (without creating a named job). + + Supports both local (graph-scoped) and global schema changes. Args: gsqlStatements: - GSQL schema change DDL statements (e.g. ``ADD VERTEX ...``, ``ADD EDGE ...``). - Can be a string of semicolon-separated statements or a list of statements. + Schema change specification in one of two formats: + + **dict (JSON format, TG >= 4.0 only)** — Sent directly to + ``POST /gsql/v1/schema/change`` as JSON. Supports keys such as + ``addVertexTypes``, ``dropVertexTypes``, ``alterVertexTypes``, + ``addEdgeTypes``, ``dropEdgeTypes``, ``alterEdgeTypes``. + + **str or list (GSQL DDL)** — Wrapped in a GSQL schema change + job and executed via ``gsql()``. Works on all TigerGraph + versions. + graphName: - Target graph name. Uses connection's graphname if not provided. + Target graph name for a local schema change. If ``None`` and + the connection has no ``graphname`` set, a **global** schema + change is executed instead. + + force: + If ``True``, abort any loading jobs that conflict with the + schema change. Only applies to the JSON (dict) path. Returns: A dict with at least a ``"message"`` key describing the outcome. Endpoints: - - `POST /gsql/v1/schema/change?graph={graphName}` (In TigerGraph versions >= 4.0) - - Falls back to GSQL schema change job for TigerGraph versions < 4.0 + - ``POST /gsql/v1/schema/change`` with JSON body (TG >= 4.0) + - GSQL schema change job via ``gsql()`` (all versions) """ logger.debug("entry: runSchemaChange") gname = graphName or self.graphname - if isinstance(gsqlStatements, list): - gsqlStatements = "\n".join( - s if s.rstrip().endswith(";") else s + ";" - for s in gsqlStatements - ) - - if await self._version_greater_than_4_0(): - params = {"graph": gname} + if isinstance(gsqlStatements, dict): + if not await self._version_greater_than_4_0(): + raise TigerGraphException( + "JSON-format schema changes require TigerGraph >= 4.0. " + "Pass GSQL DDL statements as a string instead.") + params = {} + if gname: + params["graph"] = gname + if force: + params["force"] = "true" res = await self._req("POST", self.gsUrl + "/gsql/v1/schema/change", - params=params, data=gsqlStatements, authMode="pwd", resKey=None, - headers={'Content-Type': 'text/plain'}) + params=params, data=json.dumps(gsqlStatements), + authMode="pwd", resKey=None, + headers={'Content-Type': 'application/json'}) else: + if isinstance(gsqlStatements, list): + gsqlStatements = "\n".join( + s if s.rstrip().endswith(";") else s + ";" + for s in gsqlStatements + ) job_name = f"schema_change_{uuid.uuid4().hex[:8]}" - gsql_cmd = ( - f"USE GRAPH {gname}\n" - f"CREATE SCHEMA_CHANGE JOB {job_name} FOR GRAPH {gname} {{\n" - f" {gsqlStatements}\n" - f"}}\n" - f"RUN SCHEMA_CHANGE JOB {job_name}\n" - f"DROP JOB {job_name}" - ) + if gname: + gsql_cmd = ( + f"USE GRAPH {gname}\n" + f"CREATE SCHEMA_CHANGE JOB {job_name} FOR GRAPH {gname} {{\n" + f" {gsqlStatements}\n" + f"}}\n" + f"RUN SCHEMA_CHANGE JOB {job_name}\n" + f"DROP JOB {job_name}" + ) + else: + gsql_cmd = ( + f"CREATE GLOBAL SCHEMA_CHANGE JOB {job_name} {{\n" + f" {gsqlStatements}\n" + f"}}\n" + f"RUN GLOBAL SCHEMA_CHANGE JOB {job_name}\n" + f"DROP JOB {job_name}" + ) res = _wrap_gsql_result(await self.gsql(gsql_cmd)) if logger.level == logging.DEBUG: @@ -620,3 +679,213 @@ async def runSchemaChange(self, gsqlStatements: Union[str, list], graphName: str logger.debug("exit: runSchemaChange") return res + + async def createSchemaChangeJob(self, jobName: str, statements: Union[str, list, dict], + graphName: str = None) -> dict: + """Creates a named schema change job without running it. + + Args: + jobName: + Name for the schema change job. + + statements: + Schema change specification in one of two formats: + + **dict (JSON format)** — Sent as JSON to + ``POST /gsql/v1/schema/jobs/{jobName}``. + For global jobs the dict should contain a ``"graphs"`` key; + for local jobs it should contain keys such as + ``addVertexTypes``, ``dropVertexTypes``, etc. + + **str or list (GSQL DDL)** — Individual DDL statements + (e.g. ``"ADD VERTEX Foo (...)"``). + They are wrapped in a ``CREATE [GLOBAL] SCHEMA_CHANGE JOB`` + command and sent via the ``?gsql=true`` parameter. + + graphName: + Target graph for a local schema change job. If ``None`` and + the connection has no ``graphname`` set, a **global** job is + created. + + Returns: + The server response dict. + + Endpoint: + - ``POST /gsql/v1/schema/jobs/{jobName}`` + """ + logger.debug("entry: createSchemaChangeJob") + + gname = graphName or self.graphname + + url = self.gsUrl + "/gsql/v1/schema/jobs/" + jobName + + if isinstance(statements, dict): + params = {} + if gname: + params["graph"] = gname + res = await self._req("POST", url, params=params, + data=json.dumps(statements), + authMode="pwd", resKey=None, + headers={'Content-Type': 'application/json'}) + else: + if isinstance(statements, list): + statements = "\n".join( + s if s.rstrip().endswith(";") else s + ";" + for s in statements + ) + if gname: + gsql_cmd = ( + f"CREATE SCHEMA_CHANGE JOB {jobName} FOR GRAPH {gname} {{\n" + f" {statements}\n" + f"}}" + ) + params = {"gsql": "true", "graph": gname} + else: + gsql_cmd = ( + f"CREATE GLOBAL SCHEMA_CHANGE JOB {jobName} {{\n" + f" {statements}\n" + f"}}" + ) + params = {"gsql": "true", "type": "global"} + res = await self._req("POST", url, params=params, + data=json.dumps({"gsql": gsql_cmd}), + authMode="pwd", resKey=None, + headers={'Content-Type': 'text/plain'}) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(res)) + logger.debug("exit: createSchemaChangeJob") + + return res + + async def getSchemaChangeJobs(self, jobName: str = None, graphName: str = None, + jsonFormat: bool = True) -> Union[dict, list]: + """Retrieves schema change jobs. + + Args: + jobName: + Name of a specific job to retrieve. If ``None``, all schema + change jobs are returned. + + graphName: + Graph whose local jobs to retrieve. If ``None`` and the + connection has no ``graphname`` set, global jobs are returned. + + jsonFormat: + If ``True`` (default), requests JSON-formatted output from the + server. + + Returns: + A dict (single job) or list (all jobs) describing the schema + change job(s). + + Endpoints: + - ``GET /gsql/v1/schema/jobs`` (all jobs) + - ``GET /gsql/v1/schema/jobs/{jobName}`` (single job) + """ + logger.debug("entry: getSchemaChangeJobs") + + gname = graphName or self.graphname + + url = self.gsUrl + "/gsql/v1/schema/jobs" + if jobName: + url += "/" + jobName + + params = {} + if gname: + params["graph"] = gname + if jsonFormat: + params["json"] = "true" + + res = await self._req("GET", url, params=params, authMode="pwd", + resKey=None) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(res)) + logger.debug("exit: getSchemaChangeJobs") + + return res + + async def runSchemaChangeJob(self, jobName: str, graphName: str = None, + force: bool = False) -> dict: + """Runs an existing (already created) schema change job. + + Args: + jobName: + Name of the schema change job to run. + + graphName: + Graph on which to run the local job. If ``None`` and the + connection has no ``graphname`` set, runs a global job. + + force: + If ``True``, abort any loading jobs that conflict with the + schema change. + + Returns: + The server response dict. + + Endpoint: + - ``PUT /gsql/v1/schema/jobs/{jobName}`` + """ + logger.debug("entry: runSchemaChangeJob") + + gname = graphName or self.graphname + + query_parts = [] + if gname: + query_parts.append(f"graph={gname}") + if force: + query_parts.append("force=true") + + url = self.gsUrl + "/gsql/v1/schema/jobs/" + jobName + if query_parts: + url += "?" + "&".join(query_parts) + + res = await self._put(url, authMode="pwd", resKey=None) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(res)) + logger.debug("exit: runSchemaChangeJob") + + return res + + async def dropSchemaChangeJobs(self, jobNames: Union[str, list], + graphName: str = None) -> dict: + """Drops one or more schema change jobs. + + Args: + jobNames: + A single job name (str) or a list of job names to drop. + + graphName: + Graph whose local jobs to drop. If ``None`` and the + connection has no ``graphname`` set, drops global jobs. + + Returns: + The server response dict. + + Endpoint: + - ``DELETE /gsql/v1/schema/jobs`` + """ + logger.debug("entry: dropSchemaChangeJobs") + + gname = graphName or self.graphname + + if isinstance(jobNames, list): + job_param = ",".join(jobNames) + else: + job_param = jobNames + + params = {"jobName": job_param} + if gname: + params["graph"] = gname + + res = await self._req("DELETE", self.gsUrl + "/gsql/v1/schema/jobs", + params=params, authMode="pwd", resKey=None) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(res)) + logger.debug("exit: dropSchemaChangeJobs") + + return res diff --git a/pyTigerGraph/pytgasync/pyTigerGraphUtils.py b/pyTigerGraph/pytgasync/pyTigerGraphUtils.py index ac126f8b..6b091388 100644 --- a/pyTigerGraph/pytgasync/pyTigerGraphUtils.py +++ b/pyTigerGraph/pytgasync/pyTigerGraphUtils.py @@ -214,7 +214,7 @@ async def rebuildGraph(self, threadnum: int = None, vertextype: str = "", segid: if path: params["path"] = path if force: - params["force"] = force + params["force"] = str(force).lower() res = await self._req("GET", self.restppUrl+"/rebuildnow/"+self.graphname, params=params, resKey=None) if not res["error"]: if logger.level == logging.DEBUG: diff --git a/pyTigerGraph/pytgasync/pyTigerGraphVertex.py b/pyTigerGraph/pytgasync/pyTigerGraphVertex.py index d6d5104b..b7ad02b5 100644 --- a/pyTigerGraph/pytgasync/pyTigerGraphVertex.py +++ b/pyTigerGraph/pytgasync/pyTigerGraphVertex.py @@ -456,7 +456,7 @@ async def getVertices(self, vertexType: str, select: str = "", where: str = "", withType: (When the output format is "df") should the vertex type be included in the dataframe? timeout: - Time allowed for successful execution (0 = no limit, default). + Time allowed for successful execution in seconds. 0 or omitted applies the system-wide endpoint timeout. Returns: The (selected) details of the (matching) vertex instances (sorted, limited) as @@ -520,7 +520,7 @@ async def getVertexDataFrame(self, vertexType: str, select: str = "", where: str Maximum number of vertex instances to be returned (after sorting). Must be used with `sort`. timeout: - Time allowed for successful execution (0 = no limit, default). + Time allowed for successful execution in seconds. 0 or omitted applies the system-wide endpoint timeout. Returns: The (selected) details of the (matching) vertex instances (sorted, limited) as pandas @@ -574,7 +574,7 @@ async def getVerticesById(self, vertexType: str, vertexIds: Union[int, str, list withType: (If the output format is "df") should the vertex type be included in the dataframe? timeout: - Time allowed for successful execution (0 = no limit, default). + Time allowed for successful execution in seconds. 0 or omitted applies the system-wide endpoint timeout. Returns: The (selected) details of the (matching) vertex instances as dictionary, JSON or pandas @@ -725,7 +725,7 @@ async def delVertices(self, vertexType: str, where: str = "", limit: str = "", s If true, the deleted vertex IDs can never be inserted back, unless the graph is dropped or the graph store is cleared. timeout: - Time allowed for successful execution (0 = no limit, default). + Time allowed for successful execution in seconds. 0 or omitted applies the system-wide endpoint timeout. Returns: A single number of vertices deleted. @@ -773,7 +773,7 @@ async def delVerticesById(self, vertexType: str, vertexIds: Union[int, str, list If true, the deleted vertex IDs can never be inserted back, unless the graph is dropped or the graph store is cleared. timeout: - Time allowed for successful execution (0 = no limit, default). + Time allowed for successful execution in seconds. 0 or omitted applies the system-wide endpoint timeout. Returns: A single number of vertices deleted. diff --git a/pyproject.toml b/pyproject.toml index 5aa7e4c4..766b3b1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pyTigerGraph" -version = "2.0.1" +version = "2.0.2" description = "Library to connect to TigerGraph databases" readme = "README.md" license = { text = "Apache-2.0" } diff --git a/pytigergraph-recipe/recipe/meta.yaml b/pytigergraph-recipe/recipe/meta.yaml index 054025f6..28854685 100644 --- a/pytigergraph-recipe/recipe/meta.yaml +++ b/pytigergraph-recipe/recipe/meta.yaml @@ -1,28 +1,42 @@ +{% set python_min = "3.9" %} + package: name: pytigergraph version: "2.0.1" source: - url: https://pypi.io/packages/source/p/pytigergraph/pytigergraph-2.0.1.tar.gz + url: https://pypi.org/packages/source/p/pytigergraph/pytigergraph-2.0.1.tar.gz + sha256: 6c0834a7abdacf4b00c41e3603c397954a50825677ddadde7a03086eaae479f4 build: + number: 0 noarch: python script: {{ PYTHON }} -m pip install . --no-deps -vv requirements: host: - - python >=3.9 + - python {{ python_min }} - pip - setuptools - wheel run: - - python >=3.9 + - python >={{ python_min }} - requests + - aiohttp + - httpx + - validators + +test: + imports: + - pyTigerGraph + requires: + - python {{ python_min }} about: home: https://github.com/tigergraph/pyTigerGraph license: Apache-2.0 license_family: Apache + license_file: LICENSE summary: Python client for TigerGraph extra: diff --git a/tests/test_common_base.py b/tests/test_common_base.py new file mode 100644 index 00000000..cecc6369 --- /dev/null +++ b/tests/test_common_base.py @@ -0,0 +1,116 @@ +"""Unit tests for pyTigerGraph.common.base (PyTigerGraphCore). + +These tests run without a live TigerGraph server by mocking network calls. +They guard against init-ordering bugs where _cached_token_auth is accessed +before _refresh_auth_headers() has been called. +""" + +import unittest +from unittest.mock import MagicMock, patch + +from pyTigerGraph import TigerGraphConnection + + +def _make_conn(**kwargs): + """Create a TigerGraphConnection without any real network calls.""" + defaults = dict( + host="http://127.0.0.1", + graphname="tests", + username="tigergraph", + password="tigergraph", + ) + defaults.update(kwargs) + with patch.object(TigerGraphConnection, "_verify_jwt_token_support", return_value=None): + conn = TigerGraphConnection(**defaults) + return conn + + +class TestRefreshAuthHeadersOrdering(unittest.TestCase): + """_refresh_auth_headers() must be called before any _get()/_req() in __init__. + + Regression test for GML-2041 ordering bug: + _cached_token_auth was set AFTER _verify_jwt_token_support() (and the + tgCloud ping), causing AttributeError swallowed as a JWT error message. + """ + + def test_cached_auth_set_with_username_password(self): + conn = _make_conn() + self.assertTrue(hasattr(conn, "_cached_token_auth")) + self.assertTrue(hasattr(conn, "_cached_pwd_auth")) + self.assertIn("Basic ", conn._cached_token_auth["Authorization"]) + self.assertIn("Basic ", conn._cached_pwd_auth["Authorization"]) + + def test_cached_auth_set_with_api_token(self): + conn = _make_conn(apiToken="myapitoken123") + self.assertIn("Bearer myapitoken123", conn._cached_token_auth["Authorization"]) + self.assertIn("Basic ", conn._cached_pwd_auth["Authorization"]) + + def test_cached_auth_set_with_jwt_token(self): + """Regression: jwtToken must not cause AttributeError during __init__.""" + conn = _make_conn(jwtToken="header.payload.signature") + self.assertIn("Bearer header.payload.signature", conn._cached_token_auth["Authorization"]) + self.assertIn("Bearer header.payload.signature", conn._cached_pwd_auth["Authorization"]) + + def test_jwt_token_calls_verify(self): + """_verify_jwt_token_support() must be called when jwtToken is provided.""" + with patch.object(TigerGraphConnection, "_verify_jwt_token_support") as mock_verify: + TigerGraphConnection( + host="http://127.0.0.1", + jwtToken="header.payload.signature", + ) + mock_verify.assert_called_once() + + def test_no_jwt_skips_verify(self): + """_verify_jwt_token_support() must NOT be called without jwtToken.""" + with patch.object(TigerGraphConnection, "_verify_jwt_token_support") as mock_verify: + TigerGraphConnection(host="http://127.0.0.1") + mock_verify.assert_not_called() + + def test_tgcloud_ping_does_not_crash_without_jwt(self): + """tgCloud _get() ping fires before _verify_jwt_token_support; must not AttributeError.""" + with patch.object(TigerGraphConnection, "_get", return_value="pong") as mock_get: + conn = TigerGraphConnection(host="http://my.tgcloud.io") + # _cached_token_auth must exist at the point _get() was called + self.assertTrue(hasattr(conn, "_cached_token_auth")) + + def test_tgcloud_ping_does_not_crash_with_jwt(self): + """tgCloud ping + JWT verification both fire; _cached_token_auth must precede both.""" + with patch.object(TigerGraphConnection, "_get", return_value="pong"): + with patch.object(TigerGraphConnection, "_verify_jwt_token_support"): + conn = TigerGraphConnection( + host="http://my.tgcloud.io", + jwtToken="header.payload.signature", + ) + self.assertIn("Bearer header.payload.signature", conn._cached_token_auth["Authorization"]) + + def test_x_user_agent_header_present(self): + """X-User-Agent must be baked into cached auth dicts.""" + conn = _make_conn() + self.assertEqual(conn._cached_token_auth.get("X-User-Agent"), "pyTigerGraph") + self.assertEqual(conn._cached_pwd_auth.get("X-User-Agent"), "pyTigerGraph") + + +class TestRefreshAuthHeadersUpdate(unittest.TestCase): + """_refresh_auth_headers() must update the cache after credentials change.""" + + def test_refresh_after_get_token(self): + conn = _make_conn() + self.assertIn("Basic ", conn._cached_token_auth["Authorization"]) + + conn.apiToken = "newtoken456" + conn._refresh_auth_headers() + + self.assertIn("Bearer newtoken456", conn._cached_token_auth["Authorization"]) + + def test_refresh_clears_old_token(self): + conn = _make_conn(apiToken="oldtoken") + self.assertIn("Bearer oldtoken", conn._cached_token_auth["Authorization"]) + + conn.apiToken = "" + conn._refresh_auth_headers() + + self.assertIn("Basic ", conn._cached_token_auth["Authorization"]) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/test_common_query_helpers.py b/tests/test_common_query_helpers.py new file mode 100644 index 00000000..f631d154 --- /dev/null +++ b/tests/test_common_query_helpers.py @@ -0,0 +1,522 @@ +"""Unit tests for pyTigerGraph.common.query helper functions. + +Focuses on _encode_str_for_post and _prep_query_parameters_json to ensure +POST query parameters are never silently corrupted. +""" + +import unittest +from datetime import datetime + +from pyTigerGraph.common.exception import TigerGraphException +from pyTigerGraph.common.query import ( + _encode_str_for_post, + _parse_query_parameters, + _prep_query_parameters_json, +) + + +class TestEncodeStrForPost(unittest.TestCase): + """_encode_str_for_post: only % → %25, everything else unchanged.""" + + def test_no_percent(self): + self.assertEqual(_encode_str_for_post("hello world"), "hello world") + + def test_single_percent(self): + self.assertEqual(_encode_str_for_post("50%"), "50%25") + + def test_multiple_percents(self): + self.assertEqual(_encode_str_for_post("10% and 20%"), "10%25 and 20%25") + + def test_already_encoded_percent(self): + # %25 in input → %2525 (double-encode is intentional; server decodes once) + self.assertEqual(_encode_str_for_post("%25"), "%2525") + + def test_empty_string(self): + self.assertEqual(_encode_str_for_post(""), "") + + def test_no_mutation_of_unrelated_chars(self): + s = "abc/123?key=val&other=true" + self.assertEqual(_encode_str_for_post(s), s) + + +class TestPrepQueryParametersJson(unittest.TestCase): + """_prep_query_parameters_json: every conversion rule documented in the docstring.""" + + # ------------------------------------------------------------------ + # Guard rails + # ------------------------------------------------------------------ + + def test_none_passthrough(self): + self.assertIsNone(_prep_query_parameters_json(None)) + + def test_empty_dict_passthrough(self): + # empty dict is falsy → returned as-is + result = _prep_query_parameters_json({}) + self.assertEqual(result, {}) + + def test_non_dict_passthrough(self): + # wrong type returns the value unchanged + self.assertEqual(_prep_query_parameters_json("raw_string"), "raw_string") + + # ------------------------------------------------------------------ + # Primitive / scalar values — must not be corrupted + # ------------------------------------------------------------------ + + def test_int_unchanged(self): + result = _prep_query_parameters_json({"limit": 100}) + self.assertEqual(result, {"limit": 100}) + + def test_float_unchanged(self): + result = _prep_query_parameters_json({"score": 3.14}) + self.assertEqual(result, {"score": 3.14}) + + def test_bool_unchanged(self): + result = _prep_query_parameters_json({"flag": True}) + self.assertEqual(result, {"flag": True}) + + def test_none_value_unchanged(self): + result = _prep_query_parameters_json({"x": None}) + self.assertEqual(result, {"x": None}) + + # ------------------------------------------------------------------ + # String values + # ------------------------------------------------------------------ + + def test_plain_string_unchanged(self): + result = _prep_query_parameters_json({"name": "Alice"}) + self.assertEqual(result, {"name": "Alice"}) + + def test_string_with_percent_encoded(self): + result = _prep_query_parameters_json({"pattern": "50%"}) + self.assertEqual(result, {"pattern": "50%25"}) + + def test_string_without_percent_not_touched(self): + result = _prep_query_parameters_json({"key": "hello/world"}) + self.assertEqual(result, {"key": "hello/world"}) + + # ------------------------------------------------------------------ + # datetime values + # ------------------------------------------------------------------ + + def test_datetime_converted_to_string(self): + dt = datetime(2024, 6, 15, 12, 30, 45) + result = _prep_query_parameters_json({"ts": dt}) + self.assertEqual(result, {"ts": "2024-06-15 12:30:45"}) + + def test_datetime_format_is_exact(self): + dt = datetime(2000, 1, 1, 0, 0, 0) + result = _prep_query_parameters_json({"ts": dt}) + self.assertEqual(result["ts"], "2000-01-01 00:00:00") + + # ------------------------------------------------------------------ + # Tuple → vertex dict + # ------------------------------------------------------------------ + + def test_tuple_typed_vertex_1tuple(self): + # VERTEX: (id,) → {"id": id} + result = _prep_query_parameters_json({"v": ("vid123",)}) + self.assertEqual(result, {"v": {"id": "vid123"}}) + + def test_tuple_typed_vertex_1tuple_int_id(self): + result = _prep_query_parameters_json({"v": (42,)}) + self.assertEqual(result, {"v": {"id": 42}}) + + def test_tuple_untyped_vertex_2tuple(self): + # VERTEX (untyped): (id, "type") → {"id": id, "type": "type"} + result = _prep_query_parameters_json({"v": ("vid123", "Person")}) + self.assertEqual(result, {"v": {"id": "vid123", "type": "Person"}}) + + def test_tuple_untyped_vertex_integer_id(self): + result = _prep_query_parameters_json({"v": (42, "Order")}) + self.assertEqual(result, {"v": {"id": 42, "type": "Order"}}) + + def test_tuple_invalid_3tuple_raises(self): + with self.assertRaises(TigerGraphException): + _prep_query_parameters_json({"v": ("id", "type", "extra")}) + + def test_tuple_type_not_string_raises(self): + # (id, None) — None is not a str → rejected by pyTigerGraph + with self.assertRaises(TigerGraphException): + _prep_query_parameters_json({"v": ("id", None)}) + + def test_tuple_empty_type_string_raises(self): + # (id, "") — empty string slips isinstance check but must be rejected + with self.assertRaises(TigerGraphException): + _prep_query_parameters_json({"v": (1, "")}) + + def test_tuple_empty_type_in_list_raises(self): + with self.assertRaises(TigerGraphException): + _prep_query_parameters_json({"vs": [(1, "")]}) + + # ------------------------------------------------------------------ + # Pre-formatted vertex dict — must pass through unchanged + # ------------------------------------------------------------------ + + def test_vertex_dict_passthrough(self): + vertex = {"id": "vid123", "type": "Person"} + result = _prep_query_parameters_json({"v": vertex}) + self.assertEqual(result["v"], vertex) + + def test_vertex_dict_id_only_passthrough(self): + # Typed vertex pre-formatted with just "id" (type is optional per docs) + vertex = {"id": "vid123"} + result = _prep_query_parameters_json({"v": vertex}) + self.assertEqual(result["v"], vertex) + + # ------------------------------------------------------------------ + # MAP values (Python dict without "id" → TigerGraph keylist/valuelist) + # ------------------------------------------------------------------ + + def test_map_dict_converted_to_keylist_valuelist(self): + result = _prep_query_parameters_json({"m": {49: "Alaska", 50: "Hawaii"}}) + self.assertEqual(result["m"], {"keylist": [49, 50], "valuelist": ["Alaska", "Hawaii"]}) + + def test_map_string_keys(self): + result = _prep_query_parameters_json({"m": {"a": 1, "b": 2}}) + self.assertEqual(result["m"], {"keylist": ["a", "b"], "valuelist": [1, 2]}) + + def test_map_empty_dict_converted(self): + # An empty dict has no "id" key, so it becomes an empty MAP structure. + result = _prep_query_parameters_json({"m": {}}) + self.assertEqual(result["m"], {"keylist": [], "valuelist": []}) + + def test_map_does_not_mutate_input(self): + original_map = {"x": 10, "y": 20} + original = {"m": original_map} + _prep_query_parameters_json(original) + self.assertEqual(original["m"], {"x": 10, "y": 20}) + + # ------------------------------------------------------------------ + # List values + # ------------------------------------------------------------------ + + def test_list_of_ints_unchanged(self): + result = _prep_query_parameters_json({"ids": [1, 2, 3]}) + self.assertEqual(result, {"ids": [1, 2, 3]}) + + def test_list_of_strings_percent_encoded(self): + result = _prep_query_parameters_json({"tags": ["50%", "no_percent"]}) + self.assertEqual(result, {"tags": ["50%25", "no_percent"]}) + + def test_list_of_1tuples_typed_vertex_set(self): + # SET>: [(id,), ...] → [{"id": id}, ...] + result = _prep_query_parameters_json({"vs": [("v1",), ("v2",)]}) + self.assertEqual(result, {"vs": [{"id": "v1"}, {"id": "v2"}]}) + + def test_list_of_2tuples_untyped_vertex_set(self): + # SET: [(id, "type"), ...] → [{"id": id, "type": "type"}, ...] + result = _prep_query_parameters_json({ + "vertices": [("v1", "Person"), ("v2", "Person")] + }) + self.assertEqual(result, { + "vertices": [ + {"id": "v1", "type": "Person"}, + {"id": "v2", "type": "Person"}, + ] + }) + + def test_list_of_datetimes_converted(self): + dts = [datetime(2024, 1, 1), datetime(2024, 6, 15, 8, 0, 0)] + result = _prep_query_parameters_json({"times": dts}) + self.assertEqual(result, {"times": ["2024-01-01 00:00:00", "2024-06-15 08:00:00"]}) + + def test_list_of_dicts_passthrough(self): + verts = [{"id": "v1", "type": "Person"}, {"id": "v2", "type": "Order"}] + result = _prep_query_parameters_json({"vset": verts}) + self.assertEqual(result["vset"], verts) + + def test_list_of_invalid_tuples_raises(self): + with self.assertRaises(TigerGraphException): + _prep_query_parameters_json({"v": [("id", "type", "extra")]}) + + def test_empty_list_unchanged(self): + result = _prep_query_parameters_json({"items": []}) + self.assertEqual(result, {"items": []}) + + # ------------------------------------------------------------------ + # Mixed params — keys must not cross-contaminate each other + # ------------------------------------------------------------------ + + def test_multiple_keys_independent(self): + dt = datetime(2024, 3, 10, 9, 0, 0) + result = _prep_query_parameters_json({ + "limit": 10, + "name": "Bob%", + "ts": dt, + "v": ("vid1", "User"), + }) + self.assertEqual(result["limit"], 10) + self.assertEqual(result["name"], "Bob%25") + self.assertEqual(result["ts"], "2024-03-10 09:00:00") + self.assertEqual(result["v"], {"id": "vid1", "type": "User"}) + + def test_original_dict_not_mutated(self): + original = {"name": "Alice%", "limit": 5} + _prep_query_parameters_json(original) + self.assertEqual(original["name"], "Alice%") # input untouched + + +class TestParseQueryParameters(unittest.TestCase): + """_parse_query_parameters: vertex tuple conventions for GET mode.""" + + def test_primitive_int(self): + self.assertEqual(_parse_query_parameters({"n": 5}), "n=5") + + def test_primitive_string(self): + self.assertEqual(_parse_query_parameters({"s": "hello"}), "s=hello") + + def test_list_of_primitives(self): + result = _parse_query_parameters({"ids": [1, 2, 3]}) + self.assertEqual(result, "ids=1&ids=2&ids=3") + + def test_typed_vertex_1tuple(self): + # VERTEX: (id,) → k=id + result = _parse_query_parameters({"v": ("Tom",)}) + self.assertEqual(result, "v=Tom") + + def test_typed_vertex_1tuple_int_id(self): + result = _parse_query_parameters({"v": (42,)}) + self.assertEqual(result, "v=42") + + def test_untyped_vertex_2tuple(self): + # VERTEX (untyped): (id, "type") → k=id&k.type=type + result = _parse_query_parameters({"v": ("Tom", "Person")}) + self.assertEqual(result, "v=Tom&v.type=Person") + + def test_typed_vertex_set_list_of_1tuples(self): + # SET>: [(id,), ...] → k=id1&k=id2 (repeated, no index) + result = _parse_query_parameters({"vs": [("Tom",), ("Mary",)]}) + self.assertEqual(result, "vs=Tom&vs=Mary") + + def test_untyped_vertex_set_list_of_2tuples(self): + # SET: [(id,"type"), ...] → k[i]=id&k[i].type=type + result = _parse_query_parameters({"vs": [("Tom", "Person"), ("Mary", "Person")]}) + self.assertEqual(result, "vs[0]=Tom&vs[0].type=Person&vs[1]=Mary&vs[1].type=Person") + + def test_invalid_tuple_raises(self): + with self.assertRaises(TigerGraphException): + _parse_query_parameters({"v": ("id", "type", "extra")}) + + def test_invalid_tuple_in_list_raises(self): + with self.assertRaises(TigerGraphException): + _parse_query_parameters({"vs": [("id", "type", "extra")]}) + + def test_tuple_none_type_raises(self): + # (id, None) — None is not a str → rejected by pyTigerGraph + with self.assertRaises(TigerGraphException): + _parse_query_parameters({"v": (1, None)}) + + def test_tuple_empty_type_string_raises(self): + # (id, "") — empty string must be rejected before sending to TigerGraph + with self.assertRaises(TigerGraphException): + _parse_query_parameters({"v": (1, "")}) + + def test_tuple_empty_type_in_list_raises(self): + with self.assertRaises(TigerGraphException): + _parse_query_parameters({"vs": [(1, "")]}) + + def test_datetime(self): + result = _parse_query_parameters({"ts": datetime(2024, 1, 15, 12, 0, 0)}) + self.assertIn("2024-01-15", result) + + +class TestPostParamRoundTrip(unittest.TestCase): + """Integration tests: verify every _prep_query_parameters_json conversion + survives a real POST round-trip through TigerGraph. + + Uses the pre-installed ``query4_all_param_types`` query (defined in + testserver.gsql) which PRINTs all 13 parameters back in declaration order. + No query creation or teardown required. + + PRINT order → result index: + p01_int[0] p02_uint[1] p03_float[2] p04_double[3] p05_string[4] + p06_bool[5] p07_vertex[6] p08_vertex_vertex4[7] p09_datetime[8] + p10_set_int[9] p11_bag_int[10] p13_set_vertex[11] p14_set_vertex_vertex4[12] + """ + + # Neutral baseline — every test overrides only the param(s) it cares about. + # (id, "type") tuples are used uniformly for all vertex params; the API + # converts them to the correct wire format for each transport. + _BASE = { + "p01_int": 1, + "p02_uint": 1, + "p03_float": 1.0, + "p04_double": 1.0, + "p05_string": "x", + "p06_bool": True, + "p07_vertex": (1, "vertex4"), + "p08_vertex_vertex4": (1, "vertex4"), + "p09_datetime": datetime(2000, 1, 1), + "p10_set_int": [1], + "p11_bag_int": [1], + "p13_set_vertex": [(1, "vertex4")], + "p14_set_vertex_vertex4": [(1, "vertex4")], + } + + @classmethod + def setUpClass(cls): + try: + from pyTigerGraphUnitTest import make_connection + except ImportError: + raise unittest.SkipTest("No test server configuration found") + + try: + cls.conn = make_connection() + except Exception as e: + raise unittest.SkipTest(f"Cannot connect to test server: {e}") + + # Ensure vertex4 instances 1-3 exist (other suites may not have run yet) + for i in range(1, 4): + cls.conn.upsertVertex("vertex4", i, {"a01": i}) + + def _run(self, overrides: dict): + """Merge overrides into the baseline and run query4_all_param_types via POST.""" + params = {**self._BASE, **overrides} + return self.conn.runInstalledQuery("query4_all_param_types", params, usePost=True) + + # ------------------------------------------------------------------ + # Scalar types + # ------------------------------------------------------------------ + + def test_db_int_roundtrip(self): + """INT is left as-is by _prep_query_parameters_json; DB must echo it back exactly.""" + p = 42 + res = self._run({"p01_int": p}) + self.assertEqual(res[0]["p01_int"], p) + + def test_db_uint_roundtrip(self): + p = 7 + res = self._run({"p02_uint": p}) + self.assertEqual(res[1]["p02_uint"], p) + + def test_db_float_roundtrip(self): + p = 1.5 + res = self._run({"p03_float": p}) + self.assertAlmostEqual(res[2]["p03_float"], p, places=4) + + def test_db_double_roundtrip(self): + p = 2.5 + res = self._run({"p04_double": p}) + self.assertAlmostEqual(res[3]["p04_double"], p, places=4) + + def test_db_bool_roundtrip(self): + res = self._run({"p06_bool": False}) + self.assertEqual(res[5]["p06_bool"], False) + + # ------------------------------------------------------------------ + # String — the % encoding path is the critical one + # ------------------------------------------------------------------ + + def test_db_plain_string_roundtrip(self): + """Plain string (no %) must arrive at the DB unchanged.""" + p = "hello world" + res = self._run({"p05_string": p}) + self.assertEqual(res[4]["p05_string"], p) + + def test_db_percent_string_roundtrip(self): + """String containing % must be decoded back to % by the server (not stored as %25).""" + p = "50% done" + res = self._run({"p05_string": p}) + self.assertEqual(res[4]["p05_string"], p) + + def test_db_special_chars_string_roundtrip(self): + """Unicode, symbols, and emoji survive the full round-trip unchanged.""" + p = "test <>\"'`\\/{}[]!@£$%^&*-_=+;:|,.§±~` árvíztűrő 👍" + res = self._run({"p05_string": p}) + self.assertEqual(res[4]["p05_string"], p) + + # ------------------------------------------------------------------ + # datetime + # ------------------------------------------------------------------ + + def test_db_datetime_roundtrip(self): + """datetime is converted to 'YYYY-MM-DD HH:MM:SS' and echoed back by the DB.""" + p = datetime(2024, 6, 15, 12, 30, 45) + res = self._run({"p09_datetime": p}) + self.assertEqual(res[8]["p09_datetime"], p.strftime("%Y-%m-%d %H:%M:%S")) + + def test_db_datetime_midnight_roundtrip(self): + """Midnight is formatted with explicit 00:00:00 (no truncation).""" + p = datetime(2000, 1, 1, 0, 0, 0) + res = self._run({"p09_datetime": p}) + self.assertEqual(res[8]["p09_datetime"], "2000-01-01 00:00:00") + + # ------------------------------------------------------------------ + # Vertex (untyped and typed) + # ------------------------------------------------------------------ + + def test_db_vertex_tuple_roundtrip(self): + """Tuple (id, type) → vertex JSON; DB echoes back the primary ID.""" + res = self._run({"p07_vertex": (2, "vertex4")}) + self.assertEqual(str(res[6]["p07_vertex"]), "2") + + def test_db_typed_vertex_tuple_roundtrip(self): + """Tuple (id, type) for a typed vertex; DB echoes back the primary ID.""" + res = self._run({"p08_vertex_vertex4": (3, "vertex4")}) + self.assertEqual(str(res[7]["p08_vertex_vertex4"]), "3") + + # ------------------------------------------------------------------ + # SET / BAG of scalars + # ------------------------------------------------------------------ + + def test_db_set_int_deduplicates(self): + """SET: duplicates are removed; remaining values survive unchanged.""" + sent = [1, 2, 3, 2, 3, 3] + res = self._run({"p10_set_int": sent}) + self.assertEqual(sorted(res[9]["p10_set_int"]), sorted(set(sent))) + + def test_db_bag_int_preserves_duplicates(self): + """BAG: all values including duplicates must be echoed back.""" + sent = [1, 2, 3, 2, 3, 3] + res = self._run({"p11_bag_int": sent}) + self.assertEqual(sorted(res[10]["p11_bag_int"]), sorted(sent)) + + # ------------------------------------------------------------------ + # SET of vertices (untyped and typed) + # ------------------------------------------------------------------ + + def test_db_set_vertex_tuples_roundtrip(self): + """List of (id, type) tuples → vertex JSON list; DB echoes the IDs back.""" + sent_ids = [1, 2, 3] + res = self._run({"p13_set_vertex": [(i, "vertex4") for i in sent_ids]}) + returned = sorted(str(v) for v in res[11]["p13_set_vertex"]) + self.assertEqual(returned, [str(i) for i in sorted(sent_ids)]) + + def test_db_set_typed_vertex_tuples_roundtrip(self): + """List of (id, type) tuples for a typed vertex set; DB echoes IDs back.""" + sent_ids = [1, 2, 3] + res = self._run({"p14_set_vertex_vertex4": [(i, "vertex4") for i in sent_ids]}) + returned = sorted(str(v) for v in res[12]["p14_set_vertex_vertex4"]) + self.assertEqual(returned, [str(i) for i in sorted(sent_ids)]) + + # ------------------------------------------------------------------ + # All conversions in one call + # ------------------------------------------------------------------ + + def test_db_all_conversions_together(self): + """All converted types sent in one call; each must round-trip correctly.""" + p_str = "combined 100% test" + p_dt = datetime(2024, 12, 31, 23, 59, 59) + sent_ids = [1, 2, 3] + + res = self._run({ + "p01_int": 99, + "p05_string": p_str, + "p09_datetime": p_dt, + "p07_vertex": (1, "vertex4"), + "p13_set_vertex": [(i, "vertex4") for i in sent_ids], + "p14_set_vertex_vertex4": [(i, "vertex4") for i in sent_ids], + }) + + self.assertEqual(res[0]["p01_int"], 99) + self.assertEqual(res[4]["p05_string"], p_str) + self.assertEqual(res[8]["p09_datetime"], p_dt.strftime("%Y-%m-%d %H:%M:%S")) + self.assertEqual(str(res[6]["p07_vertex"]), "1") + self.assertEqual(sorted(str(v) for v in res[11]["p13_set_vertex"]), + [str(i) for i in sorted(sent_ids)]) + self.assertEqual(sorted(str(v) for v in res[12]["p14_set_vertex_vertex4"]), + [str(i) for i in sorted(sent_ids)]) + + +if __name__ == "__main__": + unittest.main() From 6302c17de79dfe127c92818c43fc8401455dac3b Mon Sep 17 00:00:00 2001 From: Chengbiao Jin Date: Fri, 27 Mar 2026 16:10:48 -0700 Subject: [PATCH 03/11] GML-2052 Fix jwtToken exception --- CHANGELOG.md | 10 +++------- pyTigerGraph/common/base.py | 14 ++++++++++---- pyTigerGraph/pyTigerGraphEdge.py | 4 ++-- pyTigerGraph/pyTigerGraphGSQL.py | 2 +- pyTigerGraph/pyTigerGraphUtils.py | 2 +- 5 files changed, 17 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index eada59c7..24738d99 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### New Features - **Schema Change Job APIs** — `createSchemaChangeJob()`, `getSchemaChangeJobs()`, `runSchemaChangeJob()`, `dropSchemaChangeJobs()` for managing schema change jobs via REST. +- **`createGraph()` accepts vertex/edge types** — optional `vertexTypes` and `edgeTypes` parameters to include existing global types when creating a graph (e.g. `createGraph("g", vertexTypes=["Person"], edgeTypes=["Knows"])`). Pass `vertexTypes=["*"]` to include all global types. - **`force` parameter for `runSchemaChange()`** — allows forcing schema changes even when they would cause data loss. Also accepts `dict` (JSON format) for TigerGraph >= 4.0 and supports global schema changes. - **Graph scope control** — `useGraph(graphName)` and `useGlobal()` methods on the connection object, mirroring GSQL's `USE GRAPH` / `USE GLOBAL`. `useGlobal()` doubles as a context manager for temporary global scoping (`with conn.useGlobal(): ...`). - **GSQL reserved keyword helpers** — `getReservedKeywords()` and `isReservedKeyword(name)` static methods to query the canonical set of GSQL reserved keywords. @@ -17,19 +18,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- **`_refresh_auth_headers()` init ordering** — auth header cache is now built before `_verify_jwt_token_support()` and the tgCloud ping, preventing `AttributeError` when using JWT tokens or TigerGraph Cloud hosts. -- **`gsql()` graph scope** — the `POST /gsql/v1/statements` path now falls back to `self.graphname` and prepends `USE GRAPH` automatically when the connection has a graph set. +- **`_refresh_auth_headers()` init ordering** — auth header cache is now built immediately after credentials are set, before the tgCloud ping and JWT verification. Prevents `AttributeError` on `_cached_token_auth` when connecting without a token (e.g. `TigerGraphConnection(host=..., username=..., password=...)`). +- **Boolean query parameters causing `yarl` errors** — `upsertEdge()`, `upsertEdges()` (`vertexMustExist`), `getVersion()` (`verbose`), and `rebuildGraph()` (`force`) now convert boolean values to lowercase strings before passing them as URL query parameters. - **`dropVertices()`** now correctly falls back to `self.graphname` when the `graph` parameter is `None`. - **`dropAllDataSources()`** now correctly uses `self.graphname` fallback for the 4.x REST API path. - **`getVectorIndexStatus()`** no longer produces a malformed URL when called without a graph name; now supports global scope (returns status for all graphs). - **`previewSampleData()`** now raises `TigerGraphException` when no graph name is available, instead of sending an empty graph name to the server. - **Docstring fixes** — corrected `timeout` parameter descriptions across vertex and edge query methods. -### Tests - -- Added `test_common_base.py` — unit tests for auth header init ordering and credential refresh. -- Added `test_common_query_helpers.py` — unit tests for POST query parameter encoding (`_encode_str_for_post`, `_prep_query_parameters_json`) and round-trip verification. - --- ## [2.0.1] - 2026-03-23 diff --git a/pyTigerGraph/common/base.py b/pyTigerGraph/common/base.py index 7039f214..54828f81 100644 --- a/pyTigerGraph/common/base.py +++ b/pyTigerGraph/common/base.py @@ -241,14 +241,19 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "", # -- Scope helpers (mirror GSQL's USE GRAPH / USE GLOBAL) ---------- class _GlobalScope: - """Context manager returned by :meth:`useGlobal` for temporary global scope.""" + """Context manager returned by :meth:`useGlobal` for temporary global scope. + + The original graphname is captured at construction time (not at + ``__enter__``) so that ``useGlobal()`` can set ``graphname = ""`` + before the ``with`` block begins, supporting both the bare-call + and context-manager use cases. + """ def __init__(self, conn): self._conn = conn - self._saved = None + self._saved = conn.graphname def __enter__(self): - self._saved = self._conn.graphname self._conn.graphname = "" return self._conn @@ -286,8 +291,9 @@ def useGlobal(self): conn.getSchemaChangeJobs() # global # conn.graphname is restored here """ + scope = self._GlobalScope(self) self.graphname = "" - return self._GlobalScope(self) + return scope def _set_auth_header(self): """Set the authentication header based on available tokens or credentials.""" diff --git a/pyTigerGraph/pyTigerGraphEdge.py b/pyTigerGraph/pyTigerGraphEdge.py index 32cba541..e97ec4e1 100644 --- a/pyTigerGraph/pyTigerGraphEdge.py +++ b/pyTigerGraph/pyTigerGraphEdge.py @@ -459,7 +459,7 @@ def upsertEdge(self, sourceVertexType: str, sourceVertexId: str, edgeType: str, attributes ) - params = {"vertex_must_exist": vertexMustExist} + params = {"vertex_must_exist": str(vertexMustExist).lower()} ret = self._post( self.restppUrl + "/graph/" + self.graphname, data=data, @@ -540,7 +540,7 @@ def upsertEdges(self, sourceVertexType: str, edgeType: str, targetVertexType: st headers = {} if atomic: headers = {"gsql-atomic-level": "atomic"} - params = {"vertex_must_exist": vertexMustExist} + params = {"vertex_must_exist": str(vertexMustExist).lower()} ret = self._post( self.restppUrl + "/graph/" + self.graphname, data=data, diff --git a/pyTigerGraph/pyTigerGraphGSQL.py b/pyTigerGraph/pyTigerGraphGSQL.py index dc8bf831..32fde7d7 100644 --- a/pyTigerGraph/pyTigerGraphGSQL.py +++ b/pyTigerGraph/pyTigerGraphGSQL.py @@ -486,7 +486,7 @@ def getGSQLVersion(self, verbose: bool = False) -> dict: params = {} if verbose: - params["verbose"] = verbose + params["verbose"] = str(verbose).lower() res = self._get(self.gsUrl+"/gsql/v1/version", params=params, authMode="pwd", resKey=None, diff --git a/pyTigerGraph/pyTigerGraphUtils.py b/pyTigerGraph/pyTigerGraphUtils.py index de1bef86..6b581436 100644 --- a/pyTigerGraph/pyTigerGraphUtils.py +++ b/pyTigerGraph/pyTigerGraphUtils.py @@ -219,7 +219,7 @@ def rebuildGraph(self, threadnum: int = None, vertextype: str = "", segid: str = if path: params["path"] = path if force: - params["force"] = force + params["force"] = str(force).lower() res = self._get(self.restppUrl+"/rebuildnow/" + self.graphname, params=params, resKey=None) if not res["error"]: From 5a5c9f79cd73c6dd3da02e11693e53b289d52b65 Mon Sep 17 00:00:00 2001 From: Chengbiao Jin Date: Tue, 31 Mar 2026 11:38:51 -0700 Subject: [PATCH 04/11] fix tgCloud exception --- pyTigerGraph/common/base.py | 11 ++--------- pyTigerGraph/pyTigerGraphBase.py | 19 +++++++------------ 2 files changed, 9 insertions(+), 21 deletions(-) diff --git a/pyTigerGraph/common/base.py b/pyTigerGraph/common/base.py index 54828f81..9ceccc82 100644 --- a/pyTigerGraph/common/base.py +++ b/pyTigerGraph/common/base.py @@ -181,15 +181,8 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "", warnings.warn("The `gcp` parameter is deprecated.", DeprecationWarning) self.tgCloud = tgCloud or gcp - if "tgcloud" in self.netloc.lower(): - try: # If get request succeeds, using TG Cloud instance provisioned after 6/20/2022 - self._get(self.host + "/api/ping", resKey="message") - self.tgCloud = True - # If get request fails, using TG Cloud instance provisioned before 6/20/2022, before new firewall config - except requests.exceptions.RequestException: - self.tgCloud = False - except TigerGraphException: - raise (TigerGraphException("Incorrect graphname.")) + if not self.tgCloud and "tgcloud" in self.netloc.lower(): + self.tgCloud = True restppPort = str(restppPort) sslPort = str(sslPort) diff --git a/pyTigerGraph/pyTigerGraphBase.py b/pyTigerGraph/pyTigerGraphBase.py index 3412aca1..794ff20b 100644 --- a/pyTigerGraph/pyTigerGraphBase.py +++ b/pyTigerGraph/pyTigerGraphBase.py @@ -104,24 +104,19 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", TigerGraphException: In case on invalid URL scheme. """ + # Thread-local sessions and failover lock must be created BEFORE + # super().__init__() because the parent __init__ may issue HTTP + # requests (tgCloud ping, JWT verification) that go through + # _session() → self._local. + self._local = threading.local() + self._restpp_failover_lock = threading.Lock() + super().__init__(host=host, graphname=graphname, gsqlSecret=gsqlSecret, username=username, password=password, tgCloud=tgCloud, restppPort=restppPort, gsPort=gsPort, gsqlVersion=gsqlVersion, version=version, apiToken=apiToken, useCert=useCert, certPath=certPath, debug=debug, sslPort=sslPort, gcp=gcp, jwtToken=jwtToken) - # Thread-local sessions — each thread gets its own requests.Session and connection pool. - # A single shared Session serializes all threads via its internal cookie-jar RLock - # (_cookies_lock), which is acquired on every response even when no cookies are set. - # Thread-local sessions eliminate that contention while still benefiting from HTTP - # keep-alive within each thread's sequential request stream. - self._local = threading.local() - - # Lock for the one-time port failover (TG 3.x port 9000 → 4.x port 14240/restpp). - # Without a lock, all parallel threads simultaneously fail and all enter the failover - # block, doubling requests and racing to overwrite self.restppUrl / self.restppPort. - self._restpp_failover_lock = threading.Lock() - if graphname == "MyGraph": warnings.warn( "The default graphname 'MyGraph' is deprecated. Please explicitly specify your graph name.", From f0cc58614b8e1f9ef067ff4f3777c340652fe878 Mon Sep 17 00:00:00 2001 From: Chengbiao Jin Date: Wed, 1 Apr 2026 12:22:10 -0700 Subject: [PATCH 05/11] Update changelog --- CHANGELOG.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 24738d99..5a0e0838 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,8 +18,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- **`_refresh_auth_headers()` init ordering** — auth header cache is now built immediately after credentials are set, before the tgCloud ping and JWT verification. Prevents `AttributeError` on `_cached_token_auth` when connecting without a token (e.g. `TigerGraphConnection(host=..., username=..., password=...)`). -- **Boolean query parameters causing `yarl` errors** — `upsertEdge()`, `upsertEdges()` (`vertexMustExist`), `getVersion()` (`verbose`), and `rebuildGraph()` (`force`) now convert boolean values to lowercase strings before passing them as URL query parameters. +- **`_refresh_auth_headers()` called earlier in `__init__`** — auth header cache is now built immediately after credentials are set. +- **tgCloud auto-detection simplified** — removed the HTTP ping to `/api/ping`; detection now relies solely on the hostname containing `"tgcloud"`. +- **`threading.local()` init ordering** — `self._local` and `self._restpp_failover_lock` are now created before `super().__init__()` in `pyTigerGraphBase`. +- **Boolean query parameter conversion** — `upsertEdge()`, `upsertEdges()` (`vertexMustExist`), `getVersion()` (`verbose`), and `rebuildGraph()` (`force`) now convert boolean values to lowercase strings. - **`dropVertices()`** now correctly falls back to `self.graphname` when the `graph` parameter is `None`. - **`dropAllDataSources()`** now correctly uses `self.graphname` fallback for the 4.x REST API path. - **`getVectorIndexStatus()`** no longer produces a malformed URL when called without a graph name; now supports global scope (returns status for all graphs). From 22e62e8714f2f4a8ea008bc7b503a0d9f539041c Mon Sep 17 00:00:00 2001 From: Chengbiao Jin Date: Fri, 3 Apr 2026 14:30:23 -0700 Subject: [PATCH 06/11] Improve auth and installQueries for TG 4.x - Unify _cached_token_auth and _cached_pwd_auth into single _cached_auth with JWT > apiToken > Basic fallback, so GSQL endpoints use apiToken when available instead of always falling back to Basic auth - Add `wait` param to installQueries (sync default=True, async default=False) to control blocking/polling behavior - Track token origin (_token_source) to auto-refresh on 401 for tokens generated by getToken(), while raising errors for user-provided tokens --- pyTigerGraph/common/base.py | 25 +++++--------- pyTigerGraph/pyTigerGraphAuth.py | 1 + pyTigerGraph/pyTigerGraphBase.py | 7 ++++ pyTigerGraph/pyTigerGraphQuery.py | 29 ++++++++++++++-- pyTigerGraph/pytgasync/pyTigerGraphAuth.py | 1 + pyTigerGraph/pytgasync/pyTigerGraphBase.py | 8 +++++ pyTigerGraph/pytgasync/pyTigerGraphQuery.py | 37 ++++++++++++--------- tests/test_common_base.py | 37 +++++++++------------ 8 files changed, 90 insertions(+), 55 deletions(-) diff --git a/pyTigerGraph/common/base.py b/pyTigerGraph/common/base.py index 9ceccc82..1ae46300 100644 --- a/pyTigerGraph/common/base.py +++ b/pyTigerGraph/common/base.py @@ -119,12 +119,13 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "", self.jwtToken = jwtToken self.apiToken = apiToken + self._token_source = "user" if (apiToken or jwtToken) else None self.base64_credential = base64.b64encode( "{0}:{1}".format(self.username, self.password).encode("utf-8")).decode("utf-8") - # Pre-build auth header dicts immediately after credentials are set so - # _prep_req can safely use _cached_token_auth/_cached_pwd_auth in any - # subsequent _get()/_req() call (e.g. tgCloud ping, JWT verification). + # Pre-build the cached auth header dict immediately after credentials + # are set so _prep_req can safely use _cached_auth in any subsequent + # _get()/_req() call (e.g. tgCloud ping, JWT verification). self._refresh_auth_headers() # Detect auth mode automatically by checking if jwtToken or apiToken is provided @@ -298,18 +299,16 @@ def _set_auth_header(self): return {"Authorization": "Basic {0}".format(self.base64_credential)} def _refresh_auth_headers(self) -> None: - """Pre-build per-authMode header dicts used by every request. + """Pre-build the cached auth header dict used by every request. Called once at __init__ and again after getToken() updates the credentials. Eliminates per-request isinstance checks and string formatting in _prep_req's hot path. - Two dicts are kept because authMode can be either "token" or "pwd": - - "token": JWT > apiToken (tuple or str) > Basic - - "pwd": JWT > Basic + Fallback order: JWT > apiToken (tuple or str) > Basic. The "X-User-Agent" header is baked in so _prep_req skips that update too. """ - # ---- token mode ---- + # JWT > apiToken > Basic auth if isinstance(self.jwtToken, str) and self.jwtToken.strip(): token_val = "Bearer " + self.jwtToken elif isinstance(self.apiToken, tuple): @@ -319,11 +318,7 @@ def _refresh_auth_headers(self) -> None: else: token_val = "Basic " + self.base64_credential - # ---- pwd mode ---- - pwd_val = ("Bearer " + self.jwtToken) if self.jwtToken else ("Basic " + self.base64_credential) - - self._cached_token_auth = {"Authorization": token_val, "X-User-Agent": "pyTigerGraph"} - self._cached_pwd_auth = {"Authorization": pwd_val, "X-User-Agent": "pyTigerGraph"} + self._cached_auth = {"Authorization": token_val, "X-User-Agent": "pyTigerGraph"} def _verify_jwt_token_support(self): try: @@ -385,9 +380,7 @@ def _prep_req(self, authMode, headers, url, method, data): # Shallow-copy the pre-built header dict (auth + X-User-Agent already included). # _refresh_auth_headers() keeps these current after every getToken() call. - _headers = dict( - self._cached_token_auth if authMode == "token" else self._cached_pwd_auth - ) + _headers = dict(self._cached_auth) if headers: _headers.update(headers) diff --git a/pyTigerGraph/pyTigerGraphAuth.py b/pyTigerGraph/pyTigerGraphAuth.py index 73db6af5..f39e3fb7 100644 --- a/pyTigerGraph/pyTigerGraphAuth.py +++ b/pyTigerGraph/pyTigerGraphAuth.py @@ -423,6 +423,7 @@ def getToken(self, self.apiToken = token self.authHeader = auth_header self.authMode = "token" + self._token_source = "generated" logger.debug("exit: getToken") return token diff --git a/pyTigerGraph/pyTigerGraphBase.py b/pyTigerGraph/pyTigerGraphBase.py index 794ff20b..d1e3c982 100644 --- a/pyTigerGraph/pyTigerGraphBase.py +++ b/pyTigerGraph/pyTigerGraphBase.py @@ -221,6 +221,13 @@ def _req(self, method: str, url: str, authMode: str = "token", headers: dict = N conn_err = e if res is not None: + # Auto-refresh token on 401 if the token was generated by getToken() + if res.status_code == 401 and getattr(self, "_token_source", None) == "generated": + self.getToken() + self._refresh_auth_headers() + _headers, _data, _ = self._prep_req(authMode, headers, url, method, data) + res = self._do_request(method, url, _headers, _data, jsonData, params, http_timeout) + if not skipCheck and not (200 <= res.status_code < 300) and res.status_code != 404: try: self._error_check(json.loads(res.content)) diff --git a/pyTigerGraph/pyTigerGraphQuery.py b/pyTigerGraph/pyTigerGraphQuery.py index b746c643..30c3bf9a 100644 --- a/pyTigerGraph/pyTigerGraphQuery.py +++ b/pyTigerGraph/pyTigerGraphQuery.py @@ -5,6 +5,7 @@ """ import json import logging +import time from datetime import datetime from typing import TYPE_CHECKING, Union, Optional @@ -318,7 +319,7 @@ def getInstalledQueries(self, fmt: str = "py") -> Union[dict, str, list, 'pd.Dat return ret - def installQueries(self, queries: Union[str, list], flag: Union[str, list] = None) -> str: + def installQueries(self, queries: Union[str, list], flag: Union[str, list] = None, wait: bool = True) -> str: """Installs one or more queries. Args: @@ -326,11 +327,15 @@ def installQueries(self, queries: Union[str, list], flag: Union[str, list] = Non A single query string or a list of query strings to install. Use '*' or 'all' to install all queries. flag: Method to install queries. - - '-single' Install the query in single gpr mode. + - '-single' Install the query in single gpr mode. - '-legacy' Install the query in UDF mode. - '-debug' Present results contains debug info. - '-cost' Present results contains performance consumption. - '-force' Install the query even if it already installed. + wait: + If True, polls the installation status until the job completes before returning. + If False, returns immediately with the server response containing the requestId. + Defaults to True for sync connections. Returns: The response from the server. @@ -360,7 +365,25 @@ def installQueries(self, queries: Union[str, list], flag: Union[str, list] = Non flag = ",".join(flag) params["flag"] = flag - ret = self._req("GET", self.gsUrl + "/gsql/v1/queries/install", params=params, authMode="pwd", resKey=None) + res = self._req("GET", self.gsUrl + "/gsql/v1/queries/install", params=params, authMode="pwd", resKey=None) + + if wait: + # TG 4.1 may respond synchronously (no requestId) or asynchronously (with requestId). + # If a requestId is present, poll until the job completes. + request_id = res.get("requestId") if isinstance(res, dict) else None + if request_id: + ret = None + while not ret: + ret = self._req("GET", self.gsUrl + "/gsql/v1/queries/install/" + str(request_id), authMode="pwd", resKey=None) + if "SUCCESS" in ret["message"] or "FAILED" in ret["message"]: + break + else: + ret = None + time.sleep(1) + else: + ret = res + else: + ret = res if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) diff --git a/pyTigerGraph/pytgasync/pyTigerGraphAuth.py b/pyTigerGraph/pytgasync/pyTigerGraphAuth.py index 2f3d9a5a..99d3400d 100644 --- a/pyTigerGraph/pytgasync/pyTigerGraphAuth.py +++ b/pyTigerGraph/pytgasync/pyTigerGraphAuth.py @@ -371,6 +371,7 @@ async def getToken(self, secret: str = None, setToken: bool = True, lifetime: in self.apiToken = token self.authHeader = auth_header self.authMode = "token" + self._token_source = "generated" logger.debug("exit: getToken") return token diff --git a/pyTigerGraph/pytgasync/pyTigerGraphBase.py b/pyTigerGraph/pytgasync/pyTigerGraphBase.py index 8ac92827..2b0eef91 100644 --- a/pyTigerGraph/pytgasync/pyTigerGraphBase.py +++ b/pyTigerGraph/pytgasync/pyTigerGraphBase.py @@ -164,6 +164,14 @@ async def _req(self, method: str, url: str, authMode: str = "token", headers: di conn_err = e if resp is not None: + # Auto-refresh token on 401 if the token was generated by getToken() + if status == 401 and getattr(self, "_token_source", None) == "generated": + await self.getToken() + self._refresh_auth_headers() + _headers, _data, _ = self._prep_req(authMode, headers, url, method, data) + status, body, resp = await self._do_request( + method, url, _headers, _data, jsonData, params, http_timeout) + if not skipCheck and not (200 <= status < 300) and status != 404: try: self._error_check(json.loads(body)) diff --git a/pyTigerGraph/pytgasync/pyTigerGraphQuery.py b/pyTigerGraph/pytgasync/pyTigerGraphQuery.py index 835b1097..3f0889e7 100644 --- a/pyTigerGraph/pytgasync/pyTigerGraphQuery.py +++ b/pyTigerGraph/pytgasync/pyTigerGraphQuery.py @@ -3,8 +3,8 @@ The functions on this page run installed or interpret queries in TigerGraph. All functions in this module are called as methods on a link:https://docs.tigergraph.com/pytigergraph/current/core-functions/base[`TigerGraphConnection` object]. """ -import logging import asyncio +import logging from typing import TYPE_CHECKING, Union, Optional @@ -316,7 +316,7 @@ async def getInstalledQueries(self, fmt: str = "py") -> Union[dict, str, list, ' return ret - async def installQueries(self, queries: Union[str, list], flag: Union[str, list] = None) -> str: + async def installQueries(self, queries: Union[str, list], flag: Union[str, list] = None, wait: bool = False) -> str: """Installs one or more queries. Args: @@ -324,11 +324,15 @@ async def installQueries(self, queries: Union[str, list], flag: Union[str, list] A single query string or a list of query strings to install. Use '*' or 'all' to install all queries. flag: Method to install queries. - - '-single' Install the query in single gpr mode. + - '-single' Install the query in single gpr mode. - '-legacy' Install the query in UDF mode. - '-debug' Present results contains debug info. - '-cost' Present results contains performance consumption. - '-force' Install the query even if it already installed. + wait: + If True, polls the installation status until the job completes before returning. + If False, returns immediately with the server response containing the requestId. + Defaults to False for async connections. Returns: The response from the server. @@ -360,18 +364,21 @@ async def installQueries(self, queries: Union[str, list], flag: Union[str, list] res = await self._req("GET", self.gsUrl + "/gsql/v1/queries/install", params=params, authMode="pwd", resKey=None) - # TG 4.1 may respond synchronously (no requestId) or asynchronously (with requestId). - # If a requestId is present, poll until the job completes. - request_id = res.get("requestId") if isinstance(res, dict) else None - if request_id: - ret = None - while not ret: - ret = await self._req("GET", self.gsUrl + "/gsql/v1/queries/install/" + str(request_id), authMode="pwd", resKey=None) - if "SUCCESS" in ret["message"] or "FAILED" in ret["message"]: - break - else: - ret = None - await asyncio.sleep(1) + if wait: + # TG 4.1 may respond synchronously (no requestId) or asynchronously (with requestId). + # If a requestId is present, poll until the job completes. + request_id = res.get("requestId") if isinstance(res, dict) else None + if request_id: + ret = None + while not ret: + ret = await self._req("GET", self.gsUrl + "/gsql/v1/queries/install/" + str(request_id), authMode="pwd", resKey=None) + if "SUCCESS" in ret["message"] or "FAILED" in ret["message"]: + break + else: + ret = None + await asyncio.sleep(1) + else: + ret = res else: ret = res diff --git a/tests/test_common_base.py b/tests/test_common_base.py index cecc6369..7b519998 100644 --- a/tests/test_common_base.py +++ b/tests/test_common_base.py @@ -1,7 +1,7 @@ """Unit tests for pyTigerGraph.common.base (PyTigerGraphCore). These tests run without a live TigerGraph server by mocking network calls. -They guard against init-ordering bugs where _cached_token_auth is accessed +They guard against init-ordering bugs where _cached_auth is accessed before _refresh_auth_headers() has been called. """ @@ -29,27 +29,23 @@ class TestRefreshAuthHeadersOrdering(unittest.TestCase): """_refresh_auth_headers() must be called before any _get()/_req() in __init__. Regression test for GML-2041 ordering bug: - _cached_token_auth was set AFTER _verify_jwt_token_support() (and the + _cached_auth was set AFTER _verify_jwt_token_support() (and the tgCloud ping), causing AttributeError swallowed as a JWT error message. """ def test_cached_auth_set_with_username_password(self): conn = _make_conn() - self.assertTrue(hasattr(conn, "_cached_token_auth")) - self.assertTrue(hasattr(conn, "_cached_pwd_auth")) - self.assertIn("Basic ", conn._cached_token_auth["Authorization"]) - self.assertIn("Basic ", conn._cached_pwd_auth["Authorization"]) + self.assertTrue(hasattr(conn, "_cached_auth")) + self.assertIn("Basic ", conn._cached_auth["Authorization"]) def test_cached_auth_set_with_api_token(self): conn = _make_conn(apiToken="myapitoken123") - self.assertIn("Bearer myapitoken123", conn._cached_token_auth["Authorization"]) - self.assertIn("Basic ", conn._cached_pwd_auth["Authorization"]) + self.assertIn("Bearer myapitoken123", conn._cached_auth["Authorization"]) def test_cached_auth_set_with_jwt_token(self): """Regression: jwtToken must not cause AttributeError during __init__.""" conn = _make_conn(jwtToken="header.payload.signature") - self.assertIn("Bearer header.payload.signature", conn._cached_token_auth["Authorization"]) - self.assertIn("Bearer header.payload.signature", conn._cached_pwd_auth["Authorization"]) + self.assertIn("Bearer header.payload.signature", conn._cached_auth["Authorization"]) def test_jwt_token_calls_verify(self): """_verify_jwt_token_support() must be called when jwtToken is provided.""" @@ -70,24 +66,23 @@ def test_tgcloud_ping_does_not_crash_without_jwt(self): """tgCloud _get() ping fires before _verify_jwt_token_support; must not AttributeError.""" with patch.object(TigerGraphConnection, "_get", return_value="pong") as mock_get: conn = TigerGraphConnection(host="http://my.tgcloud.io") - # _cached_token_auth must exist at the point _get() was called - self.assertTrue(hasattr(conn, "_cached_token_auth")) + # _cached_auth must exist at the point _get() was called + self.assertTrue(hasattr(conn, "_cached_auth")) def test_tgcloud_ping_does_not_crash_with_jwt(self): - """tgCloud ping + JWT verification both fire; _cached_token_auth must precede both.""" + """tgCloud ping + JWT verification both fire; _cached_auth must precede both.""" with patch.object(TigerGraphConnection, "_get", return_value="pong"): with patch.object(TigerGraphConnection, "_verify_jwt_token_support"): conn = TigerGraphConnection( host="http://my.tgcloud.io", jwtToken="header.payload.signature", ) - self.assertIn("Bearer header.payload.signature", conn._cached_token_auth["Authorization"]) + self.assertIn("Bearer header.payload.signature", conn._cached_auth["Authorization"]) def test_x_user_agent_header_present(self): - """X-User-Agent must be baked into cached auth dicts.""" + """X-User-Agent must be baked into cached auth dict.""" conn = _make_conn() - self.assertEqual(conn._cached_token_auth.get("X-User-Agent"), "pyTigerGraph") - self.assertEqual(conn._cached_pwd_auth.get("X-User-Agent"), "pyTigerGraph") + self.assertEqual(conn._cached_auth.get("X-User-Agent"), "pyTigerGraph") class TestRefreshAuthHeadersUpdate(unittest.TestCase): @@ -95,21 +90,21 @@ class TestRefreshAuthHeadersUpdate(unittest.TestCase): def test_refresh_after_get_token(self): conn = _make_conn() - self.assertIn("Basic ", conn._cached_token_auth["Authorization"]) + self.assertIn("Basic ", conn._cached_auth["Authorization"]) conn.apiToken = "newtoken456" conn._refresh_auth_headers() - self.assertIn("Bearer newtoken456", conn._cached_token_auth["Authorization"]) + self.assertIn("Bearer newtoken456", conn._cached_auth["Authorization"]) def test_refresh_clears_old_token(self): conn = _make_conn(apiToken="oldtoken") - self.assertIn("Bearer oldtoken", conn._cached_token_auth["Authorization"]) + self.assertIn("Bearer oldtoken", conn._cached_auth["Authorization"]) conn.apiToken = "" conn._refresh_auth_headers() - self.assertIn("Basic ", conn._cached_token_auth["Authorization"]) + self.assertIn("Basic ", conn._cached_auth["Authorization"]) if __name__ == "__main__": From fd097aba18afaa475a3f9deaa6d8c49254237ea5 Mon Sep 17 00:00:00 2001 From: Chengbiao Jin Date: Fri, 3 Apr 2026 15:03:39 -0700 Subject: [PATCH 07/11] Add tests and update changelog for auth and installQueries changes --- CHANGELOG.md | 8 +- tests/test_common_base.py | 110 ++++++ tests/test_v202_changes.py | 760 +++++++++++++++++++++++++++++++++++++ 3 files changed, 877 insertions(+), 1 deletion(-) create mode 100644 tests/test_v202_changes.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a0e0838..6ac31746 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [2.0.2] - 2026-03-27 +## [2.0.2] - 2026-04-03 ### New Features @@ -15,6 +15,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **Graph scope control** — `useGraph(graphName)` and `useGlobal()` methods on the connection object, mirroring GSQL's `USE GRAPH` / `USE GLOBAL`. `useGlobal()` doubles as a context manager for temporary global scoping (`with conn.useGlobal(): ...`). - **GSQL reserved keyword helpers** — `getReservedKeywords()` and `isReservedKeyword(name)` static methods to query the canonical set of GSQL reserved keywords. - **Conda build support** — `build.sh` now supports `--conda-build`, `--conda-upload`, `--conda-all`, and `--conda-forge-test` for building and validating conda packages. +- **`installQueries()` now supports a `wait` parameter** — controls whether the call blocks until installation completes. Defaults to `True` for sync and `False` for async connections. + +### Improved + +- **Token-based auth for GSQL endpoints** — GSQL endpoints (e.g. `installQueries`) now use API token or JWT when available, instead of always requiring username/password. +- **Automatic token refresh on expiration** — tokens obtained via `getToken()` are automatically refreshed when the server returns a 401; user-provided tokens raise an error instead. ### Fixed diff --git a/tests/test_common_base.py b/tests/test_common_base.py index 7b519998..0508e952 100644 --- a/tests/test_common_base.py +++ b/tests/test_common_base.py @@ -107,5 +107,115 @@ def test_refresh_clears_old_token(self): self.assertIn("Basic ", conn._cached_auth["Authorization"]) +# ────────────────────────────────────────────────────────────────────── +# _token_source tracking +# ────────────────────────────────────────────────────────────────────── + +class TestTokenSource(unittest.TestCase): + """_token_source tracks whether the token was user-provided or generated.""" + + def test_no_token_source_is_none(self): + conn = _make_conn() + self.assertIsNone(conn._token_source) + + def test_api_token_source_is_user(self): + conn = _make_conn(apiToken="usertoken") + self.assertEqual(conn._token_source, "user") + + def test_jwt_token_source_is_user(self): + conn = _make_conn(jwtToken="header.payload.signature") + self.assertEqual(conn._token_source, "user") + + def test_get_token_sets_source_to_generated(self): + conn = _make_conn() + self.assertIsNone(conn._token_source) + + with patch.object(conn, "_token", return_value=({"token": "newtoken"}, "4")): + conn.getToken() + + self.assertEqual(conn._token_source, "generated") + + def test_get_token_overrides_user_source(self): + conn = _make_conn(apiToken="usertoken") + self.assertEqual(conn._token_source, "user") + + with patch.object(conn, "_token", return_value=({"token": "newtoken"}, "4")): + conn.getToken() + + self.assertEqual(conn._token_source, "generated") + + +# ────────────────────────────────────────────────────────────────────── +# Auto-refresh on 401 +# ────────────────────────────────────────────────────────────────────── + +class TestAutoRefreshOn401(unittest.TestCase): + """Token auto-refresh on 401 for generated tokens; error for user tokens.""" + + def _mock_response(self, status_code=200, content=b'{"results": "ok"}'): + resp = MagicMock() + resp.status_code = status_code + resp.content = content + resp.raise_for_status = MagicMock() + if status_code >= 400: + import requests + resp.raise_for_status.side_effect = requests.exceptions.HTTPError( + response=resp) + return resp + + def test_401_with_generated_token_auto_refreshes(self): + conn = _make_conn() + conn._token_source = "generated" + conn.restppPort = "9000" + + resp_401 = self._mock_response(401, b'{"message": "token expired"}') + resp_200 = self._mock_response(200, b'{"results": "ok"}') + + with patch.object(conn, "_do_request", side_effect=[resp_401, resp_200]) as mock_do, \ + patch.object(conn, "getToken", return_value="newtoken") as mock_get_token: + result = conn._req("GET", "http://127.0.0.1:9000/query/test") + + mock_get_token.assert_called_once() + self.assertEqual(mock_do.call_count, 2) + self.assertEqual(result, "ok") + + def test_401_with_user_token_raises(self): + import requests + conn = _make_conn(apiToken="usertoken") + conn.restppPort = "9000" + + resp_401 = self._mock_response(401, b'{"message": "token expired"}') + + with patch.object(conn, "_do_request", return_value=resp_401): + with self.assertRaises(requests.exceptions.HTTPError): + conn._req("GET", "http://127.0.0.1:9000/query/test") + + def test_401_with_no_token_raises(self): + import requests + conn = _make_conn() + conn.restppPort = "9000" + + resp_401 = self._mock_response(401, b'{"message": "unauthorized"}') + + with patch.object(conn, "_do_request", return_value=resp_401): + with self.assertRaises(requests.exceptions.HTTPError): + conn._req("GET", "http://127.0.0.1:9000/query/test") + + def test_non_401_error_not_refreshed(self): + import requests + conn = _make_conn() + conn._token_source = "generated" + conn.restppPort = "9000" + + resp_500 = self._mock_response(500, b'{"message": "server error"}') + + with patch.object(conn, "_do_request", return_value=resp_500), \ + patch.object(conn, "getToken") as mock_get_token: + with self.assertRaises(requests.exceptions.HTTPError): + conn._req("GET", "http://127.0.0.1:9000/query/test") + + mock_get_token.assert_not_called() + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_v202_changes.py b/tests/test_v202_changes.py new file mode 100644 index 00000000..7713f0cf --- /dev/null +++ b/tests/test_v202_changes.py @@ -0,0 +1,760 @@ +"""Unit tests for pyTigerGraph 2.0.2 changes. + +These tests run without a live TigerGraph server by mocking HTTP calls. +They cover: + - createGraph() with vertexTypes/edgeTypes + - Boolean query parameter conversion (upsertEdge, upsertEdges, getGSQLVersion, rebuildGraph) + - dropVertices() graph fallback + - dropAllDataSources() graphname fallback + - getVectorIndexStatus() URL construction + - previewSampleData() graph validation + - Schema change job APIs + - runSchemaChange() with force parameter + - useGraph() / useGlobal() scope control + - Reserved keyword helpers + - installQueries() wait parameter +""" + +import json +import unittest +from unittest.mock import MagicMock, patch, call + +from pyTigerGraph import TigerGraphConnection +from pyTigerGraph.common.exception import TigerGraphException + + +def _make_conn(graphname="testgraph", **kwargs): + """Create a TigerGraphConnection without network calls.""" + defaults = dict( + host="http://127.0.0.1", + graphname=graphname, + username="tigergraph", + password="tigergraph", + ) + defaults.update(kwargs) + with patch.object(TigerGraphConnection, "_verify_jwt_token_support", return_value=None): + conn = TigerGraphConnection(**defaults) + return conn + + +def _make_conn_v4(graphname="testgraph", **kwargs): + """Create a connection that reports TigerGraph >= 4.0.""" + conn = _make_conn(graphname=graphname, **kwargs) + conn._version_greater_than_4_0 = MagicMock(return_value=True) + return conn + + +def _make_conn_v3(graphname="testgraph", **kwargs): + """Create a connection that reports TigerGraph < 4.0.""" + conn = _make_conn(graphname=graphname, **kwargs) + conn._version_greater_than_4_0 = MagicMock(return_value=False) + return conn + + +# ────────────────────────────────────────────────────────────────────── +# createGraph() with vertexTypes / edgeTypes +# ────────────────────────────────────────────────────────────────────── + +class TestCreateGraphWithTypes(unittest.TestCase): + + def test_create_graph_no_types_v4(self): + conn = _make_conn_v4() + with patch.object(conn, "_post", return_value={"error": False, "message": "ok"}) as mock: + conn.createGraph("myGraph") + args, kwargs = mock.call_args + self.assertIn("/gsql/v1/schema/graphs", args[0]) + sent_data = kwargs.get("data") or args[1] if len(args) > 1 else kwargs["data"] + self.assertEqual(sent_data["gsql"], "CREATE GRAPH myGraph()") + + def test_create_graph_with_vertex_types_v4(self): + conn = _make_conn_v4() + with patch.object(conn, "_post", return_value={"error": False, "message": "ok"}) as mock: + conn.createGraph("myGraph", vertexTypes=["Person", "Company"]) + sent_data = mock.call_args[1].get("data") or mock.call_args[0][1] + self.assertEqual(sent_data["gsql"], "CREATE GRAPH myGraph(Person, Company)") + + def test_create_graph_with_edge_types_v4(self): + conn = _make_conn_v4() + with patch.object(conn, "_post", return_value={"error": False, "message": "ok"}) as mock: + conn.createGraph("myGraph", edgeTypes=["Knows", "WorksAt"]) + sent_data = mock.call_args[1].get("data") or mock.call_args[0][1] + self.assertEqual(sent_data["gsql"], "CREATE GRAPH myGraph(Knows, WorksAt)") + + def test_create_graph_with_both_types_v4(self): + conn = _make_conn_v4() + with patch.object(conn, "_post", return_value={"error": False, "message": "ok"}) as mock: + conn.createGraph("myGraph", vertexTypes=["Person"], edgeTypes=["Knows"]) + sent_data = mock.call_args[1].get("data") or mock.call_args[0][1] + self.assertEqual(sent_data["gsql"], "CREATE GRAPH myGraph(Person, Knows)") + + def test_create_graph_with_wildcard_v4(self): + conn = _make_conn_v4() + with patch.object(conn, "_post", return_value={"error": False, "message": "ok"}) as mock: + conn.createGraph("myGraph", vertexTypes=["*"]) + sent_data = mock.call_args[1].get("data") or mock.call_args[0][1] + self.assertEqual(sent_data["gsql"], "CREATE GRAPH myGraph(*)") + + def test_create_graph_v4_params(self): + conn = _make_conn_v4() + with patch.object(conn, "_post", return_value={"error": False, "message": "ok"}) as mock: + conn.createGraph("myGraph") + _, kwargs = mock.call_args + self.assertEqual(kwargs["params"]["gsql"], "true") + self.assertEqual(kwargs["params"]["graphName"], "myGraph") + + def test_create_graph_v3_fallback(self): + conn = _make_conn_v3() + with patch.object(conn, "gsql", return_value="Graph myGraph created.") as mock: + result = conn.createGraph("myGraph", vertexTypes=["Person"], edgeTypes=["Knows"]) + mock.assert_called_once_with("CREATE GRAPH myGraph(Person, Knows)") + self.assertFalse(result["error"]) + + def test_create_graph_v3_no_types(self): + conn = _make_conn_v3() + with patch.object(conn, "gsql", return_value="Graph myGraph created.") as mock: + conn.createGraph("myGraph") + mock.assert_called_once_with("CREATE GRAPH myGraph()") + + +# ────────────────────────────────────────────────────────────────────── +# Boolean query parameter conversion +# ────────────────────────────────────────────────────────────────────── + +class TestBooleanParamConversion(unittest.TestCase): + + def test_upsert_edge_vertex_must_exist_true(self): + conn = _make_conn_v4() + with patch.object(conn, "_post", return_value=[{"accepted_edges": 1}]) as mock: + conn.upsertEdge("Person", "1", "Knows", "Person", "2", vertexMustExist=True) + _, kwargs = mock.call_args + self.assertEqual(kwargs.get("params", {}).get("vertex_must_exist"), "true") + + def test_upsert_edge_vertex_must_exist_false(self): + conn = _make_conn_v4() + with patch.object(conn, "_post", return_value=[{"accepted_edges": 1}]) as mock: + conn.upsertEdge("Person", "1", "Knows", "Person", "2", vertexMustExist=False) + _, kwargs = mock.call_args + self.assertEqual(kwargs.get("params", {}).get("vertex_must_exist"), "false") + + def test_upsert_edges_vertex_must_exist(self): + conn = _make_conn_v4() + edges = [("1", "2", {})] + with patch.object(conn, "_post", return_value=[{"accepted_edges": 1}]) as mock: + conn.upsertEdges("Person", "Knows", "Person", edges, vertexMustExist=True) + _, kwargs = mock.call_args + self.assertEqual(kwargs.get("params", {}).get("vertex_must_exist"), "true") + + def test_get_gsql_version_verbose(self): + conn = _make_conn_v4() + with patch.object(conn, "_get", return_value={"version": "4.2"}) as mock: + conn.getGSQLVersion(verbose=True) + _, kwargs = mock.call_args + self.assertEqual(kwargs.get("params", {}).get("verbose"), "true") + + def test_rebuild_graph_force(self): + conn = _make_conn_v4() + with patch.object(conn, "_get", return_value={"error": False, "message": "ok"}) as mock: + conn.rebuildGraph(force=True) + _, kwargs = mock.call_args + params = kwargs.get("params", {}) + self.assertEqual(params.get("force"), "true") + + def test_rebuild_graph_force_false_omitted(self): + """When force=False (default), the param should not be included.""" + conn = _make_conn_v4() + with patch.object(conn, "_get", return_value={"error": False, "message": "ok"}) as mock: + conn.rebuildGraph(force=False) + _, kwargs = mock.call_args + params = kwargs.get("params", {}) + self.assertNotIn("force", params) + + +# ────────────────────────────────────────────────────────────────────── +# dropVertices() graph fallback +# ────────────────────────────────────────────────────────────────────── + +class TestDropVerticesGraphFallback(unittest.TestCase): + + def test_explicit_graph_param(self): + conn = _make_conn_v4() + with patch.object(conn, "_delete", return_value={"error": False, "message": "ok"}) as mock: + conn.dropVertices("MyVertex", graph="explicitGraph") + _, kwargs = mock.call_args + self.assertEqual(kwargs["params"]["graph"], "explicitGraph") + + def test_fallback_to_self_graphname(self): + conn = _make_conn_v4(graphname="defaultGraph") + with patch.object(conn, "_delete", return_value={"error": False, "message": "ok"}) as mock: + conn.dropVertices("MyVertex") + _, kwargs = mock.call_args + self.assertEqual(kwargs["params"]["graph"], "defaultGraph") + + def test_no_graph_omits_param(self): + conn = _make_conn_v4(graphname="") + with patch.object(conn, "_delete", return_value={"error": False, "message": "ok"}) as mock: + conn.dropVertices("MyVertex") + _, kwargs = mock.call_args + self.assertNotIn("graph", kwargs["params"]) + + def test_list_of_vertex_names(self): + conn = _make_conn_v4() + with patch.object(conn, "_delete", return_value={"error": False, "message": "ok"}) as mock: + conn.dropVertices(["V1", "V2", "V3"]) + _, kwargs = mock.call_args + self.assertEqual(kwargs["params"]["vertex"], "V1,V2,V3") + + def test_empty_list_raises(self): + conn = _make_conn_v4() + with self.assertRaises(TigerGraphException): + conn.dropVertices([]) + + def test_invalid_type_raises(self): + conn = _make_conn_v4() + with self.assertRaises(TigerGraphException): + conn.dropVertices(123) + + def test_v3_raises(self): + conn = _make_conn_v3() + with self.assertRaises(TigerGraphException): + conn.dropVertices("MyVertex") + + def test_ignore_errors_retry_individually(self): + conn = _make_conn_v4() + call_count = [0] + + def side_effect(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + raise Exception("batch fail") + vertex = kwargs.get("params", {}).get("vertex", "") + if vertex == "V2": + raise Exception("not found") + return {"error": False, "message": "ok"} + + with patch.object(conn, "_delete", side_effect=side_effect): + result = conn.dropVertices(["V1", "V2"], ignoreErrors=True) + self.assertIn("V1", result["message"]) + self.assertIn("V2", result["message"]) + + +# ────────────────────────────────────────────────────────────────────── +# dropAllDataSources() graphname fallback +# ────────────────────────────────────────────────────────────────────── + +class TestDropAllDataSourcesGraphFallback(unittest.TestCase): + + def test_uses_explicit_graphname(self): + conn = _make_conn_v4(graphname="defaultG") + with patch.object(conn, "_req", return_value={"message": "ok"}) as mock: + conn.dropAllDataSources(graphName="explicitG") + url = mock.call_args[0][1] + self.assertIn("graph=explicitG", url) + + def test_fallback_to_self_graphname(self): + conn = _make_conn_v4(graphname="defaultG") + with patch.object(conn, "_req", return_value={"message": "ok"}) as mock: + conn.dropAllDataSources() + url = mock.call_args[0][1] + self.assertIn("graph=defaultG", url) + + def test_no_graph_no_query_param(self): + conn = _make_conn_v4(graphname="") + with patch.object(conn, "_req", return_value={"message": "ok"}) as mock: + conn.dropAllDataSources() + url = mock.call_args[0][1] + self.assertNotIn("graph=", url) + self.assertIn("/data-sources/dropAll", url) + + +# ────────────────────────────────────────────────────────────────────── +# getVectorIndexStatus() URL construction +# ────────────────────────────────────────────────────────────────────── + +class TestGetVectorIndexStatus(unittest.TestCase): + + def test_with_all_params(self): + conn = _make_conn_v4() + with patch.object(conn, "_req", return_value={"status": "ready"}) as mock: + conn.getVectorIndexStatus(graphName="g", vertexType="Person", vectorName="emb") + url = mock.call_args[0][1] + self.assertTrue(url.endswith("/vector/status/g/Person/emb")) + + def test_with_graph_only(self): + conn = _make_conn_v4() + with patch.object(conn, "_req", return_value={"status": "ready"}) as mock: + conn.getVectorIndexStatus(graphName="g") + url = mock.call_args[0][1] + self.assertTrue(url.endswith("/vector/status/g")) + + def test_fallback_to_self_graphname(self): + conn = _make_conn_v4(graphname="defaultG") + with patch.object(conn, "_req", return_value={"status": "ready"}) as mock: + conn.getVectorIndexStatus() + url = mock.call_args[0][1] + self.assertTrue(url.endswith("/vector/status/defaultG")) + + def test_no_graph_global_scope(self): + """Without a graph, the URL should be /vector/status (no trailing graph segment).""" + conn = _make_conn_v4(graphname="") + with patch.object(conn, "_req", return_value={"status": "ready"}) as mock: + conn.getVectorIndexStatus() + url = mock.call_args[0][1] + self.assertTrue(url.endswith("/vector/status")) + self.assertNotIn("/vector/status/", url) + + def test_vertex_type_ignored_without_graph(self): + """vertexType requires a graph segment; without graph, it's silently ignored.""" + conn = _make_conn_v4(graphname="") + with patch.object(conn, "_req", return_value={"status": "ready"}) as mock: + conn.getVectorIndexStatus(vertexType="Person") + url = mock.call_args[0][1] + self.assertNotIn("Person", url) + + +# ────────────────────────────────────────────────────────────────────── +# previewSampleData() graph validation +# ────────────────────────────────────────────────────────────────────── + +class TestPreviewSampleData(unittest.TestCase): + + def test_raises_without_graph(self): + conn = _make_conn_v4(graphname="") + with self.assertRaises(TigerGraphException) as ctx: + conn.previewSampleData("ds1", "/path/to/file.csv") + self.assertIn("graph name", str(ctx.exception).lower()) + + def test_raises_on_v3(self): + conn = _make_conn_v3() + with self.assertRaises(NotImplementedError): + conn.previewSampleData("ds1", "/path/to/file.csv") + + def test_uses_explicit_graph(self): + conn = _make_conn_v4(graphname="defaultG") + with patch.object(conn, "_req", return_value={"results": []}) as mock: + conn.previewSampleData("ds1", "/file.csv", graphName="explicitG") + data = mock.call_args[1].get("data") or mock.call_args[0][2] + self.assertEqual(data["graphName"], "explicitG") + + def test_fallback_to_self_graphname(self): + conn = _make_conn_v4(graphname="defaultG") + with patch.object(conn, "_req", return_value={"results": []}) as mock: + conn.previewSampleData("ds1", "/file.csv") + data = mock.call_args[1].get("data") or mock.call_args[0][2] + self.assertEqual(data["graphName"], "defaultG") + + +# ────────────────────────────────────────────────────────────────────── +# Schema Change Job APIs +# ────────────────────────────────────────────────────────────────────── + +class TestSchemaChangeJobAPIs(unittest.TestCase): + + def test_create_schema_change_job_gsql(self): + conn = _make_conn_v4(graphname="g1") + with patch.object(conn, "_post", return_value={"error": False, "message": "ok"}) as mock: + conn.createSchemaChangeJob("job1", "ADD VERTEX V1 (PRIMARY_ID id UINT);") + url = mock.call_args[0][0] + self.assertIn("/gsql/v1/schema/jobs/job1", url) + kwargs = mock.call_args[1] + self.assertEqual(kwargs["params"]["gsql"], "true") + self.assertEqual(kwargs["params"]["graph"], "g1") + + def test_create_schema_change_job_gsql_list(self): + conn = _make_conn_v4(graphname="g1") + stmts = [ + "ADD VERTEX V1 (PRIMARY_ID id UINT)", + "ADD VERTEX V2 (PRIMARY_ID id UINT)" + ] + with patch.object(conn, "_post", return_value={"error": False, "message": "ok"}) as mock: + conn.createSchemaChangeJob("job1", stmts) + sent_data = json.loads(mock.call_args[1].get("data", "{}")) + self.assertIn("ADD VERTEX V1", sent_data["gsql"]) + self.assertIn("ADD VERTEX V2", sent_data["gsql"]) + + def test_create_schema_change_job_global(self): + conn = _make_conn_v4(graphname="") + with patch.object(conn, "_post", return_value={"error": False, "message": "ok"}) as mock: + conn.createSchemaChangeJob("job1", "ADD VERTEX V1 (PRIMARY_ID id UINT);") + kwargs = mock.call_args[1] + self.assertEqual(kwargs["params"]["gsql"], "true") + self.assertEqual(kwargs["params"]["type"], "global") + sent_data = json.loads(kwargs.get("data", "{}")) + self.assertIn("CREATE GLOBAL SCHEMA_CHANGE JOB", sent_data["gsql"]) + + def test_create_schema_change_job_json(self): + conn = _make_conn_v4(graphname="g1") + json_body = {"some": "config"} + with patch.object(conn, "_post", return_value={"error": False, "message": "ok"}) as mock: + conn.createSchemaChangeJob("job1", json_body) + kwargs = mock.call_args[1] + self.assertEqual(kwargs["params"]["graph"], "g1") + self.assertNotIn("gsql", kwargs["params"]) + + def test_get_schema_change_jobs_single(self): + conn = _make_conn_v4(graphname="g1") + with patch.object(conn, "_get", return_value={"jobs": []}) as mock: + conn.getSchemaChangeJobs(jobName="job1") + url = mock.call_args[0][0] + self.assertIn("/gsql/v1/schema/jobs/job1", url) + kwargs = mock.call_args[1] + self.assertEqual(kwargs["params"]["graph"], "g1") + self.assertEqual(kwargs["params"]["json"], "true") + + def test_get_schema_change_jobs_all(self): + conn = _make_conn_v4(graphname="g1") + with patch.object(conn, "_get", return_value={"jobs": []}) as mock: + conn.getSchemaChangeJobs() + url = mock.call_args[0][0] + self.assertTrue(url.endswith("/gsql/v1/schema/jobs")) + + def test_run_schema_change_job(self): + conn = _make_conn_v4(graphname="g1") + with patch.object(conn, "_put", return_value={"error": False, "message": "ok"}) as mock: + conn.runSchemaChangeJob("job1") + url = mock.call_args[0][0] + self.assertIn("/gsql/v1/schema/jobs/job1", url) + self.assertIn("graph=g1", url) + + def test_run_schema_change_job_with_force(self): + conn = _make_conn_v4(graphname="g1") + with patch.object(conn, "_put", return_value={"error": False, "message": "ok"}) as mock: + conn.runSchemaChangeJob("job1", force=True) + url = mock.call_args[0][0] + self.assertIn("force=true", url) + self.assertIn("graph=g1", url) + + def test_drop_schema_change_jobs_single(self): + conn = _make_conn_v4(graphname="g1") + with patch.object(conn, "_delete", return_value={"error": False, "message": "ok"}) as mock: + conn.dropSchemaChangeJobs("job1") + kwargs = mock.call_args[1] + self.assertEqual(kwargs["params"]["jobName"], "job1") + self.assertEqual(kwargs["params"]["graph"], "g1") + + def test_drop_schema_change_jobs_list(self): + conn = _make_conn_v4(graphname="g1") + with patch.object(conn, "_delete", return_value={"error": False, "message": "ok"}) as mock: + conn.dropSchemaChangeJobs(["job1", "job2"]) + kwargs = mock.call_args[1] + self.assertEqual(kwargs["params"]["jobName"], "job1,job2") + + def test_drop_schema_change_jobs_explicit_graph(self): + conn = _make_conn_v4(graphname="g1") + with patch.object(conn, "_delete", return_value={"error": False, "message": "ok"}) as mock: + conn.dropSchemaChangeJobs("job1", graphName="otherGraph") + kwargs = mock.call_args[1] + self.assertEqual(kwargs["params"]["graph"], "otherGraph") + + +# ────────────────────────────────────────────────────────────────────── +# runSchemaChange() with force parameter +# ────────────────────────────────────────────────────────────────────── + +class TestRunSchemaChangeForce(unittest.TestCase): + + def test_json_path_with_force(self): + conn = _make_conn_v4(graphname="g1") + with patch.object(conn, "_post", return_value={"error": False, "message": "ok"}) as mock: + conn.runSchemaChange({"schema": "change"}, force=True) + kwargs = mock.call_args[1] + self.assertEqual(kwargs["params"]["force"], "true") + self.assertEqual(kwargs["params"]["graph"], "g1") + + def test_json_path_without_force(self): + conn = _make_conn_v4(graphname="g1") + with patch.object(conn, "_post", return_value={"error": False, "message": "ok"}) as mock: + conn.runSchemaChange({"schema": "change"}, force=False) + kwargs = mock.call_args[1] + self.assertNotIn("force", kwargs["params"]) + + def test_json_path_raises_on_v3(self): + conn = _make_conn_v3() + with self.assertRaises(TigerGraphException): + conn.runSchemaChange({"schema": "change"}) + + def test_gsql_string_path(self): + conn = _make_conn_v4(graphname="g1") + with patch.object(conn, "gsql", return_value="ok") as mock: + conn.runSchemaChange("ADD VERTEX V1 (PRIMARY_ID id UINT);") + gsql_cmd = mock.call_args[0][0] + self.assertIn("USE GRAPH g1", gsql_cmd) + self.assertIn("CREATE SCHEMA_CHANGE JOB", gsql_cmd) + self.assertIn("RUN SCHEMA_CHANGE JOB", gsql_cmd) + self.assertIn("DROP JOB", gsql_cmd) + + def test_gsql_list_path(self): + conn = _make_conn_v4(graphname="g1") + stmts = ["ADD VERTEX V1 (PRIMARY_ID id UINT)", "ADD VERTEX V2 (PRIMARY_ID id UINT)"] + with patch.object(conn, "gsql", return_value="ok") as mock: + conn.runSchemaChange(stmts) + gsql_cmd = mock.call_args[0][0] + self.assertIn("ADD VERTEX V1", gsql_cmd) + self.assertIn("ADD VERTEX V2", gsql_cmd) + + def test_gsql_global_scope(self): + conn = _make_conn_v4(graphname="") + with patch.object(conn, "gsql", return_value="ok") as mock: + conn.runSchemaChange("ADD VERTEX V1 (PRIMARY_ID id UINT);") + gsql_cmd = mock.call_args[0][0] + self.assertIn("CREATE GLOBAL SCHEMA_CHANGE JOB", gsql_cmd) + self.assertIn("RUN GLOBAL SCHEMA_CHANGE JOB", gsql_cmd) + self.assertNotIn("USE GRAPH", gsql_cmd) + + +# ────────────────────────────────────────────────────────────────────── +# useGraph() / useGlobal() scope control +# ────────────────────────────────────────────────────────────────────── + +class TestGraphScopeControl(unittest.TestCase): + + def test_use_graph(self): + conn = _make_conn(graphname="original") + conn.useGraph("newGraph") + self.assertEqual(conn.graphname, "newGraph") + + def test_use_graph_empty_delegates_to_global(self): + conn = _make_conn(graphname="original") + conn.useGraph("") + self.assertEqual(conn.graphname, "") + + def test_use_global(self): + conn = _make_conn(graphname="original") + conn.useGlobal() + self.assertEqual(conn.graphname, "") + + def test_use_global_context_manager(self): + conn = _make_conn(graphname="original") + with conn.useGlobal(): + self.assertEqual(conn.graphname, "") + self.assertEqual(conn.graphname, "original") + + def test_use_global_context_manager_restores_on_exception(self): + conn = _make_conn(graphname="original") + try: + with conn.useGlobal(): + self.assertEqual(conn.graphname, "") + raise ValueError("test error") + except ValueError: + pass + self.assertEqual(conn.graphname, "original") + + def test_use_global_context_manager_nested(self): + conn = _make_conn(graphname="original") + with conn.useGlobal(): + self.assertEqual(conn.graphname, "") + conn.useGraph("inner") + self.assertEqual(conn.graphname, "inner") + self.assertEqual(conn.graphname, "original") + + +# ────────────────────────────────────────────────────────────────────── +# Reserved keyword helpers +# ────────────────────────────────────────────────────────────────────── + +class TestReservedKeywords(unittest.TestCase): + + def test_get_reserved_keywords_returns_frozenset(self): + kw = TigerGraphConnection.getReservedKeywords() + self.assertIsInstance(kw, frozenset) + + def test_get_reserved_keywords_not_empty(self): + kw = TigerGraphConnection.getReservedKeywords() + self.assertGreater(len(kw), 50) + + def test_known_keywords_present(self): + kw = TigerGraphConnection.getReservedKeywords() + for word in ["SELECT", "CREATE", "DROP", "VERTEX", "EDGE", "GRAPH", + "FROM", "WHERE", "AND", "OR", "NOT", "INT", "STRING", + "BOOL", "FLOAT", "DOUBLE", "UINT", "PRIMARY_ID"]: + self.assertIn(word, kw, f"{word} should be a reserved keyword") + + def test_is_reserved_keyword_true(self): + self.assertTrue(TigerGraphConnection.isReservedKeyword("SELECT")) + self.assertTrue(TigerGraphConnection.isReservedKeyword("VERTEX")) + + def test_is_reserved_keyword_case_insensitive(self): + self.assertTrue(TigerGraphConnection.isReservedKeyword("select")) + self.assertTrue(TigerGraphConnection.isReservedKeyword("Select")) + self.assertTrue(TigerGraphConnection.isReservedKeyword("VERTEX")) + self.assertTrue(TigerGraphConnection.isReservedKeyword("vertex")) + + def test_is_reserved_keyword_false(self): + self.assertFalse(TigerGraphConnection.isReservedKeyword("myCustomName")) + self.assertFalse(TigerGraphConnection.isReservedKeyword("foobar")) + self.assertFalse(TigerGraphConnection.isReservedKeyword("")) + + +# ────────────────────────────────────────────────────────────────────── +# _wrap_gsql_result helper +# ────────────────────────────────────────────────────────────────────── + +class TestWrapGsqlResult(unittest.TestCase): + + def test_success_result(self): + from pyTigerGraph.common.gsql import _wrap_gsql_result + result = _wrap_gsql_result("Graph g1 created successfully.") + self.assertFalse(result["error"]) + self.assertEqual(result["message"], "Graph g1 created successfully.") + + def test_error_result_raises(self): + from pyTigerGraph.common.gsql import _wrap_gsql_result + with self.assertRaises(TigerGraphException): + _wrap_gsql_result("Semantic Check Fails: vertex type does not exist") + + def test_error_result_skip_check(self): + from pyTigerGraph.common.gsql import _wrap_gsql_result + result = _wrap_gsql_result( + "Semantic Check Fails: vertex type does not exist", skipCheck=True + ) + self.assertTrue(result["error"]) + + def test_none_result(self): + from pyTigerGraph.common.gsql import _wrap_gsql_result + result = _wrap_gsql_result(None) + self.assertFalse(result["error"]) + self.assertEqual(result["message"], "") + + +# ────────────────────────────────────────────────────────────────────── +# _parse_graph_list helper +# ────────────────────────────────────────────────────────────────────── + +class TestParseGraphList(unittest.TestCase): + + def test_parses_typed_entries(self): + from pyTigerGraph.common.gsql import _parse_graph_list + output = ( + " - Graph g1(Person:v, Company:v, Knows:e, WorksAt:e)\n" + " - Graph g2(V1:v)\n" + ) + result = _parse_graph_list(output) + self.assertEqual(len(result), 2) + self.assertEqual(result[0]["GraphName"], "g1") + self.assertEqual(sorted(result[0]["VertexTypes"]), ["Company", "Person"]) + self.assertEqual(sorted(result[0]["EdgeTypes"]), ["Knows", "WorksAt"]) + self.assertEqual(result[1]["GraphName"], "g2") + self.assertEqual(result[1]["VertexTypes"], ["V1"]) + self.assertEqual(result[1]["EdgeTypes"], []) + + def test_empty_output(self): + from pyTigerGraph.common.gsql import _parse_graph_list + result = _parse_graph_list("") + self.assertEqual(result, []) + + def test_none_output(self): + from pyTigerGraph.common.gsql import _parse_graph_list + result = _parse_graph_list(None) + self.assertEqual(result, []) + + +# ────────────────────────────────────────────────────────────────────── +# installQueries() wait parameter +# ────────────────────────────────────────────────────────────────────── + +def _make_conn_v41(graphname="testgraph", **kwargs): + """Create a connection that reports TigerGraph 4.1.0.""" + conn = _make_conn(graphname=graphname, **kwargs) + conn.getVer = MagicMock(return_value="4.1.0") + conn.ver = "4.1.0" + return conn + + +class TestInstallQueriesWaitParam(unittest.TestCase): + + def test_wait_true_polls_until_success(self): + """With wait=True, installQueries polls until SUCCESS.""" + conn = _make_conn_v41() + install_response = {"requestId": "req123", "message": "submitted"} + pending_response = {"requestId": "req123", "message": "RUNNING"} + success_response = {"requestId": "req123", "message": "SUCCESS"} + + with patch.object(conn, "_req", side_effect=[ + install_response, pending_response, success_response + ]) as mock_req, \ + patch("pyTigerGraph.pyTigerGraphQuery.time.sleep"): + result = conn.installQueries("my_query", wait=True) + + self.assertEqual(result, success_response) + # 1 install call + 2 status poll calls + self.assertEqual(mock_req.call_count, 3) + + def test_wait_true_polls_until_failed(self): + """With wait=True, installQueries stops polling on FAILED.""" + conn = _make_conn_v41() + install_response = {"requestId": "req123", "message": "submitted"} + failed_response = {"requestId": "req123", "message": "FAILED: compile error"} + + with patch.object(conn, "_req", side_effect=[ + install_response, failed_response + ]), \ + patch("pyTigerGraph.pyTigerGraphQuery.time.sleep"): + result = conn.installQueries("my_query", wait=True) + + self.assertEqual(result, failed_response) + + def test_wait_false_returns_immediately(self): + """With wait=False, installQueries returns the initial response.""" + conn = _make_conn_v41() + install_response = {"requestId": "req123", "message": "submitted"} + + with patch.object(conn, "_req", return_value=install_response) as mock_req: + result = conn.installQueries("my_query", wait=False) + + self.assertEqual(result, install_response) + mock_req.assert_called_once() + + def test_wait_true_no_request_id_returns_directly(self): + """With wait=True, if no requestId in response, returns immediately.""" + conn = _make_conn_v41() + sync_response = {"message": "SUCCESS"} + + with patch.object(conn, "_req", return_value=sync_response) as mock_req: + result = conn.installQueries("my_query", wait=True) + + self.assertEqual(result, sync_response) + mock_req.assert_called_once() + + def test_sync_default_wait_is_true(self): + """Sync installQueries defaults to wait=True.""" + import inspect + from pyTigerGraph.pyTigerGraphQuery import pyTigerGraphQuery + sig = inspect.signature(pyTigerGraphQuery.installQueries) + self.assertTrue(sig.parameters["wait"].default) + + def test_async_default_wait_is_false(self): + """Async installQueries defaults to wait=False.""" + import inspect + from pyTigerGraph.pytgasync.pyTigerGraphQuery import AsyncPyTigerGraphQuery + sig = inspect.signature(AsyncPyTigerGraphQuery.installQueries) + self.assertFalse(sig.parameters["wait"].default) + + def test_queries_list_joined(self): + """A list of query names is joined with commas.""" + conn = _make_conn_v41() + with patch.object(conn, "_req", return_value={"message": "SUCCESS"}) as mock_req: + conn.installQueries(["q1", "q2", "q3"], wait=False) + + _, kwargs = mock_req.call_args + self.assertEqual(kwargs["params"]["queries"], "q1,q2,q3") + + def test_flag_list_joined(self): + """A list of flags is joined with commas.""" + conn = _make_conn_v41() + with patch.object(conn, "_req", return_value={"message": "SUCCESS"}) as mock_req: + conn.installQueries("q1", flag=["-single", "-force"], wait=False) + + _, kwargs = mock_req.call_args + self.assertEqual(kwargs["params"]["flag"], "-single,-force") + + def test_rejects_version_below_4_1(self): + """installQueries raises on TigerGraph < 4.1.""" + conn = _make_conn(graphname="testgraph") + conn.getVer = MagicMock(return_value="4.0.0") + conn.ver = "4.0.0" + + with self.assertRaises(TigerGraphException): + conn.installQueries("my_query") + + +if __name__ == "__main__": + unittest.main(verbosity=2) From 93c17d3ece710b34960a806412adc9b886d381b6 Mon Sep 17 00:00:00 2001 From: Chengbiao Jin Date: Fri, 3 Apr 2026 15:54:04 -0700 Subject: [PATCH 08/11] Fix polling robustness, 401 recursion guard, and useGlobal context manager - installQueries polling: use ret.get() to avoid KeyError, increase sleep to 10s, add 1-hour timeout with TigerGraphException - 401 auto-refresh: add _refreshing_token guard to prevent infinite recursion when getToken() itself triggers a 401 - useGlobal context manager: re-capture graphname at __enter__ time so deferred use restores correctly --- CHANGELOG.md | 2 + pyTigerGraph/common/base.py | 3 + pyTigerGraph/pyTigerGraphBase.py | 15 ++-- pyTigerGraph/pyTigerGraphQuery.py | 14 ++-- pyTigerGraph/pytgasync/pyTigerGraphBase.py | 17 ++-- pyTigerGraph/pytgasync/pyTigerGraphQuery.py | 14 ++-- tests/test_common_base.py | 22 ++++++ tests/test_v202_changes.py | 86 +++++++++++++++++++++ 8 files changed, 152 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6ac31746..446cecfa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- **`installQueries()` polling robustness** — handles missing `message` key in server responses and times out after 1 hour instead of hanging indefinitely. +- **`useGlobal()` context manager** — now correctly restores the graph name when used as a deferred context manager. - **`_refresh_auth_headers()` called earlier in `__init__`** — auth header cache is now built immediately after credentials are set. - **tgCloud auto-detection simplified** — removed the HTTP ping to `/api/ping`; detection now relies solely on the hostname containing `"tgcloud"`. - **`threading.local()` init ordering** — `self._local` and `self._restpp_failover_lock` are now created before `super().__init__()` in `pyTigerGraphBase`. diff --git a/pyTigerGraph/common/base.py b/pyTigerGraph/common/base.py index 1ae46300..2e662790 100644 --- a/pyTigerGraph/common/base.py +++ b/pyTigerGraph/common/base.py @@ -248,6 +248,9 @@ def __init__(self, conn): self._saved = conn.graphname def __enter__(self): + # Re-capture if graphname changed since useGlobal() was called + if self._conn.graphname != "": + self._saved = self._conn.graphname self._conn.graphname = "" return self._conn diff --git a/pyTigerGraph/pyTigerGraphBase.py b/pyTigerGraph/pyTigerGraphBase.py index d1e3c982..ade6d415 100644 --- a/pyTigerGraph/pyTigerGraphBase.py +++ b/pyTigerGraph/pyTigerGraphBase.py @@ -222,11 +222,16 @@ def _req(self, method: str, url: str, authMode: str = "token", headers: dict = N if res is not None: # Auto-refresh token on 401 if the token was generated by getToken() - if res.status_code == 401 and getattr(self, "_token_source", None) == "generated": - self.getToken() - self._refresh_auth_headers() - _headers, _data, _ = self._prep_req(authMode, headers, url, method, data) - res = self._do_request(method, url, _headers, _data, jsonData, params, http_timeout) + if res.status_code == 401 and getattr(self, "_token_source", None) == "generated" \ + and not getattr(self, "_refreshing_token", False): + try: + self._refreshing_token = True + self.getToken() + self._refresh_auth_headers() + _headers, _data, _ = self._prep_req(authMode, headers, url, method, data) + res = self._do_request(method, url, _headers, _data, jsonData, params, http_timeout) + finally: + self._refreshing_token = False if not skipCheck and not (200 <= res.status_code < 300) and res.status_code != 404: try: diff --git a/pyTigerGraph/pyTigerGraphQuery.py b/pyTigerGraph/pyTigerGraphQuery.py index 30c3bf9a..fa82d6ae 100644 --- a/pyTigerGraph/pyTigerGraphQuery.py +++ b/pyTigerGraph/pyTigerGraphQuery.py @@ -372,14 +372,18 @@ def installQueries(self, queries: Union[str, list], flag: Union[str, list] = Non # If a requestId is present, poll until the job completes. request_id = res.get("requestId") if isinstance(res, dict) else None if request_id: + max_retries = 360 # 1 hour with 10s sleep ret = None - while not ret: + for _ in range(max_retries): ret = self._req("GET", self.gsUrl + "/gsql/v1/queries/install/" + str(request_id), authMode="pwd", resKey=None) - if "SUCCESS" in ret["message"] or "FAILED" in ret["message"]: + msg = ret.get("message", "") if isinstance(ret, dict) else "" + if "SUCCESS" in msg or "FAILED" in msg: break - else: - ret = None - time.sleep(1) + ret = None + time.sleep(10) + else: + raise TigerGraphException( + "Query installation timed out after polling for 1 hour.", 0) else: ret = res else: diff --git a/pyTigerGraph/pytgasync/pyTigerGraphBase.py b/pyTigerGraph/pytgasync/pyTigerGraphBase.py index 2b0eef91..634933ba 100644 --- a/pyTigerGraph/pytgasync/pyTigerGraphBase.py +++ b/pyTigerGraph/pytgasync/pyTigerGraphBase.py @@ -165,12 +165,17 @@ async def _req(self, method: str, url: str, authMode: str = "token", headers: di if resp is not None: # Auto-refresh token on 401 if the token was generated by getToken() - if status == 401 and getattr(self, "_token_source", None) == "generated": - await self.getToken() - self._refresh_auth_headers() - _headers, _data, _ = self._prep_req(authMode, headers, url, method, data) - status, body, resp = await self._do_request( - method, url, _headers, _data, jsonData, params, http_timeout) + if status == 401 and getattr(self, "_token_source", None) == "generated" \ + and not getattr(self, "_refreshing_token", False): + try: + self._refreshing_token = True + await self.getToken() + self._refresh_auth_headers() + _headers, _data, _ = self._prep_req(authMode, headers, url, method, data) + status, body, resp = await self._do_request( + method, url, _headers, _data, jsonData, params, http_timeout) + finally: + self._refreshing_token = False if not skipCheck and not (200 <= status < 300) and status != 404: try: diff --git a/pyTigerGraph/pytgasync/pyTigerGraphQuery.py b/pyTigerGraph/pytgasync/pyTigerGraphQuery.py index 3f0889e7..0291597f 100644 --- a/pyTigerGraph/pytgasync/pyTigerGraphQuery.py +++ b/pyTigerGraph/pytgasync/pyTigerGraphQuery.py @@ -369,14 +369,18 @@ async def installQueries(self, queries: Union[str, list], flag: Union[str, list] # If a requestId is present, poll until the job completes. request_id = res.get("requestId") if isinstance(res, dict) else None if request_id: + max_retries = 360 # 1 hour with 10s sleep ret = None - while not ret: + for _ in range(max_retries): ret = await self._req("GET", self.gsUrl + "/gsql/v1/queries/install/" + str(request_id), authMode="pwd", resKey=None) - if "SUCCESS" in ret["message"] or "FAILED" in ret["message"]: + msg = ret.get("message", "") if isinstance(ret, dict) else "" + if "SUCCESS" in msg or "FAILED" in msg: break - else: - ret = None - await asyncio.sleep(1) + ret = None + await asyncio.sleep(10) + else: + raise TigerGraphException( + "Query installation timed out after polling for 1 hour.", 0) else: ret = res else: diff --git a/tests/test_common_base.py b/tests/test_common_base.py index 0508e952..2dce5715 100644 --- a/tests/test_common_base.py +++ b/tests/test_common_base.py @@ -216,6 +216,28 @@ def test_non_401_error_not_refreshed(self): mock_get_token.assert_not_called() + def test_401_refresh_does_not_recurse(self): + """If getToken() itself triggers a 401, it must not recurse infinitely.""" + import requests + conn = _make_conn() + conn._token_source = "generated" + conn.restppPort = "9000" + + resp_401 = self._mock_response(401, b'{"message": "token expired"}') + + def fake_get_token(*a, **kw): + # Simulate getToken calling _req which also gets 401. + # The guard flag should prevent recursion. + raise requests.exceptions.HTTPError(response=resp_401) + + with patch.object(conn, "_do_request", return_value=resp_401), \ + patch.object(conn, "getToken", side_effect=fake_get_token): + with self.assertRaises(requests.exceptions.HTTPError): + conn._req("GET", "http://127.0.0.1:9000/query/test") + + # Guard flag must be cleared after the failed refresh + self.assertFalse(getattr(conn, "_refreshing_token", False)) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_v202_changes.py b/tests/test_v202_changes.py index 7713f0cf..a6a4aef7 100644 --- a/tests/test_v202_changes.py +++ b/tests/test_v202_changes.py @@ -755,6 +755,92 @@ def test_rejects_version_below_4_1(self): with self.assertRaises(TigerGraphException): conn.installQueries("my_query") + def test_wait_true_times_out(self): + """Polling raises TigerGraphException after max retries.""" + conn = _make_conn_v41() + install_response = {"requestId": "req123", "message": "submitted"} + running_response = {"requestId": "req123", "message": "RUNNING"} + + with patch.object(conn, "_req", side_effect=[install_response] + [running_response] * 360), \ + patch("pyTigerGraph.pyTigerGraphQuery.time.sleep"): + with self.assertRaises(TigerGraphException) as ctx: + conn.installQueries("my_query", wait=True) + self.assertIn("timed out", str(ctx.exception).lower()) + + def test_wait_true_handles_missing_message(self): + """Polling handles responses without 'message' key gracefully.""" + conn = _make_conn_v41() + install_response = {"requestId": "req123", "message": "submitted"} + no_message_response = {"requestId": "req123", "status": "unknown"} + success_response = {"requestId": "req123", "message": "SUCCESS"} + + with patch.object(conn, "_req", side_effect=[ + install_response, no_message_response, success_response + ]), \ + patch("pyTigerGraph.pyTigerGraphQuery.time.sleep"): + result = conn.installQueries("my_query", wait=True) + + self.assertEqual(result, success_response) + + +# ────────────────────────────────────────────────────────────────────── +# _GlobalScope context manager +# ────────────────────────────────────────────────────────────────────── + +class TestGlobalScopeDeferredUse(unittest.TestCase): + + def test_deferred_context_manager_restores_correctly(self): + """Deferred use of context manager saves graphname at __enter__ time.""" + conn = _make_conn(graphname="original") + scope = conn.useGlobal() # graphname -> "" + conn.useGraph("changed") # graphname -> "changed" + with scope: + self.assertEqual(conn.graphname, "") + # Should restore "changed" (captured at __enter__), not "original" + self.assertEqual(conn.graphname, "changed") + + +# ────────────────────────────────────────────────────────────────────── +# runSchemaChangeJob URL encoding +# ────────────────────────────────────────────────────────────────────── + +class TestRunSchemaChangeJobUrl(unittest.TestCase): + + def test_url_includes_graph_name(self): + conn = _make_conn_v4(graphname="myGraph") + with patch.object(conn, "_put", return_value={"error": False, "message": "ok"}) as mock: + conn.runSchemaChangeJob("job1") + url = mock.call_args[0][0] + self.assertIn("graph=myGraph", url) + + def test_url_includes_force(self): + conn = _make_conn_v4(graphname="myGraph") + with patch.object(conn, "_put", return_value={"error": False, "message": "ok"}) as mock: + conn.runSchemaChangeJob("job1", force=True) + url = mock.call_args[0][0] + self.assertIn("force=true", url) + + +# ────────────────────────────────────────────────────────────────────── +# createSchemaChangeJob Content-Type +# ────────────────────────────────────────────────────────────────────── + +class TestCreateSchemaChangeJobContentType(unittest.TestCase): + + def test_gsql_path_sends_text_plain_content_type(self): + conn = _make_conn_v4(graphname="g1") + with patch.object(conn, "_post", return_value={"error": False, "message": "ok"}) as mock: + conn.createSchemaChangeJob("job1", "ADD VERTEX V1 (PRIMARY_ID id UINT);") + kwargs = mock.call_args[1] + self.assertEqual(kwargs["headers"]["Content-Type"], "text/plain") + + def test_dict_path_sends_json_content_type(self): + conn = _make_conn_v4(graphname="g1") + with patch.object(conn, "_post", return_value={"error": False, "message": "ok"}) as mock: + conn.createSchemaChangeJob("job1", {"some": "config"}) + kwargs = mock.call_args[1] + self.assertEqual(kwargs["headers"]["Content-Type"], "application/json") + if __name__ == "__main__": unittest.main(verbosity=2) From e57c968f70e52d81560840d7a16b64a625f220d6 Mon Sep 17 00:00:00 2001 From: Chengbiao Jin Date: Fri, 3 Apr 2026 16:08:53 -0700 Subject: [PATCH 09/11] Update conda recipe version to 2.0.2 --- pytigergraph-recipe/recipe/meta.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytigergraph-recipe/recipe/meta.yaml b/pytigergraph-recipe/recipe/meta.yaml index 28854685..7b877372 100644 --- a/pytigergraph-recipe/recipe/meta.yaml +++ b/pytigergraph-recipe/recipe/meta.yaml @@ -2,11 +2,11 @@ package: name: pytigergraph - version: "2.0.1" + version: "2.0.2" source: - url: https://pypi.org/packages/source/p/pytigergraph/pytigergraph-2.0.1.tar.gz - sha256: 6c0834a7abdacf4b00c41e3603c397954a50825677ddadde7a03086eaae479f4 + url: https://pypi.org/packages/source/p/pytigergraph/pytigergraph-2.0.2.tar.gz + # sha256: update after publishing to PyPI build: number: 0 From b23f5dc93f866e2257508b6635583a1a387d5aa4 Mon Sep 17 00:00:00 2001 From: Chengbiao Jin Date: Fri, 3 Apr 2026 16:30:34 -0700 Subject: [PATCH 10/11] Add params support to _put and use it in runSchemaChangeJob - Add params argument to _put (sync and async) to match _get/_post/_delete - runSchemaChangeJob now passes query params via _put(params=...) instead of manual query string construction --- pyTigerGraph/pyTigerGraphBase.py | 11 +++++-- pyTigerGraph/pyTigerGraphSchema.py | 10 +++---- pyTigerGraph/pytgasync/pyTigerGraphBase.py | 11 +++++-- pyTigerGraph/pytgasync/pyTigerGraphSchema.py | 10 +++---- tests/test_v202_changes.py | 30 +++++++++++++------- 5 files changed, 45 insertions(+), 27 deletions(-) diff --git a/pyTigerGraph/pyTigerGraphBase.py b/pyTigerGraph/pyTigerGraphBase.py index ade6d415..8ffc68d1 100644 --- a/pyTigerGraph/pyTigerGraphBase.py +++ b/pyTigerGraph/pyTigerGraphBase.py @@ -362,7 +362,8 @@ def _post(self, url: str, authMode: str = "token", headers: dict = None, return res - def _put(self, url: str, authMode: str = "token", data=None, resKey=None, jsonData=False) -> Union[dict, list]: + def _put(self, url: str, authMode: str = "token", data=None, resKey=None, + jsonData=False, params: Union[dict, list, str] = None) -> Union[dict, list]: """Generic PUT method. Args: @@ -370,6 +371,12 @@ def _put(self, url: str, authMode: str = "token", data=None, resKey=None, jsonDa Complete REST++ API URL including path and parameters. authMode: Authentication mode, either `"token"` (default) or `"pwd"`. + data: + Request payload, typically a JSON document. + resKey: + The JSON subdocument to be returned, default is `None`. + params: + Request URL parameters. Returns: The response from the request (as a dictionary). @@ -379,7 +386,7 @@ def _put(self, url: str, authMode: str = "token", data=None, resKey=None, jsonDa logger.debug("params: " + self._locals(locals())) res = self._req("PUT", url, authMode, data=data, - resKey=resKey, jsonData=jsonData) + resKey=resKey, jsonData=jsonData, params=params) if logger.level == logging.DEBUG: logger.debug("return: " + str(res)) diff --git a/pyTigerGraph/pyTigerGraphSchema.py b/pyTigerGraph/pyTigerGraphSchema.py index c3fb63d9..12ab9da3 100644 --- a/pyTigerGraph/pyTigerGraphSchema.py +++ b/pyTigerGraph/pyTigerGraphSchema.py @@ -830,17 +830,15 @@ def runSchemaChangeJob(self, jobName: str, graphName: str = None, gname = graphName or self.graphname - query_parts = [] + params = {} if gname: - query_parts.append(f"graph={gname}") + params["graph"] = gname if force: - query_parts.append("force=true") + params["force"] = "true" url = self.gsUrl + "/gsql/v1/schema/jobs/" + jobName - if query_parts: - url += "?" + "&".join(query_parts) - res = self._put(url, authMode="pwd", resKey=None) + res = self._put(url, authMode="pwd", resKey=None, params=params) if logger.level == logging.DEBUG: logger.debug("return: " + str(res)) diff --git a/pyTigerGraph/pytgasync/pyTigerGraphBase.py b/pyTigerGraph/pytgasync/pyTigerGraphBase.py index 634933ba..fdf065f2 100644 --- a/pyTigerGraph/pytgasync/pyTigerGraphBase.py +++ b/pyTigerGraph/pytgasync/pyTigerGraphBase.py @@ -303,7 +303,8 @@ async def _post(self, url: str, authMode: str = "token", headers: dict = None, return res - async def _put(self, url: str, authMode: str = "token", data=None, resKey=None, jsonData=False) -> Union[dict, list]: + async def _put(self, url: str, authMode: str = "token", data=None, resKey=None, + jsonData=False, params: Union[dict, list, str] = None) -> Union[dict, list]: """Generic PUT method. Args: @@ -311,6 +312,12 @@ async def _put(self, url: str, authMode: str = "token", data=None, resKey=None, Complete REST++ API URL including path and parameters. authMode: Authentication mode, either `"token"` (default) or `"pwd"`. + data: + Request payload, typically a JSON document. + resKey: + The JSON subdocument to be returned, default is `None`. + params: + Request URL parameters. Returns: The response from the request (as a dictionary). @@ -319,7 +326,7 @@ async def _put(self, url: str, authMode: str = "token", data=None, resKey=None, if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - res = await self._req("PUT", url, authMode, data=data, resKey=resKey, jsonData=jsonData) + res = await self._req("PUT", url, authMode, data=data, resKey=resKey, jsonData=jsonData, params=params) if logger.level == logging.DEBUG: logger.debug("return: " + str(res)) diff --git a/pyTigerGraph/pytgasync/pyTigerGraphSchema.py b/pyTigerGraph/pytgasync/pyTigerGraphSchema.py index 7c4f2640..7370c63c 100644 --- a/pyTigerGraph/pytgasync/pyTigerGraphSchema.py +++ b/pyTigerGraph/pytgasync/pyTigerGraphSchema.py @@ -832,17 +832,15 @@ async def runSchemaChangeJob(self, jobName: str, graphName: str = None, gname = graphName or self.graphname - query_parts = [] + params = {} if gname: - query_parts.append(f"graph={gname}") + params["graph"] = gname if force: - query_parts.append("force=true") + params["force"] = "true" url = self.gsUrl + "/gsql/v1/schema/jobs/" + jobName - if query_parts: - url += "?" + "&".join(query_parts) - res = await self._put(url, authMode="pwd", resKey=None) + res = await self._put(url, authMode="pwd", resKey=None, params=params) if logger.level == logging.DEBUG: logger.debug("return: " + str(res)) diff --git a/tests/test_v202_changes.py b/tests/test_v202_changes.py index a6a4aef7..5009ece3 100644 --- a/tests/test_v202_changes.py +++ b/tests/test_v202_changes.py @@ -413,15 +413,16 @@ def test_run_schema_change_job(self): conn.runSchemaChangeJob("job1") url = mock.call_args[0][0] self.assertIn("/gsql/v1/schema/jobs/job1", url) - self.assertIn("graph=g1", url) + kwargs = mock.call_args[1] + self.assertEqual(kwargs["params"]["graph"], "g1") def test_run_schema_change_job_with_force(self): conn = _make_conn_v4(graphname="g1") with patch.object(conn, "_put", return_value={"error": False, "message": "ok"}) as mock: conn.runSchemaChangeJob("job1", force=True) - url = mock.call_args[0][0] - self.assertIn("force=true", url) - self.assertIn("graph=g1", url) + kwargs = mock.call_args[1] + self.assertEqual(kwargs["params"]["force"], "true") + self.assertEqual(kwargs["params"]["graph"], "g1") def test_drop_schema_change_jobs_single(self): conn = _make_conn_v4(graphname="g1") @@ -804,21 +805,28 @@ def test_deferred_context_manager_restores_correctly(self): # runSchemaChangeJob URL encoding # ────────────────────────────────────────────────────────────────────── -class TestRunSchemaChangeJobUrl(unittest.TestCase): +class TestRunSchemaChangeJobParams(unittest.TestCase): - def test_url_includes_graph_name(self): + def test_passes_graph_in_params(self): conn = _make_conn_v4(graphname="myGraph") with patch.object(conn, "_put", return_value={"error": False, "message": "ok"}) as mock: conn.runSchemaChangeJob("job1") - url = mock.call_args[0][0] - self.assertIn("graph=myGraph", url) + kwargs = mock.call_args[1] + self.assertEqual(kwargs["params"]["graph"], "myGraph") - def test_url_includes_force(self): + def test_passes_force_in_params(self): conn = _make_conn_v4(graphname="myGraph") with patch.object(conn, "_put", return_value={"error": False, "message": "ok"}) as mock: conn.runSchemaChangeJob("job1", force=True) - url = mock.call_args[0][0] - self.assertIn("force=true", url) + kwargs = mock.call_args[1] + self.assertEqual(kwargs["params"]["force"], "true") + + def test_no_graph_omits_param(self): + conn = _make_conn_v4(graphname="") + with patch.object(conn, "_put", return_value={"error": False, "message": "ok"}) as mock: + conn.runSchemaChangeJob("job1") + kwargs = mock.call_args[1] + self.assertNotIn("graph", kwargs["params"]) # ────────────────────────────────────────────────────────────────────── From 1a43108bc65320370e0c68ace1e91827763f5999 Mon Sep 17 00:00:00 2001 From: Chengbiao Jin Date: Fri, 3 Apr 2026 16:37:59 -0700 Subject: [PATCH 11/11] update build.sh to auto-update conda recipe --- build.sh | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/build.sh b/build.sh index f9065979..416a08cf 100755 --- a/build.sh +++ b/build.sh @@ -76,6 +76,32 @@ if $DO_UPLOAD; then echo "---- Uploading to PyPI ----" python3 -m twine upload dist/* + + # Update conda recipe meta.yaml with the new version and sha256 + PKG_VERSION=$(grep "^version" pyproject.toml | awk -F'"' '{print $2}') + TARBALL_URL="https://pypi.org/packages/source/p/$PYPI_PACKAGE/$PYPI_PACKAGE-$PKG_VERSION.tar.gz" + + echo "---- Updating conda recipe to $PKG_VERSION ----" + # Wait briefly for PyPI to make the tarball available + for i in $(seq 1 30); do + HTTP_CODE=$(curl -sL -o /dev/null -w "%{http_code}" "$TARBALL_URL") + if [[ "$HTTP_CODE" == "200" ]]; then + break + fi + echo " Waiting for PyPI tarball to become available... ($i/30)" + sleep 5 + done + if [[ "$HTTP_CODE" != "200" ]]; then + echo "Warning: could not fetch tarball from PyPI. Update meta.yaml manually." >&2 + else + NEW_SHA=$(curl -sL "$TARBALL_URL" | sha256sum | awk '{print $1}') + sed -i.bak "s|^ version:.*| version: \"$PKG_VERSION\"|" "$RECIPE_DIR/meta.yaml" + sed -i.bak "s|^ url:.*| url: $TARBALL_URL|" "$RECIPE_DIR/meta.yaml" + sed -i.bak "s|^ sha256:.*| sha256: $NEW_SHA|" "$RECIPE_DIR/meta.yaml" + sed -i.bak "s|^ # sha256:.*| sha256: $NEW_SHA|" "$RECIPE_DIR/meta.yaml" + rm -f "$RECIPE_DIR/meta.yaml.bak" + echo " ✓ Updated meta.yaml: version=$PKG_VERSION sha256=$NEW_SHA" + fi fi # ── Conda ───────────────────────────────────────────────────────────────────