diff --git a/packages/google-cloud-firestore/google/cloud/firestore_v1/base_pipeline.py b/packages/google-cloud-firestore/google/cloud/firestore_v1/base_pipeline.py index fac7f8bc4bce..061927783baf 100644 --- a/packages/google-cloud-firestore/google/cloud/firestore_v1/base_pipeline.py +++ b/packages/google-cloud-firestore/google/cloud/firestore_v1/base_pipeline.py @@ -274,6 +274,28 @@ def find_nearest( return self._append( stages.FindNearest(field, vector, distance_measure, options) ) + + def let(self, **variables: Expression) -> "_BasePipeline": + """ + Defines variables that can be used in subsequent pipeline stages. + This stage allows you to compute and name values based on existing data + or constants. These variables can then be referenced in later stages, + similarly to how fields are used. + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field, add + >>> pipeline = client.pipeline().collection("books") + >>> pipeline = pipeline.let( + ... rating_plus_one=add(Field.of("rating"), 1), + ... has_awards=Field.of("awards").exists() + ... ) + >>> # Later stages can use Variable.of("rating_plus_one") + Args: + **variables: Keyword arguments where keys are the variable names (str) + and values are the `Expression` objects defining them. + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Let(**variables)) def replace_with( self, diff --git a/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_stages.py b/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_stages.py index b00d923c673c..bde9505eff27 100644 --- a/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_stages.py +++ b/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_stages.py @@ -303,6 +303,24 @@ def _pb_options(self) -> dict[str, Value]: return options +class Let(Stage): + """Defines variables for use in subsequent stages.""" + + def __init__(self, **variables: Expression): + super().__init__("let") + self.variables = variables + + def _pb_args(self): + map_val = { + k: v._to_pb() for k, v in self.variables.items() + } + return [Value(map_value={"fields": map_val})] + + def __repr__(self): + vars_str = ", ".join(f"{k}={v!r}" for k, v in self.variables.items()) + return f"{self.__class__.__name__}({vars_str})" + + class RawStage(Stage): """Represents a generic, named stage with parameters.""" diff --git a/packages/google-cloud-firestore/tests/system/pipeline_e2e/general.yaml b/packages/google-cloud-firestore/tests/system/pipeline_e2e/general.yaml index 46a10cd4d1af..75a457d8c27b 100644 --- a/packages/google-cloud-firestore/tests/system/pipeline_e2e/general.yaml +++ b/packages/google-cloud-firestore/tests/system/pipeline_e2e/general.yaml @@ -684,4 +684,69 @@ tests: - args: - fieldReferenceValue: awards - stringValue: full_replace - name: replace_with \ No newline at end of file + name: replace_with + - description: testLetStage + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Let: + my_rating: + Field: rating + author_name: + Field: author + rating_plus_one: + FunctionExpression.add: + - Field: rating + - Constant: 1 + - Select: + - title + - Variable: my_rating + - Variable: author_name + - Variable: rating_plus_one + assert_results: + - title: "The Hitchhiker's Guide to the Galaxy" + my_rating: 4.8 + author_name: "Douglas Adams" + rating_plus_one: 5.8 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "The Hitchhiker's Guide to the Galaxy" + name: equal + name: where + - args: + - mapValue: + fields: + my_rating: + fieldReferenceValue: rating + author_name: + fieldReferenceValue: author + rating_plus_one: + functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 1.0 + name: add + name: let + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + my_rating: + variableReferenceValue: my_rating + author_name: + variableReferenceValue: author_name + rating_plus_one: + variableReferenceValue: rating_plus_one + name: select \ No newline at end of file diff --git a/packages/google-cloud-firestore/tests/unit/v1/test_pipeline.py b/packages/google-cloud-firestore/tests/unit/v1/test_pipeline.py index 5953398709a3..bee37a610daf 100644 --- a/packages/google-cloud-firestore/tests/unit/v1/test_pipeline.py +++ b/packages/google-cloud-firestore/tests/unit/v1/test_pipeline.py @@ -390,6 +390,7 @@ def test_pipeline_execute_stream_equivalence(): ("add_fields", (Field.of("n"),), stages.AddFields), ("remove_fields", ("name",), stages.RemoveFields), ("remove_fields", (Field.of("n"),), stages.RemoveFields), + ("let", {"var1": Field.of("n")}, stages.Let), ("select", ("name",), stages.Select), ("select", (Field.of("n"),), stages.Select), ("where", (Field.of("n").exists(),), stages.Where), @@ -422,7 +423,10 @@ def test_pipeline_execute_stream_equivalence(): def test_pipeline_methods(method, args, result_cls): start_ppl = _make_pipeline() method_ptr = getattr(start_ppl, method) - result_ppl = method_ptr(*args) + if method == "let": + result_ppl = method_ptr(**args) + else: + result_ppl = method_ptr(*args) assert result_ppl != start_ppl assert len(start_ppl.stages) == 0 assert len(result_ppl.stages) == 1 diff --git a/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_stages.py b/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_stages.py index b32a6e5d3f13..770c9d717c3f 100644 --- a/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_stages.py +++ b/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_stages.py @@ -421,6 +421,38 @@ def test_to_pb_no_options(self): assert len(result.args) == 3 +class TestLet: + def _make_one(self, *args, **kwargs): + return stages.Let(*args, **kwargs) + + def test_ctor(self): + expr1 = Field.of("field1") + expr2 = Field.of("field2").add(1) + instance = self._make_one(var1=expr1, var2=expr2) + assert instance.variables == {"var1": expr1, "var2": expr2} + assert instance.name == "let" + + def test_repr(self): + expr1 = Field.of("field1") + instance = self._make_one(var1=expr1) + repr_str = repr(instance) + assert repr_str == "Let(var1=Field.of('field1'))" + + def test_to_pb(self): + expr1 = Field.of("field1") + expr2 = Constant.of(5) + instance = self._make_one(var1=expr1, num=expr2) + result = instance._to_pb() + assert result.name == "let" + assert len(result.args) == 1 + expected_map_value = { + "var1": Value(field_reference_value="field1"), + "num": Value(integer_value=5), + } + assert result.args[0].map_value.fields == expected_map_value + assert len(result.options) == 0 + + class TestRawStage: def _make_one(self, *args, **kwargs): return stages.RawStage(*args, **kwargs)