@@ -73,6 +73,37 @@ struct HugeintSumOperation : public BaseSumOperation<SumSetOperation, HugeintAdd
7373 }
7474};
7575
76+ template <class T >
77+ static LogicalType GetValueLogicalType ();
78+
79+ template <>
80+ LogicalType GetValueLogicalType<int64_t >() {
81+ return LogicalType::BIGINT;
82+ }
83+ template <>
84+ LogicalType GetValueLogicalType<hugeint_t >() {
85+ return LogicalType::HUGEINT;
86+ }
87+ template <>
88+ LogicalType GetValueLogicalType<double >() {
89+ return LogicalType::DOUBLE;
90+ }
91+
92+ template <class T >
93+ LogicalType GetSumStateType (const AggregateFunction &function) {
94+ child_list_t <LogicalType> child_types;
95+ child_types.emplace_back (" isset" , LogicalType::BOOLEAN);
96+
97+ LogicalType value_type = GetValueLogicalType<T>();
98+ // Use the return type when its physical representation matches the state type
99+ if (function.return_type .InternalType () == value_type.InternalType ()) {
100+ value_type = function.return_type ;
101+ }
102+ child_types.emplace_back (" value" , value_type);
103+
104+ return LogicalType::STRUCT (std::move (child_types));
105+ }
106+
76107unique_ptr<FunctionData> SumNoOverflowBind (ClientContext &context, AggregateFunction &function,
77108 vector<unique_ptr<Expression>> &arguments) {
78109 throw BinderException (" sum_no_overflow is for internal use only!" );
@@ -98,7 +129,7 @@ AggregateFunction GetSumAggregateNoOverflow(PhysicalType type) {
98129 function.SetBindCallback (SumNoOverflowBind);
99130 function.SetSerializeCallback (SumNoOverflowSerialize);
100131 function.SetDeserializeCallback (SumNoOverflowDeserialize);
101- return function;
132+ return function. SetStructStateExport (GetSumStateType< int64_t >) ;
102133 }
103134 case PhysicalType::INT64: {
104135 auto function = AggregateFunction::UnaryAggregate<SumState<int64_t >, int64_t , hugeint_t , IntegerSumOperation>(
@@ -108,7 +139,7 @@ AggregateFunction GetSumAggregateNoOverflow(PhysicalType type) {
108139 function.SetBindCallback (SumNoOverflowBind);
109140 function.SetSerializeCallback (SumNoOverflowSerialize);
110141 function.SetDeserializeCallback (SumNoOverflowDeserialize);
111- return function;
142+ return function. SetStructStateExport (GetSumStateType< int64_t >) ;
112143 }
113144 default :
114145 throw BinderException (" Unsupported internal type for sum_no_overflow" );
@@ -164,13 +195,13 @@ AggregateFunction GetSumAggregate(PhysicalType type) {
164195 auto function = AggregateFunction::UnaryAggregate<SumState<int64_t >, bool , hugeint_t , IntegerSumOperation>(
165196 LogicalType::BOOLEAN, LogicalType::HUGEINT);
166197 function.SetOrderDependent (AggregateOrderDependent::NOT_ORDER_DEPENDENT);
167- return function;
198+ return function. SetStructStateExport (GetSumStateType< int64_t >) ;
168199 }
169200 case PhysicalType::INT16: {
170201 auto function = AggregateFunction::UnaryAggregate<SumState<int64_t >, int16_t , hugeint_t , IntegerSumOperation>(
171202 LogicalType::SMALLINT, LogicalType::HUGEINT);
172203 function.SetOrderDependent (AggregateOrderDependent::NOT_ORDER_DEPENDENT);
173- return function;
204+ return function. SetStructStateExport (GetSumStateType< int64_t >) ;
174205 }
175206
176207 case PhysicalType::INT32: {
@@ -179,22 +210,22 @@ AggregateFunction GetSumAggregate(PhysicalType type) {
179210 LogicalType::INTEGER, LogicalType::HUGEINT);
180211 function.SetStatisticsCallback (SumPropagateStats);
181212 function.SetOrderDependent (AggregateOrderDependent::NOT_ORDER_DEPENDENT);
182- return function;
213+ return function. SetStructStateExport (GetSumStateType< hugeint_t >) ;
183214 }
184215 case PhysicalType::INT64: {
185216 auto function =
186217 AggregateFunction::UnaryAggregate<SumState<hugeint_t >, int64_t , hugeint_t , SumToHugeintOperation>(
187218 LogicalType::BIGINT, LogicalType::HUGEINT);
188219 function.SetStatisticsCallback (SumPropagateStats);
189220 function.SetOrderDependent (AggregateOrderDependent::NOT_ORDER_DEPENDENT);
190- return function;
221+ return function. SetStructStateExport (GetSumStateType< hugeint_t >) ;
191222 }
192223 case PhysicalType::INT128: {
193224 auto function =
194225 AggregateFunction::UnaryAggregate<SumState<hugeint_t >, hugeint_t , hugeint_t , HugeintSumOperation>(
195226 LogicalType::HUGEINT, LogicalType::HUGEINT);
196227 function.SetOrderDependent (AggregateOrderDependent::NOT_ORDER_DEPENDENT);
197- return function;
228+ return function. SetStructStateExport (GetSumStateType< hugeint_t >) ;
198229 }
199230 default :
200231 throw InternalException (" Unimplemented sum aggregate" );
@@ -283,7 +314,8 @@ AggregateFunctionSet SumFun::GetFunctions() {
283314 sum.AddFunction (GetSumAggregate (PhysicalType::INT64));
284315 sum.AddFunction (GetSumAggregate (PhysicalType::INT128));
285316 sum.AddFunction (AggregateFunction::UnaryAggregate<SumState<double >, double , double , NumericSumOperation>(
286- LogicalType::DOUBLE, LogicalType::DOUBLE));
317+ LogicalType::DOUBLE, LogicalType::DOUBLE)
318+ .SetStructStateExport (GetSumStateType<double >));
287319 sum.AddFunction (AggregateFunction::UnaryAggregate<BignumState, bignum_t , bignum_t , BignumOperation>(
288320 LogicalType::BIGNUM, LogicalType::BIGNUM));
289321 return sum;
@@ -301,9 +333,18 @@ AggregateFunctionSet SumNoOverflowFun::GetFunctions() {
301333 return sum_no_overflow;
302334}
303335
336+ LogicalType GetKahanSumStateType (const AggregateFunction &function) {
337+ child_list_t <LogicalType> children;
338+ children.emplace_back (" isset" , LogicalType::BOOLEAN);
339+ children.emplace_back (" value" , LogicalType::DOUBLE);
340+ children.emplace_back (" err" , LogicalType::DOUBLE);
341+ return LogicalType::STRUCT (std::move (children));
342+ }
343+
304344AggregateFunction KahanSumFun::GetFunction () {
305345 return AggregateFunction::UnaryAggregate<KahanSumState, double , double , KahanSumOperation>(LogicalType::DOUBLE,
306- LogicalType::DOUBLE);
346+ LogicalType::DOUBLE)
347+ .SetStructStateExport (GetKahanSumStateType);
307348}
308349
309350} // namespace duckdb
0 commit comments