Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion dbt/adapters/sqlserver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
adapter=SQLServerAdapter,
credentials=SQLServerCredentials,
include_path=sqlserver.PACKAGE_PATH,
dependencies=["fabric"],
)

__all__ = [
Expand Down
129 changes: 121 additions & 8 deletions dbt/adapters/sqlserver/sqlserver_adapter.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,50 @@
from typing import Optional
from typing import List, Optional

import agate
import dbt.exceptions
from dbt.adapters.base import available
from dbt.adapters.base.impl import ConstraintSupport
from dbt.adapters.fabric import FabricAdapter
from dbt.adapters.base.relation import BaseRelation
from dbt.adapters.sql import SQLAdapter
from dbt.contracts.graph.nodes import ConstraintType
from dbt_common.contracts.constraints import ColumnLevelConstraint

from dbt.adapters.sqlserver.sqlserver_column import SQLServerColumn
from dbt.adapters.sqlserver.sqlserver_connections import SQLServerConnectionManager
from dbt.adapters.sqlserver.sqlserver_relation import SQLServerRelation

COLUMNS_EQUAL_SQL = """
with diff_count as (
SELECT
1 as id,
COUNT(*) as num_missing FROM (
(SELECT {columns} FROM {relation_a} {except_op}
SELECT {columns} FROM {relation_b})
UNION ALL
(SELECT {columns} FROM {relation_b} {except_op}
SELECT {columns} FROM {relation_a})
) as a
), table_a as (
SELECT COUNT(*) as num_rows FROM {relation_a}
), table_b as (
SELECT COUNT(*) as num_rows FROM {relation_b}
), row_count_diff as (
select
1 as id,
table_a.num_rows - table_b.num_rows as difference
from table_a, table_b
)
select
row_count_diff.difference as row_count_difference,
diff_count.num_missing as num_mismatched
from row_count_diff
join diff_count on row_count_diff.id = diff_count.id
""".strip()

class SQLServerAdapter(FabricAdapter):

class SQLServerAdapter(SQLAdapter):
"""
Controls actual implmentation of adapter, and ability to override certain methods.
Controls actual implementation of adapter, and ability to override certain methods.
"""

ConnectionManager = SQLServerConnectionManager
Expand All @@ -27,6 +59,91 @@ class SQLServerAdapter(FabricAdapter):
ConstraintType.foreign_key: ConstraintSupport.ENFORCED,
}

# -- Type conversions (inlined from FabricAdapter) --

@classmethod
def convert_boolean_type(cls, agate_table, col_idx):
return "bit"

@classmethod
def convert_datetime_type(cls, agate_table, col_idx):
return "datetime2(6)"

@classmethod
def convert_number_type(cls, agate_table, col_idx):
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx))
return "float" if decimals else "int"

@classmethod
def convert_text_type(cls, agate_table, col_idx):
column = agate_table.columns[col_idx]
lens = [len(d.encode("utf-8")) for d in column.values_without_nulls()]
max_len = max(lens) if lens else 64
length = max_len if max_len > 16 else 16
return "varchar({})".format(length)

@classmethod
def convert_time_type(cls, agate_table, col_idx):
return "time(6)"

@classmethod
def date_function(cls):
return "getdate()"

# -- SQL helpers (inlined from FabricAdapter) --

def timestamp_add_sql(self, add_to: str, number: int = 1, interval: str = "hour") -> str:
return f"DATEADD({interval},{number},{add_to})"

def string_add_sql(self, add_to: str, value: str, location="append") -> str:
"""+ is T-SQL's string concatenation operator"""
if location == "append":
return f"{add_to} + '{value}'"
elif location == "prepend":
return f"'{value}' + {add_to}"
else:
raise ValueError(f'Got an unexpected location value of "{location}"')

def get_rows_different_sql(
self,
relation_a: BaseRelation,
relation_b: BaseRelation,
column_names: Optional[List[str]] = None,
except_operator: str = "EXCEPT",
) -> str:
names: List[str]
if column_names is None:
columns = self.get_columns_in_relation(relation_a)
names = sorted((self.quote(c.name) for c in columns))
else:
names = sorted((self.quote(n) for n in column_names))

columns_csv = ", ".join(names)

sql = COLUMNS_EQUAL_SQL.format(
columns=columns_csv,
relation_a=str(relation_a),
relation_b=str(relation_b),
except_op=except_operator,
)
return sql

