From e2b6268caecaa14c10d09cf6df1960cce050a38a Mon Sep 17 00:00:00 2001 From: Cool6689 <3322351820@qq.com> Date: Sun, 1 Feb 2026 14:35:22 +0800 Subject: [PATCH 1/7] Add correlation and covariance aggregate functions to the query engine --- .../BuiltinAggregationFunctionEnum.java | 5 +- .../aggregation/AccumulatorFactory.java | 18 ++ .../aggregation/CorrelationAccumulator.java | 249 ++++++++++++++++ .../SlidingWindowAggregatorFactory.java | 3 + .../aggregation/AccumulatorFactory.java | 36 +++ .../TableCorrelationAccumulator.java | 269 +++++++++++++++++ .../GroupedCorrelationAccumulator.java | 272 ++++++++++++++++++ .../plan/analyze/ExpressionTypeAnalyzer.java | 20 ++ .../queryengine/plan/parser/ASTVisitor.java | 3 + .../plan/parameter/AggregationDescriptor.java | 9 + .../analyzer/ExpressionAnalyzer.java | 16 ++ .../metadata/TableMetadataImpl.java | 20 ++ .../apache/iotdb/db/utils/SchemaUtils.java | 21 ++ .../iotdb/db/utils/TypeInferenceUtils.java | 13 +- .../iotdb/db/utils/constant/SqlConstant.java | 3 + .../builtin/BuiltinAggregationFunction.java | 11 +- .../TableBuiltinAggregationFunction.java | 8 +- .../src/main/thrift/common.thrift | 5 +- 18 files changed, 976 insertions(+), 5 deletions(-) create mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CorrelationAccumulator.java create mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCorrelationAccumulator.java create mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCorrelationAccumulator.java diff --git a/integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinAggregationFunctionEnum.java b/integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinAggregationFunctionEnum.java index 7c2c283d30e3a..e45af9710f83c 100644 --- a/integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinAggregationFunctionEnum.java +++ b/integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinAggregationFunctionEnum.java @@ -42,7 +42,10 @@ public enum BuiltinAggregationFunctionEnum { AVG("avg"), SUM("sum"), MAX_BY("max_by"), - MIN_BY("min_by"); + MIN_BY("min_by"), + CORR("corr"), + COVAR_POP("covar_pop"), + COVAR_SAMP("covar_samp"); private final String functionName; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/AccumulatorFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/AccumulatorFactory.java index 24a998f54a917..5976f4f1dd53f 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/AccumulatorFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/AccumulatorFactory.java @@ -69,6 +69,9 @@ public static boolean isMultiInputAggregation(TAggregationType aggregationType) switch (aggregationType) { case MAX_BY: case MIN_BY: + case CORR: + case COVAR_POP: + case COVAR_SAMP: return true; default: return false; @@ -84,6 +87,21 @@ public static Accumulator createBuiltinMultiInputAccumulator( case MIN_BY: checkState(inputDataTypes.size() == 2, "Wrong inputDataTypes size."); return new MinByAccumulator(inputDataTypes.get(0), inputDataTypes.get(1)); + case CORR: + checkState(inputDataTypes.size() == 2, "Wrong inputDataTypes size."); + return new CorrelationAccumulator( + new TSDataType[] {inputDataTypes.get(0), inputDataTypes.get(1)}, + CorrelationAccumulator.CorrelationType.CORR); + case COVAR_POP: + checkState(inputDataTypes.size() == 2, "Wrong inputDataTypes size."); + return new CorrelationAccumulator( + new TSDataType[] {inputDataTypes.get(0), inputDataTypes.get(1)}, + CorrelationAccumulator.CorrelationType.COVAR_POP); + case COVAR_SAMP: + checkState(inputDataTypes.size() == 2, "Wrong inputDataTypes size."); + return new CorrelationAccumulator( + new TSDataType[] {inputDataTypes.get(0), inputDataTypes.get(1)}, + CorrelationAccumulator.CorrelationType.COVAR_SAMP); default: throw new IllegalArgumentException("Invalid Aggregation function: " + aggregationType); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CorrelationAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CorrelationAccumulator.java new file mode 100644 index 0000000000000..d5f77c017084b --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CorrelationAccumulator.java @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.aggregation; + +import static com.google.common.base.Preconditions.checkArgument; + +import java.nio.ByteBuffer; +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.statistics.Statistics; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.BitMap; + +public class CorrelationAccumulator implements Accumulator { + + public enum CorrelationType { + CORR, + COVAR_POP, + COVAR_SAMP + } + + private final TSDataType[] seriesDataTypes; + private final CorrelationType correlationType; + + private long count; + private double meanX; + private double meanY; + private double m2X; // sum((x - meanX)^2) + private double m2Y; // sum((y - meanY)^2) + private double c2; // sum((x - meanX) * (y - meanY)) + + public CorrelationAccumulator(TSDataType[] seriesDataTypes, CorrelationType correlationType) { + this.seriesDataTypes = seriesDataTypes; + this.correlationType = correlationType; + } + + @Override + public void addInput(Column[] columns, BitMap bitMap) { + // columns[0] is time column + // columns[1] and columns[2] are the two data columns + int size = columns[0].getPositionCount(); + for (int i = 0; i < size; i++) { + if (bitMap != null && !bitMap.isMarked(i)) { + continue; + } + if (columns[1].isNull(i) || columns[2].isNull(i)) { + continue; + } + + double x = getDoubleValue(columns[1], i, seriesDataTypes[0]); + double y = getDoubleValue(columns[2], i, seriesDataTypes[1]); + + update(x, y); + } + } + + private double getDoubleValue(Column column, int position, TSDataType dataType) { + switch (dataType) { + case INT32: + return column.getInt(position); + case INT64: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new IllegalArgumentException("Unsupported data type: " + dataType); + } + } + + private void update(double x, double y) { + long newCount = count + 1; + double deltaX = x - meanX; + double deltaY = y - meanY; + + meanX += deltaX / newCount; + meanY += deltaY / newCount; + + // Welford's algorithm for covariance and variance + // C2_new = C2_old + (x - meanX_old) * (y - meanY_new) + c2 += deltaX * (y - meanY); + m2X += deltaX * (x - meanX); + m2Y += deltaY * (y - meanY); + + count = newCount; + } + + @Override + public void addIntermediate(Column[] partialResult) { + checkArgument(partialResult.length == 1, "partialResult of Correlation should be 1"); + if (partialResult[0].isNull(0)) { + return; + } + byte[] bytes = partialResult[0].getBinary(0).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMeanX = buffer.getDouble(); + double otherMeanY = buffer.getDouble(); + double otherM2X = buffer.getDouble(); + double otherM2Y = buffer.getDouble(); + double otherC2 = buffer.getDouble(); + + merge(otherCount, otherMeanX, otherMeanY, otherM2X, otherM2Y, otherC2); + } + + private void merge( + long otherCount, + double otherMeanX, + double otherMeanY, + double otherM2X, + double otherM2Y, + double otherC2) { + if (otherCount == 0) { + return; + } + if (count == 0) { + count = otherCount; + meanX = otherMeanX; + meanY = otherMeanY; + m2X = otherM2X; + m2Y = otherM2Y; + c2 = otherC2; + } else { + long newCount = count + otherCount; + double deltaX = otherMeanX - meanX; + double deltaY = otherMeanY - meanY; + + // Merge formulas + c2 += otherC2 + deltaX * deltaY * count * otherCount / newCount; + m2X += otherM2X + deltaX * deltaX * count * otherCount / newCount; + m2Y += otherM2Y + deltaY * deltaY * count * otherCount / newCount; + + meanX += deltaX * otherCount / newCount; + meanY += deltaY * otherCount / newCount; + count = newCount; + } + } + + @Override + public void outputIntermediate(ColumnBuilder[] columnBuilders) { + checkArgument(columnBuilders.length == 1, "partialResult of Correlation should be 1"); + if (count == 0) { + columnBuilders[0].appendNull(); + } else { + ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES + Double.BYTES * 5); + buffer.putLong(count); + buffer.putDouble(meanX); + buffer.putDouble(meanY); + buffer.putDouble(m2X); + buffer.putDouble(m2Y); + buffer.putDouble(c2); + columnBuilders[0].writeBinary(new Binary(buffer.array())); + } + } + + @Override + public void outputFinal(ColumnBuilder columnBuilder) { + switch (correlationType) { + case CORR: + if (count < 2) { + // Not enough data to calculate correlation + columnBuilder.appendNull(); + } else if (m2X == 0 || m2Y == 0) { + // If either variable has zero variance (all values the same), correlation is 0 + columnBuilder.writeDouble(0.0); + } else { + columnBuilder.writeDouble(c2 / Math.sqrt(m2X * m2Y)); + } + break; + case COVAR_POP: + if (count == 0) { + columnBuilder.appendNull(); + } else { + columnBuilder.writeDouble(c2 / count); + } + break; + case COVAR_SAMP: + if (count < 2) { + columnBuilder.appendNull(); + } else { + columnBuilder.writeDouble(c2 / (count - 1)); + } + break; + default: + throw new UnsupportedOperationException("Unknown type: " + correlationType); + } + } + + @Override + public void removeIntermediate(Column[] input) { + // Optional: sliding window logic implementation if needed, otherwise throw exception + throw new UnsupportedOperationException("Remove not implemented for Correlation"); + } + + @Override + public void addStatistics(Statistics statistics) { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void setFinal(Column finalResult) { + // No-op for this accumulator typically + } + + @Override + public void reset() { + count = 0; + meanX = 0; + meanY = 0; + m2X = 0; + m2Y = 0; + c2 = 0; + } + + @Override + public boolean hasFinalResult() { + return false; + } + + @Override + public TSDataType[] getIntermediateType() { + return new TSDataType[] {TSDataType.TEXT}; + } + + @Override + public TSDataType getFinalType() { + return TSDataType.DOUBLE; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/slidingwindow/SlidingWindowAggregatorFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/slidingwindow/SlidingWindowAggregatorFactory.java index 572d41d518486..38e072da41705 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/slidingwindow/SlidingWindowAggregatorFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/slidingwindow/SlidingWindowAggregatorFactory.java @@ -200,6 +200,9 @@ public static SlidingWindowAggregator createSlidingWindowAggregator( case VARIANCE: case VAR_POP: case VAR_SAMP: + case CORR: + case COVAR_POP: + case COVAR_SAMP: case UDAF: // Currently UDAF belongs to SmoothQueueSlidingWindowAggregator return new SmoothQueueSlidingWindowAggregator(accumulator, inputLocationList, step); case MAX_VALUE: diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java index 3ff20974168be..a8715bfade117 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java @@ -21,6 +21,7 @@ import org.apache.iotdb.common.rpc.thrift.TAggregationType; import org.apache.iotdb.commons.udf.utils.UDFDataTypeTransformer; +import org.apache.iotdb.db.queryengine.execution.aggregation.CorrelationAccumulator; import org.apache.iotdb.db.queryengine.execution.aggregation.VarianceAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.BinaryGroupedApproxMostFrequentAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.BlobGroupedApproxMostFrequentAccumulator; @@ -30,6 +31,7 @@ import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedApproxCountDistinctAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedAvgAccumulator; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedCorrelationAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedCountAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedCountAllAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedCountIfAccumulator; @@ -256,6 +258,21 @@ private static GroupedAccumulator createBuiltinGroupedAccumulator( return new GroupedApproxCountDistinctAccumulator(inputDataTypes.get(0)); case APPROX_MOST_FREQUENT: return getGroupedApproxMostFrequentAccumulator(inputDataTypes.get(0)); + case CORR: + return new GroupedCorrelationAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + CorrelationAccumulator.CorrelationType.CORR); + case COVAR_POP: + return new GroupedCorrelationAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + CorrelationAccumulator.CorrelationType.COVAR_POP); + case COVAR_SAMP: + return new GroupedCorrelationAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + CorrelationAccumulator.CorrelationType.COVAR_SAMP); default: throw new IllegalArgumentException("Invalid Aggregation function: " + aggregationType); } @@ -325,6 +342,21 @@ public static TableAccumulator createBuiltinAccumulator( return new ApproxCountDistinctAccumulator(inputDataTypes.get(0)); case APPROX_MOST_FREQUENT: return getApproxMostFrequentAccumulator(inputDataTypes.get(0)); + case CORR: + return new TableCorrelationAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + CorrelationAccumulator.CorrelationType.CORR); + case COVAR_POP: + return new TableCorrelationAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + CorrelationAccumulator.CorrelationType.COVAR_POP); + case COVAR_SAMP: + return new TableCorrelationAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + CorrelationAccumulator.CorrelationType.COVAR_SAMP); default: throw new IllegalArgumentException("Invalid Aggregation function: " + aggregationType); } @@ -385,6 +417,10 @@ public static boolean isMultiInputAggregation(TAggregationType aggregationType) case MAX_BY: case MIN_BY: return true; + case CORR: + case COVAR_POP: + case COVAR_SAMP: + return true; default: return false; } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCorrelationAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCorrelationAccumulator.java new file mode 100644 index 0000000000000..9ab44df27a512 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCorrelationAccumulator.java @@ -0,0 +1,269 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation; + +import org.apache.iotdb.db.queryengine.execution.aggregation.CorrelationAccumulator; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.statistics.Statistics; +import org.apache.tsfile.read.common.block.column.BinaryColumn; +import org.apache.tsfile.read.common.block.column.BinaryColumnBuilder; +import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.RamUsageEstimator; +import org.apache.tsfile.write.UnSupportedDataTypeException; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class TableCorrelationAccumulator implements TableAccumulator { + + private static final long INSTANCE_SIZE = + RamUsageEstimator.shallowSizeOfInstance(TableCorrelationAccumulator.class); + private final TSDataType xDataType; + private final TSDataType yDataType; + private final CorrelationAccumulator.CorrelationType correlationType; + + private long count; + private double meanX; + private double meanY; + private double m2X; // sum((x - meanX)^2) + private double m2Y; // sum((y - meanY)^2) + private double c2; // sum((x - meanX) * (y - meanY)) + + public TableCorrelationAccumulator( + TSDataType xDataType, + TSDataType yDataType, + CorrelationAccumulator.CorrelationType correlationType) { + this.xDataType = xDataType; + this.yDataType = yDataType; + this.correlationType = correlationType; + } + + @Override + public long getEstimatedSize() { + return INSTANCE_SIZE; + } + + @Override + public TableAccumulator copy() { + return new TableCorrelationAccumulator(xDataType, yDataType, correlationType); + } + + @Override + public void addInput(Column[] arguments, AggregationMask mask) { + int positionCount = mask.getSelectedPositionCount(); + + if (mask.isSelectAll()) { + for (int i = 0; i < positionCount; i++) { + if (arguments[0].isNull(i) || arguments[1].isNull(i)) { + continue; + } + double x = getDoubleValue(arguments[0], i, xDataType); + double y = getDoubleValue(arguments[1], i, yDataType); + update(x, y); + } + } else { + int[] selectedPositions = mask.getSelectedPositions(); + for (int i = 0; i < positionCount; i++) { + int position = selectedPositions[i]; + if (arguments[0].isNull(position) || arguments[1].isNull(position)) { + continue; + } + double x = getDoubleValue(arguments[0], position, xDataType); + double y = getDoubleValue(arguments[1], position, yDataType); + update(x, y); + } + } + } + + private double getDoubleValue(Column column, int position, TSDataType dataType) { + switch (dataType) { + case INT32: + case DATE: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new UnSupportedDataTypeException( + String.format("Unsupported data type in Correlation Aggregation: %s", dataType)); + } + } + + private void update(double x, double y) { + long newCount = count + 1; + double deltaX = x - meanX; + double deltaY = y - meanY; + + meanX += deltaX / newCount; + meanY += deltaY / newCount; + + // Welford's algorithm for covariance and variance + c2 += deltaX * (y - meanY); + m2X += deltaX * (x - meanX); + m2Y += deltaY * (y - meanY); + + count = newCount; + } + + @Override + public void removeInput(Column[] arguments) { + throw new UnsupportedOperationException("Remove not implemented for Correlation Accumulator"); + } + + @Override + public void addIntermediate(Column argument) { + checkArgument( + argument instanceof BinaryColumn + || (argument instanceof RunLengthEncodedColumn + && ((RunLengthEncodedColumn) argument).getValue() instanceof BinaryColumn), + "intermediate input and output should be BinaryColumn"); + + for (int i = 0; i < argument.getPositionCount(); i++) { + if (argument.isNull(i)) { + continue; + } + + byte[] bytes = argument.getBinary(i).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMeanX = buffer.getDouble(); + double otherMeanY = buffer.getDouble(); + double otherM2X = buffer.getDouble(); + double otherM2Y = buffer.getDouble(); + double otherC2 = buffer.getDouble(); + + merge(otherCount, otherMeanX, otherMeanY, otherM2X, otherM2Y, otherC2); + } + } + + private void merge( + long otherCount, + double otherMeanX, + double otherMeanY, + double otherM2X, + double otherM2Y, + double otherC2) { + if (otherCount == 0) { + return; + } + if (count == 0) { + count = otherCount; + meanX = otherMeanX; + meanY = otherMeanY; + m2X = otherM2X; + m2Y = otherM2Y; + c2 = otherC2; + } else { + long newCount = count + otherCount; + double deltaX = otherMeanX - meanX; + double deltaY = otherMeanY - meanY; + + // Merge formulas + c2 += otherC2 + deltaX * deltaY * count * otherCount / newCount; + m2X += otherM2X + deltaX * deltaX * count * otherCount / newCount; + m2Y += otherM2Y + deltaY * deltaY * count * otherCount / newCount; + + meanX += deltaX * otherCount / newCount; + meanY += deltaY * otherCount / newCount; + count = newCount; + } + } + + @Override + public void evaluateIntermediate(ColumnBuilder columnBuilder) { + checkArgument( + columnBuilder instanceof BinaryColumnBuilder, + "intermediate input and output should be BinaryColumn"); + + if (count == 0) { + columnBuilder.appendNull(); + } else { + ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES + Double.BYTES * 5); + buffer.putLong(count); + buffer.putDouble(meanX); + buffer.putDouble(meanY); + buffer.putDouble(m2X); + buffer.putDouble(m2Y); + buffer.putDouble(c2); + columnBuilder.writeBinary(new Binary(buffer.array())); + } + } + + @Override + public void evaluateFinal(ColumnBuilder columnBuilder) { + switch (correlationType) { + case CORR: + if (count < 2) { + columnBuilder.appendNull(); + } else if (m2X == 0 || m2Y == 0) { + columnBuilder.writeDouble(0.0); + } else { + columnBuilder.writeDouble(c2 / Math.sqrt(m2X * m2Y)); + } + break; + case COVAR_POP: + if (count == 0) { + columnBuilder.appendNull(); + } else { + columnBuilder.writeDouble(c2 / count); + } + break; + case COVAR_SAMP: + if (count < 2) { + columnBuilder.appendNull(); + } else { + columnBuilder.writeDouble(c2 / (count - 1)); + } + break; + default: + throw new UnsupportedOperationException("Unknown type: " + correlationType); + } + } + + @Override + public boolean hasFinalResult() { + return false; + } + + @Override + public void addStatistics(Statistics[] statistics) { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void reset() { + count = 0; + meanX = 0; + meanY = 0; + m2X = 0; + m2Y = 0; + c2 = 0; + } + + @Override + public boolean removable() { + return false; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCorrelationAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCorrelationAccumulator.java new file mode 100644 index 0000000000000..beb474b97c7b9 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCorrelationAccumulator.java @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped; + +import org.apache.iotdb.db.queryengine.execution.aggregation.CorrelationAccumulator; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.AggregationMask; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.DoubleBigArray; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.LongBigArray; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.read.common.block.column.BinaryColumn; +import org.apache.tsfile.read.common.block.column.BinaryColumnBuilder; +import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.RamUsageEstimator; +import org.apache.tsfile.write.UnSupportedDataTypeException; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class GroupedCorrelationAccumulator implements GroupedAccumulator { + + private static final long INSTANCE_SIZE = + RamUsageEstimator.shallowSizeOfInstance(GroupedCorrelationAccumulator.class); + private final TSDataType xDataType; + private final TSDataType yDataType; + private final CorrelationAccumulator.CorrelationType correlationType; + + private final LongBigArray counts = new LongBigArray(); + private final DoubleBigArray meanXs = new DoubleBigArray(); + private final DoubleBigArray meanYs = new DoubleBigArray(); + private final DoubleBigArray m2Xs = new DoubleBigArray(); + private final DoubleBigArray m2Ys = new DoubleBigArray(); + private final DoubleBigArray c2s = new DoubleBigArray(); + + public GroupedCorrelationAccumulator( + TSDataType xDataType, + TSDataType yDataType, + CorrelationAccumulator.CorrelationType correlationType) { + this.xDataType = xDataType; + this.yDataType = yDataType; + this.correlationType = correlationType; + } + + @Override + public long getEstimatedSize() { + return INSTANCE_SIZE + + counts.sizeOf() + + meanXs.sizeOf() + + meanYs.sizeOf() + + m2Xs.sizeOf() + + m2Ys.sizeOf() + + c2s.sizeOf(); + } + + @Override + public void setGroupCount(long groupCount) { + counts.ensureCapacity(groupCount); + meanXs.ensureCapacity(groupCount); + meanYs.ensureCapacity(groupCount); + m2Xs.ensureCapacity(groupCount); + m2Ys.ensureCapacity(groupCount); + c2s.ensureCapacity(groupCount); + } + + @Override + public void addInput(int[] groupIds, Column[] arguments, AggregationMask mask) { + int positionCount = mask.getSelectedPositionCount(); + + if (mask.isSelectAll()) { + for (int i = 0; i < positionCount; i++) { + if (arguments[0].isNull(i) || arguments[1].isNull(i)) { + continue; + } + double x = getDoubleValue(arguments[0], i, xDataType); + double y = getDoubleValue(arguments[1], i, yDataType); + update(groupIds[i], x, y); + } + } else { + int[] selectedPositions = mask.getSelectedPositions(); + for (int i = 0; i < positionCount; i++) { + int position = selectedPositions[i]; + if (arguments[0].isNull(position) || arguments[1].isNull(position)) { + continue; + } + double x = getDoubleValue(arguments[0], position, xDataType); + double y = getDoubleValue(arguments[1], position, yDataType); + update(groupIds[position], x, y); + } + } + } + + private double getDoubleValue(Column column, int position, TSDataType dataType) { + switch (dataType) { + case INT32: + case DATE: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new UnSupportedDataTypeException( + String.format("Unsupported data type in Correlation Aggregation: %s", dataType)); + } + } + + private void update(int groupId, double x, double y) { + long newCount = counts.get(groupId) + 1; + double deltaX = x - meanXs.get(groupId); + double deltaY = y - meanYs.get(groupId); + + meanXs.add(groupId, deltaX / newCount); + meanYs.add(groupId, deltaY / newCount); + + // Welford's algorithm for covariance and variance + c2s.add(groupId, deltaX * (y - meanYs.get(groupId))); + m2Xs.add(groupId, deltaX * (x - meanXs.get(groupId))); + m2Ys.add(groupId, deltaY * (y - meanYs.get(groupId))); + + counts.set(groupId, newCount); + } + + @Override + public void addIntermediate(int[] groupIds, Column argument) { + checkArgument( + argument instanceof BinaryColumn + || (argument instanceof RunLengthEncodedColumn + && ((RunLengthEncodedColumn) argument).getValue() instanceof BinaryColumn), + "intermediate input and output should be BinaryColumn"); + + for (int i = 0; i < argument.getPositionCount(); i++) { + if (argument.isNull(i)) { + continue; + } + + byte[] bytes = argument.getBinary(i).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMeanX = buffer.getDouble(); + double otherMeanY = buffer.getDouble(); + double otherM2X = buffer.getDouble(); + double otherM2Y = buffer.getDouble(); + double otherC2 = buffer.getDouble(); + + merge(groupIds[i], otherCount, otherMeanX, otherMeanY, otherM2X, otherM2Y, otherC2); + } + } + + private void merge( + int groupId, + long otherCount, + double otherMeanX, + double otherMeanY, + double otherM2X, + double otherM2Y, + double otherC2) { + if (otherCount == 0) { + return; + } + if (counts.get(groupId) == 0) { + counts.set(groupId, otherCount); + meanXs.set(groupId, otherMeanX); + meanYs.set(groupId, otherMeanY); + m2Xs.set(groupId, otherM2X); + m2Ys.set(groupId, otherM2Y); + c2s.set(groupId, otherC2); + } else { + long newCount = counts.get(groupId) + otherCount; + double deltaX = otherMeanX - meanXs.get(groupId); + double deltaY = otherMeanY - meanYs.get(groupId); + + // Merge formulas + c2s.add(groupId, otherC2 + deltaX * deltaY * counts.get(groupId) * otherCount / newCount); + m2Xs.add(groupId, otherM2X + deltaX * deltaX * counts.get(groupId) * otherCount / newCount); + m2Ys.add(groupId, otherM2Y + deltaY * deltaY * counts.get(groupId) * otherCount / newCount); + + meanXs.add(groupId, deltaX * otherCount / newCount); + meanYs.add(groupId, deltaY * otherCount / newCount); + counts.set(groupId, newCount); + } + } + + @Override + public void evaluateIntermediate(int groupId, ColumnBuilder columnBuilder) { + checkArgument( + columnBuilder instanceof BinaryColumnBuilder, + "intermediate input and output should be BinaryColumn"); + + if (counts.get(groupId) == 0) { + columnBuilder.appendNull(); + } else { + ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES + Double.BYTES * 5); + buffer.putLong(counts.get(groupId)); + buffer.putDouble(meanXs.get(groupId)); + buffer.putDouble(meanYs.get(groupId)); + buffer.putDouble(m2Xs.get(groupId)); + buffer.putDouble(m2Ys.get(groupId)); + buffer.putDouble(c2s.get(groupId)); + columnBuilder.writeBinary(new Binary(buffer.array())); + } + } + + @Override + public void evaluateFinal(int groupId, ColumnBuilder columnBuilder) { + switch (correlationType) { + case CORR: + if (counts.get(groupId) < 2) { + columnBuilder.appendNull(); + } else if (m2Xs.get(groupId) == 0 || m2Ys.get(groupId) == 0) { + columnBuilder.writeDouble(0.0); + } else { + columnBuilder.writeDouble( + c2s.get(groupId) / Math.sqrt(m2Xs.get(groupId) * m2Ys.get(groupId))); + } + break; + case COVAR_POP: + if (counts.get(groupId) == 0) { + columnBuilder.appendNull(); + } else { + columnBuilder.writeDouble(c2s.get(groupId) / counts.get(groupId)); + } + break; + case COVAR_SAMP: + if (counts.get(groupId) < 2) { + columnBuilder.appendNull(); + } else { + columnBuilder.writeDouble(c2s.get(groupId) / (counts.get(groupId) - 1)); + } + break; + default: + throw new UnsupportedOperationException("Unknown type: " + correlationType); + } + } + + @Override + public void prepareFinal() {} + + @Override + public void reset() { + counts.reset(); + meanXs.reset(); + meanYs.reset(); + m2Xs.reset(); + m2Ys.reset(); + c2s.reset(); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java index adc80c7bb1522..286f2fb199971 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java @@ -361,6 +361,23 @@ public TSDataType visitFunctionExpression( } if (functionExpression.isBuiltInAggregationFunctionExpression()) { + // Additional type check for multi-input aggregation functions + String funcName = functionExpression.getFunctionName().toLowerCase(); + if (funcName.equals(SqlConstant.CORR) + || funcName.equals(SqlConstant.COVAR_POP) + || funcName.equals(SqlConstant.COVAR_SAMP)) { + // Check both input parameters are numeric + if (inputExpressions.size() >= 2) { + TSDataType secondInputType = expressionTypes.get(NodeRef.of(inputExpressions.get(1))); + if (secondInputType != null && !secondInputType.isNumeric()) { + throw new SemanticException( + String.format( + "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE]", + functionExpression.getFunctionName().toUpperCase())); + } + } + } + return setExpressionType( functionExpression, TypeInferenceUtils.getBuiltinAggregationDataType( @@ -542,6 +559,9 @@ private TSDataType getInputExpressionTypeForAggregation( case SqlConstant.VARIANCE: case SqlConstant.VAR_POP: case SqlConstant.VAR_SAMP: + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: case SqlConstant.MAX_BY: case SqlConstant.MIN_BY: return expressionTypes.get(NodeRef.of(inputExpressions.get(0))); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java index 706d14f052cd2..86bcdd5c803d9 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java @@ -3194,6 +3194,9 @@ private void checkAggregationFunctionInput(FunctionExpression functionExpression case SqlConstant.COUNT_IF: case SqlConstant.MAX_BY: case SqlConstant.MIN_BY: + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: checkFunctionExpressionInputSize( functionExpression.getExpressionString(), functionExpression.getExpressions().size(), diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/AggregationDescriptor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/AggregationDescriptor.java index ac30dcf505afd..e3737e7cbeb6c 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/AggregationDescriptor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/AggregationDescriptor.java @@ -187,6 +187,15 @@ public List getActualAggregationNames(boolean isPartial) { case VAR_SAMP: outputAggregationNames.add(addPartialSuffix(SqlConstant.VAR_SAMP)); break; + case CORR: + outputAggregationNames.add(addPartialSuffix(SqlConstant.CORR)); + break; + case COVAR_POP: + outputAggregationNames.add(addPartialSuffix(SqlConstant.COVAR_POP)); + break; + case COVAR_SAMP: + outputAggregationNames.add(addPartialSuffix(SqlConstant.COVAR_SAMP)); + break; case MAX_BY: outputAggregationNames.add(addPartialSuffix(SqlConstant.MAX_BY)); break; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java index 6157612b42382..1df371c9a2141 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java @@ -1066,6 +1066,22 @@ protected Type visitFunctionCall( } } + // Check argument count for specific aggregation functions before calling + // getFunctionReturnType + if (isAggregation) { + String lowerFuncName = functionName.toLowerCase(); + if (lowerFuncName.equals("corr") + || lowerFuncName.equals("covar_pop") + || lowerFuncName.equals("covar_samp")) { + if (argumentTypes.size() != 2) { + throw new SemanticException( + String.format( + "Error size of input expressions. expression: %s, actual size: %s, expected size: [2].", + node, argumentTypes.size())); + } + } + } + Type type = metadata.getFunctionReturnType(functionName, argumentTypes); FunctionKind functionKind = FunctionKind.SCALAR; if (isAggregation) { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java index 8934de172e9c5..99ea1c5d30faa 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java @@ -647,6 +647,23 @@ && isIntegerNumber(argumentTypes.get(2)))) { "Second argument of Aggregate functions [%s] should be orderable", functionName)); } + break; + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: + // Argument count is already checked in ExpressionAnalyzer + if (!isSupportedMathNumericType(argumentTypes.get(0))) { + throw new SemanticException( + String.format( + "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE]", + functionName.toUpperCase())); + } else if (!isSupportedMathNumericType(argumentTypes.get(1))) { + throw new SemanticException( + String.format( + "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE]", + functionName.toUpperCase())); + } + break; case SqlConstant.APPROX_COUNT_DISTINCT: if (argumentTypes.size() != 1 && argumentTypes.size() != 2) { @@ -701,6 +718,9 @@ && isIntegerNumber(argumentTypes.get(2)))) { case SqlConstant.VARIANCE: case SqlConstant.VAR_POP: case SqlConstant.VAR_SAMP: + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: return DOUBLE; case SqlConstant.APPROX_MOST_FREQUENT: return STRING; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/SchemaUtils.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/SchemaUtils.java index 773de36a067fd..27ab21580fa8a 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/SchemaUtils.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/SchemaUtils.java @@ -88,6 +88,9 @@ public static TSDataType getBuiltinAggregationTypeByFuncName(String aggregation) case SqlConstant.VARIANCE: case SqlConstant.VAR_POP: case SqlConstant.VAR_SAMP: + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: return TSDataType.DOUBLE; // Partial aggregation names case SqlConstant.STDDEV + "_partial": @@ -96,6 +99,9 @@ public static TSDataType getBuiltinAggregationTypeByFuncName(String aggregation) case SqlConstant.VARIANCE + "_partial": case SqlConstant.VAR_POP + "_partial": case SqlConstant.VAR_SAMP + "_partial": + case SqlConstant.CORR + "_partial": + case SqlConstant.COVAR_POP + "_partial": + case SqlConstant.COVAR_SAMP + "_partial": case SqlConstant.MAX_BY + "_partial": case SqlConstant.MIN_BY + "_partial": return TSDataType.TEXT; @@ -163,6 +169,12 @@ public static String getBuiltinAggregationName(TAggregationType aggregationType) return SqlConstant.VAR_POP; case VAR_SAMP: return SqlConstant.VAR_SAMP; + case CORR: + return SqlConstant.CORR; + case COVAR_POP: + return SqlConstant.COVAR_POP; + case COVAR_SAMP: + return SqlConstant.COVAR_SAMP; default: return null; } @@ -198,6 +210,9 @@ public static boolean isConsistentWithScanOrder( case VAR_SAMP: case MAX_BY: case MIN_BY: + case CORR: + case COVAR_POP: + case COVAR_SAMP: case UDAF: return true; default: @@ -232,6 +247,12 @@ public static List splitPartialBuiltinAggregation(TAggregationType aggre return Collections.singletonList(addPartialSuffix(SqlConstant.VAR_POP)); case VAR_SAMP: return Collections.singletonList(addPartialSuffix(SqlConstant.VAR_SAMP)); + case CORR: + return Collections.singletonList(addPartialSuffix(SqlConstant.CORR)); + case COVAR_POP: + return Collections.singletonList(addPartialSuffix(SqlConstant.COVAR_POP)); + case COVAR_SAMP: + return Collections.singletonList(addPartialSuffix(SqlConstant.COVAR_SAMP)); case MAX_BY: return Collections.singletonList(addPartialSuffix(SqlConstant.MAX_BY)); case MIN_BY: diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java index 8fc1d647dc36e..6a83f59644d30 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java @@ -154,6 +154,9 @@ public static TSDataType getBuiltinAggregationDataType( case SqlConstant.VARIANCE: case SqlConstant.VAR_POP: case SqlConstant.VAR_SAMP: + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: return TSDataType.DOUBLE; default: throw new IllegalArgumentException( @@ -186,11 +189,16 @@ private static void verifyIsAggregationDataTypeMatched(String aggrFuncName, TSDa case SqlConstant.VARIANCE: case SqlConstant.VAR_POP: case SqlConstant.VAR_SAMP: + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: if (dataType.isNumeric()) { return; } throw new SemanticException( - "Aggregate functions [AVG, SUM, EXTREME, STDDEV, STDDEV_POP, STDDEV_SAMP, VARIANCE, VAR_POP, VAR_SAMP] only support numeric data types [INT32, INT64, FLOAT, DOUBLE]"); + "Aggregate functions [AVG, SUM, EXTREME, STDDEV, STDDEV_POP, STDDEV_SAMP, " + + "VARIANCE, VAR_POP, VAR_SAMP, CORR, COVAR_POP, COVAR_SAMP] only support " + + "numeric data types [INT32, INT64, FLOAT, DOUBLE]"); case SqlConstant.COUNT: case SqlConstant.COUNT_TIME: case SqlConstant.MIN_TIME: @@ -245,6 +253,9 @@ public static void bindTypeForBuiltinAggregationNonSeriesInputExpressions( case SqlConstant.VARIANCE: case SqlConstant.VAR_POP: case SqlConstant.VAR_SAMP: + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: case SqlConstant.MAX_BY: case SqlConstant.MIN_BY: return; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/SqlConstant.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/SqlConstant.java index 8120aff6059ba..49180bb13e2c9 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/SqlConstant.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/SqlConstant.java @@ -74,6 +74,9 @@ protected SqlConstant() { public static final String VARIANCE = "variance"; public static final String VAR_POP = "var_pop"; public static final String VAR_SAMP = "var_samp"; + public static final String CORR = "corr"; + public static final String COVAR_POP = "covar_pop"; + public static final String COVAR_SAMP = "covar_samp"; public static final String COUNT_TIME = "count_time"; public static final String COUNT_TIME_HEADER = "count_time(*)"; diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinAggregationFunction.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinAggregationFunction.java index 1c6b25ef53aaf..0f25a1ad6e282 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinAggregationFunction.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinAggregationFunction.java @@ -46,7 +46,10 @@ public enum BuiltinAggregationFunction { VAR_POP("var_pop"), VAR_SAMP("var_samp"), MAX_BY("max_by"), - MIN_BY("min_by"); + MIN_BY("min_by"), + CORR("corr"), + COVAR_POP("covar_pop"), + COVAR_SAMP("covar_samp"); private final String functionName; @@ -97,6 +100,9 @@ public static boolean canUseStatistics(String name) { case "var_samp": case "max_by": case "min_by": + case "corr": + case "covar_pop": + case "covar_samp": return false; default: throw new IllegalArgumentException("Invalid Aggregation function: " + name); @@ -131,6 +137,9 @@ public static boolean canSplitToMultiPhases(String name) { case "var_samp": case "max_by": case "min_by": + case "corr": + case "covar_pop": + case "covar_samp": return true; case "count_if": case "count_time": diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java index 39f7cde84c490..69320e2aa00e5 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java @@ -58,7 +58,10 @@ public enum TableBuiltinAggregationFunction { VAR_POP("var_pop"), VAR_SAMP("var_samp"), APPROX_COUNT_DISTINCT("approx_count_distinct"), - APPROX_MOST_FREQUENT("approx_most_frequent"); + APPROX_MOST_FREQUENT("approx_most_frequent"), + CORR("corr"), + COVAR_POP("covar_pop"), + COVAR_SAMP("covar_samp"); private final String functionName; @@ -103,6 +106,9 @@ public static Type getIntermediateType(String name, List originalArgumentT case "variance": case "var_pop": case "var_samp": + case "corr": + case "covar_pop": + case "covar_samp": case "approx_count_distinct": return RowType.anonymous(Collections.emptyList()); case "extreme": diff --git a/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift b/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift index 3287f35b9bb92..214d4a6dce4ec 100644 --- a/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift +++ b/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift @@ -293,7 +293,10 @@ enum TAggregationType { MAX, COUNT_ALL, APPROX_COUNT_DISTINCT, - APPROX_MOST_FREQUENT + APPROX_MOST_FREQUENT, + CORR, + COVAR_POP, + COVAR_SAMP } struct TShowConfigurationTemplateResp { From 5212024224be8a80e1d6ce8b92764c387dd22ea3 Mon Sep 17 00:00:00 2001 From: Cool6689 <3322351820@qq.com> Date: Mon, 2 Feb 2026 17:30:20 +0800 Subject: [PATCH 2/7] Enhance aggregate functions to support TIMESTAMP data type --- .../aggregation/CorrelationAccumulator.java | 8 +++++--- .../plan/analyze/ExpressionTypeAnalyzer.java | 19 ++++++++++++++++--- .../metadata/TableMetadataImpl.java | 10 +++++++--- .../iotdb/db/utils/TypeInferenceUtils.java | 14 ++++++++++---- 4 files changed, 38 insertions(+), 13 deletions(-) diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CorrelationAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CorrelationAccumulator.java index d5f77c017084b..722659c6581c8 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CorrelationAccumulator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CorrelationAccumulator.java @@ -19,9 +19,6 @@ package org.apache.iotdb.db.queryengine.execution.aggregation; -import static com.google.common.base.Preconditions.checkArgument; - -import java.nio.ByteBuffer; import org.apache.tsfile.block.column.Column; import org.apache.tsfile.block.column.ColumnBuilder; import org.apache.tsfile.enums.TSDataType; @@ -29,6 +26,10 @@ import org.apache.tsfile.utils.Binary; import org.apache.tsfile.utils.BitMap; +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + public class CorrelationAccumulator implements Accumulator { public enum CorrelationType { @@ -77,6 +78,7 @@ private double getDoubleValue(Column column, int position, TSDataType dataType) case INT32: return column.getInt(position); case INT64: + case TIMESTAMP: return column.getLong(position); case FLOAT: return column.getFloat(position); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java index 286f2fb199971..796a5265c0b3f 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java @@ -366,13 +366,26 @@ public TSDataType visitFunctionExpression( if (funcName.equals(SqlConstant.CORR) || funcName.equals(SqlConstant.COVAR_POP) || funcName.equals(SqlConstant.COVAR_SAMP)) { - // Check both input parameters are numeric + // Check both input parameters are numeric or timestamp + if (inputExpressions.size() >= 1) { + TSDataType firstInputType = expressionTypes.get(NodeRef.of(inputExpressions.get(0))); + if (firstInputType != null + && !firstInputType.isNumeric() + && firstInputType != TSDataType.TIMESTAMP) { + throw new SemanticException( + String.format( + "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", + functionExpression.getFunctionName().toUpperCase())); + } + } if (inputExpressions.size() >= 2) { TSDataType secondInputType = expressionTypes.get(NodeRef.of(inputExpressions.get(1))); - if (secondInputType != null && !secondInputType.isNumeric()) { + if (secondInputType != null + && !secondInputType.isNumeric() + && secondInputType != TSDataType.TIMESTAMP) { throw new SemanticException( String.format( - "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE]", + "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", functionExpression.getFunctionName().toUpperCase())); } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java index 99ea1c5d30faa..6d481dd6662b2 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java @@ -655,12 +655,12 @@ && isIntegerNumber(argumentTypes.get(2)))) { if (!isSupportedMathNumericType(argumentTypes.get(0))) { throw new SemanticException( String.format( - "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE]", + "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", functionName.toUpperCase())); } else if (!isSupportedMathNumericType(argumentTypes.get(1))) { throw new SemanticException( String.format( - "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE]", + "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", functionName.toUpperCase())); } @@ -991,7 +991,11 @@ public static boolean isBool(Type type) { } public static boolean isSupportedMathNumericType(Type type) { - return DOUBLE.equals(type) || FLOAT.equals(type) || INT32.equals(type) || INT64.equals(type); + return DOUBLE.equals(type) + || FLOAT.equals(type) + || INT32.equals(type) + || INT64.equals(type) + || TIMESTAMP.equals(type); } public static boolean isNumericType(Type type) { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java index 6a83f59644d30..18bba010e8ad8 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java @@ -189,16 +189,22 @@ private static void verifyIsAggregationDataTypeMatched(String aggrFuncName, TSDa case SqlConstant.VARIANCE: case SqlConstant.VAR_POP: case SqlConstant.VAR_SAMP: - case SqlConstant.CORR: - case SqlConstant.COVAR_POP: - case SqlConstant.COVAR_SAMP: if (dataType.isNumeric()) { return; } throw new SemanticException( "Aggregate functions [AVG, SUM, EXTREME, STDDEV, STDDEV_POP, STDDEV_SAMP, " - + "VARIANCE, VAR_POP, VAR_SAMP, CORR, COVAR_POP, COVAR_SAMP] only support " + + "VARIANCE, VAR_POP, VAR_SAMP] only support " + "numeric data types [INT32, INT64, FLOAT, DOUBLE]"); + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: + if (dataType.isNumeric() || TSDataType.TIMESTAMP.equals(dataType)) { + return; + } + throw new SemanticException( + "Aggregate functions [CORR, COVAR_POP, COVAR_SAMP] only support " + + "numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]"); case SqlConstant.COUNT: case SqlConstant.COUNT_TIME: case SqlConstant.MIN_TIME: From f5759b31c9c9dec899a084382af507f522a93f3e Mon Sep 17 00:00:00 2001 From: Cool6689 <3322351820@qq.com> Date: Tue, 10 Mar 2026 18:36:47 +0800 Subject: [PATCH 3/7] Add regression aggregate functions: REGR_SLOPE and REGR_INTERCEPT --- .../BuiltinAggregationFunctionEnum.java | 4 +- .../aggregation/AccumulatorFactory.java | 12 + .../aggregation/RegressionAccumulator.java | 229 ++++++++++++++++ .../SlidingWindowAggregatorFactory.java | 2 + .../aggregation/AccumulatorFactory.java | 24 ++ .../TableRegressionAccumulator.java | 230 ++++++++++++++++ .../grouped/GroupedRegressionAccumulator.java | 250 ++++++++++++++++++ .../plan/analyze/ExpressionTypeAnalyzer.java | 6 +- .../queryengine/plan/parser/ASTVisitor.java | 2 + .../plan/parameter/AggregationDescriptor.java | 6 + .../analyzer/ExpressionAnalyzer.java | 4 +- .../metadata/TableMetadataImpl.java | 18 +- .../apache/iotdb/db/utils/SchemaUtils.java | 14 + .../iotdb/db/utils/TypeInferenceUtils.java | 12 + .../iotdb/db/utils/constant/SqlConstant.java | 2 + .../builtin/BuiltinAggregationFunction.java | 8 +- .../TableBuiltinAggregationFunction.java | 6 +- .../src/main/thrift/common.thrift | 4 +- 18 files changed, 826 insertions(+), 7 deletions(-) create mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/RegressionAccumulator.java create mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableRegressionAccumulator.java create mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedRegressionAccumulator.java diff --git a/integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinAggregationFunctionEnum.java b/integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinAggregationFunctionEnum.java index e45af9710f83c..0e758b39f9b53 100644 --- a/integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinAggregationFunctionEnum.java +++ b/integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinAggregationFunctionEnum.java @@ -45,7 +45,9 @@ public enum BuiltinAggregationFunctionEnum { MIN_BY("min_by"), CORR("corr"), COVAR_POP("covar_pop"), - COVAR_SAMP("covar_samp"); + COVAR_SAMP("covar_samp"), + REGR_SLOPE("regr_slope"), + REGR_INTERCEPT("regr_intercept"); private final String functionName; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/AccumulatorFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/AccumulatorFactory.java index 5976f4f1dd53f..12f45494ed057 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/AccumulatorFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/AccumulatorFactory.java @@ -72,6 +72,8 @@ public static boolean isMultiInputAggregation(TAggregationType aggregationType) case CORR: case COVAR_POP: case COVAR_SAMP: + case REGR_SLOPE: + case REGR_INTERCEPT: return true; default: return false; @@ -102,6 +104,16 @@ public static Accumulator createBuiltinMultiInputAccumulator( return new CorrelationAccumulator( new TSDataType[] {inputDataTypes.get(0), inputDataTypes.get(1)}, CorrelationAccumulator.CorrelationType.COVAR_SAMP); + case REGR_SLOPE: + checkState(inputDataTypes.size() == 2, "Wrong inputDataTypes size."); + return new RegressionAccumulator( + new TSDataType[] {inputDataTypes.get(0), inputDataTypes.get(1)}, + RegressionAccumulator.RegressionType.REGR_SLOPE); + case REGR_INTERCEPT: + checkState(inputDataTypes.size() == 2, "Wrong inputDataTypes size."); + return new RegressionAccumulator( + new TSDataType[] {inputDataTypes.get(0), inputDataTypes.get(1)}, + RegressionAccumulator.RegressionType.REGR_INTERCEPT); default: throw new IllegalArgumentException("Invalid Aggregation function: " + aggregationType); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/RegressionAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/RegressionAccumulator.java new file mode 100644 index 0000000000000..70db4ac33e357 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/RegressionAccumulator.java @@ -0,0 +1,229 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.aggregation; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.statistics.Statistics; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.BitMap; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class RegressionAccumulator implements Accumulator { + + public enum RegressionType { + REGR_SLOPE, + REGR_INTERCEPT + } + + private final TSDataType[] seriesDataTypes; + private final RegressionType regressionType; + + // 状态变量 (不需要 m2Y) + private long count; + private double meanX; + private double meanY; + private double m2X; // Sum((x - meanX)^2) + private double c2; // Sum((x - meanX) * (y - meanY)) + + public RegressionAccumulator(TSDataType[] seriesDataTypes, RegressionType regressionType) { + this.seriesDataTypes = seriesDataTypes; + this.regressionType = regressionType; + } + + @Override + public void addInput(Column[] columns, BitMap bitMap) { + // Tree 模型: columns[0] 是 Time + // REGR_SLOPE(y, x) -> columns[1] 是 y, columns[2] 是 x + + int size = columns[1].getPositionCount(); + for (int i = 0; i < size; i++) { + if (bitMap != null && !bitMap.isMarked(i)) { + continue; + } + if (columns[1].isNull(i) || columns[2].isNull(i)) { + continue; + } + + double y = getDoubleValue(columns[1], i, seriesDataTypes[0]); // Arg1: Y (因变量) + double x = getDoubleValue(columns[2], i, seriesDataTypes[1]); // Arg2: X (自变量) + + update(x, y); + } + } + + private double getDoubleValue(Column column, int position, TSDataType dataType) { + switch (dataType) { + case INT32: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new IllegalArgumentException("Unsupported data type: " + dataType); + } + } + + private void update(double x, double y) { + long newCount = count + 1; + double deltaX = x - meanX; + double deltaY = y - meanY; + + meanX += deltaX / newCount; + meanY += deltaY / newCount; + + // Welford Covariance & Variance + c2 += deltaX * (y - meanY); + m2X += deltaX * (x - meanX); + + count = newCount; + } + + @Override + public void addIntermediate(Column[] partialResult) { + checkArgument(partialResult.length == 1, "partialResult of Regression should be 1"); + if (partialResult[0].isNull(0)) { + return; + } + byte[] bytes = partialResult[0].getBinary(0).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMeanX = buffer.getDouble(); + double otherMeanY = buffer.getDouble(); + double otherM2X = buffer.getDouble(); + double otherC2 = buffer.getDouble(); + + merge(otherCount, otherMeanX, otherMeanY, otherM2X, otherC2); + } + + private void merge( + long otherCount, double otherMeanX, double otherMeanY, double otherM2X, double otherC2) { + if (otherCount == 0) { + return; + } + if (count == 0) { + count = otherCount; + meanX = otherMeanX; + meanY = otherMeanY; + m2X = otherM2X; + c2 = otherC2; + } else { + long newCount = count + otherCount; + double deltaX = otherMeanX - meanX; + double deltaY = otherMeanY - meanY; + + // Merge Logic + c2 += otherC2 + deltaX * deltaY * count * otherCount / newCount; + m2X += otherM2X + deltaX * deltaX * count * otherCount / newCount; + + meanX += deltaX * otherCount / newCount; + meanY += deltaY * otherCount / newCount; + count = newCount; + } + } + + @Override + public void outputIntermediate(ColumnBuilder[] columnBuilders) { + checkArgument(columnBuilders.length == 1, "partialResult of Regression should be 1"); + if (count == 0) { + columnBuilders[0].appendNull(); + } else { + // 序列化 5 个变量: long(8) + 4 * double(8) = 40 bytes + byte[] bytes = new byte[40]; + ByteBuffer buffer = ByteBuffer.wrap(bytes); + buffer.putLong(count); + buffer.putDouble(meanX); + buffer.putDouble(meanY); + buffer.putDouble(m2X); + buffer.putDouble(c2); + columnBuilders[0].writeBinary(new Binary(bytes)); + } + } + + @Override + public void outputFinal(ColumnBuilder columnBuilder) { + if (count == 0) { + columnBuilder.appendNull(); + return; + } + + // 如果 X 没有波动 (m2X=0), 斜率无法计算 (除以0), 返回 NULL + if (m2X == 0) { + columnBuilder.appendNull(); + return; + } + + double slope = c2 / m2X; + + switch (regressionType) { + case REGR_SLOPE: + columnBuilder.writeDouble(slope); + break; + case REGR_INTERCEPT: + // Intercept = MeanY - Slope * MeanX + columnBuilder.writeDouble(meanY - slope * meanX); + break; + default: + throw new UnsupportedOperationException("Unknown type: " + regressionType); + } + } + + // 其他必须实现的接口方法 + @Override + public void removeIntermediate(Column[] input) { + throw new UnsupportedOperationException(); + } + + @Override + public void addStatistics(Statistics statistics) { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void setFinal(Column finalResult) {} + + @Override + public void reset() { + count = 0; + meanX = 0; + meanY = 0; + m2X = 0; + c2 = 0; + } + + @Override + public boolean hasFinalResult() { + return false; + } + + @Override + public TSDataType[] getIntermediateType() { + return new TSDataType[] {TSDataType.TEXT}; + } + + @Override + public TSDataType getFinalType() { + return TSDataType.DOUBLE; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/slidingwindow/SlidingWindowAggregatorFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/slidingwindow/SlidingWindowAggregatorFactory.java index 38e072da41705..84419a11bedf7 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/slidingwindow/SlidingWindowAggregatorFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/slidingwindow/SlidingWindowAggregatorFactory.java @@ -203,6 +203,8 @@ public static SlidingWindowAggregator createSlidingWindowAggregator( case CORR: case COVAR_POP: case COVAR_SAMP: + case REGR_SLOPE: + case REGR_INTERCEPT: case UDAF: // Currently UDAF belongs to SmoothQueueSlidingWindowAggregator return new SmoothQueueSlidingWindowAggregator(accumulator, inputLocationList, step); case MAX_VALUE: diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java index a8715bfade117..eeedb97a28f43 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java @@ -22,6 +22,7 @@ import org.apache.iotdb.common.rpc.thrift.TAggregationType; import org.apache.iotdb.commons.udf.utils.UDFDataTypeTransformer; import org.apache.iotdb.db.queryengine.execution.aggregation.CorrelationAccumulator; +import org.apache.iotdb.db.queryengine.execution.aggregation.RegressionAccumulator; import org.apache.iotdb.db.queryengine.execution.aggregation.VarianceAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.BinaryGroupedApproxMostFrequentAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.BlobGroupedApproxMostFrequentAccumulator; @@ -45,6 +46,7 @@ import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedMinAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedMinByAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedModeAccumulator; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedRegressionAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedSumAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedUserDefinedAggregateAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedVarianceAccumulator; @@ -273,6 +275,16 @@ private static GroupedAccumulator createBuiltinGroupedAccumulator( inputDataTypes.get(0), inputDataTypes.get(1), CorrelationAccumulator.CorrelationType.COVAR_SAMP); + case REGR_SLOPE: + return new GroupedRegressionAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + RegressionAccumulator.RegressionType.REGR_SLOPE); + case REGR_INTERCEPT: + return new GroupedRegressionAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + RegressionAccumulator.RegressionType.REGR_INTERCEPT); default: throw new IllegalArgumentException("Invalid Aggregation function: " + aggregationType); } @@ -357,6 +369,16 @@ public static TableAccumulator createBuiltinAccumulator( inputDataTypes.get(0), inputDataTypes.get(1), CorrelationAccumulator.CorrelationType.COVAR_SAMP); + case REGR_SLOPE: + return new TableRegressionAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + RegressionAccumulator.RegressionType.REGR_SLOPE); + case REGR_INTERCEPT: + return new TableRegressionAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + RegressionAccumulator.RegressionType.REGR_INTERCEPT); default: throw new IllegalArgumentException("Invalid Aggregation function: " + aggregationType); } @@ -420,6 +442,8 @@ public static boolean isMultiInputAggregation(TAggregationType aggregationType) case CORR: case COVAR_POP: case COVAR_SAMP: + case REGR_SLOPE: + case REGR_INTERCEPT: return true; default: return false; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableRegressionAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableRegressionAccumulator.java new file mode 100644 index 0000000000000..e0dd15b37438b --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableRegressionAccumulator.java @@ -0,0 +1,230 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation; + +import org.apache.iotdb.db.queryengine.execution.aggregation.RegressionAccumulator; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.statistics.Statistics; +import org.apache.tsfile.read.common.block.column.BinaryColumn; +import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.RamUsageEstimator; +import org.apache.tsfile.write.UnSupportedDataTypeException; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class TableRegressionAccumulator implements TableAccumulator { + + private static final long INSTANCE_SIZE = + RamUsageEstimator.shallowSizeOfInstance(TableRegressionAccumulator.class); + + private final TSDataType yDataType; + private final TSDataType xDataType; + private final RegressionAccumulator.RegressionType regressionType; + + private long count; + private double meanX; + private double meanY; + private double m2X; + private double c2; + + public TableRegressionAccumulator( + TSDataType yDataType, + TSDataType xDataType, + RegressionAccumulator.RegressionType regressionType) { + this.yDataType = yDataType; + this.xDataType = xDataType; + this.regressionType = regressionType; + } + + @Override + public long getEstimatedSize() { + return INSTANCE_SIZE; + } + + @Override + public TableAccumulator copy() { + return new TableRegressionAccumulator(yDataType, xDataType, regressionType); + } + + @Override + public void addInput(Column[] arguments, AggregationMask mask) { + // arguments[0] -> Y, arguments[1] -> X + int positionCount = mask.getSelectedPositionCount(); + + if (mask.isSelectAll()) { + for (int i = 0; i < positionCount; i++) { + if (arguments[0].isNull(i) || arguments[1].isNull(i)) { + continue; + } + double y = getDoubleValue(arguments[0], i, yDataType); + double x = getDoubleValue(arguments[1], i, xDataType); + update(x, y); + } + } else { + int[] selectedPositions = mask.getSelectedPositions(); + for (int i = 0; i < positionCount; i++) { + int position = selectedPositions[i]; + if (arguments[0].isNull(position) || arguments[1].isNull(position)) { + continue; + } + double y = getDoubleValue(arguments[0], position, yDataType); + double x = getDoubleValue(arguments[1], position, xDataType); + update(x, y); + } + } + } + + private double getDoubleValue(Column column, int position, TSDataType dataType) { + switch (dataType) { + case INT32: + case DATE: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new UnSupportedDataTypeException( + String.format("Unsupported data type in Regression Aggregation: %s", dataType)); + } + } + + private void update(double x, double y) { + long newCount = count + 1; + double deltaX = x - meanX; + double deltaY = y - meanY; + meanX += deltaX / newCount; + meanY += deltaY / newCount; + c2 += deltaX * (y - meanY); + m2X += deltaX * (x - meanX); + count = newCount; + } + + @Override + public void addIntermediate(Column argument) { + checkArgument( + argument instanceof BinaryColumn + || (argument instanceof RunLengthEncodedColumn + && ((RunLengthEncodedColumn) argument).getValue() instanceof BinaryColumn)); + + for (int i = 0; i < argument.getPositionCount(); i++) { + if (argument.isNull(i)) continue; + byte[] bytes = argument.getBinary(i).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMeanX = buffer.getDouble(); + double otherMeanY = buffer.getDouble(); + double otherM2X = buffer.getDouble(); + double otherC2 = buffer.getDouble(); + + merge(otherCount, otherMeanX, otherMeanY, otherM2X, otherC2); + } + } + + private void merge( + long otherCount, double otherMeanX, double otherMeanY, double otherM2X, double otherC2) { + if (otherCount == 0) return; + if (count == 0) { + count = otherCount; + meanX = otherMeanX; + meanY = otherMeanY; + m2X = otherM2X; + c2 = otherC2; + } else { + long newCount = count + otherCount; + double deltaX = otherMeanX - meanX; + double deltaY = otherMeanY - meanY; + c2 += otherC2 + deltaX * deltaY * count * otherCount / newCount; + m2X += otherM2X + deltaX * deltaX * count * otherCount / newCount; + meanX += deltaX * otherCount / newCount; + meanY += deltaY * otherCount / newCount; + count = newCount; + } + } + + @Override + public void evaluateIntermediate(ColumnBuilder columnBuilder) { + if (count == 0) { + columnBuilder.appendNull(); + } else { + byte[] bytes = new byte[40]; + ByteBuffer buffer = ByteBuffer.wrap(bytes); + buffer.putLong(count); + buffer.putDouble(meanX); + buffer.putDouble(meanY); + buffer.putDouble(m2X); + buffer.putDouble(c2); + columnBuilder.writeBinary(new Binary(bytes)); + } + } + + @Override + public void evaluateFinal(ColumnBuilder columnBuilder) { + if (count == 0 || m2X == 0) { + columnBuilder.appendNull(); + return; + } + double slope = c2 / m2X; + switch (regressionType) { + case REGR_SLOPE: + columnBuilder.writeDouble(slope); + break; + case REGR_INTERCEPT: + columnBuilder.writeDouble(meanY - slope * meanX); + break; + default: + throw new UnsupportedOperationException("Unknown type: " + regressionType); + } + } + + @Override + public void removeInput(Column[] arguments) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean hasFinalResult() { + return false; + } + + @Override + public void addStatistics(Statistics[] statistics) { + throw new UnsupportedOperationException(); + } + + @Override + public void reset() { + count = 0; + meanX = 0; + meanY = 0; + m2X = 0; + c2 = 0; + } + + @Override + public boolean removable() { + return false; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedRegressionAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedRegressionAccumulator.java new file mode 100644 index 0000000000000..d5cd7f9ce4307 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedRegressionAccumulator.java @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped; + +import org.apache.iotdb.db.queryengine.execution.aggregation.RegressionAccumulator; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.AggregationMask; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.DoubleBigArray; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.LongBigArray; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.read.common.block.column.BinaryColumn; +import org.apache.tsfile.read.common.block.column.BinaryColumnBuilder; +import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.RamUsageEstimator; +import org.apache.tsfile.write.UnSupportedDataTypeException; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class GroupedRegressionAccumulator implements GroupedAccumulator { + + private static final long INSTANCE_SIZE = + RamUsageEstimator.shallowSizeOfInstance(GroupedRegressionAccumulator.class); + + private final TSDataType yDataType; + private final TSDataType xDataType; + private final RegressionAccumulator.RegressionType regressionType; + + private final LongBigArray counts = new LongBigArray(); + private final DoubleBigArray meanXs = new DoubleBigArray(); + private final DoubleBigArray meanYs = new DoubleBigArray(); + private final DoubleBigArray m2Xs = new DoubleBigArray(); + private final DoubleBigArray c2s = new DoubleBigArray(); + + public GroupedRegressionAccumulator( + TSDataType yDataType, + TSDataType xDataType, + RegressionAccumulator.RegressionType regressionType) { + this.yDataType = yDataType; + this.xDataType = xDataType; + this.regressionType = regressionType; + } + + @Override + public long getEstimatedSize() { + return INSTANCE_SIZE + + counts.sizeOf() + + meanXs.sizeOf() + + meanYs.sizeOf() + + m2Xs.sizeOf() + + c2s.sizeOf(); + } + + @Override + public void setGroupCount(long groupCount) { + counts.ensureCapacity(groupCount); + meanXs.ensureCapacity(groupCount); + meanYs.ensureCapacity(groupCount); + m2Xs.ensureCapacity(groupCount); + c2s.ensureCapacity(groupCount); + } + + @Override + public void addInput(int[] groupIds, Column[] arguments, AggregationMask mask) { + // arguments[0] -> Y, arguments[1] -> X + int positionCount = mask.getSelectedPositionCount(); + + if (mask.isSelectAll()) { + for (int i = 0; i < positionCount; i++) { + if (arguments[0].isNull(i) || arguments[1].isNull(i)) { + continue; + } + double y = getDoubleValue(arguments[0], i, yDataType); + double x = getDoubleValue(arguments[1], i, xDataType); + update(groupIds[i], x, y); + } + } else { + int[] selectedPositions = mask.getSelectedPositions(); + for (int i = 0; i < positionCount; i++) { + int position = selectedPositions[i]; + if (arguments[0].isNull(position) || arguments[1].isNull(position)) { + continue; + } + double y = getDoubleValue(arguments[0], position, yDataType); + double x = getDoubleValue(arguments[1], position, xDataType); + update(groupIds[position], x, y); + } + } + } + + private double getDoubleValue(Column column, int position, TSDataType dataType) { + switch (dataType) { + case INT32: + case DATE: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new UnSupportedDataTypeException( + String.format("Unsupported data type in Regression Aggregation: %s", dataType)); + } + } + + private void update(int groupId, double x, double y) { + long newCount = counts.get(groupId) + 1; + double deltaX = x - meanXs.get(groupId); + double deltaY = y - meanYs.get(groupId); + + meanXs.add(groupId, deltaX / newCount); + meanYs.add(groupId, deltaY / newCount); + + // Welford's algorithm for covariance and variance of X + c2s.add(groupId, deltaX * (y - meanYs.get(groupId))); + m2Xs.add(groupId, deltaX * (x - meanXs.get(groupId))); + + counts.set(groupId, newCount); + } + + @Override + public void addIntermediate(int[] groupIds, Column argument) { + checkArgument( + argument instanceof BinaryColumn + || (argument instanceof RunLengthEncodedColumn + && ((RunLengthEncodedColumn) argument).getValue() instanceof BinaryColumn), + "intermediate input and output should be BinaryColumn"); + + for (int i = 0; i < argument.getPositionCount(); i++) { + if (argument.isNull(i)) { + continue; + } + + byte[] bytes = argument.getBinary(i).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMeanX = buffer.getDouble(); + double otherMeanY = buffer.getDouble(); + double otherM2X = buffer.getDouble(); + double otherC2 = buffer.getDouble(); + + merge(groupIds[i], otherCount, otherMeanX, otherMeanY, otherM2X, otherC2); + } + } + + private void merge( + int groupId, + long otherCount, + double otherMeanX, + double otherMeanY, + double otherM2X, + double otherC2) { + if (otherCount == 0) { + return; + } + if (counts.get(groupId) == 0) { + counts.set(groupId, otherCount); + meanXs.set(groupId, otherMeanX); + meanYs.set(groupId, otherMeanY); + m2Xs.set(groupId, otherM2X); + c2s.set(groupId, otherC2); + } else { + long newCount = counts.get(groupId) + otherCount; + double deltaX = otherMeanX - meanXs.get(groupId); + double deltaY = otherMeanY - meanYs.get(groupId); + + c2s.add(groupId, otherC2 + deltaX * deltaY * counts.get(groupId) * otherCount / newCount); + m2Xs.add(groupId, otherM2X + deltaX * deltaX * counts.get(groupId) * otherCount / newCount); + + meanXs.add(groupId, deltaX * otherCount / newCount); + meanYs.add(groupId, deltaY * otherCount / newCount); + counts.set(groupId, newCount); + } + } + + @Override + public void evaluateIntermediate(int groupId, ColumnBuilder columnBuilder) { + checkArgument( + columnBuilder instanceof BinaryColumnBuilder, + "intermediate input and output should be BinaryColumn"); + + if (counts.get(groupId) == 0) { + columnBuilder.appendNull(); + } else { + ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES + Double.BYTES * 4); + buffer.putLong(counts.get(groupId)); + buffer.putDouble(meanXs.get(groupId)); + buffer.putDouble(meanYs.get(groupId)); + buffer.putDouble(m2Xs.get(groupId)); + buffer.putDouble(c2s.get(groupId)); + columnBuilder.writeBinary(new Binary(buffer.array())); + } + } + + @Override + public void evaluateFinal(int groupId, ColumnBuilder columnBuilder) { + if (counts.get(groupId) == 0 || m2Xs.get(groupId) == 0) { + columnBuilder.appendNull(); + return; + } + double slope = c2s.get(groupId) / m2Xs.get(groupId); + switch (regressionType) { + case REGR_SLOPE: + columnBuilder.writeDouble(slope); + break; + case REGR_INTERCEPT: + columnBuilder.writeDouble(meanYs.get(groupId) - slope * meanXs.get(groupId)); + break; + default: + throw new UnsupportedOperationException("Unknown type: " + regressionType); + } + } + + @Override + public void prepareFinal() {} + + @Override + public void reset() { + counts.reset(); + meanXs.reset(); + meanYs.reset(); + m2Xs.reset(); + c2s.reset(); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java index 796a5265c0b3f..c2f1de149779f 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java @@ -365,7 +365,9 @@ public TSDataType visitFunctionExpression( String funcName = functionExpression.getFunctionName().toLowerCase(); if (funcName.equals(SqlConstant.CORR) || funcName.equals(SqlConstant.COVAR_POP) - || funcName.equals(SqlConstant.COVAR_SAMP)) { + || funcName.equals(SqlConstant.COVAR_SAMP) + || funcName.equals(SqlConstant.REGR_SLOPE) + || funcName.equals(SqlConstant.REGR_INTERCEPT)) { // Check both input parameters are numeric or timestamp if (inputExpressions.size() >= 1) { TSDataType firstInputType = expressionTypes.get(NodeRef.of(inputExpressions.get(0))); @@ -575,6 +577,8 @@ private TSDataType getInputExpressionTypeForAggregation( case SqlConstant.CORR: case SqlConstant.COVAR_POP: case SqlConstant.COVAR_SAMP: + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: case SqlConstant.MAX_BY: case SqlConstant.MIN_BY: return expressionTypes.get(NodeRef.of(inputExpressions.get(0))); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java index 86bcdd5c803d9..a469df7a22fe5 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java @@ -3197,6 +3197,8 @@ private void checkAggregationFunctionInput(FunctionExpression functionExpression case SqlConstant.CORR: case SqlConstant.COVAR_POP: case SqlConstant.COVAR_SAMP: + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: checkFunctionExpressionInputSize( functionExpression.getExpressionString(), functionExpression.getExpressions().size(), diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/AggregationDescriptor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/AggregationDescriptor.java index e3737e7cbeb6c..306109eaddae4 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/AggregationDescriptor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/AggregationDescriptor.java @@ -196,6 +196,12 @@ public List getActualAggregationNames(boolean isPartial) { case COVAR_SAMP: outputAggregationNames.add(addPartialSuffix(SqlConstant.COVAR_SAMP)); break; + case REGR_SLOPE: + outputAggregationNames.add(addPartialSuffix(SqlConstant.REGR_SLOPE)); + break; + case REGR_INTERCEPT: + outputAggregationNames.add(addPartialSuffix(SqlConstant.REGR_INTERCEPT)); + break; case MAX_BY: outputAggregationNames.add(addPartialSuffix(SqlConstant.MAX_BY)); break; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java index 1df371c9a2141..6ea01b136f802 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java @@ -1072,7 +1072,9 @@ protected Type visitFunctionCall( String lowerFuncName = functionName.toLowerCase(); if (lowerFuncName.equals("corr") || lowerFuncName.equals("covar_pop") - || lowerFuncName.equals("covar_samp")) { + || lowerFuncName.equals("covar_samp") + || lowerFuncName.equals("regr_slope") + || lowerFuncName.equals("regr_intercept")) { if (argumentTypes.size() != 2) { throw new SemanticException( String.format( diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java index 6d481dd6662b2..ddee5f5f39a1b 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java @@ -663,7 +663,21 @@ && isIntegerNumber(argumentTypes.get(2)))) { "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", functionName.toUpperCase())); } - + break; + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: + // Argument count is already checked in ExpressionAnalyzer + if (!isSupportedMathNumericType(argumentTypes.get(0))) { + throw new SemanticException( + String.format( + "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", + functionName.toUpperCase())); + } else if (!isSupportedMathNumericType(argumentTypes.get(1))) { + throw new SemanticException( + String.format( + "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", + functionName.toUpperCase())); + } break; case SqlConstant.APPROX_COUNT_DISTINCT: if (argumentTypes.size() != 1 && argumentTypes.size() != 2) { @@ -721,6 +735,8 @@ && isIntegerNumber(argumentTypes.get(2)))) { case SqlConstant.CORR: case SqlConstant.COVAR_POP: case SqlConstant.COVAR_SAMP: + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: return DOUBLE; case SqlConstant.APPROX_MOST_FREQUENT: return STRING; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/SchemaUtils.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/SchemaUtils.java index 27ab21580fa8a..6ebc44a5c41b0 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/SchemaUtils.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/SchemaUtils.java @@ -91,6 +91,8 @@ public static TSDataType getBuiltinAggregationTypeByFuncName(String aggregation) case SqlConstant.CORR: case SqlConstant.COVAR_POP: case SqlConstant.COVAR_SAMP: + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: return TSDataType.DOUBLE; // Partial aggregation names case SqlConstant.STDDEV + "_partial": @@ -102,6 +104,8 @@ public static TSDataType getBuiltinAggregationTypeByFuncName(String aggregation) case SqlConstant.CORR + "_partial": case SqlConstant.COVAR_POP + "_partial": case SqlConstant.COVAR_SAMP + "_partial": + case SqlConstant.REGR_SLOPE + "_partial": + case SqlConstant.REGR_INTERCEPT + "_partial": case SqlConstant.MAX_BY + "_partial": case SqlConstant.MIN_BY + "_partial": return TSDataType.TEXT; @@ -175,6 +179,10 @@ public static String getBuiltinAggregationName(TAggregationType aggregationType) return SqlConstant.COVAR_POP; case COVAR_SAMP: return SqlConstant.COVAR_SAMP; + case REGR_SLOPE: + return SqlConstant.REGR_SLOPE; + case REGR_INTERCEPT: + return SqlConstant.REGR_INTERCEPT; default: return null; } @@ -213,6 +221,8 @@ public static boolean isConsistentWithScanOrder( case CORR: case COVAR_POP: case COVAR_SAMP: + case REGR_SLOPE: + case REGR_INTERCEPT: case UDAF: return true; default: @@ -253,6 +263,10 @@ public static List splitPartialBuiltinAggregation(TAggregationType aggre return Collections.singletonList(addPartialSuffix(SqlConstant.COVAR_POP)); case COVAR_SAMP: return Collections.singletonList(addPartialSuffix(SqlConstant.COVAR_SAMP)); + case REGR_SLOPE: + return Collections.singletonList(addPartialSuffix(SqlConstant.REGR_SLOPE)); + case REGR_INTERCEPT: + return Collections.singletonList(addPartialSuffix(SqlConstant.REGR_INTERCEPT)); case MAX_BY: return Collections.singletonList(addPartialSuffix(SqlConstant.MAX_BY)); case MIN_BY: diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java index 18bba010e8ad8..b5b21126b8a79 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java @@ -157,6 +157,8 @@ public static TSDataType getBuiltinAggregationDataType( case SqlConstant.CORR: case SqlConstant.COVAR_POP: case SqlConstant.COVAR_SAMP: + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: return TSDataType.DOUBLE; default: throw new IllegalArgumentException( @@ -205,6 +207,14 @@ private static void verifyIsAggregationDataTypeMatched(String aggrFuncName, TSDa throw new SemanticException( "Aggregate functions [CORR, COVAR_POP, COVAR_SAMP] only support " + "numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]"); + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: + if (dataType.isNumeric() || TSDataType.TIMESTAMP.equals(dataType)) { + return; + } + throw new SemanticException( + "Aggregate functions [REGR_SLOPE, REGR_INTERCEPT] only support " + + "numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]"); case SqlConstant.COUNT: case SqlConstant.COUNT_TIME: case SqlConstant.MIN_TIME: @@ -262,6 +272,8 @@ public static void bindTypeForBuiltinAggregationNonSeriesInputExpressions( case SqlConstant.CORR: case SqlConstant.COVAR_POP: case SqlConstant.COVAR_SAMP: + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: case SqlConstant.MAX_BY: case SqlConstant.MIN_BY: return; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/SqlConstant.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/SqlConstant.java index 49180bb13e2c9..1e2a90b41a661 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/SqlConstant.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/SqlConstant.java @@ -77,6 +77,8 @@ protected SqlConstant() { public static final String CORR = "corr"; public static final String COVAR_POP = "covar_pop"; public static final String COVAR_SAMP = "covar_samp"; + public static final String REGR_SLOPE = "regr_slope"; + public static final String REGR_INTERCEPT = "regr_intercept"; public static final String COUNT_TIME = "count_time"; public static final String COUNT_TIME_HEADER = "count_time(*)"; diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinAggregationFunction.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinAggregationFunction.java index 0f25a1ad6e282..0cef02dd1dafc 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinAggregationFunction.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinAggregationFunction.java @@ -49,7 +49,9 @@ public enum BuiltinAggregationFunction { MIN_BY("min_by"), CORR("corr"), COVAR_POP("covar_pop"), - COVAR_SAMP("covar_samp"); + COVAR_SAMP("covar_samp"), + REGR_SLOPE("regr_slope"), + REGR_INTERCEPT("regr_intercept"); private final String functionName; @@ -103,6 +105,8 @@ public static boolean canUseStatistics(String name) { case "corr": case "covar_pop": case "covar_samp": + case "regr_slope": + case "regr_intercept": return false; default: throw new IllegalArgumentException("Invalid Aggregation function: " + name); @@ -140,6 +144,8 @@ public static boolean canSplitToMultiPhases(String name) { case "corr": case "covar_pop": case "covar_samp": + case "regr_slope": + case "regr_intercept": return true; case "count_if": case "count_time": diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java index 69320e2aa00e5..71bba543449b6 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java @@ -61,7 +61,9 @@ public enum TableBuiltinAggregationFunction { APPROX_MOST_FREQUENT("approx_most_frequent"), CORR("corr"), COVAR_POP("covar_pop"), - COVAR_SAMP("covar_samp"); + COVAR_SAMP("covar_samp"), + REGR_SLOPE("regr_slope"), + REGR_INTERCEPT("regr_intercept"); private final String functionName; @@ -109,6 +111,8 @@ public static Type getIntermediateType(String name, List originalArgumentT case "corr": case "covar_pop": case "covar_samp": + case "regr_slope": + case "regr_intercept": case "approx_count_distinct": return RowType.anonymous(Collections.emptyList()); case "extreme": diff --git a/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift b/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift index 214d4a6dce4ec..659ed7938486f 100644 --- a/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift +++ b/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift @@ -296,7 +296,9 @@ enum TAggregationType { APPROX_MOST_FREQUENT, CORR, COVAR_POP, - COVAR_SAMP + COVAR_SAMP, + REGR_SLOPE, + REGR_INTERCEPT } struct TShowConfigurationTemplateResp { From e71fbfb2afec5ba40df5b2b938dd93762e285adc Mon Sep 17 00:00:00 2001 From: Cool6689 <3322351820@qq.com> Date: Wed, 11 Mar 2026 16:09:11 +0800 Subject: [PATCH 4/7] Add skewness and kurtosis aggregate functions with input validation --- .../BuiltinAggregationFunctionEnum.java | 4 +- .../aggregation/AccumulatorFactory.java | 6 + .../aggregation/CentralMomentAccumulator.java | 249 ++++++++++++++++ .../SlidingWindowAggregatorFactory.java | 2 + .../aggregation/AccumulatorFactory.java | 14 + .../TableCentralMomentAccumulator.java | 254 ++++++++++++++++ .../GroupedCentralMomentAccumulator.java | 278 ++++++++++++++++++ .../plan/analyze/ExpressionTypeAnalyzer.java | 15 + .../queryengine/plan/parser/ASTVisitor.java | 2 + .../plan/parameter/AggregationDescriptor.java | 6 + .../analyzer/ExpressionAnalyzer.java | 8 + .../metadata/TableMetadataImpl.java | 12 + .../apache/iotdb/db/utils/SchemaUtils.java | 14 + .../iotdb/db/utils/TypeInferenceUtils.java | 12 + .../iotdb/db/utils/constant/SqlConstant.java | 2 + .../builtin/BuiltinAggregationFunction.java | 8 +- .../TableBuiltinAggregationFunction.java | 6 +- .../src/main/thrift/common.thrift | 4 +- 18 files changed, 892 insertions(+), 4 deletions(-) create mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CentralMomentAccumulator.java create mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCentralMomentAccumulator.java create mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCentralMomentAccumulator.java diff --git a/integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinAggregationFunctionEnum.java b/integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinAggregationFunctionEnum.java index 0e758b39f9b53..5ebb8e12f9be0 100644 --- a/integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinAggregationFunctionEnum.java +++ b/integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinAggregationFunctionEnum.java @@ -47,7 +47,9 @@ public enum BuiltinAggregationFunctionEnum { COVAR_POP("covar_pop"), COVAR_SAMP("covar_samp"), REGR_SLOPE("regr_slope"), - REGR_INTERCEPT("regr_intercept"); + REGR_INTERCEPT("regr_intercept"), + SKEWNESS("skewness"), + KURTOSIS("kurtosis"); private final String functionName; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/AccumulatorFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/AccumulatorFactory.java index 12f45494ed057..90ec90f81acf4 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/AccumulatorFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/AccumulatorFactory.java @@ -170,6 +170,12 @@ private static Accumulator createBuiltinSingleInputAccumulator( return new VarianceAccumulator(tsDataType, VarianceAccumulator.VarianceType.VAR_SAMP); case VAR_POP: return new VarianceAccumulator(tsDataType, VarianceAccumulator.VarianceType.VAR_POP); + case SKEWNESS: + return new CentralMomentAccumulator( + tsDataType, CentralMomentAccumulator.MomentType.SKEWNESS); + case KURTOSIS: + return new CentralMomentAccumulator( + tsDataType, CentralMomentAccumulator.MomentType.KURTOSIS); default: throw new IllegalArgumentException("Invalid Aggregation function: " + aggregationType); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CentralMomentAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CentralMomentAccumulator.java new file mode 100644 index 0000000000000..bcb732db918f7 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CentralMomentAccumulator.java @@ -0,0 +1,249 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.aggregation; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.statistics.Statistics; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.BitMap; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class CentralMomentAccumulator implements Accumulator { + + public enum MomentType { + SKEWNESS, + KURTOSIS + } + + private final TSDataType seriesDataType; + private final MomentType momentType; + + // State variables: count, mean, M2, M3, M4 + private long count; + private double mean; + private double m2; + private double m3; + private double m4; + + public CentralMomentAccumulator(TSDataType seriesDataType, MomentType momentType) { + this.seriesDataType = seriesDataType; + this.momentType = momentType; + } + + @Override + public void addInput(Column[] columns, BitMap bitMap) { + // Tree model: columns[0] is Time, columns[1] is data + int size = columns[1].getPositionCount(); + for (int i = 0; i < size; i++) { + if (bitMap != null && !bitMap.isMarked(i)) { + continue; + } + if (columns[1].isNull(i)) { + continue; + } + update(getDoubleValue(columns[1], i)); + } + } + + private double getDoubleValue(Column column, int position) { + switch (seriesDataType) { + case INT32: + case DATE: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new UnsupportedOperationException( + "Unsupported data type in CentralMoment Aggregation: " + seriesDataType); + } + } + + private void update(double value) { + long n1 = count; + count++; + + double delta = value - mean; + double delta_n = delta / count; + double delta_n2 = delta_n * delta_n; + double term1 = delta * delta_n * n1; + + mean += delta_n; + + // 更新 M4 (顺序很重要,必须在更新 M3, M2 之前) + m4 += term1 * delta_n2 * (count * count - 3 * count + 3) + 6 * delta_n2 * m2 - 4 * delta_n * m3; + + // 更新 M3 + m3 += term1 * delta_n * (count - 2) - 3 * delta_n * m2; + + // 更新 M2 + m2 += term1; + } + + @Override + public void addIntermediate(Column[] partialResult) { + checkArgument(partialResult.length == 1, "partialResult of CentralMoment should be 1"); + if (partialResult[0].isNull(0)) { + return; + } + byte[] bytes = partialResult[0].getBinary(0).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMean = buffer.getDouble(); + double otherM2 = buffer.getDouble(); + double otherM3 = buffer.getDouble(); + double otherM4 = buffer.getDouble(); + + merge(otherCount, otherMean, otherM2, otherM3, otherM4); + } + + private void merge(long nB, double meanB, double m2B, double m3B, double m4B) { + if (nB == 0) return; + if (count == 0) { + count = nB; + mean = meanB; + m2 = m2B; + m3 = m3B; + m4 = m4B; + } else { + long nA = count; + long nTotal = nA + nB; + double delta = meanB - mean; + double delta2 = delta * delta; + double delta3 = delta * delta2; + double delta4 = delta2 * delta2; + + // 合并公式 (Chan et al.) + // M4 合并 + m4 += + m4B + + delta4 * nA * nB * (nA * nA - nA * nB + nB * nB) / (nTotal * nTotal * nTotal) + + 6.0 * delta2 * (nA * nA * m2B + nB * nB * m2) / (nTotal * nTotal) + + 4.0 * delta * (nA * m3B - nB * m3) / nTotal; + + // M3 合并 + m3 += + m3B + + delta3 * nA * nB * (nA - nB) / (nTotal * nTotal) + + 3.0 * delta * (nA * m2B - nB * m2) / nTotal; + + // M2 合并 + m2 += m2B + delta2 * nA * nB / nTotal; + + // Mean 合并 + mean += delta * nB / nTotal; + count = nTotal; + } + } + + @Override + public void outputIntermediate(ColumnBuilder[] columnBuilders) { + checkArgument(columnBuilders.length == 1, "partialResult should be 1"); + if (count == 0) { + columnBuilders[0].appendNull(); + } else { + // 序列化: long + 4 * double = 40 bytes + byte[] bytes = new byte[40]; + ByteBuffer buffer = ByteBuffer.wrap(bytes); + buffer.putLong(count); + buffer.putDouble(mean); + buffer.putDouble(m2); + buffer.putDouble(m3); + buffer.putDouble(m4); + columnBuilders[0].writeBinary(new Binary(bytes)); + } + } + + @Override + public void outputFinal(ColumnBuilder columnBuilder) { + if (count == 0 || m2 == 0) { // 方差为0或无数据 + columnBuilder.appendNull(); + return; + } + + if (momentType == MomentType.SKEWNESS) { + if (count < 3) { // 偏度要求 N >= 3 + columnBuilder.appendNull(); + } else { + // 无偏估计公式: (N * M3) / ((N-1)*(N-2) * sigma^3) + // sigma = sqrt(M2 / (N-1)) + double variance = m2 / (count - 1); + double stdev = Math.sqrt(variance); + double result = (count * m3) / ((count - 1) * (count - 2) * stdev * stdev * stdev); + columnBuilder.writeDouble(result); + } + } else { // KURTOSIS + if (count < 4) { // 峰度要求 N >= 4 + columnBuilder.appendNull(); + } else { + // 无偏估计公式 (超额峰度 Excess Kurtosis) + double variance = m2 / (count - 1); + double term1 = + (count * (count + 1) * m4) + / ((count - 1) * (count - 2) * (count - 3) * variance * variance); + double term2 = (3 * Math.pow(count - 1, 2)) / ((count - 2) * (count - 3)); + columnBuilder.writeDouble(term1 - term2); + } + } + } + + // 默认实现 + @Override + public void removeIntermediate(Column[] input) { + throw new UnsupportedOperationException(); + } + + @Override + public void addStatistics(Statistics statistics) { + throw new UnsupportedOperationException(); + } + + @Override + public void setFinal(Column finalResult) {} + + @Override + public void reset() { + count = 0; + mean = 0; + m2 = 0; + m3 = 0; + m4 = 0; + } + + @Override + public boolean hasFinalResult() { + return false; + } + + @Override + public TSDataType[] getIntermediateType() { + return new TSDataType[] {TSDataType.TEXT}; + } + + @Override + public TSDataType getFinalType() { + return TSDataType.DOUBLE; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/slidingwindow/SlidingWindowAggregatorFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/slidingwindow/SlidingWindowAggregatorFactory.java index 84419a11bedf7..a3ca7212c8bbc 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/slidingwindow/SlidingWindowAggregatorFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/slidingwindow/SlidingWindowAggregatorFactory.java @@ -205,6 +205,8 @@ public static SlidingWindowAggregator createSlidingWindowAggregator( case COVAR_SAMP: case REGR_SLOPE: case REGR_INTERCEPT: + case SKEWNESS: + case KURTOSIS: case UDAF: // Currently UDAF belongs to SmoothQueueSlidingWindowAggregator return new SmoothQueueSlidingWindowAggregator(accumulator, inputLocationList, step); case MAX_VALUE: diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java index eeedb97a28f43..b84f948f0ccc3 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java @@ -21,6 +21,7 @@ import org.apache.iotdb.common.rpc.thrift.TAggregationType; import org.apache.iotdb.commons.udf.utils.UDFDataTypeTransformer; +import org.apache.iotdb.db.queryengine.execution.aggregation.CentralMomentAccumulator; import org.apache.iotdb.db.queryengine.execution.aggregation.CorrelationAccumulator; import org.apache.iotdb.db.queryengine.execution.aggregation.RegressionAccumulator; import org.apache.iotdb.db.queryengine.execution.aggregation.VarianceAccumulator; @@ -32,6 +33,7 @@ import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedApproxCountDistinctAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedAvgAccumulator; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedCentralMomentAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedCorrelationAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedCountAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedCountAllAccumulator; @@ -285,6 +287,12 @@ private static GroupedAccumulator createBuiltinGroupedAccumulator( inputDataTypes.get(0), inputDataTypes.get(1), RegressionAccumulator.RegressionType.REGR_INTERCEPT); + case SKEWNESS: + return new GroupedCentralMomentAccumulator( + inputDataTypes.get(0), CentralMomentAccumulator.MomentType.SKEWNESS); + case KURTOSIS: + return new GroupedCentralMomentAccumulator( + inputDataTypes.get(0), CentralMomentAccumulator.MomentType.KURTOSIS); default: throw new IllegalArgumentException("Invalid Aggregation function: " + aggregationType); } @@ -379,6 +387,12 @@ public static TableAccumulator createBuiltinAccumulator( inputDataTypes.get(0), inputDataTypes.get(1), RegressionAccumulator.RegressionType.REGR_INTERCEPT); + case SKEWNESS: + return new TableCentralMomentAccumulator( + inputDataTypes.get(0), CentralMomentAccumulator.MomentType.SKEWNESS); + case KURTOSIS: + return new TableCentralMomentAccumulator( + inputDataTypes.get(0), CentralMomentAccumulator.MomentType.KURTOSIS); default: throw new IllegalArgumentException("Invalid Aggregation function: " + aggregationType); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCentralMomentAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCentralMomentAccumulator.java new file mode 100644 index 0000000000000..97ff93fce45a3 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCentralMomentAccumulator.java @@ -0,0 +1,254 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation; + +import org.apache.iotdb.db.queryengine.execution.aggregation.CentralMomentAccumulator; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.statistics.Statistics; +import org.apache.tsfile.read.common.block.column.BinaryColumn; +import org.apache.tsfile.read.common.block.column.BinaryColumnBuilder; +import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.RamUsageEstimator; +import org.apache.tsfile.write.UnSupportedDataTypeException; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class TableCentralMomentAccumulator implements TableAccumulator { + + private static final long INSTANCE_SIZE = + RamUsageEstimator.shallowSizeOfInstance(TableCentralMomentAccumulator.class); + + private final TSDataType seriesDataType; + private final CentralMomentAccumulator.MomentType momentType; + + // State + private long count; + private double mean; + private double m2; + private double m3; + private double m4; + + public TableCentralMomentAccumulator( + TSDataType seriesDataType, CentralMomentAccumulator.MomentType momentType) { + this.seriesDataType = seriesDataType; + this.momentType = momentType; + } + + @Override + public void addInput(Column[] arguments, AggregationMask mask) { + int positionCount = mask.getSelectedPositionCount(); + if (mask.isSelectAll()) { + for (int i = 0; i < positionCount; i++) { + if (!arguments[0].isNull(i)) { + update(getDoubleValue(arguments[0], i)); + } + } + } else { + int[] selectedPositions = mask.getSelectedPositions(); + for (int i = 0; i < positionCount; i++) { + int position = selectedPositions[i]; + if (!arguments[0].isNull(position)) { + update(getDoubleValue(arguments[0], position)); + } + } + } + } + + private double getDoubleValue(Column column, int position) { + switch (seriesDataType) { + case INT32: + case DATE: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new UnSupportedDataTypeException( + String.format( + "Unsupported data type in CentralMoment Aggregation: %s", seriesDataType)); + } + } + + private void update(double value) { + long n1 = count; + count++; + double delta = value - mean; + double delta_n = delta / count; + double delta_n2 = delta_n * delta_n; + double term1 = delta * delta_n * n1; + mean += delta_n; + m4 += term1 * delta_n2 * (count * count - 3 * count + 3) + 6 * delta_n2 * m2 - 4 * delta_n * m3; + m3 += term1 * delta_n * (count - 2) - 3 * delta_n * m2; + m2 += term1; + } + + @Override + public void addIntermediate(Column argument) { + checkArgument( + argument instanceof BinaryColumn + || (argument instanceof RunLengthEncodedColumn + && ((RunLengthEncodedColumn) argument).getValue() instanceof BinaryColumn), + "intermediate input and output should be BinaryColumn"); + + for (int i = 0; i < argument.getPositionCount(); i++) { + if (argument.isNull(i)) { + continue; + } + byte[] bytes = argument.getBinary(i).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMean = buffer.getDouble(); + double otherM2 = buffer.getDouble(); + double otherM3 = buffer.getDouble(); + double otherM4 = buffer.getDouble(); + + merge(otherCount, otherMean, otherM2, otherM3, otherM4); + } + } + + private void merge(long nB, double meanB, double m2B, double m3B, double m4B) { + if (nB == 0) return; + if (count == 0) { + count = nB; + mean = meanB; + m2 = m2B; + m3 = m3B; + m4 = m4B; + } else { + long nA = count; + long nTotal = nA + nB; + double delta = meanB - mean; + double delta2 = delta * delta; + double delta3 = delta * delta2; + double delta4 = delta2 * delta2; + + m4 += + m4B + + delta4 * nA * nB * (nA * nA - nA * nB + nB * nB) / (nTotal * nTotal * nTotal) + + 6.0 * delta2 * (nA * nA * m2B + nB * nB * m2) / (nTotal * nTotal) + + 4.0 * delta * (nA * m3B - nB * m3) / nTotal; + + m3 += + m3B + + delta3 * nA * nB * (nA - nB) / (nTotal * nTotal) + + 3.0 * delta * (nA * m2B - nB * m2) / nTotal; + + m2 += m2B + delta2 * nA * nB / nTotal; + + mean += delta * nB / nTotal; + count = nTotal; + } + } + + @Override + public void evaluateIntermediate(ColumnBuilder columnBuilder) { + checkArgument( + columnBuilder instanceof BinaryColumnBuilder, + "intermediate input and output should be BinaryColumn"); + + if (count == 0) { + columnBuilder.appendNull(); + } else { + // Serialize: long(8) + 4*double(32) = 40 bytes + ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES + Double.BYTES * 4); + buffer.putLong(count); + buffer.putDouble(mean); + buffer.putDouble(m2); + buffer.putDouble(m3); + buffer.putDouble(m4); + columnBuilder.writeBinary(new Binary(buffer.array())); + } + } + + @Override + public void evaluateFinal(ColumnBuilder columnBuilder) { + if (count == 0 || m2 == 0) { + columnBuilder.appendNull(); + return; + } + + if (momentType == CentralMomentAccumulator.MomentType.SKEWNESS) { + if (count < 3) { + columnBuilder.appendNull(); + } else { + double variance = m2 / (count - 1); + double stdev = Math.sqrt(variance); + double result = (count * m3) / ((count - 1) * (count - 2) * stdev * stdev * stdev); + columnBuilder.writeDouble(result); + } + } else { // KURTOSIS + if (count < 4) { + columnBuilder.appendNull(); + } else { + double variance = m2 / (count - 1); + double term1 = + (count * (count + 1) * m4) + / ((count - 1) * (count - 2) * (count - 3) * variance * variance); + double term2 = (3 * Math.pow(count - 1, 2)) / ((count - 2) * (count - 3)); + columnBuilder.writeDouble(term1 - term2); + } + } + } + + @Override + public long getEstimatedSize() { + return INSTANCE_SIZE; + } + + @Override + public TableAccumulator copy() { + return new TableCentralMomentAccumulator(seriesDataType, momentType); + } + + @Override + public void removeInput(Column[] arguments) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean hasFinalResult() { + return false; + } + + @Override + public void addStatistics(Statistics[] statistics) { + throw new UnsupportedOperationException(); + } + + @Override + public void reset() { + count = 0; + mean = 0; + m2 = 0; + m3 = 0; + m4 = 0; + } + + @Override + public boolean removable() { + return false; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCentralMomentAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCentralMomentAccumulator.java new file mode 100644 index 0000000000000..5674ca4510e2c --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCentralMomentAccumulator.java @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped; + +import org.apache.iotdb.db.queryengine.execution.aggregation.CentralMomentAccumulator; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.AggregationMask; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.DoubleBigArray; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.LongBigArray; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.read.common.block.column.BinaryColumn; +import org.apache.tsfile.read.common.block.column.BinaryColumnBuilder; +import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.RamUsageEstimator; +import org.apache.tsfile.write.UnSupportedDataTypeException; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class GroupedCentralMomentAccumulator implements GroupedAccumulator { + + private static final long INSTANCE_SIZE = + RamUsageEstimator.shallowSizeOfInstance(GroupedCentralMomentAccumulator.class); + + private final TSDataType seriesDataType; + private final CentralMomentAccumulator.MomentType momentType; + + private final LongBigArray counts = new LongBigArray(); + private final DoubleBigArray means = new DoubleBigArray(); + private final DoubleBigArray m2s = new DoubleBigArray(); + private final DoubleBigArray m3s = new DoubleBigArray(); + private final DoubleBigArray m4s = new DoubleBigArray(); + + public GroupedCentralMomentAccumulator( + TSDataType seriesDataType, CentralMomentAccumulator.MomentType momentType) { + this.seriesDataType = seriesDataType; + this.momentType = momentType; + } + + @Override + public long getEstimatedSize() { + return INSTANCE_SIZE + + counts.sizeOf() + + means.sizeOf() + + m2s.sizeOf() + + m3s.sizeOf() + + m4s.sizeOf(); + } + + @Override + public void setGroupCount(long groupCount) { + counts.ensureCapacity(groupCount); + means.ensureCapacity(groupCount); + m2s.ensureCapacity(groupCount); + m3s.ensureCapacity(groupCount); + m4s.ensureCapacity(groupCount); + } + + @Override + public void addInput(int[] groupIds, Column[] arguments, AggregationMask mask) { + int positionCount = mask.getSelectedPositionCount(); + if (mask.isSelectAll()) { + for (int i = 0; i < positionCount; i++) { + if (!arguments[0].isNull(i)) { + update(groupIds[i], getDoubleValue(arguments[0], i)); + } + } + } else { + int[] selectedPositions = mask.getSelectedPositions(); + for (int i = 0; i < positionCount; i++) { + int position = selectedPositions[i]; + if (!arguments[0].isNull(position)) { + update(groupIds[position], getDoubleValue(arguments[0], position)); + } + } + } + } + + private double getDoubleValue(Column column, int position) { + switch (seriesDataType) { + case INT32: + case DATE: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new UnSupportedDataTypeException( + String.format( + "Unsupported data type in CentralMoment Aggregation: %s", seriesDataType)); + } + } + + private void update(int groupId, double value) { + long n1 = counts.get(groupId); + long newCount = n1 + 1; + double mean = means.get(groupId); + double m2 = m2s.get(groupId); + double m3 = m3s.get(groupId); + double m4 = m4s.get(groupId); + + double delta = value - mean; + double delta_n = delta / newCount; + double delta_n2 = delta_n * delta_n; + double term1 = delta * delta_n * n1; + + mean += delta_n; + m4 += + term1 * delta_n2 * (newCount * newCount - 3 * newCount + 3) + + 6 * delta_n2 * m2 + - 4 * delta_n * m3; + m3 += term1 * delta_n * (newCount - 2) - 3 * delta_n * m2; + m2 += term1; + + counts.set(groupId, newCount); + means.set(groupId, mean); + m2s.set(groupId, m2); + m3s.set(groupId, m3); + m4s.set(groupId, m4); + } + + @Override + public void addIntermediate(int[] groupIds, Column argument) { + checkArgument( + argument instanceof BinaryColumn + || (argument instanceof RunLengthEncodedColumn + && ((RunLengthEncodedColumn) argument).getValue() instanceof BinaryColumn), + "intermediate input and output should be BinaryColumn"); + + for (int i = 0; i < argument.getPositionCount(); i++) { + if (argument.isNull(i)) { + continue; + } + byte[] bytes = argument.getBinary(i).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMean = buffer.getDouble(); + double otherM2 = buffer.getDouble(); + double otherM3 = buffer.getDouble(); + double otherM4 = buffer.getDouble(); + + merge(groupIds[i], otherCount, otherMean, otherM2, otherM3, otherM4); + } + } + + private void merge(int groupId, long nB, double meanB, double m2B, double m3B, double m4B) { + if (nB == 0) return; + long nA = counts.get(groupId); + if (nA == 0) { + counts.set(groupId, nB); + means.set(groupId, meanB); + m2s.set(groupId, m2B); + m3s.set(groupId, m3B); + m4s.set(groupId, m4B); + } else { + long nTotal = nA + nB; + double delta = meanB - means.get(groupId); + double delta2 = delta * delta; + double delta3 = delta * delta2; + double delta4 = delta2 * delta2; + + double m2 = m2s.get(groupId); + double m3 = m3s.get(groupId); + double m4 = m4s.get(groupId); + + m4 += + m4B + + delta4 * nA * nB * (nA * nA - nA * nB + nB * nB) / (nTotal * nTotal * nTotal) + + 6.0 * delta2 * (nA * nA * m2B + nB * nB * m2) / (nTotal * nTotal) + + 4.0 * delta * (nA * m3B - nB * m3) / nTotal; + + m3 += + m3B + + delta3 * nA * nB * (nA - nB) / (nTotal * nTotal) + + 3.0 * delta * (nA * m2B - nB * m2) / nTotal; + + m2 += m2B + delta2 * nA * nB / nTotal; + + means.add(groupId, delta * nB / nTotal); + counts.set(groupId, nTotal); + m2s.set(groupId, m2); + m3s.set(groupId, m3); + m4s.set(groupId, m4); + } + } + + @Override + public void evaluateIntermediate(int groupId, ColumnBuilder columnBuilder) { + checkArgument( + columnBuilder instanceof BinaryColumnBuilder, + "intermediate input and output should be BinaryColumn"); + + if (counts.get(groupId) == 0) { + columnBuilder.appendNull(); + } else { + ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES + Double.BYTES * 4); + buffer.putLong(counts.get(groupId)); + buffer.putDouble(means.get(groupId)); + buffer.putDouble(m2s.get(groupId)); + buffer.putDouble(m3s.get(groupId)); + buffer.putDouble(m4s.get(groupId)); + columnBuilder.writeBinary(new Binary(buffer.array())); + } + } + + @Override + public void evaluateFinal(int groupId, ColumnBuilder columnBuilder) { + long count = counts.get(groupId); + double m2 = m2s.get(groupId); + + if (count == 0 || m2 == 0) { + columnBuilder.appendNull(); + return; + } + + if (momentType == CentralMomentAccumulator.MomentType.SKEWNESS) { + if (count < 3) { + columnBuilder.appendNull(); + } else { + double m3 = m3s.get(groupId); + double variance = m2 / (count - 1); + double stdev = Math.sqrt(variance); + double result = (count * m3) / ((count - 1) * (count - 2) * stdev * stdev * stdev); + columnBuilder.writeDouble(result); + } + } else { // KURTOSIS + if (count < 4) { + columnBuilder.appendNull(); + } else { + double m4 = m4s.get(groupId); + double variance = m2 / (count - 1); + double term1 = + (count * (count + 1) * m4) + / ((count - 1) * (count - 2) * (count - 3) * variance * variance); + double term2 = (3 * Math.pow(count - 1, 2)) / ((count - 2) * (count - 3)); + columnBuilder.writeDouble(term1 - term2); + } + } + } + + @Override + public void prepareFinal() {} + + @Override + public void reset() { + counts.reset(); + means.reset(); + m2s.reset(); + m3s.reset(); + m4s.reset(); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java index c2f1de149779f..663d7645f6214 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java @@ -392,6 +392,19 @@ public TSDataType visitFunctionExpression( } } } + if (funcName.equals(SqlConstant.SKEWNESS) || funcName.equals(SqlConstant.KURTOSIS)) { + if (!inputExpressions.isEmpty()) { + TSDataType firstInputType = expressionTypes.get(NodeRef.of(inputExpressions.get(0))); + if (firstInputType != null + && !firstInputType.isNumeric() + && firstInputType != TSDataType.TIMESTAMP) { + throw new SemanticException( + String.format( + "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", + functionExpression.getFunctionName().toUpperCase())); + } + } + } return setExpressionType( functionExpression, @@ -579,6 +592,8 @@ private TSDataType getInputExpressionTypeForAggregation( case SqlConstant.COVAR_SAMP: case SqlConstant.REGR_SLOPE: case SqlConstant.REGR_INTERCEPT: + case SqlConstant.SKEWNESS: + case SqlConstant.KURTOSIS: case SqlConstant.MAX_BY: case SqlConstant.MIN_BY: return expressionTypes.get(NodeRef.of(inputExpressions.get(0))); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java index a469df7a22fe5..b8b2117e976a0 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java @@ -3186,6 +3186,8 @@ private void checkAggregationFunctionInput(FunctionExpression functionExpression case SqlConstant.VARIANCE: case SqlConstant.VAR_POP: case SqlConstant.VAR_SAMP: + case SqlConstant.SKEWNESS: + case SqlConstant.KURTOSIS: checkFunctionExpressionInputSize( functionExpression.getExpressionString(), functionExpression.getExpressions().size(), diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/AggregationDescriptor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/AggregationDescriptor.java index 306109eaddae4..36c6470078853 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/AggregationDescriptor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/AggregationDescriptor.java @@ -202,6 +202,12 @@ public List getActualAggregationNames(boolean isPartial) { case REGR_INTERCEPT: outputAggregationNames.add(addPartialSuffix(SqlConstant.REGR_INTERCEPT)); break; + case SKEWNESS: + outputAggregationNames.add(addPartialSuffix(SqlConstant.SKEWNESS)); + break; + case KURTOSIS: + outputAggregationNames.add(addPartialSuffix(SqlConstant.KURTOSIS)); + break; case MAX_BY: outputAggregationNames.add(addPartialSuffix(SqlConstant.MAX_BY)); break; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java index 6ea01b136f802..0c43398ba1f70 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java @@ -1082,6 +1082,14 @@ protected Type visitFunctionCall( node, argumentTypes.size())); } } + if (lowerFuncName.equals("skewness") || lowerFuncName.equals("kurtosis")) { + if (argumentTypes.size() != 1) { + throw new SemanticException( + String.format( + "Error size of input expressions. expression: %s, actual size: %s, expected size: [1].", + node, argumentTypes.size())); + } + } } Type type = metadata.getFunctionReturnType(functionName, argumentTypes); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java index ddee5f5f39a1b..a670ce67e079d 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java @@ -679,6 +679,16 @@ && isIntegerNumber(argumentTypes.get(2)))) { functionName.toUpperCase())); } break; + case SqlConstant.SKEWNESS: + case SqlConstant.KURTOSIS: + // Argument count is already checked in ExpressionAnalyzer + if (!isSupportedMathNumericType(argumentTypes.get(0))) { + throw new SemanticException( + String.format( + "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", + functionName.toUpperCase())); + } + break; case SqlConstant.APPROX_COUNT_DISTINCT: if (argumentTypes.size() != 1 && argumentTypes.size() != 2) { throw new SemanticException( @@ -737,6 +747,8 @@ && isIntegerNumber(argumentTypes.get(2)))) { case SqlConstant.COVAR_SAMP: case SqlConstant.REGR_SLOPE: case SqlConstant.REGR_INTERCEPT: + case SqlConstant.SKEWNESS: + case SqlConstant.KURTOSIS: return DOUBLE; case SqlConstant.APPROX_MOST_FREQUENT: return STRING; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/SchemaUtils.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/SchemaUtils.java index 6ebc44a5c41b0..71a24bc23678f 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/SchemaUtils.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/SchemaUtils.java @@ -93,6 +93,8 @@ public static TSDataType getBuiltinAggregationTypeByFuncName(String aggregation) case SqlConstant.COVAR_SAMP: case SqlConstant.REGR_SLOPE: case SqlConstant.REGR_INTERCEPT: + case SqlConstant.SKEWNESS: + case SqlConstant.KURTOSIS: return TSDataType.DOUBLE; // Partial aggregation names case SqlConstant.STDDEV + "_partial": @@ -106,6 +108,8 @@ public static TSDataType getBuiltinAggregationTypeByFuncName(String aggregation) case SqlConstant.COVAR_SAMP + "_partial": case SqlConstant.REGR_SLOPE + "_partial": case SqlConstant.REGR_INTERCEPT + "_partial": + case SqlConstant.SKEWNESS + "_partial": + case SqlConstant.KURTOSIS + "_partial": case SqlConstant.MAX_BY + "_partial": case SqlConstant.MIN_BY + "_partial": return TSDataType.TEXT; @@ -183,6 +187,10 @@ public static String getBuiltinAggregationName(TAggregationType aggregationType) return SqlConstant.REGR_SLOPE; case REGR_INTERCEPT: return SqlConstant.REGR_INTERCEPT; + case SKEWNESS: + return SqlConstant.SKEWNESS; + case KURTOSIS: + return SqlConstant.KURTOSIS; default: return null; } @@ -223,6 +231,8 @@ public static boolean isConsistentWithScanOrder( case COVAR_SAMP: case REGR_SLOPE: case REGR_INTERCEPT: + case SKEWNESS: + case KURTOSIS: case UDAF: return true; default: @@ -267,6 +277,10 @@ public static List splitPartialBuiltinAggregation(TAggregationType aggre return Collections.singletonList(addPartialSuffix(SqlConstant.REGR_SLOPE)); case REGR_INTERCEPT: return Collections.singletonList(addPartialSuffix(SqlConstant.REGR_INTERCEPT)); + case SKEWNESS: + return Collections.singletonList(addPartialSuffix(SqlConstant.SKEWNESS)); + case KURTOSIS: + return Collections.singletonList(addPartialSuffix(SqlConstant.KURTOSIS)); case MAX_BY: return Collections.singletonList(addPartialSuffix(SqlConstant.MAX_BY)); case MIN_BY: diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java index b5b21126b8a79..be9d72cdcddd8 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java @@ -159,6 +159,8 @@ public static TSDataType getBuiltinAggregationDataType( case SqlConstant.COVAR_SAMP: case SqlConstant.REGR_SLOPE: case SqlConstant.REGR_INTERCEPT: + case SqlConstant.SKEWNESS: + case SqlConstant.KURTOSIS: return TSDataType.DOUBLE; default: throw new IllegalArgumentException( @@ -215,6 +217,14 @@ private static void verifyIsAggregationDataTypeMatched(String aggrFuncName, TSDa throw new SemanticException( "Aggregate functions [REGR_SLOPE, REGR_INTERCEPT] only support " + "numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]"); + case SqlConstant.SKEWNESS: + case SqlConstant.KURTOSIS: + if (dataType.isNumeric() || TSDataType.TIMESTAMP.equals(dataType)) { + return; + } + throw new SemanticException( + "Aggregate functions [SKEWNESS, KURTOSIS] only support " + + "numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]"); case SqlConstant.COUNT: case SqlConstant.COUNT_TIME: case SqlConstant.MIN_TIME: @@ -274,6 +284,8 @@ public static void bindTypeForBuiltinAggregationNonSeriesInputExpressions( case SqlConstant.COVAR_SAMP: case SqlConstant.REGR_SLOPE: case SqlConstant.REGR_INTERCEPT: + case SqlConstant.SKEWNESS: + case SqlConstant.KURTOSIS: case SqlConstant.MAX_BY: case SqlConstant.MIN_BY: return; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/SqlConstant.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/SqlConstant.java index 1e2a90b41a661..32d43cedd5da3 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/SqlConstant.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/SqlConstant.java @@ -79,6 +79,8 @@ protected SqlConstant() { public static final String COVAR_SAMP = "covar_samp"; public static final String REGR_SLOPE = "regr_slope"; public static final String REGR_INTERCEPT = "regr_intercept"; + public static final String SKEWNESS = "skewness"; + public static final String KURTOSIS = "kurtosis"; public static final String COUNT_TIME = "count_time"; public static final String COUNT_TIME_HEADER = "count_time(*)"; diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinAggregationFunction.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinAggregationFunction.java index 0cef02dd1dafc..c8274b94d1c44 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinAggregationFunction.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinAggregationFunction.java @@ -51,7 +51,9 @@ public enum BuiltinAggregationFunction { COVAR_POP("covar_pop"), COVAR_SAMP("covar_samp"), REGR_SLOPE("regr_slope"), - REGR_INTERCEPT("regr_intercept"); + REGR_INTERCEPT("regr_intercept"), + SKEWNESS("skewness"), + KURTOSIS("kurtosis"); private final String functionName; @@ -107,6 +109,8 @@ public static boolean canUseStatistics(String name) { case "covar_samp": case "regr_slope": case "regr_intercept": + case "skewness": + case "kurtosis": return false; default: throw new IllegalArgumentException("Invalid Aggregation function: " + name); @@ -146,6 +150,8 @@ public static boolean canSplitToMultiPhases(String name) { case "covar_samp": case "regr_slope": case "regr_intercept": + case "skewness": + case "kurtosis": return true; case "count_if": case "count_time": diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java index 71bba543449b6..151df10988d65 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java @@ -63,7 +63,9 @@ public enum TableBuiltinAggregationFunction { COVAR_POP("covar_pop"), COVAR_SAMP("covar_samp"), REGR_SLOPE("regr_slope"), - REGR_INTERCEPT("regr_intercept"); + REGR_INTERCEPT("regr_intercept"), + SKEWNESS("skewness"), + KURTOSIS("kurtosis"); private final String functionName; @@ -113,6 +115,8 @@ public static Type getIntermediateType(String name, List originalArgumentT case "covar_samp": case "regr_slope": case "regr_intercept": + case "skewness": + case "kurtosis": case "approx_count_distinct": return RowType.anonymous(Collections.emptyList()); case "extreme": diff --git a/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift b/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift index 659ed7938486f..fa4ea97fef110 100644 --- a/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift +++ b/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift @@ -298,7 +298,9 @@ enum TAggregationType { COVAR_POP, COVAR_SAMP, REGR_SLOPE, - REGR_INTERCEPT + REGR_INTERCEPT, + SKEWNESS, + KURTOSIS } struct TShowConfigurationTemplateResp { From c31dda45cb7f06239861abfb6c6fc2f194bcbff7 Mon Sep 17 00:00:00 2001 From: Cool6689 <3322351820@qq.com> Date: Thu, 12 Mar 2026 19:22:46 +0800 Subject: [PATCH 5/7] Refactor input validation for aggregate functions to streamline argument checks --- .../analyzer/ExpressionAnalyzer.java | 26 ------ .../metadata/TableMetadataImpl.java | 80 +++++++++---------- 2 files changed, 39 insertions(+), 67 deletions(-) diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java index 0c43398ba1f70..6157612b42382 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java @@ -1066,32 +1066,6 @@ protected Type visitFunctionCall( } } - // Check argument count for specific aggregation functions before calling - // getFunctionReturnType - if (isAggregation) { - String lowerFuncName = functionName.toLowerCase(); - if (lowerFuncName.equals("corr") - || lowerFuncName.equals("covar_pop") - || lowerFuncName.equals("covar_samp") - || lowerFuncName.equals("regr_slope") - || lowerFuncName.equals("regr_intercept")) { - if (argumentTypes.size() != 2) { - throw new SemanticException( - String.format( - "Error size of input expressions. expression: %s, actual size: %s, expected size: [2].", - node, argumentTypes.size())); - } - } - if (lowerFuncName.equals("skewness") || lowerFuncName.equals("kurtosis")) { - if (argumentTypes.size() != 1) { - throw new SemanticException( - String.format( - "Error size of input expressions. expression: %s, actual size: %s, expected size: [1].", - node, argumentTypes.size())); - } - } - } - Type type = metadata.getFunctionReturnType(functionName, argumentTypes); FunctionKind functionKind = FunctionKind.SCALAR; if (isAggregation) { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java index a670ce67e079d..02e6a1bae440d 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java @@ -597,6 +597,45 @@ && isIntegerNumber(argumentTypes.get(2)))) { functionName)); } break; + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: + if (argumentTypes.size() != 2) { + throw new SemanticException( + String.format( + "Error size of input expressions. expression: %s, actual size: %s, expected size: [2].", + functionName.toUpperCase(), argumentTypes.size())); + } + if (!isSupportedMathNumericType(argumentTypes.get(0))) { + throw new SemanticException( + String.format( + "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", + functionName.toUpperCase())); + } + if (!isSupportedMathNumericType(argumentTypes.get(1))) { + throw new SemanticException( + String.format( + "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", + functionName.toUpperCase())); + } + break; + case SqlConstant.SKEWNESS: + case SqlConstant.KURTOSIS: + if (argumentTypes.size() != 1) { + throw new SemanticException( + String.format( + "Error size of input expressions. expression: %s, actual size: %s, expected size: [1].", + functionName.toUpperCase(), argumentTypes.size())); + } + if (!isSupportedMathNumericType(argumentTypes.get(0))) { + throw new SemanticException( + String.format( + "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", + functionName.toUpperCase())); + } + break; case SqlConstant.MIN: case SqlConstant.MAX: case SqlConstant.MODE: @@ -647,47 +686,6 @@ && isIntegerNumber(argumentTypes.get(2)))) { "Second argument of Aggregate functions [%s] should be orderable", functionName)); } - break; - case SqlConstant.CORR: - case SqlConstant.COVAR_POP: - case SqlConstant.COVAR_SAMP: - // Argument count is already checked in ExpressionAnalyzer - if (!isSupportedMathNumericType(argumentTypes.get(0))) { - throw new SemanticException( - String.format( - "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", - functionName.toUpperCase())); - } else if (!isSupportedMathNumericType(argumentTypes.get(1))) { - throw new SemanticException( - String.format( - "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", - functionName.toUpperCase())); - } - break; - case SqlConstant.REGR_SLOPE: - case SqlConstant.REGR_INTERCEPT: - // Argument count is already checked in ExpressionAnalyzer - if (!isSupportedMathNumericType(argumentTypes.get(0))) { - throw new SemanticException( - String.format( - "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", - functionName.toUpperCase())); - } else if (!isSupportedMathNumericType(argumentTypes.get(1))) { - throw new SemanticException( - String.format( - "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", - functionName.toUpperCase())); - } - break; - case SqlConstant.SKEWNESS: - case SqlConstant.KURTOSIS: - // Argument count is already checked in ExpressionAnalyzer - if (!isSupportedMathNumericType(argumentTypes.get(0))) { - throw new SemanticException( - String.format( - "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", - functionName.toUpperCase())); - } break; case SqlConstant.APPROX_COUNT_DISTINCT: if (argumentTypes.size() != 1 && argumentTypes.size() != 2) { From dedfbc56dcb5eff5f2093fbb5852d006290758b6 Mon Sep 17 00:00:00 2001 From: Cool6689 <3322351820@qq.com> Date: Thu, 12 Mar 2026 21:15:19 +0800 Subject: [PATCH 6/7] Refactor aggregation function type checks to improve input validation --- .../plan/analyze/ExpressionTypeAnalyzer.java | 57 +++---------------- .../iotdb/db/utils/TypeInferenceUtils.java | 46 +++++++++------ 2 files changed, 38 insertions(+), 65 deletions(-) diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java index 663d7645f6214..7bc866fdd2bf5 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java @@ -361,50 +361,6 @@ public TSDataType visitFunctionExpression( } if (functionExpression.isBuiltInAggregationFunctionExpression()) { - // Additional type check for multi-input aggregation functions - String funcName = functionExpression.getFunctionName().toLowerCase(); - if (funcName.equals(SqlConstant.CORR) - || funcName.equals(SqlConstant.COVAR_POP) - || funcName.equals(SqlConstant.COVAR_SAMP) - || funcName.equals(SqlConstant.REGR_SLOPE) - || funcName.equals(SqlConstant.REGR_INTERCEPT)) { - // Check both input parameters are numeric or timestamp - if (inputExpressions.size() >= 1) { - TSDataType firstInputType = expressionTypes.get(NodeRef.of(inputExpressions.get(0))); - if (firstInputType != null - && !firstInputType.isNumeric() - && firstInputType != TSDataType.TIMESTAMP) { - throw new SemanticException( - String.format( - "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", - functionExpression.getFunctionName().toUpperCase())); - } - } - if (inputExpressions.size() >= 2) { - TSDataType secondInputType = expressionTypes.get(NodeRef.of(inputExpressions.get(1))); - if (secondInputType != null - && !secondInputType.isNumeric() - && secondInputType != TSDataType.TIMESTAMP) { - throw new SemanticException( - String.format( - "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", - functionExpression.getFunctionName().toUpperCase())); - } - } - } - if (funcName.equals(SqlConstant.SKEWNESS) || funcName.equals(SqlConstant.KURTOSIS)) { - if (!inputExpressions.isEmpty()) { - TSDataType firstInputType = expressionTypes.get(NodeRef.of(inputExpressions.get(0))); - if (firstInputType != null - && !firstInputType.isNumeric() - && firstInputType != TSDataType.TIMESTAMP) { - throw new SemanticException( - String.format( - "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", - functionExpression.getFunctionName().toUpperCase())); - } - } - } return setExpressionType( functionExpression, @@ -587,15 +543,20 @@ private TSDataType getInputExpressionTypeForAggregation( case SqlConstant.VARIANCE: case SqlConstant.VAR_POP: case SqlConstant.VAR_SAMP: + case SqlConstant.SKEWNESS: + case SqlConstant.KURTOSIS: + case SqlConstant.MAX_BY: + case SqlConstant.MIN_BY: + return expressionTypes.get(NodeRef.of(inputExpressions.get(0))); case SqlConstant.CORR: case SqlConstant.COVAR_POP: case SqlConstant.COVAR_SAMP: case SqlConstant.REGR_SLOPE: case SqlConstant.REGR_INTERCEPT: - case SqlConstant.SKEWNESS: - case SqlConstant.KURTOSIS: - case SqlConstant.MAX_BY: - case SqlConstant.MIN_BY: + TypeInferenceUtils.verifyIsAggregationDataTypeMatchedForBothInputs( + aggregateFunctionName, + expressionTypes.get(NodeRef.of(inputExpressions.get(0))), + expressionTypes.get(NodeRef.of(inputExpressions.get(1)))); return expressionTypes.get(NodeRef.of(inputExpressions.get(0))); default: throw new IllegalArgumentException( diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java index be9d72cdcddd8..10413c6f2537e 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java @@ -200,23 +200,6 @@ private static void verifyIsAggregationDataTypeMatched(String aggrFuncName, TSDa "Aggregate functions [AVG, SUM, EXTREME, STDDEV, STDDEV_POP, STDDEV_SAMP, " + "VARIANCE, VAR_POP, VAR_SAMP] only support " + "numeric data types [INT32, INT64, FLOAT, DOUBLE]"); - case SqlConstant.CORR: - case SqlConstant.COVAR_POP: - case SqlConstant.COVAR_SAMP: - if (dataType.isNumeric() || TSDataType.TIMESTAMP.equals(dataType)) { - return; - } - throw new SemanticException( - "Aggregate functions [CORR, COVAR_POP, COVAR_SAMP] only support " - + "numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]"); - case SqlConstant.REGR_SLOPE: - case SqlConstant.REGR_INTERCEPT: - if (dataType.isNumeric() || TSDataType.TIMESTAMP.equals(dataType)) { - return; - } - throw new SemanticException( - "Aggregate functions [REGR_SLOPE, REGR_INTERCEPT] only support " - + "numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]"); case SqlConstant.SKEWNESS: case SqlConstant.KURTOSIS: if (dataType.isNumeric() || TSDataType.TIMESTAMP.equals(dataType)) { @@ -225,6 +208,11 @@ private static void verifyIsAggregationDataTypeMatched(String aggrFuncName, TSDa throw new SemanticException( "Aggregate functions [SKEWNESS, KURTOSIS] only support " + "numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]"); + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: case SqlConstant.COUNT: case SqlConstant.COUNT_TIME: case SqlConstant.MIN_TIME: @@ -249,6 +237,30 @@ private static void verifyIsAggregationDataTypeMatched(String aggrFuncName, TSDa } } + public static void verifyIsAggregationDataTypeMatchedForBothInputs( + String aggrFuncName, TSDataType firstDataType, TSDataType secondDataType) { + switch (aggrFuncName.toLowerCase()) { + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: + if ((firstDataType != null + && !firstDataType.isNumeric() + && !TSDataType.TIMESTAMP.equals(firstDataType)) + || (secondDataType != null + && !secondDataType.isNumeric() + && !TSDataType.TIMESTAMP.equals(secondDataType))) { + throw new SemanticException( + "Aggregate functions [CORR, COVAR_POP, COVAR_SAMP, REGR_SLOPE, REGR_INTERCEPT] only support " + + "numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]"); + } + return; + default: + break; + } + } + /** * Bind Type for non-series input Expressions of AggregationFunction and check Semantic * From 2523cebf3b45755c3a58d6b1fa4cbe2af8a6d6c6 Mon Sep 17 00:00:00 2001 From: Cool6689 <3322351820@qq.com> Date: Fri, 13 Mar 2026 10:13:10 +0800 Subject: [PATCH 7/7] Refactor correlation and regression accumulators to improve code clarity by removing unnecessary comments and whitespace --- .../aggregation/CentralMomentAccumulator.java | 27 ++++++------------- .../aggregation/CorrelationAccumulator.java | 22 ++++++--------- .../aggregation/RegressionAccumulator.java | 17 +++--------- .../TableCentralMomentAccumulator.java | 5 ++-- .../TableCorrelationAccumulator.java | 8 +++--- .../TableRegressionAccumulator.java | 2 +- .../GroupedCentralMomentAccumulator.java | 2 +- .../GroupedCorrelationAccumulator.java | 2 -- .../grouped/GroupedRegressionAccumulator.java | 3 +-- 9 files changed, 28 insertions(+), 60 deletions(-) diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CentralMomentAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CentralMomentAccumulator.java index bcb732db918f7..6433757df92ba 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CentralMomentAccumulator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CentralMomentAccumulator.java @@ -35,7 +35,6 @@ public enum MomentType { private final TSDataType seriesDataType; private final MomentType momentType; - // State variables: count, mean, M2, M3, M4 private long count; private double mean; private double m2; @@ -49,7 +48,7 @@ public CentralMomentAccumulator(TSDataType seriesDataType, MomentType momentType @Override public void addInput(Column[] columns, BitMap bitMap) { - // Tree model: columns[0] is Time, columns[1] is data + int size = columns[1].getPositionCount(); for (int i = 0; i < size; i++) { if (bitMap != null && !bitMap.isMarked(i)) { @@ -91,13 +90,10 @@ private void update(double value) { mean += delta_n; - // 更新 M4 (顺序很重要,必须在更新 M3, M2 之前) m4 += term1 * delta_n2 * (count * count - 3 * count + 3) + 6 * delta_n2 * m2 - 4 * delta_n * m3; - // 更新 M3 m3 += term1 * delta_n * (count - 2) - 3 * delta_n * m2; - // 更新 M2 m2 += term1; } @@ -135,24 +131,19 @@ private void merge(long nB, double meanB, double m2B, double m3B, double m4B) { double delta3 = delta * delta2; double delta4 = delta2 * delta2; - // 合并公式 (Chan et al.) - // M4 合并 m4 += m4B + delta4 * nA * nB * (nA * nA - nA * nB + nB * nB) / (nTotal * nTotal * nTotal) + 6.0 * delta2 * (nA * nA * m2B + nB * nB * m2) / (nTotal * nTotal) + 4.0 * delta * (nA * m3B - nB * m3) / nTotal; - // M3 合并 m3 += m3B + delta3 * nA * nB * (nA - nB) / (nTotal * nTotal) + 3.0 * delta * (nA * m2B - nB * m2) / nTotal; - // M2 合并 m2 += m2B + delta2 * nA * nB / nTotal; - // Mean 合并 mean += delta * nB / nTotal; count = nTotal; } @@ -164,7 +155,7 @@ public void outputIntermediate(ColumnBuilder[] columnBuilders) { if (count == 0) { columnBuilders[0].appendNull(); } else { - // 序列化: long + 4 * double = 40 bytes + byte[] bytes = new byte[40]; ByteBuffer buffer = ByteBuffer.wrap(bytes); buffer.putLong(count); @@ -178,27 +169,26 @@ public void outputIntermediate(ColumnBuilder[] columnBuilders) { @Override public void outputFinal(ColumnBuilder columnBuilder) { - if (count == 0 || m2 == 0) { // 方差为0或无数据 + if (count == 0 || m2 == 0) { columnBuilder.appendNull(); return; } if (momentType == MomentType.SKEWNESS) { - if (count < 3) { // 偏度要求 N >= 3 + if (count < 3) { columnBuilder.appendNull(); } else { - // 无偏估计公式: (N * M3) / ((N-1)*(N-2) * sigma^3) - // sigma = sqrt(M2 / (N-1)) + double variance = m2 / (count - 1); double stdev = Math.sqrt(variance); double result = (count * m3) / ((count - 1) * (count - 2) * stdev * stdev * stdev); columnBuilder.writeDouble(result); } - } else { // KURTOSIS - if (count < 4) { // 峰度要求 N >= 4 + } else { + if (count < 4) { columnBuilder.appendNull(); } else { - // 无偏估计公式 (超额峰度 Excess Kurtosis) + double variance = m2 / (count - 1); double term1 = (count * (count + 1) * m4) @@ -209,7 +199,6 @@ public void outputFinal(ColumnBuilder columnBuilder) { } } - // 默认实现 @Override public void removeIntermediate(Column[] input) { throw new UnsupportedOperationException(); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CorrelationAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CorrelationAccumulator.java index 722659c6581c8..c304bd2bd673b 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CorrelationAccumulator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CorrelationAccumulator.java @@ -44,9 +44,9 @@ public enum CorrelationType { private long count; private double meanX; private double meanY; - private double m2X; // sum((x - meanX)^2) - private double m2Y; // sum((y - meanY)^2) - private double c2; // sum((x - meanX) * (y - meanY)) + private double m2X; + private double m2Y; + private double c2; public CorrelationAccumulator(TSDataType[] seriesDataTypes, CorrelationType correlationType) { this.seriesDataTypes = seriesDataTypes; @@ -55,8 +55,7 @@ public CorrelationAccumulator(TSDataType[] seriesDataTypes, CorrelationType corr @Override public void addInput(Column[] columns, BitMap bitMap) { - // columns[0] is time column - // columns[1] and columns[2] are the two data columns + int size = columns[0].getPositionCount(); for (int i = 0; i < size; i++) { if (bitMap != null && !bitMap.isMarked(i)) { @@ -97,8 +96,6 @@ private void update(double x, double y) { meanX += deltaX / newCount; meanY += deltaY / newCount; - // Welford's algorithm for covariance and variance - // C2_new = C2_old + (x - meanX_old) * (y - meanY_new) c2 += deltaX * (y - meanY); m2X += deltaX * (x - meanX); m2Y += deltaY * (y - meanY); @@ -147,7 +144,6 @@ private void merge( double deltaX = otherMeanX - meanX; double deltaY = otherMeanY - meanY; - // Merge formulas c2 += otherC2 + deltaX * deltaY * count * otherCount / newCount; m2X += otherM2X + deltaX * deltaX * count * otherCount / newCount; m2Y += otherM2Y + deltaY * deltaY * count * otherCount / newCount; @@ -180,10 +176,10 @@ public void outputFinal(ColumnBuilder columnBuilder) { switch (correlationType) { case CORR: if (count < 2) { - // Not enough data to calculate correlation + columnBuilder.appendNull(); } else if (m2X == 0 || m2Y == 0) { - // If either variable has zero variance (all values the same), correlation is 0 + columnBuilder.writeDouble(0.0); } else { columnBuilder.writeDouble(c2 / Math.sqrt(m2X * m2Y)); @@ -210,7 +206,7 @@ public void outputFinal(ColumnBuilder columnBuilder) { @Override public void removeIntermediate(Column[] input) { - // Optional: sliding window logic implementation if needed, otherwise throw exception + throw new UnsupportedOperationException("Remove not implemented for Correlation"); } @@ -220,9 +216,7 @@ public void addStatistics(Statistics statistics) { } @Override - public void setFinal(Column finalResult) { - // No-op for this accumulator typically - } + public void setFinal(Column finalResult) {} @Override public void reset() { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/RegressionAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/RegressionAccumulator.java index 70db4ac33e357..c6e39746fc954 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/RegressionAccumulator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/RegressionAccumulator.java @@ -35,12 +35,11 @@ public enum RegressionType { private final TSDataType[] seriesDataTypes; private final RegressionType regressionType; - // 状态变量 (不需要 m2Y) private long count; private double meanX; private double meanY; - private double m2X; // Sum((x - meanX)^2) - private double c2; // Sum((x - meanX) * (y - meanY)) + private double m2X; + private double c2; public RegressionAccumulator(TSDataType[] seriesDataTypes, RegressionType regressionType) { this.seriesDataTypes = seriesDataTypes; @@ -49,8 +48,6 @@ public RegressionAccumulator(TSDataType[] seriesDataTypes, RegressionType regres @Override public void addInput(Column[] columns, BitMap bitMap) { - // Tree 模型: columns[0] 是 Time - // REGR_SLOPE(y, x) -> columns[1] 是 y, columns[2] 是 x int size = columns[1].getPositionCount(); for (int i = 0; i < size; i++) { @@ -61,8 +58,8 @@ public void addInput(Column[] columns, BitMap bitMap) { continue; } - double y = getDoubleValue(columns[1], i, seriesDataTypes[0]); // Arg1: Y (因变量) - double x = getDoubleValue(columns[2], i, seriesDataTypes[1]); // Arg2: X (自变量) + double y = getDoubleValue(columns[1], i, seriesDataTypes[0]); + double x = getDoubleValue(columns[2], i, seriesDataTypes[1]); update(x, y); } @@ -92,7 +89,6 @@ private void update(double x, double y) { meanX += deltaX / newCount; meanY += deltaY / newCount; - // Welford Covariance & Variance c2 += deltaX * (y - meanY); m2X += deltaX * (x - meanX); @@ -133,7 +129,6 @@ private void merge( double deltaX = otherMeanX - meanX; double deltaY = otherMeanY - meanY; - // Merge Logic c2 += otherC2 + deltaX * deltaY * count * otherCount / newCount; m2X += otherM2X + deltaX * deltaX * count * otherCount / newCount; @@ -149,7 +144,6 @@ public void outputIntermediate(ColumnBuilder[] columnBuilders) { if (count == 0) { columnBuilders[0].appendNull(); } else { - // 序列化 5 个变量: long(8) + 4 * double(8) = 40 bytes byte[] bytes = new byte[40]; ByteBuffer buffer = ByteBuffer.wrap(bytes); buffer.putLong(count); @@ -168,7 +162,6 @@ public void outputFinal(ColumnBuilder columnBuilder) { return; } - // 如果 X 没有波动 (m2X=0), 斜率无法计算 (除以0), 返回 NULL if (m2X == 0) { columnBuilder.appendNull(); return; @@ -181,7 +174,6 @@ public void outputFinal(ColumnBuilder columnBuilder) { columnBuilder.writeDouble(slope); break; case REGR_INTERCEPT: - // Intercept = MeanY - Slope * MeanX columnBuilder.writeDouble(meanY - slope * meanX); break; default: @@ -189,7 +181,6 @@ public void outputFinal(ColumnBuilder columnBuilder) { } } - // 其他必须实现的接口方法 @Override public void removeIntermediate(Column[] input) { throw new UnsupportedOperationException(); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCentralMomentAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCentralMomentAccumulator.java index 97ff93fce45a3..bddef8ee41e6e 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCentralMomentAccumulator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCentralMomentAccumulator.java @@ -39,7 +39,6 @@ public class TableCentralMomentAccumulator implements TableAccumulator { private final TSDataType seriesDataType; private final CentralMomentAccumulator.MomentType momentType; - // State private long count; private double mean; private double m2; @@ -172,7 +171,7 @@ public void evaluateIntermediate(ColumnBuilder columnBuilder) { if (count == 0) { columnBuilder.appendNull(); } else { - // Serialize: long(8) + 4*double(32) = 40 bytes + ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES + Double.BYTES * 4); buffer.putLong(count); buffer.putDouble(mean); @@ -199,7 +198,7 @@ public void evaluateFinal(ColumnBuilder columnBuilder) { double result = (count * m3) / ((count - 1) * (count - 2) * stdev * stdev * stdev); columnBuilder.writeDouble(result); } - } else { // KURTOSIS + } else { if (count < 4) { columnBuilder.appendNull(); } else { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCorrelationAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCorrelationAccumulator.java index 9ab44df27a512..ad8407e492b3c 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCorrelationAccumulator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCorrelationAccumulator.java @@ -42,9 +42,9 @@ public class TableCorrelationAccumulator implements TableAccumulator { private long count; private double meanX; private double meanY; - private double m2X; // sum((x - meanX)^2) - private double m2Y; // sum((y - meanY)^2) - private double c2; // sum((x - meanX) * (y - meanY)) + private double m2X; + private double m2Y; + private double c2; public TableCorrelationAccumulator( TSDataType xDataType, @@ -118,7 +118,6 @@ private void update(double x, double y) { meanX += deltaX / newCount; meanY += deltaY / newCount; - // Welford's algorithm for covariance and variance c2 += deltaX * (y - meanY); m2X += deltaX * (x - meanX); m2Y += deltaY * (y - meanY); @@ -180,7 +179,6 @@ private void merge( double deltaX = otherMeanX - meanX; double deltaY = otherMeanY - meanY; - // Merge formulas c2 += otherC2 + deltaX * deltaY * count * otherCount / newCount; m2X += otherM2X + deltaX * deltaX * count * otherCount / newCount; m2Y += otherM2Y + deltaY * deltaY * count * otherCount / newCount; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableRegressionAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableRegressionAccumulator.java index e0dd15b37438b..5b655a67eb23c 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableRegressionAccumulator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableRegressionAccumulator.java @@ -66,7 +66,7 @@ public TableAccumulator copy() { @Override public void addInput(Column[] arguments, AggregationMask mask) { - // arguments[0] -> Y, arguments[1] -> X + int positionCount = mask.getSelectedPositionCount(); if (mask.isSelectAll()) { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCentralMomentAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCentralMomentAccumulator.java index 5674ca4510e2c..6e84edd9b94be 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCentralMomentAccumulator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCentralMomentAccumulator.java @@ -249,7 +249,7 @@ public void evaluateFinal(int groupId, ColumnBuilder columnBuilder) { double result = (count * m3) / ((count - 1) * (count - 2) * stdev * stdev * stdev); columnBuilder.writeDouble(result); } - } else { // KURTOSIS + } else { if (count < 4) { columnBuilder.appendNull(); } else { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCorrelationAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCorrelationAccumulator.java index beb474b97c7b9..6b6a21f00844c 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCorrelationAccumulator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCorrelationAccumulator.java @@ -136,7 +136,6 @@ private void update(int groupId, double x, double y) { meanXs.add(groupId, deltaX / newCount); meanYs.add(groupId, deltaY / newCount); - // Welford's algorithm for covariance and variance c2s.add(groupId, deltaX * (y - meanYs.get(groupId))); m2Xs.add(groupId, deltaX * (x - meanXs.get(groupId))); m2Ys.add(groupId, deltaY * (y - meanYs.get(groupId))); @@ -194,7 +193,6 @@ private void merge( double deltaX = otherMeanX - meanXs.get(groupId); double deltaY = otherMeanY - meanYs.get(groupId); - // Merge formulas c2s.add(groupId, otherC2 + deltaX * deltaY * counts.get(groupId) * otherCount / newCount); m2Xs.add(groupId, otherM2X + deltaX * deltaX * counts.get(groupId) * otherCount / newCount); m2Ys.add(groupId, otherM2Y + deltaY * deltaY * counts.get(groupId) * otherCount / newCount); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedRegressionAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedRegressionAccumulator.java index d5cd7f9ce4307..97aabf8c96a19 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedRegressionAccumulator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedRegressionAccumulator.java @@ -83,7 +83,7 @@ public void setGroupCount(long groupCount) { @Override public void addInput(int[] groupIds, Column[] arguments, AggregationMask mask) { - // arguments[0] -> Y, arguments[1] -> X + int positionCount = mask.getSelectedPositionCount(); if (mask.isSelectAll()) { @@ -135,7 +135,6 @@ private void update(int groupId, double x, double y) { meanXs.add(groupId, deltaX / newCount); meanYs.add(groupId, deltaY / newCount); - // Welford's algorithm for covariance and variance of X c2s.add(groupId, deltaX * (y - meanYs.get(groupId))); m2Xs.add(groupId, deltaX * (x - meanXs.get(groupId)));