diff --git a/drivers/python/age/age.py b/drivers/python/age/age.py index fad1f27b1..c1bcf7c99 100644 --- a/drivers/python/age/age.py +++ b/drivers/python/age/age.py @@ -208,14 +208,19 @@ def _validate_column(col: str) -> str: name, type_name = parts validate_identifier(name, "Column name") validate_identifier(type_name, "Column type") - return f"{name} {type_name}" + # Only the column name is double-quoted. The type name is left + # unquoted so PostgreSQL applies its default identifier folding + # for type names in column definitions. Double-quoting would + # make the type name case-sensitive and could change type + # resolution in surprising ways for user-defined types. + return f'"{name}" {type_name}' else: validate_identifier(col, "Column name") - return f"{col} agtype" + return f'"{col}" agtype' def buildCypher(graphName:str, cypherStmt:str, columns:list) ->str: - if graphName == None: + if graphName is None: raise _EXCEPTION_GraphNotSet columnExp=[] @@ -225,16 +230,18 @@ def buildCypher(graphName:str, cypherStmt:str, columns:list) ->str: if validated: columnExp.append(validated) else: - columnExp.append('v agtype') + columnExp.append('"v" agtype') # Design note: String concatenation is used here instead of # psycopg.sql.Identifier() because column specifications are - # "name type" pairs (e.g. "v agtype") that don't map directly to + # "name type" pairs (e.g. '"v" agtype') that don't map directly to # sql.Identifier(). Each component has already been validated by # _validate_column() → validate_identifier(), which restricts - # names to ^[A-Za-z_][A-Za-z0-9_]*$ and max 63 chars. The - # graphName and cypherStmt are NOT embedded here — this template - # only contains the validated column list and static SQL keywords. + # names to ^[A-Za-z_][A-Za-z0-9_]*$ and max 63 chars. Column names + # are always double-quoted to avoid conflicts with PostgreSQL + # reserved words (e.g. "count", "order", "type"). The graphName + # and cypherStmt are NOT embedded here — this template only + # contains the validated column list and static SQL keywords. stmtArr = [] stmtArr.append("SELECT * from cypher(NULL,NULL) as (") stmtArr.append(','.join(columnExp)) diff --git a/drivers/python/test_age_py.py b/drivers/python/test_age_py.py index f904fb9e3..77bde2b9d 100644 --- a/drivers/python/test_age_py.py +++ b/drivers/python/test_age_py.py @@ -13,11 +13,17 @@ # specific language governing permissions and limitations # under the License. import json +import re from age.models import Vertex import unittest import decimal import age +# _validate_column is private but tested directly because its quoting +# behavior is security-relevant and the public surface (buildCypher) +# makes it difficult to isolate quoting assertions. +from age.age import buildCypher, _validate_column +from age.exceptions import InvalidIdentifier import argparse TEST_HOST = "localhost" @@ -28,6 +34,72 @@ TEST_GRAPH_NAME = "test_graph" +class TestBuildCypher(unittest.TestCase): + """Unit tests for buildCypher() and _validate_column() — no DB required.""" + + def test_simple_column(self): + result = buildCypher("g", "MATCH (n) RETURN n", ["n"]) + self.assertIn('"n" agtype', result) + + def test_column_with_type(self): + result = buildCypher("g", "MATCH (n) RETURN n", ["n agtype"]) + self.assertIn('"n" agtype', result) + + def test_reserved_word_count(self): + """Issue #2370: 'count' is a PostgreSQL reserved word.""" + result = buildCypher("g", "MATCH (n) RETURN count(n)", ["count"]) + self.assertIn('"count" agtype', result) + # Verify 'count' never appears unquoted as a column name + self.assertIsNone( + re.search(r'(?() RETURN type(r)", ["type"]) + self.assertIn('"type" agtype', result) + + def test_reserved_word_select(self): + """Issue #2370: 'select' is a PostgreSQL reserved word.""" + result = buildCypher("g", "MATCH (n) RETURN n", ["select"]) + self.assertIn('"select" agtype', result) + + def test_reserved_word_group(self): + """Issue #2370: 'group' is a PostgreSQL reserved word.""" + result = buildCypher("g", "MATCH (n) RETURN n", ["group"]) + self.assertIn('"group" agtype', result) + + def test_multiple_columns(self): + result = buildCypher("g", "MATCH (n) RETURN n.name, count(n)", ["name", "count"]) + self.assertIn('"name" agtype', result) + self.assertIn('"count" agtype', result) + + def test_default_column(self): + result = buildCypher("g", "MATCH (n) RETURN n", None) + self.assertIn('"v" agtype', result) + + def test_invalid_column_rejected(self): + with self.assertRaises(InvalidIdentifier): + buildCypher("g", "MATCH (n) RETURN n", ["invalid;col"]) + + def test_reserved_word_in_name_type_pair(self): + """Quoting applies even when the column is specified as 'name type'.""" + result = buildCypher("g", "MATCH (n) RETURN n.order", ["order agtype"]) + self.assertIn('"order" agtype', result) + + def test_validate_column_quoting(self): + self.assertEqual(_validate_column("v"), '"v" agtype') + self.assertEqual(_validate_column("v agtype"), '"v" agtype') + self.assertEqual(_validate_column("count"), '"count" agtype') + self.assertEqual(_validate_column("my_col"), '"my_col" agtype') + + class TestAgeBasic(unittest.TestCase): ag = None args: argparse.Namespace = argparse.Namespace( diff --git a/drivers/python/test_security.py b/drivers/python/test_security.py index 55347868e..01afee793 100644 --- a/drivers/python/test_security.py +++ b/drivers/python/test_security.py @@ -167,10 +167,10 @@ class TestColumnValidation(unittest.TestCase): """Test _validate_column prevents injection through column specs.""" def test_plain_column_name(self): - self.assertEqual(_validate_column('v'), 'v agtype') + self.assertEqual(_validate_column('v'), '"v" agtype') def test_column_with_type(self): - self.assertEqual(_validate_column('n agtype'), 'n agtype') + self.assertEqual(_validate_column('n agtype'), '"n" agtype') def test_empty_column(self): self.assertEqual(_validate_column(''), '') @@ -198,20 +198,20 @@ class TestBuildCypher(unittest.TestCase): def test_default_column(self): result = buildCypher('test_graph', 'MATCH (n) RETURN n', None) - self.assertIn('v agtype', result) + self.assertIn('"v" agtype', result) def test_single_column(self): result = buildCypher('test_graph', 'MATCH (n) RETURN n', ['n']) - self.assertIn('n agtype', result) + self.assertIn('"n" agtype', result) def test_typed_column(self): result = buildCypher('test_graph', 'MATCH (n) RETURN n', ['n agtype']) - self.assertIn('n agtype', result) + self.assertIn('"n" agtype', result) def test_multiple_columns(self): result = buildCypher('test_graph', 'MATCH (n) RETURN n', ['a', 'b']) - self.assertIn('a agtype', result) - self.assertIn('b agtype', result) + self.assertIn('"a" agtype', result) + self.assertIn('"b" agtype', result) def test_rejects_injection_in_column(self): with self.assertRaises(InvalidIdentifier):