# -- Constraint rendering --

@available
@classmethod
def render_column_constraint(cls, constraint: ColumnLevelConstraint) -> Optional[str]:
rendered_column_constraint = None
if constraint.type == ConstraintType.not_null:
rendered_column_constraint = "not null"
else:
rendered_column_constraint = ""

if rendered_column_constraint:
rendered_column_constraint = rendered_column_constraint.strip()

return rendered_column_constraint

@classmethod
def render_model_constraint(cls, constraint) -> Optional[str]:
constraint_prefix = "add constraint "
Expand Down Expand Up @@ -56,10 +173,6 @@ def render_model_constraint(cls, constraint) -> Optional[str]:
else:
return None

@classmethod
def date_function(cls):
return "getdate()"

def valid_incremental_strategies(self):
"""The set of standard builtin strategies which this adapter supports out-of-the-box.
Not used to validate custom strategies defined by end users.
Expand Down
82 changes: 78 additions & 4 deletions dbt/adapters/sqlserver/sqlserver_column.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,96 @@
from dbt.adapters.fabric import FabricColumn
from typing import Any, ClassVar, Dict

from dbt.adapters.base import Column
from dbt_common.exceptions import DbtRuntimeError


class SQLServerColumn(Column):
TYPE_LABELS: ClassVar[Dict[str, str]] = {
"STRING": "VARCHAR(8000)",
"VARCHAR": "VARCHAR(8000)",
"CHAR": "CHAR(1)",
"NCHAR": "CHAR(1)",
"NVARCHAR": "VARCHAR(8000)",
"TIMESTAMP": "DATETIME2(6)",
"DATETIME2": "DATETIME2(6)",
"DATETIME2(6)": "DATETIME2(6)",
"DATE": "DATE",
"TIME": "TIME(6)",
"FLOAT": "FLOAT",
"REAL": "REAL",
"INT": "INT",
"INTEGER": "INT",
"BIGINT": "BIGINT",
"SMALLINT": "SMALLINT",
"TINYINT": "SMALLINT",
"BIT": "BIT",
"BOOLEAN": "BIT",
"DECIMAL": "DECIMAL",
"NUMERIC": "NUMERIC",
"MONEY": "DECIMAL",
"SMALLMONEY": "DECIMAL",
"UNIQUEIDENTIFIER": "UNIQUEIDENTIFIER",
"VARBINARY": "VARBINARY(MAX)",
"BINARY": "BINARY(1)",
}

@classmethod
def string_type(cls, size: int) -> str:
return f"varchar({size if size > 0 else '8000'})"

def literal(self, value: Any) -> str:
return "cast('{}' as {})".format(value, self.data_type)

@property
def data_type(self) -> str:
if self.dtype.lower() == "datetime2":
return "datetime2(6)"
if self.is_string():
return self.string_type(self.string_size())
elif self.is_numeric():
return self.numeric_type(self.dtype, self.numeric_precision, self.numeric_scale)
else:
return self.dtype

def is_string(self) -> bool:
return self.dtype.lower() in ["varchar", "char"]

def is_number(self):
return any([self.is_integer(), self.is_numeric(), self.is_float()])

def is_float(self):
return self.dtype.lower() in ["float", "real"]

class SQLServerColumn(FabricColumn):
def is_integer(self) -> bool:
return self.dtype.lower() in [
# real types
"smallint",
"integer",
"bigint",
"smallserial",
"serial",
"bigserial",
# aliases
"int2",
"int4",
"int8",
"serial2",
"serial4",
"serial8",
"int",
"tinyint",
]

def is_numeric(self) -> bool:
return self.dtype.lower() in ["numeric", "decimal", "money", "smallmoney"]

def string_size(self) -> int:
if not self.is_string():
raise DbtRuntimeError("Called string_size() on non-string field!")
if self.char_size is None:
return 8000
else:
return int(self.char_size)

def can_expand_to(self, other_column: "SQLServerColumn") -> bool:
if not self.is_string() or not other_column.is_string():
return False
return other_column.string_size() > self.string_size()
4 changes: 2 additions & 2 deletions dbt/adapters/sqlserver/sqlserver_configs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from dataclasses import dataclass

from dbt.adapters.fabric import FabricConfigs
from dbt.adapters.protocol import AdapterConfig


@dataclass
class SQLServerConfigs(FabricConfigs):
class SQLServerConfigs(AdapterConfig):
pass
Loading
Loading