diff --git a/cassandra/query.py b/cassandra/query.py index 246c0565ba..6c6878fdb4 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -761,6 +761,7 @@ class BatchStatement(Statement): _statements_and_parameters = None _session = None + _is_lwt = False def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None, consistency_level=None, serial_consistency_level=None, @@ -845,6 +846,8 @@ def add(self, statement, parameters=None): query_id = statement.query_id bound_statement = statement.bind(() if parameters is None else parameters) self._update_state(bound_statement) + if statement.is_lwt(): + self._is_lwt = True self._add_statement_and_params(True, query_id, bound_statement.values) elif isinstance(statement, BoundStatement): if parameters: @@ -852,6 +855,8 @@ def add(self, statement, parameters=None): "Parameters cannot be passed with a BoundStatement " "to BatchStatement.add()") self._update_state(statement) + if statement.is_lwt(): + self._is_lwt = True self._add_statement_and_params(True, statement.prepared_statement.query_id, statement.values) else: # it must be a SimpleStatement @@ -860,6 +865,8 @@ def add(self, statement, parameters=None): encoder = Encoder() if self._session is None else self._session.encoder query_string = bind_params(query_string, parameters, encoder) self._update_state(statement) + if statement.is_lwt(): + self._is_lwt = True self._add_statement_and_params(False, query_string, ()) return self @@ -893,6 +900,9 @@ def _update_state(self, statement): self._maybe_set_routing_attributes(statement) self._update_custom_payload(statement) + def is_lwt(self): + return self._is_lwt + def __len__(self): return len(self._statements_and_parameters) diff --git a/tests/unit/test_query.py b/tests/unit/test_query.py index 29c800b99c..6b0ebe690e 100644 --- a/tests/unit/test_query.py +++ b/tests/unit/test_query.py @@ -14,7 +14,7 @@ import unittest -from cassandra.query import BatchStatement, SimpleStatement +from cassandra.query import BatchStatement, PreparedStatement, SimpleStatement class BatchStatementTest(unittest.TestCase): @@ -68,3 +68,50 @@ def test_len(self): batch.add_all(statements=['%s'] * n, parameters=[(i,) for i in range(n)]) assert len(batch) == n + + def _make_prepared_statement(self, is_lwt=False): + return PreparedStatement( + column_metadata=[], + query_id=b"query-id", + routing_key_indexes=[], + query="INSERT INTO test.table (id) VALUES (1)", + keyspace=None, + protocol_version=4, + result_metadata=[], + result_metadata_id=None, + is_lwt=is_lwt, + ) + + def test_is_lwt_false_for_non_lwt_statements(self): + batch = BatchStatement() + batch.add(self._make_prepared_statement(is_lwt=False)) + batch.add(self._make_prepared_statement(is_lwt=False).bind(())) + batch.add(SimpleStatement("INSERT INTO test.table (id) VALUES (3)")) + batch.add("INSERT INTO test.table (id) VALUES (4)") + assert batch.is_lwt() is False + + def test_is_lwt_propagates_from_statements(self): + batch = BatchStatement() + batch.add(self._make_prepared_statement(is_lwt=False)) + assert batch.is_lwt() is False + + batch.add(self._make_prepared_statement(is_lwt=True)) + assert batch.is_lwt() is True + + bound_lwt = self._make_prepared_statement(is_lwt=True).bind(()) + batch_with_bound = BatchStatement() + batch_with_bound.add(bound_lwt) + assert batch_with_bound.is_lwt() is True + + class LwtSimpleStatement(SimpleStatement): + def __init__(self): + super(LwtSimpleStatement, self).__init__( + "INSERT INTO test.table (id) VALUES (2) IF NOT EXISTS" + ) + + def is_lwt(self): + return True + + batch_with_simple = BatchStatement() + batch_with_simple.add(LwtSimpleStatement()) + assert batch_with_simple.is_lwt() is True