diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index 638f0c28c8..c8843fcc5e 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -1062,6 +1062,7 @@ class BigQueryConnectionConfig(ConnectionConfig): job_retry_deadline_seconds: t.Optional[int] = None priority: t.Optional[BigQueryPriority] = None maximum_bytes_billed: t.Optional[int] = None + reservation: t.Optional[str] = None concurrent_tasks: int = 1 register_comments: bool = True @@ -1171,6 +1172,7 @@ def _extra_engine_config(self) -> t.Dict[str, t.Any]: "job_retry_deadline_seconds", "priority", "maximum_bytes_billed", + "reservation", } } diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index 59a56b6ace..6e5ae11a61 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -140,8 +140,10 @@ def _job_params(self) -> t.Dict[str, t.Any]: "priority", BigQueryPriority.INTERACTIVE.bigquery_constant ), } - if self._extra_config.get("maximum_bytes_billed"): + if self._extra_config.get("maximum_bytes_billed") is not None: params["maximum_bytes_billed"] = self._extra_config.get("maximum_bytes_billed") + if self._extra_config.get("reservation") is not None: + params["reservation"] = self._extra_config.get("reservation") if self.correlation_id: # BigQuery label keys must be lowercase key = self.correlation_id.job_type.value.lower() @@ -1106,7 +1108,9 @@ def _execute( else [] ) + # Create job config job_config = QueryJobConfig(**self._job_params, connection_properties=connection_properties) + self._query_job = self._db_call( self.client.query, query=sql, diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index 480d186fa1..2dfb9c6313 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -1951,7 +1951,7 @@ def test_init_dbt_template(runner: CliRunner, tmp_path: Path): def test_init_project_engine_configs(tmp_path): engine_type_to_config = { "redshift": "# concurrent_tasks: 4\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # user: \n # password: \n # database: \n # host: \n # port: \n # source_address: \n # unix_sock: \n # ssl: \n # sslmode: \n # timeout: \n # tcp_keepalive: \n # application_name: \n # preferred_role: \n # principal_arn: \n # credentials_provider: \n # region: \n # cluster_identifier: \n # iam: \n # is_serverless: \n # serverless_acct_id: \n # serverless_work_group: \n # enable_merge: ", - "bigquery": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # method: oauth\n # project: \n # execution_project: \n # quota_project: \n # location: \n # keyfile: \n # keyfile_json: \n # token: \n # refresh_token: \n # client_id: \n # client_secret: \n # token_uri: \n # scopes: \n # impersonated_service_account: \n # job_creation_timeout_seconds: \n # job_execution_timeout_seconds: \n # job_retries: 1\n # job_retry_deadline_seconds: \n # priority: \n # maximum_bytes_billed: ", + "bigquery": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # method: oauth\n # project: \n # execution_project: \n # quota_project: \n # location: \n # keyfile: \n # keyfile_json: \n # token: \n # refresh_token: \n # client_id: \n # client_secret: \n # token_uri: \n # scopes: \n # impersonated_service_account: \n # job_creation_timeout_seconds: \n # job_execution_timeout_seconds: \n # job_retries: 1\n # job_retry_deadline_seconds: \n # priority: \n # maximum_bytes_billed: \n # reservation: ", "snowflake": "account: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # user: \n # password: \n # warehouse: \n # database: \n # role: \n # authenticator: \n # token: \n # host: \n # port: \n # application: Tobiko_SQLMesh\n # private_key: \n # private_key_path: \n # private_key_passphrase: \n # session_parameters: ", "databricks": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # server_hostname: \n # http_path: \n # access_token: \n # auth_type: \n # oauth_client_id: \n # oauth_client_secret: \n # catalog: \n # http_headers: \n # session_configuration: \n # databricks_connect_server_hostname: \n # databricks_connect_access_token: \n # databricks_connect_cluster_id: \n # databricks_connect_use_serverless: False\n # force_databricks_connect: False\n # disable_databricks_connect: False\n # disable_spark_session: False", "postgres": "host: \n user: \n password: \n port: \n database: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: True\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # keepalives_idle: \n # connect_timeout: 10\n # role: \n # sslmode: \n # application_name: ", diff --git a/tests/core/test_connection_config.py b/tests/core/test_connection_config.py index dd979a2551..2ff95525f7 100644 --- a/tests/core/test_connection_config.py +++ b/tests/core/test_connection_config.py @@ -1131,6 +1131,27 @@ def test_bigquery(make_config): assert config.get_catalog() == "project" assert config.is_recommended_for_state_sync is False + # Test reservation + config_with_reservation = make_config( + type="bigquery", + project="project", + reservation="projects/my-project/locations/us-central1/reservations/my-reservation", + check_import=False, + ) + assert isinstance(config_with_reservation, BigQueryConnectionConfig) + assert ( + config_with_reservation.reservation + == "projects/my-project/locations/us-central1/reservations/my-reservation" + ) + + # Test that reservation is included in _extra_engine_config + extra_config = config_with_reservation._extra_engine_config + assert "reservation" in extra_config + assert ( + extra_config["reservation"] + == "projects/my-project/locations/us-central1/reservations/my-reservation" + ) + with pytest.raises(ConfigError, match="you must also specify the `project` field"): make_config(type="bigquery", execution_project="execution_project", check_import=False)