diff --git a/sdks/python/apache_beam/yaml/examples/transforms/blueprint/write_to_jdbc.yaml b/sdks/python/apache_beam/yaml/examples/transforms/blueprint/write_to_jdbc.yaml new file mode 100644 index 000000000000..0097c73c6709 --- /dev/null +++ b/sdks/python/apache_beam/yaml/examples/transforms/blueprint/write_to_jdbc.yaml @@ -0,0 +1,45 @@ +# coding=utf-8 +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This is an example of a Beam YAML pipeline that reads from spanner database +# and writes to GCS avro files. This matches the Dataflow Template located +# here - https://cloud.google.com/dataflow/docs/guides/templates/provided/cloud-spanner-to-avro + +pipeline: + type: chain + transforms: + # Step 1: Generate sample data + - type: Create + name: CreateData + config: + elements: + - {id: 3, name: 'Smith', email: 'smith@example.com', zip: 'NY'} + - {id: 4, name: 'Beamberg', email: 'beamberg@example.com', zip: 'NY'} + # Step 2: Write records out to AlloyDB + - type: WriteToJdbc + name: WriteToAlloyDBTable + config: + location: "users" + driver_class_name: "org.postgresql.Driver" + jdbc_url: "jdbc:postgresql:///db?socketFactory=com.google.cloud.alloydb.SocketFactory&alloydbInstanceName=projects/apache-beam-testing/locations/us-central1/clusters/alloydb-yaml-test/instances/alloydb-yaml-test-primary&alloydbIpType=PUBLIC" + username: "{{ ALLOYDB_USERNAME }}" + password: "{{ ALLOYDB_PASSWORD }}" + encryption: + key: "projects/apache-beam-testing/locations/global/keyRings/tarun_test/cryptoKeys/tarun_test_username" + fields: + - username + - password \ No newline at end of file diff --git a/sdks/python/apache_beam/yaml/kms.py b/sdks/python/apache_beam/yaml/kms.py new file mode 100644 index 000000000000..747f3d096403 --- /dev/null +++ b/sdks/python/apache_beam/yaml/kms.py @@ -0,0 +1,73 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import base64 +import logging + +try: + from google.cloud import kms +except ImportError: + kms = None + +_LOGGER = logging.getLogger(__name__) + + +def decrypt_value(ciphertext: str, key_name: str) -> str: + """Decrypts a ciphertext using Google Cloud KMS. + + Args: + ciphertext: The base64 encoded ciphertext to decrypt. + key_name: The resource name of the CryptoKey to use for decryption. + Format: projects/*/locations/*/keyRings/*/cryptoKeys/* + + Returns: + The decrypted plaintext string. + + Raises: + ValueError: If the key_name format is invalid. + ImportError: If google-cloud-kms is not installed. + Exception: If decryption fails. + """ + if not key_name.startswith('projects/'): + raise ValueError(f'Key name must start with "projects/", got {key_name}') + + if kms is None: + raise ImportError( + 'google-cloud-kms is required for encryption. ' + 'Please install apache-beam[gcp] or `pip install google-cloud-kms`.') + + client = kms.KeyManagementServiceClient() + + # Decode the base64 ciphertext + try: + ciphertext_bytes = base64.b64decode(ciphertext) + except Exception as e: + raise ValueError(f"Failed to base64 decode ciphertext: {e}") from e + + # Build the request + request = { + "name": key_name, + "ciphertext": ciphertext_bytes, + } + + # Call the API + try: + response = client.decrypt(request=request) + return response.plaintext.decode('utf-8') + except Exception as e: + _LOGGER.error(f"Failed to decrypt value with key {key_name}: {e}") + raise diff --git a/sdks/python/apache_beam/yaml/kms_test.py b/sdks/python/apache_beam/yaml/kms_test.py new file mode 100644 index 000000000000..a9adaa2061a4 --- /dev/null +++ b/sdks/python/apache_beam/yaml/kms_test.py @@ -0,0 +1,92 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import base64 +import unittest +from unittest import mock + +from apache_beam.yaml import kms +from apache_beam.yaml import yaml_transform + +class KmsTest(unittest.TestCase): + + def test_decrypt_value(self): + with mock.patch('apache_beam.yaml.kms.kms') as mock_kms_module: + mock_client = mock.Mock() + mock_kms_module.KeyManagementServiceClient.return_value = mock_client + mock_response = mock.Mock() + mock_response.plaintext = b'my_secret' + mock_client.decrypt.return_value = mock_response + + ciphertext = base64.b64encode(b'encrypted_secret').decode('utf-8') + key_name = 'projects/p/locations/l/keyRings/k/cryptoKeys/c' + + plaintext = kms.decrypt_value(ciphertext, key_name) + + self.assertEqual(plaintext, 'my_secret') + mock_client.decrypt.assert_called_once() + args, kwargs = mock_client.decrypt.call_args + self.assertEqual(kwargs['request']['name'], key_name) + self.assertEqual(kwargs['request']['ciphertext'], b'encrypted_secret') + + def test_preprocess_encryption(self): + with mock.patch('apache_beam.yaml.kms.decrypt_value') as mock_decrypt: + mock_decrypt.return_value = 'decrypted_password' + + spec = { + 'type': 'MyTransform', + 'config': { + 'username': 'user', + 'password': 'encrypted_password' + }, + 'encryption': { + 'key': 'projects/p/locations/l/keyRings/k/cryptoKeys/c', + 'fields': ['password'] + } + } + + processed_spec = yaml_transform.preprocess_encryption(spec) + + self.assertNotIn('encryption', processed_spec) + self.assertEqual(processed_spec['config']['password'], 'decrypted_password') + mock_decrypt.assert_called_once_with('encrypted_password', 'projects/p/locations/l/keyRings/k/cryptoKeys/c') + + def test_preprocess_encryption_missing_key(self): + spec = { + 'type': 'MyTransform', + 'config': {'p': 'v'}, + 'encryption': { + 'fields': ['p'] + } + } + with self.assertRaisesRegex(ValueError, "Encryption block missing 'key'"): + yaml_transform.preprocess_encryption(spec) + + def test_preprocess_encryption_missing_field(self): + spec = { + 'type': 'MyTransform', + 'config': {'other': 'v'}, + 'encryption': { + 'key': 'k', + 'fields': ['missing_field'] + } + } + with self.assertRaisesRegex(ValueError, "Encrypted field 'missing_field' not found"): + yaml_transform.preprocess_encryption(spec) + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/yaml/pipeline.schema.yaml b/sdks/python/apache_beam/yaml/pipeline.schema.yaml index 35625d58d160..0b740c8665d0 100644 --- a/sdks/python/apache_beam/yaml/pipeline.schema.yaml +++ b/sdks/python/apache_beam/yaml/pipeline.schema.yaml @@ -47,6 +47,19 @@ $defs: type: string output_schema: type: object + encryption: + type: object + properties: + key: { type: string } + fields: + type: array + items: { type: string } + __line__: {} + __uuid__: {} + additionalProperties: false + required: + - key + - fields additionalProperties: true required: - type @@ -135,6 +148,7 @@ $defs: windowing: {} resource_hints: {} config: { type: object } + encryption: {} __line__: {} __uuid__: {} additionalProperties: false diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index bd1fc8da9018..61c70e15a919 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -43,6 +43,7 @@ from apache_beam.yaml.yaml_mapping import Validate from apache_beam.yaml.yaml_mapping import normalize_mapping from apache_beam.yaml.yaml_mapping import validate_generic_expressions +from apache_beam.yaml import kms from apache_beam.yaml.yaml_utils import SafeLineLoader __all__ = ["YamlTransform"] @@ -1233,6 +1234,33 @@ def apply_phase(phase, spec): return spec +def preprocess_encryption(spec): + if 'encryption' in spec: + enc_spec = spec['encryption'] + key = enc_spec.get('key') + fields = enc_spec.get('fields', []) + if not key: + raise ValueError( + f"Encryption block missing 'key' in {identify_object(spec)}") + + config = spec.get('config', {}) + for field in fields: + if field not in config: + raise ValueError( + f"Encrypted field '{field}' not found in config of {identify_object(spec)}" + ) + + try: + config[field] = kms.decrypt_value(config[field], key) + except Exception as e: + raise ValueError( + f"Failed to decrypt field '{field}' in {identify_object(spec)}: {e}" + ) from e + + del spec['encryption'] + return spec + + def preprocess(spec, verbose=False, known_transforms=None): if verbose: pprint.pprint(spec) @@ -1282,6 +1310,7 @@ def validate_transform_references(spec): return spec for phase in [ + preprocess_encryption, ensure_transforms_have_types, normalize_mapping, normalize_combine, diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 074d64ae8921..1f321eac65a2 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -486,6 +486,7 @@ def get_portability_package_data(): # GCP Packages required by ML functionality 'google-cloud-dlp>=3.0.0,<4', 'google-cloud-language>=2.0,<3', + 'google-cloud-kms>=2.0.0,<3', 'google-cloud-secret-manager>=2.0,<3', 'google-cloud-videointelligence>=2.0,<3', 'google-cloud-vision>=2,<4',