diff --git a/dbt/adapters/sqlserver/__init__.py b/dbt/adapters/sqlserver/__init__.py index 879ea74c..30bc8036 100644 --- a/dbt/adapters/sqlserver/__init__.py +++ b/dbt/adapters/sqlserver/__init__.py @@ -11,7 +11,6 @@ adapter=SQLServerAdapter, credentials=SQLServerCredentials, include_path=sqlserver.PACKAGE_PATH, - dependencies=["fabric"], ) __all__ = [ diff --git a/dbt/adapters/sqlserver/sqlserver_adapter.py b/dbt/adapters/sqlserver/sqlserver_adapter.py index 6f05c501..eafd6d1a 100644 --- a/dbt/adapters/sqlserver/sqlserver_adapter.py +++ b/dbt/adapters/sqlserver/sqlserver_adapter.py @@ -1,18 +1,50 @@ -from typing import Optional +from typing import List, Optional +import agate import dbt.exceptions +from dbt.adapters.base import available from dbt.adapters.base.impl import ConstraintSupport -from dbt.adapters.fabric import FabricAdapter +from dbt.adapters.base.relation import BaseRelation +from dbt.adapters.sql import SQLAdapter from dbt.contracts.graph.nodes import ConstraintType +from dbt_common.contracts.constraints import ColumnLevelConstraint from dbt.adapters.sqlserver.sqlserver_column import SQLServerColumn from dbt.adapters.sqlserver.sqlserver_connections import SQLServerConnectionManager from dbt.adapters.sqlserver.sqlserver_relation import SQLServerRelation +COLUMNS_EQUAL_SQL = """ +with diff_count as ( + SELECT + 1 as id, + COUNT(*) as num_missing FROM ( + (SELECT {columns} FROM {relation_a} {except_op} + SELECT {columns} FROM {relation_b}) + UNION ALL + (SELECT {columns} FROM {relation_b} {except_op} + SELECT {columns} FROM {relation_a}) + ) as a +), table_a as ( + SELECT COUNT(*) as num_rows FROM {relation_a} +), table_b as ( + SELECT COUNT(*) as num_rows FROM {relation_b} +), row_count_diff as ( + select + 1 as id, + table_a.num_rows - table_b.num_rows as difference + from table_a, table_b +) +select + row_count_diff.difference as row_count_difference, + diff_count.num_missing as num_mismatched +from row_count_diff +join diff_count on row_count_diff.id = diff_count.id +""".strip() -class SQLServerAdapter(FabricAdapter): + +class SQLServerAdapter(SQLAdapter): """ - Controls actual implmentation of adapter, and ability to override certain methods. + Controls actual implementation of adapter, and ability to override certain methods. """ ConnectionManager = SQLServerConnectionManager @@ -27,6 +59,91 @@ class SQLServerAdapter(FabricAdapter): ConstraintType.foreign_key: ConstraintSupport.ENFORCED, } + # -- Type conversions (inlined from FabricAdapter) -- + + @classmethod + def convert_boolean_type(cls, agate_table, col_idx): + return "bit" + + @classmethod + def convert_datetime_type(cls, agate_table, col_idx): + return "datetime2(6)" + + @classmethod + def convert_number_type(cls, agate_table, col_idx): + decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) + return "float" if decimals else "int" + + @classmethod + def convert_text_type(cls, agate_table, col_idx): + column = agate_table.columns[col_idx] + lens = [len(d.encode("utf-8")) for d in column.values_without_nulls()] + max_len = max(lens) if lens else 64 + length = max_len if max_len > 16 else 16 + return "varchar({})".format(length) + + @classmethod + def convert_time_type(cls, agate_table, col_idx): + return "time(6)" + + @classmethod + def date_function(cls): + return "getdate()" + + # -- SQL helpers (inlined from FabricAdapter) -- + + def timestamp_add_sql(self, add_to: str, number: int = 1, interval: str = "hour") -> str: + return f"DATEADD({interval},{number},{add_to})" + + def string_add_sql(self, add_to: str, value: str, location="append") -> str: + """+ is T-SQL's string concatenation operator""" + if location == "append": + return f"{add_to} + '{value}'" + elif location == "prepend": + return f"'{value}' + {add_to}" + else: + raise ValueError(f'Got an unexpected location value of "{location}"') + + def get_rows_different_sql( + self, + relation_a: BaseRelation, + relation_b: BaseRelation, + column_names: Optional[List[str]] = None, + except_operator: str = "EXCEPT", + ) -> str: + names: List[str] + if column_names is None: + columns = self.get_columns_in_relation(relation_a) + names = sorted((self.quote(c.name) for c in columns)) + else: + names = sorted((self.quote(n) for n in column_names)) + + columns_csv = ", ".join(names) + + sql = COLUMNS_EQUAL_SQL.format( + columns=columns_csv, + relation_a=str(relation_a), + relation_b=str(relation_b), + except_op=except_operator, + ) + return sql + + # -- Constraint rendering -- + + @available + @classmethod + def render_column_constraint(cls, constraint: ColumnLevelConstraint) -> Optional[str]: + rendered_column_constraint = None + if constraint.type == ConstraintType.not_null: + rendered_column_constraint = "not null" + else: + rendered_column_constraint = "" + + if rendered_column_constraint: + rendered_column_constraint = rendered_column_constraint.strip() + + return rendered_column_constraint + @classmethod def render_model_constraint(cls, constraint) -> Optional[str]: constraint_prefix = "add constraint " @@ -56,10 +173,6 @@ def render_model_constraint(cls, constraint) -> Optional[str]: else: return None - @classmethod - def date_function(cls): - return "getdate()" - def valid_incremental_strategies(self): """The set of standard builtin strategies which this adapter supports out-of-the-box. Not used to validate custom strategies defined by end users. diff --git a/dbt/adapters/sqlserver/sqlserver_column.py b/dbt/adapters/sqlserver/sqlserver_column.py index 68ef98e3..a9292c52 100644 --- a/dbt/adapters/sqlserver/sqlserver_column.py +++ b/dbt/adapters/sqlserver/sqlserver_column.py @@ -1,17 +1,74 @@ -from dbt.adapters.fabric import FabricColumn +from typing import Any, ClassVar, Dict +from dbt.adapters.base import Column +from dbt_common.exceptions import DbtRuntimeError + + +class SQLServerColumn(Column): + TYPE_LABELS: ClassVar[Dict[str, str]] = { + "STRING": "VARCHAR(8000)", + "VARCHAR": "VARCHAR(8000)", + "CHAR": "CHAR(1)", + "NCHAR": "CHAR(1)", + "NVARCHAR": "VARCHAR(8000)", + "TIMESTAMP": "DATETIME2(6)", + "DATETIME2": "DATETIME2(6)", + "DATETIME2(6)": "DATETIME2(6)", + "DATE": "DATE", + "TIME": "TIME(6)", + "FLOAT": "FLOAT", + "REAL": "REAL", + "INT": "INT", + "INTEGER": "INT", + "BIGINT": "BIGINT", + "SMALLINT": "SMALLINT", + "TINYINT": "SMALLINT", + "BIT": "BIT", + "BOOLEAN": "BIT", + "DECIMAL": "DECIMAL", + "NUMERIC": "NUMERIC", + "MONEY": "DECIMAL", + "SMALLMONEY": "DECIMAL", + "UNIQUEIDENTIFIER": "UNIQUEIDENTIFIER", + "VARBINARY": "VARBINARY(MAX)", + "BINARY": "BINARY(1)", + } + + @classmethod + def string_type(cls, size: int) -> str: + return f"varchar({size if size > 0 else '8000'})" + + def literal(self, value: Any) -> str: + return "cast('{}' as {})".format(value, self.data_type) + + @property + def data_type(self) -> str: + if self.dtype.lower() == "datetime2": + return "datetime2(6)" + if self.is_string(): + return self.string_type(self.string_size()) + elif self.is_numeric(): + return self.numeric_type(self.dtype, self.numeric_precision, self.numeric_scale) + else: + return self.dtype + + def is_string(self) -> bool: + return self.dtype.lower() in ["varchar", "char"] + + def is_number(self): + return any([self.is_integer(), self.is_numeric(), self.is_float()]) + + def is_float(self): + return self.dtype.lower() in ["float", "real"] -class SQLServerColumn(FabricColumn): def is_integer(self) -> bool: return self.dtype.lower() in [ - # real types "smallint", "integer", "bigint", "smallserial", "serial", "bigserial", - # aliases "int2", "int4", "int8", @@ -19,4 +76,21 @@ def is_integer(self) -> bool: "serial4", "serial8", "int", + "tinyint", ] + + def is_numeric(self) -> bool: + return self.dtype.lower() in ["numeric", "decimal", "money", "smallmoney"] + + def string_size(self) -> int: + if not self.is_string(): + raise DbtRuntimeError("Called string_size() on non-string field!") + if self.char_size is None: + return 8000 + else: + return int(self.char_size) + + def can_expand_to(self, other_column: "SQLServerColumn") -> bool: + if not self.is_string() or not other_column.is_string(): + return False + return other_column.string_size() > self.string_size() diff --git a/dbt/adapters/sqlserver/sqlserver_configs.py b/dbt/adapters/sqlserver/sqlserver_configs.py index 35ce4262..c41be608 100644 --- a/dbt/adapters/sqlserver/sqlserver_configs.py +++ b/dbt/adapters/sqlserver/sqlserver_configs.py @@ -1,8 +1,8 @@ from dataclasses import dataclass -from dbt.adapters.fabric import FabricConfigs +from dbt.adapters.protocol import AdapterConfig @dataclass -class SQLServerConfigs(FabricConfigs): +class SQLServerConfigs(AdapterConfig): pass diff --git a/dbt/adapters/sqlserver/sqlserver_connections.py b/dbt/adapters/sqlserver/sqlserver_connections.py index c4424656..b92e5ec0 100644 --- a/dbt/adapters/sqlserver/sqlserver_connections.py +++ b/dbt/adapters/sqlserver/sqlserver_connections.py @@ -1,73 +1,20 @@ -import dbt_common.exceptions # noqa -import pyodbc -from azure.core.credentials import AccessToken -from azure.identity import ClientSecretCredential, ManagedIdentityCredential -from dbt.adapters.contracts.connection import Connection, ConnectionState +from contextlib import contextmanager +from typing import Any, Optional, Tuple, Union + +import agate +import dbt_common.exceptions +from adbc_driver_manager import dbapi as adbc_dbapi +from dbt.adapters.contracts.connection import AdapterResponse, Connection, ConnectionState from dbt.adapters.events.logging import AdapterLogger -from dbt.adapters.fabric import FabricConnectionManager -from dbt.adapters.fabric.fabric_connection_manager import ( - AZURE_AUTH_FUNCTIONS as AZURE_AUTH_FUNCTIONS_FABRIC, -) -from dbt.adapters.fabric.fabric_connection_manager import ( - AZURE_CREDENTIAL_SCOPE, - bool_to_connection_string_arg, - get_pyodbc_attrs_before_credentials, -) - -from dbt.adapters.sqlserver import __version__ +from dbt.adapters.sql import SQLConnectionManager +from dbt_common.clients.agate_helper import empty_table + from dbt.adapters.sqlserver.sqlserver_credentials import SQLServerCredentials logger = AdapterLogger("sqlserver") -def get_msi_access_token(credentials: SQLServerCredentials) -> AccessToken: - """ - Get an Azure access token from the system's managed identity - - Parameters - ----------- - credentials: SQLServerCredentials - Credentials. - - Returns - ------- - out : AccessToken - The access token. - """ - token = ManagedIdentityCredential().get_token(AZURE_CREDENTIAL_SCOPE) - return token - - -def get_sp_access_token(credentials: SQLServerCredentials) -> AccessToken: - """ - Get an Azure access token using the SP credentials. - - Parameters - ---------- - credentials : SQLServerCredentials - Credentials. - - Returns - ------- - out : AccessToken - The access token. - """ - token = ClientSecretCredential( - str(credentials.tenant_id), - str(credentials.client_id), - str(credentials.client_secret), - ).get_token(AZURE_CREDENTIAL_SCOPE) - return token - - -AZURE_AUTH_FUNCTIONS = { - **AZURE_AUTH_FUNCTIONS_FABRIC, - "serviceprincipal": get_sp_access_token, - "msi": get_msi_access_token, -} - - -class SQLServerConnectionManager(FabricConnectionManager): +class SQLServerConnectionManager(SQLConnectionManager): TYPE = "sqlserver" @classmethod @@ -76,74 +23,29 @@ def open(cls, connection: Connection) -> Connection: logger.debug("Connection is already open, skipping open.") return connection - credentials = cls.get_credentials(connection.credentials) - if credentials.authentication != "sql": - return super().open(connection) - - # sql login authentication - - con_str = [f"DRIVER={{{credentials.driver}}}"] - - if "\\" in credentials.host: - # If there is a backslash \ in the host name, the host is a - # SQL Server named instance. In this case then port number has to be omitted. - con_str.append(f"SERVER={credentials.host}") - else: - con_str.append(f"SERVER={credentials.host},{credentials.port}") - - con_str.append(f"Database={credentials.database}") - - assert credentials.authentication is not None + credentials: SQLServerCredentials = cls.get_credentials(connection.credentials) - con_str.append(f"UID={{{credentials.UID}}}") - con_str.append(f"PWD={{{credentials.PWD}}}") + uri = credentials.build_adbc_uri() - # https://docs.microsoft.com/en-us/sql/relational-databases/native-client/features/using-encryption-without-validation?view=sql-server-ver15 - assert credentials.encrypt is not None - assert credentials.trust_cert is not None + # Build a display URI with password masked + display_uri = uri + if credentials.PWD: + from urllib.parse import quote_plus - con_str.append(bool_to_connection_string_arg("encrypt", credentials.encrypt)) - con_str.append( - bool_to_connection_string_arg("TrustServerCertificate", credentials.trust_cert) - ) - - plugin_version = __version__.version - application_name = f"dbt-{credentials.type}/{plugin_version}" - con_str.append(f"APP={application_name}") - - con_str_concat = ";".join(con_str) - - index = [] - for i, elem in enumerate(con_str): - if "pwd=" in elem.lower(): - index.append(i) - - if len(index) != 0: - con_str[index[0]] = "PWD=***" + display_uri = uri.replace(quote_plus(credentials.PWD), "***") - con_str_display = ";".join(con_str) - - retryable_exceptions = [ # https://github.com/mkleehammer/pyodbc/wiki/Exceptions - pyodbc.InternalError, # not used according to docs, but defined in PEP-249 - pyodbc.OperationalError, + retryable_exceptions = [ + adbc_dbapi.OperationalError, + adbc_dbapi.InterfaceError, ] - if credentials.authentication.lower() in AZURE_AUTH_FUNCTIONS: - # Temporary login/token errors fall into this category when using AAD - retryable_exceptions.append(pyodbc.InterfaceError) - def connect(): - logger.debug(f"Using connection string: {con_str_display}") - - attrs_before = get_pyodbc_attrs_before_credentials(credentials) - - handle = pyodbc.connect( - con_str_concat, - attrs_before=attrs_before, + logger.debug(f"Using ADBC URI: {display_uri}") + handle = adbc_dbapi.connect( + driver=credentials.driver, + db_kwargs={"uri": uri}, autocommit=True, - timeout=credentials.login_timeout, ) - handle.timeout = credentials.query_timeout logger.debug(f"Connected to db: {credentials.database}") return handle @@ -154,3 +56,95 @@ def connect(): retry_limit=credentials.retries, retryable_exceptions=retryable_exceptions, ) + + @contextmanager + def exception_handler(self, sql): + try: + yield + except adbc_dbapi.DatabaseError as e: + logger.debug(f"Database error: {e}") + self.release() + raise dbt_common.exceptions.DbtDatabaseError(str(e).strip()) from e + except Exception as e: + logger.debug(f"Error running SQL: {sql}") + logger.debug(f"Rolling back due to: {e}") + self.release() + if isinstance(e, dbt_common.exceptions.DbtRuntimeError): + raise + raise dbt_common.exceptions.DbtRuntimeError(str(e)) from e + + def cancel(self, connection: Connection): + logger.debug("Cancel query") + + def add_begin_query(self): + pass # autocommit mode + + def add_commit_query(self): + pass # autocommit mode + + @classmethod + def get_credentials(cls, credentials: SQLServerCredentials) -> SQLServerCredentials: + return credentials + + @classmethod + def data_type_code_to_name(cls, type_code: Union[int, str]) -> str: + """Map ADBC/Arrow type codes to SQL Server type names.""" + code = str(type_code).lower() + + # Arrow type code → SQL Server type name + if code in ("int8", "int16", "int32"): + return "int" + if code == "int64": + return "bigint" + if code in ("float", "float32"): + return "real" + if code in ("double", "float64"): + return "float" + if code in ("string", "large_string", "utf8", "large_utf8"): + return "varchar" + if code == "bool": + return "bit" + if code.startswith("decimal"): + return "decimal" + if code.startswith("date"): + return "date" + if code.startswith("time") and "stamp" not in code: + return "time" + if code.startswith("timestamp"): + return "datetime2" + if code == "binary" or code == "large_binary": + return "varbinary" + + return str(type_code) + + @classmethod + def get_response(cls, cursor: Any) -> AdapterResponse: + rows = cursor.rowcount if hasattr(cursor, "rowcount") else -1 + return AdapterResponse(_message="OK", rows_affected=rows) + + def execute( + self, + sql: str, + auto_begin: bool = True, + fetch: bool = False, + limit: Optional[int] = None, + ) -> Tuple[AdapterResponse, agate.Table]: + _, cursor = self.add_query(sql, auto_begin) + response = self.get_response(cursor) + + if fetch: + # ADBC cursors may not support nextset(), so guard it + # Skip result sets without column descriptions (e.g. SET NOCOUNT ON) + if hasattr(cursor, "nextset"): + while cursor.description is None: + if not cursor.nextset(): + break + + if cursor.description is not None: + table = self.get_result_from_cursor(cursor, limit) + else: + table = empty_table() + else: + table = empty_table() + + return response, table diff --git a/dbt/adapters/sqlserver/sqlserver_credentials.py b/dbt/adapters/sqlserver/sqlserver_credentials.py index bf1f5075..f13b8198 100644 --- a/dbt/adapters/sqlserver/sqlserver_credentials.py +++ b/dbt/adapters/sqlserver/sqlserver_credentials.py @@ -1,22 +1,95 @@ from dataclasses import dataclass from typing import Optional +from urllib.parse import quote_plus -from dbt.adapters.fabric import FabricCredentials +from dbt.adapters.contracts.connection import Credentials @dataclass -class SQLServerCredentials(FabricCredentials): +class SQLServerCredentials(Credentials): """ Defines database specific credentials that get added to profiles.yml to connect to new adapter """ - port: Optional[int] = 1433 - authentication: Optional[str] = "sql" + # Connection + host: str = "" + database: str = "" + schema: str = "" + port: int = 1433 + + # Auth (SQL auth only for this spike) + UID: Optional[str] = None + PWD: Optional[str] = None + authentication: str = "sql" + + # Connection options + driver: str = "mssql" # ADBC driver name (not ODBC driver) + encrypt: Optional[bool] = True + trust_cert: Optional[bool] = False + retries: int = 3 + login_timeout: int = 0 + query_timeout: int = 0 + + _ALIASES = { + "user": "UID", + "username": "UID", + "pass": "PWD", + "password": "PWD", + "server": "host", + "TrustServerCertificate": "trust_cert", + } + + def build_adbc_uri(self) -> str: + """Construct go-mssqldb connection URI from profile fields.""" + # URL-encode user and password for special characters + user = quote_plus(self.UID) if self.UID else "" + pwd = quote_plus(self.PWD) if self.PWD else "" + + # Build userinfo + if user and pwd: + userinfo = f"{user}:{pwd}@" + elif user: + userinfo = f"{user}@" + else: + userinfo = "" + + # Handle named instances (backslash in host) — omit port + if "\\" in self.host: + host_part = self.host + else: + host_part = f"{self.host}:{self.port}" + + # Build query parameters + params = [] + if self.database: + params.append(f"database={quote_plus(self.database)}") + if self.encrypt is not None: + params.append(f"encrypt={'true' if self.encrypt else 'false'}") + if self.trust_cert is not None: + params.append(f"TrustServerCertificate={'true' if self.trust_cert else 'false'}") + if self.login_timeout: + params.append(f"connection timeout={self.login_timeout}") + + query_string = "&".join(params) + return f"sqlserver://{userinfo}{host_part}?{query_string}" @property def type(self): return "sqlserver" + @property + def unique_field(self): + return self.host + def _connection_keys(self): - return super()._connection_keys() + ("port",) + return ( + "host", + "port", + "database", + "schema", + "UID", + "authentication", + "encrypt", + "trust_cert", + ) diff --git a/dbt/include/sqlserver/macros/adapter/apply_grants.sql b/dbt/include/sqlserver/macros/adapter/apply_grants.sql new file mode 100644 index 00000000..2bdded5f --- /dev/null +++ b/dbt/include/sqlserver/macros/adapter/apply_grants.sql @@ -0,0 +1,59 @@ +{% macro sqlserver__get_show_grant_sql(relation) %} + select + GRANTEE as grantee, + PRIVILEGE_TYPE as privilege_type + from INFORMATION_SCHEMA.TABLE_PRIVILEGES {{ information_schema_hints() }} + where TABLE_CATALOG = '{{ relation.database }}' + and TABLE_SCHEMA = '{{ relation.schema }}' + and TABLE_NAME = '{{ relation.identifier }}' +{% endmacro %} + + +{%- macro sqlserver__get_grant_sql(relation, privilege, grantees) -%} + {%- set grantees_safe = [] -%} + {%- for grantee in grantees -%} + {%- set grantee_safe = "[" ~ grantee ~ "]" -%} + {%- do grantees_safe.append(grantee_safe) -%} + {%- endfor -%} + grant {{ privilege }} on {{ relation }} to {{ grantees_safe | join(', ') }} +{%- endmacro -%} + + +{%- macro sqlserver__get_revoke_sql(relation, privilege, grantees) -%} + {%- set grantees_safe = [] -%} + {%- for grantee in grantees -%} + {%- set grantee_safe = "[" ~ grantee ~ "]" -%} + {%- do grantees_safe.append(grantee_safe) -%} + {%- endfor -%} + revoke {{ privilege }} on {{ relation }} from {{ grantees_safe | join(', ') }} +{%- endmacro -%} + + +{% macro sqlserver__apply_grants(relation, grant_config, should_revoke=True) %} + {#-- If grant_config is {} or None, this is a no-op --#} + {% if grant_config %} + {% if should_revoke %} + {#-- We think previous grants may have carried over --#} + {#-- Show current grants and calculate diffs --#} + {% set current_grants_table = run_query(get_show_grant_sql(relation)) %} + {% set current_grants_dict = adapter.standardize_grants_dict(current_grants_table) %} + {% set needs_granting = diff_of_two_dicts(grant_config, current_grants_dict) %} + {% set needs_revoking = diff_of_two_dicts(current_grants_dict, grant_config) %} + {% if not (needs_granting or needs_revoking) %} + {{ log('On ' ~ relation ~': All grants are in place, no revocation or granting needed.')}} + {% endif %} + {% else %} + {#-- Jump straight to granting what the user has configured. --#} + {% set needs_revoking = {} %} + {% set needs_granting = grant_config %} + {% endif %} + {% if needs_granting or needs_revoking %} + {% set revoke_statement_list = get_dcl_statement_list(relation, needs_revoking, get_revoke_sql) %} + {% set grant_statement_list = get_dcl_statement_list(relation, needs_granting, get_grant_sql) %} + {% set dcl_statement_list = revoke_statement_list + grant_statement_list %} + {% if dcl_statement_list %} + {{ call_dcl_statements(dcl_statement_list) }} + {% endif %} + {% endif %} + {% endif %} +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/adapter/columns.sql b/dbt/include/sqlserver/macros/adapter/columns.sql index a98750e7..984ffcaa 100644 --- a/dbt/include/sqlserver/macros/adapter/columns.sql +++ b/dbt/include/sqlserver/macros/adapter/columns.sql @@ -23,6 +23,51 @@ {{ return(load_result('get_columns_in_query').table.columns | map(attribute='name') | list) }} {% endmacro %} +{% macro sqlserver__get_columns_in_relation(relation) -%} + {% call statement('get_columns_in_relation', fetch_result=True) %} + {{ get_use_database_sql(relation.database) }} + with mapping as ( + select + row_number() over (partition by object_name(c.object_id) order by c.column_id) as ordinal_position, + c.name collate database_default as column_name, + t.name as data_type, + c.max_length as character_maximum_length, + c.precision as numeric_precision, + c.scale as numeric_scale + from sys.columns c {{ information_schema_hints() }} + inner join sys.types t {{ information_schema_hints() }} + on c.user_type_id = t.user_type_id + where c.object_id = object_id('{{ 'tempdb..' ~ relation.include(database=false, schema=false) if '#' in relation.identifier else relation }}') + ) + + select + column_name, + data_type, + character_maximum_length, + numeric_precision, + numeric_scale + from mapping + order by ordinal_position + + {% endcall %} + {% set table = load_result('get_columns_in_relation').table %} + {{ return(sql_convert_columns_in_relation(table)) }} +{% endmacro %} + +{% macro sqlserver__alter_relation_add_remove_columns(relation, add_columns, remove_columns) %} + {% call statement('add_drop_columns') -%} + {% if add_columns %} + alter {{ relation.type }} {{ relation }} + add {% for column in add_columns %}"{{ column.name }}" {{ column.data_type }}{{ ', ' if not loop.last }}{% endfor %}; + {% endif %} + + {% if remove_columns %} + alter {{ relation.type }} {{ relation }} + drop column {% for column in remove_columns %}"{{ column.name }}"{{ ',' if not loop.last }}{% endfor %}; + {% endif %} + {%- endcall -%} +{% endmacro %} + {% macro sqlserver__alter_column_type(relation, column_name, new_column_type) %} {%- set tmp_column = column_name + "__dbt_alter" -%} diff --git a/dbt/include/sqlserver/macros/adapter/metadata.sql b/dbt/include/sqlserver/macros/adapter/metadata.sql index ac8981c9..d7e65846 100644 --- a/dbt/include/sqlserver/macros/adapter/metadata.sql +++ b/dbt/include/sqlserver/macros/adapter/metadata.sql @@ -1,8 +1,106 @@ -{% macro apply_label() %} - {{ log (config.get('query_tag','dbt-sqlserver'))}} - {%- set query_label = config.get('query_tag','dbt-sqlserver') -%} - OPTION (LABEL = '{{query_label}}'); +{# -- apply_label: no-op for standard SQL Server (OPTION LABEL is Synapse-only) -- #} +{% macro apply_label() %}{% endmacro %} + +{# -- Dispatch wrappers (previously provided by dbt-fabric) -- #} +{% macro information_schema_hints() %} + {{ return(adapter.dispatch('information_schema_hints')()) }} {% endmacro %} {% macro default__information_schema_hints() %}{% endmacro %} {% macro sqlserver__information_schema_hints() %}with (nolock){% endmacro %} + +{% macro get_use_database_sql(database) %} + {{ return(adapter.dispatch('get_use_database_sql', 'dbt')(database)) }} +{% endmacro %} + +{%- macro sqlserver__get_use_database_sql(database) -%} + USE [{{database | replace('"', '')}}]; +{%- endmacro -%} + +{% macro sqlserver__information_schema_name(database) -%} + information_schema +{%- endmacro %} + +{% macro sqlserver__list_schemas(database) %} + {% call statement('list_schemas', fetch_result=True, auto_begin=False) -%} + {{ get_use_database_sql(database) }} + select name as [schema] + from sys.schemas {{ information_schema_hints() }} + {% endcall %} + {{ return(load_result('list_schemas').table) }} +{% endmacro %} + +{% macro sqlserver__check_schema_exists(information_schema, schema) -%} + {% call statement('check_schema_exists', fetch_result=True, auto_begin=False) -%} + SELECT count(*) as schema_exist FROM sys.schemas WHERE name = '{{ schema }}' + {%- endcall %} + {{ return(load_result('check_schema_exists').table) }} +{% endmacro %} + +{% macro sqlserver__list_relations_without_caching(schema_relation) -%} + {% call statement('list_relations_without_caching', fetch_result=True) -%} + {{ get_use_database_sql(schema_relation.database) }} + with base as ( + select + DB_NAME() as [database], + t.name as [name], + SCHEMA_NAME(t.schema_id) as [schema], + 'table' as table_type + from sys.tables as t {{ information_schema_hints() }} + union all + select + DB_NAME() as [database], + v.name as [name], + SCHEMA_NAME(v.schema_id) as [schema], + 'view' as table_type + from sys.views as v {{ information_schema_hints() }} + ) + select * from base + where [schema] like '{{ schema_relation.schema }}' + {% endcall %} + {{ return(load_result('list_relations_without_caching').table) }} +{% endmacro %} + +{% macro sqlserver__get_relation_without_caching(schema_relation) -%} + {% call statement('get_relation_without_caching', fetch_result=True) -%} + {{ get_use_database_sql(schema_relation.database) }} + with base as ( + select + DB_NAME() as [database], + t.name as [name], + SCHEMA_NAME(t.schema_id) as [schema], + 'table' as table_type + from sys.tables as t {{ information_schema_hints() }} + union all + select + DB_NAME() as [database], + v.name as [name], + SCHEMA_NAME(v.schema_id) as [schema], + 'view' as table_type + from sys.views as v {{ information_schema_hints() }} + ) + select * from base + where [schema] like '{{ schema_relation.schema }}' + and [name] like '{{ schema_relation.identifier }}' + {% endcall %} + {{ return(load_result('get_relation_without_caching').table) }} +{% endmacro %} + +{% macro sqlserver__get_relation_last_modified(information_schema, relations) -%} + {%- call statement('last_modified', fetch_result=True) -%} + select + o.name as [identifier] + , s.name as [schema] + , o.modify_date as last_modified + , current_timestamp as snapshotted_at + from sys.objects o + inner join sys.schemas s on o.schema_id = s.schema_id and [type] = 'U' + where ( + {%- for relation in relations -%} + (upper(s.name) = upper('{{ relation.schema }}') and + upper(o.name) = upper('{{ relation.identifier }}')){%- if not loop.last %} or {% endif -%} + {%- endfor -%} + ) + {%- endcall -%} + {{ return(load_result('last_modified')) }} +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/adapter/relation.sql b/dbt/include/sqlserver/macros/adapter/relation.sql index 57defbd1..65030df7 100644 --- a/dbt/include/sqlserver/macros/adapter/relation.sql +++ b/dbt/include/sqlserver/macros/adapter/relation.sql @@ -3,3 +3,53 @@ truncate table {{ relation }} {%- endcall %} {% endmacro %} + +{% macro sqlserver__make_temp_relation(base_relation, suffix='__dbt_temp') %} + {%- set temp_identifier = base_relation.identifier ~ suffix -%} + {%- set temp_relation = base_relation.incorporate( + path={"identifier": temp_identifier}) -%} + + {{ return(temp_relation) }} +{% endmacro %} + +{% macro sqlserver__get_drop_sql(relation) -%} + {% if relation.type == 'view' -%} + {% call statement('find_references', fetch_result=true) %} + {{ get_use_database_sql(relation.database) }} + select + sch.name as schema_name, + obj.name as view_name + from sys.sql_expression_dependencies refs + inner join sys.objects obj + on refs.referencing_id = obj.object_id + inner join sys.schemas sch + on obj.schema_id = sch.schema_id + where refs.referenced_database_name = '{{ relation.database }}' + and refs.referenced_schema_name = '{{ relation.schema }}' + and refs.referenced_entity_name = '{{ relation.identifier }}' + and refs.referencing_class = 1 + and obj.type = 'V' + {% endcall %} + {% set references = load_result('find_references')['data'] %} + {% for reference in references -%} + -- dropping referenced view {{ reference[0] }}.{{ reference[1] }} + {% do adapter.drop_relation + (api.Relation.create( + identifier = reference[1], schema = reference[0], database = relation.database, type='view' + ))%} + {% endfor %} + {% elif relation.type == 'table'%} + {% set object_id_type = 'U' %} + {%- else -%} + {{ exceptions.raise_not_implemented('Invalid relation being dropped: ' ~ relation) }} + {% endif %} + {{ get_use_database_sql(relation.database) }} + EXEC('DROP {{ relation.type }} IF EXISTS {{ relation.include(database=False) }};'); +{% endmacro %} + +{% macro sqlserver__rename_relation(from_relation, to_relation) -%} + {% call statement('rename_relation') -%} + {{ get_use_database_sql(from_relation.database) }} + EXEC sp_rename '{{ from_relation.schema }}.{{ from_relation.identifier }}', '{{ to_relation.identifier }}' + {%- endcall %} +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/adapter/schemas.sql b/dbt/include/sqlserver/macros/adapter/schemas.sql index 8317d6cb..2b496da8 100644 --- a/dbt/include/sqlserver/macros/adapter/schemas.sql +++ b/dbt/include/sqlserver/macros/adapter/schemas.sql @@ -1,5 +1,42 @@ - {% macro sqlserver__drop_schema_named(schema_name) %} {% set schema_relation = api.Relation.create(schema=schema_name) %} {{ adapter.drop_schema(schema_relation) }} {% endmacro %} + +{% macro sqlserver__create_schema(relation) -%} + {% call statement('create_schema') -%} + USE [{{ relation.database }}]; + IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = '{{ relation.schema }}') + BEGIN + EXEC('CREATE SCHEMA [{{ relation.schema }}]') + END + {% endcall %} +{% endmacro %} + +{% macro sqlserver__create_schema_with_authorization(relation, schema_authorization) -%} + {% call statement('create_schema') -%} + {{ get_use_database_sql(relation.database) }} + IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = '{{ relation.schema }}') + BEGIN + EXEC('CREATE SCHEMA [{{ relation.schema }}] AUTHORIZATION [{{ schema_authorization }}]') + END + {% endcall %} +{% endmacro %} + +{% macro sqlserver__drop_schema(relation) -%} + {%- set relations_in_schema = list_relations_without_caching(relation) %} + + {% for row in relations_in_schema %} + {%- set schema_relation = api.Relation.create(database=relation.database, + schema=relation.schema, + identifier=row[1], + type=row[3] + ) -%} + {% do adapter.drop_relation(schema_relation) %} + {%- endfor %} + + {% call statement('drop_schema') -%} + {{ get_use_database_sql(relation.database) }} + EXEC('DROP SCHEMA IF EXISTS {{ relation.schema }}') + {% endcall %} +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/adapter/show.sql b/dbt/include/sqlserver/macros/adapter/show.sql new file mode 100644 index 00000000..14337a73 --- /dev/null +++ b/dbt/include/sqlserver/macros/adapter/show.sql @@ -0,0 +1,12 @@ +{% macro sqlserver__get_limit_sql(sql, limit) %} + {%- if limit == -1 or limit is none -%} + {{ sql }} + {#- Special processing if the last non-blank line starts with order by -#} + {%- elif sql.strip().splitlines()[-1].strip().lower().startswith('order by') -%} + {{ sql }} + offset 0 rows fetch first {{ limit }} rows only + {%- else -%} + {{ sql }} + order by (select null) offset 0 rows fetch first {{ limit }} rows only + {%- endif -%} +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/materializations/models/incremental/incremental_strategies.sql b/dbt/include/sqlserver/macros/materializations/models/incremental/incremental_strategies.sql new file mode 100644 index 00000000..95d5fd4d --- /dev/null +++ b/dbt/include/sqlserver/macros/materializations/models/incremental/incremental_strategies.sql @@ -0,0 +1,11 @@ +{% macro sqlserver__get_incremental_default_sql(arg_dict) %} + + {% if arg_dict["unique_key"] %} + -- Delete + Insert Strategy + {% do return(get_incremental_delete_insert_sql(arg_dict)) %} + {% else %} + -- Incremental Append will insert data into target table. + {% do return(get_incremental_append_sql(arg_dict)) %} + {% endif %} + +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/materializations/models/incremental/merge.sql b/dbt/include/sqlserver/macros/materializations/models/incremental/merge.sql index 9d8cdc0f..7806949b 100644 --- a/dbt/include/sqlserver/macros/materializations/models/incremental/merge.sql +++ b/dbt/include/sqlserver/macros/materializations/models/incremental/merge.sql @@ -1,3 +1,54 @@ +{% macro sqlserver__get_merge_sql(target, source, unique_key, dest_columns, incremental_predicates=none) %} + {{ default__get_merge_sql(target, source, unique_key, dest_columns, incremental_predicates) }}; +{% endmacro %} + +{% macro sqlserver__get_insert_overwrite_merge_sql(target, source, dest_columns, predicates, include_sql_header) %} + {{ default__get_insert_overwrite_merge_sql(target, source, dest_columns, predicates, include_sql_header) }}; +{% endmacro %} + +{% macro sqlserver__get_delete_insert_merge_sql(target, source, unique_key, dest_columns, incremental_predicates=none) %} + + {%- set dest_cols_csv = get_quoted_csv(dest_columns | map(attribute="name")) -%} + + {% if unique_key %} + {% if unique_key is sequence and unique_key is not string %} + delete from {{ target }} + where exists ( + select null + from {{ source }} + where + {% for key in unique_key %} + {{ source }}.{{ key }} = {{ target }}.{{ key }} + {{ "and " if not loop.last }} + {% endfor %} + ) + {% if incremental_predicates %} + {% for predicate in incremental_predicates %} + and {{ predicate }} + {% endfor %} + {% endif %} + {% else %} + delete from {{ target }} + where ( + {{ unique_key }}) in ( + select ({{ unique_key }}) + from {{ source }} + ) + {%- if incremental_predicates %} + {% for predicate in incremental_predicates %} + and {{ predicate }} + {% endfor %} + {%- endif -%} + {% endif %} + {% endif %} + + insert into {{ target }} ({{ dest_cols_csv }}) + ( + select {{ dest_cols_csv }} + from {{ source }} + ) +{% endmacro %} + {% macro sqlserver__get_incremental_microbatch_sql(arg_dict) %} {%- set target = arg_dict["target_relation"] -%} {%- set source = arg_dict["temp_relation"] -%} diff --git a/dbt/include/sqlserver/macros/materializations/models/table/columns_spec_ddl.sql b/dbt/include/sqlserver/macros/materializations/models/table/columns_spec_ddl.sql new file mode 100644 index 00000000..e545dbad --- /dev/null +++ b/dbt/include/sqlserver/macros/materializations/models/table/columns_spec_ddl.sql @@ -0,0 +1,30 @@ +{% macro build_columns_constraints(relation) %} + {{ return(adapter.dispatch('build_columns_constraints', 'dbt')(relation)) }} +{% endmacro %} + +{% macro sqlserver__build_columns_constraints(relation) %} + {# loop through user_provided_columns to create DDL with data types and constraints #} + {%- set raw_column_constraints = adapter.render_raw_columns_constraints(raw_columns=model['columns']) -%} + ( + {% for c in raw_column_constraints -%} + {{ c }}{{ "," if not loop.last }} + {% endfor %} + ) +{% endmacro %} + +{% macro build_model_constraints(relation) %} + {{ return(adapter.dispatch('build_model_constraints', 'dbt')(relation)) }} +{% endmacro %} + +{% macro sqlserver__build_model_constraints(relation) %} + {# loop through user_provided_columns to create DDL with data types and constraints #} + {%- set raw_model_constraints = adapter.render_raw_model_constraints(raw_constraints=model['constraints']) -%} + {% for c in raw_model_constraints -%} + {% set alter_table_script %} + alter table {{ relation.include(database=False) }} {{c}}; + {%endset%} + {% call statement('alter_table_add_constraint') -%} + {{alter_table_script}} + {%- endcall %} + {% endfor -%} +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/materializations/snapshot/helpers.sql b/dbt/include/sqlserver/macros/materializations/snapshot/helpers.sql index 3317a9f3..306b85d3 100644 --- a/dbt/include/sqlserver/macros/materializations/snapshot/helpers.sql +++ b/dbt/include/sqlserver/macros/materializations/snapshot/helpers.sql @@ -1,3 +1,8 @@ +{% macro sqlserver__post_snapshot(staging_relation) %} + -- Clean up the snapshot temp table + {% do drop_relation_if_exists(staging_relation) %} +{% endmacro %} + {% macro sqlserver__create_columns(relation, columns) %} {% set column_list %} {% for column_entry in columns %} @@ -14,6 +19,150 @@ {% endmacro %} +{% macro sqlserver__get_true_sql() %} + {{ return('1=1') }} +{% endmacro %} + +{% macro sqlserver__build_snapshot_table(strategy, relation) %} + {% set columns = config.get('snapshot_table_column_names') or get_snapshot_table_column_names() %} + select *, + {{ strategy.scd_id }} as {{ columns.dbt_scd_id }}, + {{ strategy.updated_at }} as {{ columns.dbt_updated_at }}, + {{ strategy.updated_at }} as {{ columns.dbt_valid_from }}, + {{ get_dbt_valid_to_current(strategy, columns) }} + {%- if strategy.hard_deletes == 'new_record' -%} + , 'False' as {{ columns.dbt_is_deleted }} + {% endif -%} + from ( + select * from {{ relation }} + ) sbq + +{% endmacro %} + +{% macro sqlserver__snapshot_staging_table(strategy, temp_snapshot_relation, target_relation) -%} + + {% set columns = config.get('snapshot_table_column_names') or get_snapshot_table_column_names() %} + + with snapshot_query as ( + select * from {{ temp_snapshot_relation }} + ), + snapshotted_data as ( + select *, + {{ unique_key_fields(strategy.unique_key) }} + from {{ target_relation }} + where + {% if config.get('dbt_valid_to_current') %} + ( {{ columns.dbt_valid_to }} = {{ config.get('dbt_valid_to_current') }} or {{ columns.dbt_valid_to }} is null) + {% else %} + {{ columns.dbt_valid_to }} is null + {% endif %} + ), + insertions_source_data as ( + select *, + {{ unique_key_fields(strategy.unique_key) }}, + {{ strategy.updated_at }} as {{ columns.dbt_updated_at }}, + {{ strategy.updated_at }} as {{ columns.dbt_valid_from }}, + {{ get_dbt_valid_to_current(strategy, columns) }}, + {{ strategy.scd_id }} as {{ columns.dbt_scd_id }} + from snapshot_query + ), + updates_source_data as ( + select *, + {{ unique_key_fields(strategy.unique_key) }}, + {{ strategy.updated_at }} as {{ columns.dbt_updated_at }}, + {{ strategy.updated_at }} as {{ columns.dbt_valid_from }}, + {{ strategy.updated_at }} as {{ columns.dbt_valid_to }} + from snapshot_query + ), + {%- if strategy.hard_deletes == 'invalidate' or strategy.hard_deletes == 'new_record' %} + deletes_source_data as ( + select *, {{ unique_key_fields(strategy.unique_key) }} + from snapshot_query + ), + {% endif %} + insertions as ( + select 'insert' as dbt_change_type, source_data.* + {%- if strategy.hard_deletes == 'new_record' -%} + ,'False' as {{ columns.dbt_is_deleted }} + {%- endif %} + from insertions_source_data as source_data + left outer join snapshotted_data + on {{ unique_key_join_on(strategy.unique_key, "snapshotted_data", "source_data") }} + where {{ unique_key_is_null(strategy.unique_key, "snapshotted_data") }} + or ({{ unique_key_is_not_null(strategy.unique_key, "snapshotted_data") }} and ({{ strategy.row_changed }})) + ), + updates as ( + select 'update' as dbt_change_type, source_data.*, + snapshotted_data.{{ columns.dbt_scd_id }} + {%- if strategy.hard_deletes == 'new_record' -%} + , snapshotted_data.{{ columns.dbt_is_deleted }} + {%- endif %} + from updates_source_data as source_data + join snapshotted_data + on {{ unique_key_join_on(strategy.unique_key, "snapshotted_data", "source_data") }} + where ({{ strategy.row_changed }}) + ) + {%- if strategy.hard_deletes == 'invalidate' or strategy.hard_deletes == 'new_record' %} + , + deletes as ( + select 'delete' as dbt_change_type, + source_data.*, + {{ snapshot_get_time() }} as {{ columns.dbt_valid_from }}, + {{ snapshot_get_time() }} as {{ columns.dbt_updated_at }}, + {{ snapshot_get_time() }} as {{ columns.dbt_valid_to }}, + snapshotted_data.{{ columns.dbt_scd_id }} + {%- if strategy.hard_deletes == 'new_record' -%} + , snapshotted_data.{{ columns.dbt_is_deleted }} + {%- endif %} + from snapshotted_data + left join deletes_source_data as source_data + on {{ unique_key_join_on(strategy.unique_key, "snapshotted_data", "source_data") }} + where {{ unique_key_is_null(strategy.unique_key, "source_data") }} + ) + {%- endif %} + {%- if strategy.hard_deletes == 'new_record' %} + {%set source_query = "select * from "~temp_snapshot_relation%} + {% set source_sql_cols = get_column_schema_from_query(source_query) %} + , + deletion_records as ( + + select + 'insert' as dbt_change_type, + {%- for col in source_sql_cols -%} + snapshotted_data.{{ adapter.quote(col.column) }}, + {% endfor -%} + {%- if strategy.unique_key | is_list -%} + {%- for key in strategy.unique_key -%} + snapshotted_data.{{ key }} as dbt_unique_key_{{ loop.index }}, + {% endfor -%} + {%- else -%} + snapshotted_data.dbt_unique_key as dbt_unique_key, + {% endif -%} + {{ snapshot_get_time() }} as {{ columns.dbt_valid_from }}, + {{ snapshot_get_time() }} as {{ columns.dbt_updated_at }}, + snapshotted_data.{{ columns.dbt_valid_to }} as {{ columns.dbt_valid_to }}, + snapshotted_data.{{ columns.dbt_scd_id }}, + 'True' as {{ columns.dbt_is_deleted }} + from snapshotted_data + left join deletes_source_data as source_data + on {{ unique_key_join_on(strategy.unique_key, "snapshotted_data", "source_data") }} + where {{ unique_key_is_null(strategy.unique_key, "source_data") }} + ) + {%- endif %} + select * from insertions + union all + select * from updates + {%- if strategy.hard_deletes == 'invalidate' or strategy.hard_deletes == 'new_record' %} + union all + select * from deletes + {%- endif %} + {%- if strategy.hard_deletes == 'new_record' %} + union all + select * from deletion_records + {%- endif %} + +{%- endmacro %} + {% macro build_snapshot_staging_table(strategy, temp_snapshot_relation, target_relation) %} {% set temp_relation = make_temp_relation(target_relation) %} {{ adapter.drop_relation(temp_relation) }} diff --git a/dbt/include/sqlserver/macros/materializations/snapshot/strategies.sql b/dbt/include/sqlserver/macros/materializations/snapshot/strategies.sql new file mode 100644 index 00000000..6a316c6f --- /dev/null +++ b/dbt/include/sqlserver/macros/materializations/snapshot/strategies.sql @@ -0,0 +1,5 @@ +{% macro sqlserver__snapshot_hash_arguments(args) %} + CONVERT(VARCHAR(32), HashBytes('MD5', {% for arg in args %} + coalesce(cast({{ arg }} as varchar(8000)), '') {% if not loop.last %} + '|' + {% endif %} + {% endfor %}), 2) +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/materializations/unit_test/unit_test_table.sql b/dbt/include/sqlserver/macros/materializations/unit_test/unit_test_table.sql index 7c21db16..424bd432 100644 --- a/dbt/include/sqlserver/macros/materializations/unit_test/unit_test_table.sql +++ b/dbt/include/sqlserver/macros/materializations/unit_test/unit_test_table.sql @@ -23,8 +23,7 @@ {%- endfor -%} {% if not expected_sql %} - {% set expected_sql = get_expected_sql(expected_rows, column_name_to_data_types) %} - {# column_name_to_quoted can be added once supported by get_expected_sql #} + {% set expected_sql = get_expected_sql(expected_rows, column_name_to_data_types, column_name_to_quoted) %} {% endif %} {% set unit_test_sql = get_unit_test_sql(sql, expected_sql, expected_column_names_quoted) %} diff --git a/dbt/include/sqlserver/macros/relations/seeds/helpers.sql b/dbt/include/sqlserver/macros/relations/seeds/helpers.sql index 34c8e726..c0e1ae44 100644 --- a/dbt/include/sqlserver/macros/relations/seeds/helpers.sql +++ b/dbt/include/sqlserver/macros/relations/seeds/helpers.sql @@ -20,32 +20,41 @@ {% endmacro %} {% macro sqlserver__load_csv_rows(model, agate_table) %} + {# + ADBC's mssql driver does not support parameterized queries with '?' placeholders. + Instead, we inline literal values directly into the INSERT statements. + #} {% set cols_sql = get_seed_column_quoted_csv(model, agate_table.column_names) %} {% set batch_size = calc_batch_size(agate_table.column_names|length) %} - {% set bindings = [] %} {% set statements = [] %} {{ log("Inserting batches of " ~ batch_size ~ " records") }} {% for chunk in agate_table.rows | batch(batch_size) %} - {% set bindings = [] %} - - {% for row in chunk %} - {% do bindings.extend(row) %} - {% endfor %} - {% set sql %} insert into {{ this.render() }} ({{ cols_sql }}) values {% for row in chunk -%} - ({%- for column in agate_table.column_names -%} - {{ get_binding_char() }} + ({%- for value in row -%} + {%- if value is none -%} + null + {%- elif value is sameas true -%} + 1 + {%- elif value is sameas false -%} + 0 + {%- elif value is number -%} + {{ value }} + {%- elif value is string -%} + '{{ value | replace("'", "''") }}' + {%- else -%} + '{{ value | replace("'", "''") }}' + {%- endif -%} {%- if not loop.last%},{%- endif %} {%- endfor -%}) {%- if not loop.last%},{%- endif %} {%- endfor %} {% endset %} - {% do adapter.add_query(sql, bindings=bindings, abridge_sql_log=True) %} + {% do adapter.add_query(sql, abridge_sql_log=True) %} {% if loop.index0 == 0 %} {% do statements.append(sql) %} diff --git a/dbt/include/sqlserver/macros/utils/any_value.sql b/dbt/include/sqlserver/macros/utils/any_value.sql new file mode 100644 index 00000000..6dcf8ec2 --- /dev/null +++ b/dbt/include/sqlserver/macros/utils/any_value.sql @@ -0,0 +1,5 @@ +{% macro sqlserver__any_value(expression) -%} + + min({{ expression }}) + +{%- endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/array_construct.sql b/dbt/include/sqlserver/macros/utils/array_construct.sql new file mode 100644 index 00000000..5088c9ac --- /dev/null +++ b/dbt/include/sqlserver/macros/utils/array_construct.sql @@ -0,0 +1,3 @@ +{% macro sqlserver__array_construct(inputs, data_type) -%} + JSON_ARRAY({{ inputs|join(' , ') }}) +{%- endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/cast_bool_to_text.sql b/dbt/include/sqlserver/macros/utils/cast_bool_to_text.sql new file mode 100644 index 00000000..9771afbf --- /dev/null +++ b/dbt/include/sqlserver/macros/utils/cast_bool_to_text.sql @@ -0,0 +1,7 @@ +{% macro sqlserver__cast_bool_to_text(field) %} + case {{ field }} + when 1 then 'true' + when 0 then 'false' + else null + end +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/concat.sql b/dbt/include/sqlserver/macros/utils/concat.sql new file mode 100644 index 00000000..1b7c1755 --- /dev/null +++ b/dbt/include/sqlserver/macros/utils/concat.sql @@ -0,0 +1,7 @@ +{% macro sqlserver__concat(fields) -%} + {%- if fields|length < 2 -%} + {{ fields[0] }} + {%- else -%} + concat({{ fields|join(', ') }}) + {%- endif -%} +{%- endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/date_trunc.sql b/dbt/include/sqlserver/macros/utils/date_trunc.sql new file mode 100644 index 00000000..85b4ce32 --- /dev/null +++ b/dbt/include/sqlserver/macros/utils/date_trunc.sql @@ -0,0 +1,3 @@ +{% macro sqlserver__date_trunc(datepart, date) %} + CAST(DATEADD({{datepart}}, DATEDIFF({{datepart}}, 0, {{date}}), 0) AS DATE) +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/dateadd.sql b/dbt/include/sqlserver/macros/utils/dateadd.sql new file mode 100644 index 00000000..f3b24fa6 --- /dev/null +++ b/dbt/include/sqlserver/macros/utils/dateadd.sql @@ -0,0 +1,9 @@ +{% macro sqlserver__dateadd(datepart, interval, from_date_or_timestamp) %} + + dateadd( + {{ datepart }}, + {{ interval }}, + cast({{ from_date_or_timestamp }} as datetime2(6)) + ) + +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/get_tables_by_pattern.sql b/dbt/include/sqlserver/macros/utils/get_tables_by_pattern.sql new file mode 100644 index 00000000..75d6b500 --- /dev/null +++ b/dbt/include/sqlserver/macros/utils/get_tables_by_pattern.sql @@ -0,0 +1,12 @@ +{% macro sqlserver__get_tables_by_pattern_sql(schema_pattern, table_pattern, exclude='', database=target.database) %} + + select distinct + table_schema as {{ adapter.quote('table_schema') }}, + table_name as {{ adapter.quote('table_name') }}, + {{ dbt_utils.get_table_types_sql() }} + from {{ database }}.INFORMATION_SCHEMA.TABLES + where table_schema like '{{ schema_pattern }}' + and table_name like '{{ table_pattern }}' + and table_name not like '{{ exclude }}' + +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/hash.sql b/dbt/include/sqlserver/macros/utils/hash.sql new file mode 100644 index 00000000..d965f81f --- /dev/null +++ b/dbt/include/sqlserver/macros/utils/hash.sql @@ -0,0 +1,3 @@ +{% macro sqlserver__hash(field) %} + lower(convert(varchar(50), hashbytes('md5', coalesce(convert(varchar(8000), {{field}}), '')), 2)) +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/last_day.sql b/dbt/include/sqlserver/macros/utils/last_day.sql new file mode 100644 index 00000000..c523d944 --- /dev/null +++ b/dbt/include/sqlserver/macros/utils/last_day.sql @@ -0,0 +1,13 @@ +{% macro sqlserver__last_day(date, datepart) -%} + + {%- if datepart == 'quarter' -%} + CAST(DATEADD(QUARTER, DATEDIFF(QUARTER, 0, {{ date }}) + 1, -1) AS DATE) + {%- elif datepart == 'month' -%} + EOMONTH ( {{ date }}) + {%- elif datepart == 'year' -%} + CAST(DATEADD(YEAR, DATEDIFF(year, 0, {{ date }}) + 1, -1) AS DATE) + {%- else -%} + {{dbt_utils.default_last_day(date, datepart)}} + {%- endif -%} + +{%- endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/length.sql b/dbt/include/sqlserver/macros/utils/length.sql new file mode 100644 index 00000000..ee9431ac --- /dev/null +++ b/dbt/include/sqlserver/macros/utils/length.sql @@ -0,0 +1,5 @@ +{% macro sqlserver__length(expression) %} + + len( {{ expression }} ) + +{%- endmacro -%} diff --git a/dbt/include/sqlserver/macros/utils/listagg.sql b/dbt/include/sqlserver/macros/utils/listagg.sql new file mode 100644 index 00000000..4d6ab215 --- /dev/null +++ b/dbt/include/sqlserver/macros/utils/listagg.sql @@ -0,0 +1,8 @@ +{% macro sqlserver__listagg(measure, delimiter_text, order_by_clause, limit_num) -%} + + string_agg({{ measure }}, {{ delimiter_text }}) + {%- if order_by_clause != None %} + within group ({{ order_by_clause }}) + {%- endif %} + +{%- endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/position.sql b/dbt/include/sqlserver/macros/utils/position.sql new file mode 100644 index 00000000..bd3f6577 --- /dev/null +++ b/dbt/include/sqlserver/macros/utils/position.sql @@ -0,0 +1,8 @@ +{% macro sqlserver__position(substring_text, string_text) %} + + CHARINDEX( + {{ substring_text }}, + {{ string_text }} + ) + +{%- endmacro -%} diff --git a/dbt/include/sqlserver/macros/utils/safe_cast.sql b/dbt/include/sqlserver/macros/utils/safe_cast.sql new file mode 100644 index 00000000..4ae065a7 --- /dev/null +++ b/dbt/include/sqlserver/macros/utils/safe_cast.sql @@ -0,0 +1,3 @@ +{% macro sqlserver__safe_cast(field, type) %} + try_cast({{field}} as {{type}}) +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/timestamps.sql b/dbt/include/sqlserver/macros/utils/timestamps.sql new file mode 100644 index 00000000..31795764 --- /dev/null +++ b/dbt/include/sqlserver/macros/utils/timestamps.sql @@ -0,0 +1,8 @@ +{% macro sqlserver__current_timestamp() -%} + CAST(SYSDATETIME() AS DATETIME2(6)) +{%- endmacro %} + +{% macro sqlserver__snapshot_string_as_time(timestamp) -%} + {%- set result = "CONVERT(DATETIME2(6), '" ~ timestamp ~ "')" -%} + {{ return(result) }} +{%- endmacro %} diff --git a/setup.py b/setup.py index af97049a..7fd5f899 100644 --- a/setup.py +++ b/setup.py @@ -66,10 +66,11 @@ def run(self): packages=find_namespace_packages(include=["dbt", "dbt.*"]), include_package_data=True, install_requires=[ - "dbt-fabric==1.9.6", "dbt-core>=1.9.0,<2.0", "dbt-common>=1.0,<2.0", "dbt-adapters>=1.11.0,<2.0", + "adbc-driver-manager>=1.9.0", + "pyarrow>=20.0.0", ], cmdclass={ "verify": VerifyVersionCommand, diff --git a/tests/__init__.py b/tests/__init__.py index c6609dfc..e69de29b 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,41 +0,0 @@ -import pytest -from azure.identity import AzureCliCredential - -from dbt.adapters.sqlserver.sqlserver_connections import ( # byte_array_to_datetime, - bool_to_connection_string_arg, - get_pyodbc_attrs_before_credentials, -) -from dbt.adapters.sqlserver.sqlserver_credentials import SQLServerCredentials - -# See -# https://github.com/Azure/azure-sdk-for-python/blob/azure-identity_1.5.0/sdk/identity/azure-identity/tests/test_cli_credential.py -CHECK_OUTPUT = AzureCliCredential.__module__ + ".subprocess.check_output" - - -@pytest.fixture -def credentials() -> SQLServerCredentials: - credentials = SQLServerCredentials( - driver="ODBC Driver 18 for SQL Server", - host="fake.sql.sqlserver.net", - database="dbt", - schema="sqlserver", - ) - return credentials - - -def test_get_pyodbc_attrs_before_empty_dict_when_service_principal( - credentials: SQLServerCredentials, -) -> None: - """ - When the authentication is set to sql we expect an empty attrs before. - """ - attrs_before = get_pyodbc_attrs_before_credentials(credentials) - assert attrs_before == {} - - -@pytest.mark.parametrize( - "key, value, expected", - [("somekey", False, "somekey=No"), ("somekey", True, "somekey=Yes")], -) -def test_bool_to_connection_string_arg(key: str, value: bool, expected: str) -> None: - assert bool_to_connection_string_arg(key, value) == expected diff --git a/tests/conftest.py b/tests/conftest.py index 540ee302..4459b194 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -51,7 +51,6 @@ def is_azure(request: FixtureRequest) -> bool: def _all_profiles_base(): return { "type": "sqlserver", - "driver": os.getenv("SQLSERVER_TEST_DRIVER", "ODBC Driver 18 for SQL Server"), "port": int(os.getenv("SQLSERVER_TEST_PORT", "1433")), "retries": 2, } diff --git a/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py b/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py index 2acb2520..22f5a636 100644 --- a/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py +++ b/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py @@ -1,41 +1,120 @@ import pytest -from azure.identity import AzureCliCredential -from dbt.adapters.sqlserver.sqlserver_connections import ( # byte_array_to_datetime, - bool_to_connection_string_arg, - get_pyodbc_attrs_before_credentials, -) from dbt.adapters.sqlserver.sqlserver_credentials import SQLServerCredentials -# See -# https://github.com/Azure/azure-sdk-for-python/blob/azure-identity_1.5.0/sdk/identity/azure-identity/tests/test_cli_credential.py -CHECK_OUTPUT = AzureCliCredential.__module__ + ".subprocess.check_output" - @pytest.fixture def credentials() -> SQLServerCredentials: - credentials = SQLServerCredentials( - driver="ODBC Driver 17 for SQL Server", + return SQLServerCredentials( host="fake.sql.sqlserver.net", database="dbt", schema="sqlserver", ) - return credentials -def test_get_pyodbc_attrs_before_empty_dict_when_service_principal( - credentials: SQLServerCredentials, -) -> None: - """ - When the authentication is set to sql we expect an empty attrs before. - """ - attrs_before = get_pyodbc_attrs_before_credentials(credentials) - assert attrs_before == {} +@pytest.fixture +def sql_auth_credentials() -> SQLServerCredentials: + return SQLServerCredentials( + host="127.0.0.1", + port=1433, + database="TestDB", + schema="dbo", + UID="SA", + PWD="L0calTesting!", + authentication="sql", + encrypt=True, + trust_cert=True, + ) + + +class TestBuildAdbcUri: + def test_basic_sql_auth(self, sql_auth_credentials: SQLServerCredentials) -> None: + uri = sql_auth_credentials.build_adbc_uri() + assert uri == ( + "sqlserver://SA:L0calTesting%21@127.0.0.1:1433" + "?database=TestDB&encrypt=true&TrustServerCertificate=true" + ) + + def test_named_instance_omits_port(self) -> None: + creds = SQLServerCredentials( + host=r"myserver\SQLEXPRESS", + database="TestDB", + schema="dbo", + UID="SA", + PWD="test", + ) + uri = creds.build_adbc_uri() + assert r"myserver\SQLEXPRESS" in uri + assert ":1433" not in uri + + def test_special_chars_in_password(self) -> None: + creds = SQLServerCredentials( + host="localhost", + port=1433, + database="TestDB", + schema="dbo", + UID="SA", + PWD="p@ss:word/with+special=chars&more", + encrypt=False, + trust_cert=False, + ) + uri = creds.build_adbc_uri() + # Password should be URL-encoded + assert "p%40ss%3Aword%2Fwith%2Bspecial%3Dchars%26more" in uri + # Original password should NOT appear unencoded + assert "p@ss:word" not in uri + + def test_encrypt_false(self) -> None: + creds = SQLServerCredentials( + host="localhost", + port=1433, + database="TestDB", + schema="dbo", + encrypt=False, + trust_cert=False, + ) + uri = creds.build_adbc_uri() + assert "encrypt=false" in uri + assert "TrustServerCertificate=false" in uri + + def test_login_timeout(self) -> None: + creds = SQLServerCredentials( + host="localhost", + port=1433, + database="TestDB", + schema="dbo", + login_timeout=30, + ) + uri = creds.build_adbc_uri() + assert "connection timeout=30" in uri + + def test_no_login_timeout_when_zero(self) -> None: + creds = SQLServerCredentials( + host="localhost", + port=1433, + database="TestDB", + schema="dbo", + login_timeout=0, + ) + uri = creds.build_adbc_uri() + assert "connection timeout" not in uri + + def test_no_user_no_password(self, credentials: SQLServerCredentials) -> None: + uri = credentials.build_adbc_uri() + assert "sqlserver://fake.sql.sqlserver.net:1433?" in uri + assert "database=dbt" in uri + + +class TestCredentialProperties: + def test_type(self, credentials: SQLServerCredentials) -> None: + assert credentials.type == "sqlserver" + def test_unique_field(self, credentials: SQLServerCredentials) -> None: + assert credentials.unique_field == "fake.sql.sqlserver.net" -@pytest.mark.parametrize( - "key, value, expected", - [("somekey", False, "somekey=No"), ("somekey", True, "somekey=Yes")], -) -def test_bool_to_connection_string_arg(key: str, value: bool, expected: str) -> None: - assert bool_to_connection_string_arg(key, value) == expected + def test_connection_keys(self, credentials: SQLServerCredentials) -> None: + keys = credentials._connection_keys() + assert "host" in keys + assert "port" in keys + assert "database" in keys + assert "schema" in keys