33 infer_types ,
44 named_parameters_to_dbsqlparams_v1 ,
55 named_parameters_to_dbsqlparams_v2 ,
6+ calculate_decimal_cast_string ,
7+ DbsqlDynamicDecimalType
68)
79from databricks .sql .thrift_api .TCLIService .ttypes import (
810 TSparkParameter ,
1113from databricks .sql .utils import DbSqlParameter , DbSqlType
1214import pytest
1315
16+ from decimal import Decimal
17+
18+ from typing import List
19+
1420
1521class TestTSparkParameterConversion (object ):
1622 def test_conversion_e2e (self ):
@@ -31,7 +37,7 @@ def test_conversion_e2e(self):
3137 name = "" , type = "FLOAT" , value = TSparkParameterValue (stringValue = "1.0" )
3238 ),
3339 TSparkParameter (
34- name = "" , type = "DECIMAL" , value = TSparkParameterValue (stringValue = "1.0" )
40+ name = "" , type = "DECIMAL(2,1) " , value = TSparkParameterValue (stringValue = "1.0" )
3541 ),
3642 ]
3743
@@ -53,23 +59,64 @@ def test_basic_conversions_v2(self):
5359 DbSqlParameter ("3" , "foo" ),
5460 ]
5561
56- def test_type_inference (self ):
62+ def test_infer_types_none (self ):
5763 with pytest .raises (ValueError ):
5864 infer_types ([DbSqlParameter ("" , None )])
65+
66+ def test_infer_types_dict (self ):
5967 with pytest .raises (ValueError ):
6068 infer_types ([DbSqlParameter ("" , {1 : 1 })])
61- assert infer_types ([DbSqlParameter ("" , 1 )]) == [
62- DbSqlParameter ("" , "1" , DbSqlType .INTEGER )
63- ]
64- assert infer_types ([DbSqlParameter ("" , True )]) == [
65- DbSqlParameter ("" , "True" , DbSqlType .BOOLEAN )
66- ]
67- assert infer_types ([DbSqlParameter ("" , 1.0 )]) == [
68- DbSqlParameter ("" , "1.0" , DbSqlType .FLOAT )
69- ]
70- assert infer_types ([DbSqlParameter ("" , "foo" )]) == [
71- DbSqlParameter ("" , "foo" , DbSqlType .STRING )
72- ]
73- assert infer_types ([DbSqlParameter ("" , 1.0 , DbSqlType .DECIMAL )]) == [
74- DbSqlParameter ("" , "1.0" , DbSqlType .DECIMAL )
75- ]
69+
70+ def test_infer_types_integer (self ):
71+ input = DbSqlParameter ("" , 1 )
72+ output = infer_types ([input ])
73+ assert output == [DbSqlParameter ("" , "1" , DbSqlType .INTEGER )]
74+
75+ def test_infer_types_boolean (self ):
76+ input = DbSqlParameter ("" , True )
77+ output = infer_types ([input ])
78+ assert output == [DbSqlParameter ("" , "True" , DbSqlType .BOOLEAN )]
79+
80+ def test_infer_types_float (self ):
81+ input = DbSqlParameter ("" , 1.0 )
82+ output = infer_types ([input ])
83+ assert output == [DbSqlParameter ("" , "1.0" , DbSqlType .FLOAT )]
84+
85+ def test_infer_types_string (self ):
86+ input = DbSqlParameter ("" , "foo" )
87+ output = infer_types ([input ])
88+ assert output == [DbSqlParameter ("" , "foo" , DbSqlType .STRING )]
89+
90+ def test_infer_types_decimal (self ):
91+ # The output decimal will have a dynamically calculated decimal type with a value of DECIMAL(2,1)
92+ input = DbSqlParameter ("" , Decimal ("1.0" ))
93+ output : List [DbSqlParameter ] = infer_types ([input ])
94+
95+ x = output [0 ]
96+
97+ assert x .value == "1.0"
98+ assert isinstance (x .type , DbsqlDynamicDecimalType )
99+ assert x .type .value == "DECIMAL(2,1)"
100+
101+
102+ class TestCalculateDecimalCast (object ):
103+
104+ def test_38_38 (self ):
105+ input = Decimal (".12345678912345678912345678912345678912" )
106+ output = calculate_decimal_cast_string (input )
107+ assert output == "DECIMAL(38,38)"
108+
109+ def test_18_9 (self ):
110+ input = Decimal ("123456789.123456789" )
111+ output = calculate_decimal_cast_string (input )
112+ assert output == "DECIMAL(18,9)"
113+
114+ def test_38_0 (self ):
115+ input = Decimal ("12345678912345678912345678912345678912" )
116+ output = calculate_decimal_cast_string (input )
117+ assert output == "DECIMAL(38,0)"
118+
119+ def test_6_2 (self ):
120+ input = Decimal ("1234.56" )
121+ output = calculate_decimal_cast_string (input )
122+ assert output == "DECIMAL(6,2)"
0 commit comments