diff --git a/integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinAggregationFunctionEnum.java b/integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinAggregationFunctionEnum.java index 7c2c283d30e3a..5ebb8e12f9be0 100644 --- a/integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinAggregationFunctionEnum.java +++ b/integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinAggregationFunctionEnum.java @@ -42,7 +42,14 @@ public enum BuiltinAggregationFunctionEnum { AVG("avg"), SUM("sum"), MAX_BY("max_by"), - MIN_BY("min_by"); + MIN_BY("min_by"), + CORR("corr"), + COVAR_POP("covar_pop"), + COVAR_SAMP("covar_samp"), + REGR_SLOPE("regr_slope"), + REGR_INTERCEPT("regr_intercept"), + SKEWNESS("skewness"), + KURTOSIS("kurtosis"); private final String functionName; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/AccumulatorFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/AccumulatorFactory.java index 24a998f54a917..90ec90f81acf4 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/AccumulatorFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/AccumulatorFactory.java @@ -69,6 +69,11 @@ public static boolean isMultiInputAggregation(TAggregationType aggregationType) switch (aggregationType) { case MAX_BY: case MIN_BY: + case CORR: + case COVAR_POP: + case COVAR_SAMP: + case REGR_SLOPE: + case REGR_INTERCEPT: return true; default: return false; @@ -84,6 +89,31 @@ public static Accumulator createBuiltinMultiInputAccumulator( case MIN_BY: checkState(inputDataTypes.size() == 2, "Wrong inputDataTypes size."); return new MinByAccumulator(inputDataTypes.get(0), inputDataTypes.get(1)); + case CORR: + checkState(inputDataTypes.size() == 2, "Wrong inputDataTypes size."); + return new CorrelationAccumulator( + new TSDataType[] {inputDataTypes.get(0), inputDataTypes.get(1)}, + CorrelationAccumulator.CorrelationType.CORR); + case COVAR_POP: + checkState(inputDataTypes.size() == 2, "Wrong inputDataTypes size."); + return new CorrelationAccumulator( + new TSDataType[] {inputDataTypes.get(0), inputDataTypes.get(1)}, + CorrelationAccumulator.CorrelationType.COVAR_POP); + case COVAR_SAMP: + checkState(inputDataTypes.size() == 2, "Wrong inputDataTypes size."); + return new CorrelationAccumulator( + new TSDataType[] {inputDataTypes.get(0), inputDataTypes.get(1)}, + CorrelationAccumulator.CorrelationType.COVAR_SAMP); + case REGR_SLOPE: + checkState(inputDataTypes.size() == 2, "Wrong inputDataTypes size."); + return new RegressionAccumulator( + new TSDataType[] {inputDataTypes.get(0), inputDataTypes.get(1)}, + RegressionAccumulator.RegressionType.REGR_SLOPE); + case REGR_INTERCEPT: + checkState(inputDataTypes.size() == 2, "Wrong inputDataTypes size."); + return new RegressionAccumulator( + new TSDataType[] {inputDataTypes.get(0), inputDataTypes.get(1)}, + RegressionAccumulator.RegressionType.REGR_INTERCEPT); default: throw new IllegalArgumentException("Invalid Aggregation function: " + aggregationType); } @@ -140,6 +170,12 @@ private static Accumulator createBuiltinSingleInputAccumulator( return new VarianceAccumulator(tsDataType, VarianceAccumulator.VarianceType.VAR_SAMP); case VAR_POP: return new VarianceAccumulator(tsDataType, VarianceAccumulator.VarianceType.VAR_POP); + case SKEWNESS: + return new CentralMomentAccumulator( + tsDataType, CentralMomentAccumulator.MomentType.SKEWNESS); + case KURTOSIS: + return new CentralMomentAccumulator( + tsDataType, CentralMomentAccumulator.MomentType.KURTOSIS); default: throw new IllegalArgumentException("Invalid Aggregation function: " + aggregationType); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CentralMomentAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CentralMomentAccumulator.java new file mode 100644 index 0000000000000..6433757df92ba --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CentralMomentAccumulator.java @@ -0,0 +1,238 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.aggregation; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.statistics.Statistics; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.BitMap; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class CentralMomentAccumulator implements Accumulator { + + public enum MomentType { + SKEWNESS, + KURTOSIS + } + + private final TSDataType seriesDataType; + private final MomentType momentType; + + private long count; + private double mean; + private double m2; + private double m3; + private double m4; + + public CentralMomentAccumulator(TSDataType seriesDataType, MomentType momentType) { + this.seriesDataType = seriesDataType; + this.momentType = momentType; + } + + @Override + public void addInput(Column[] columns, BitMap bitMap) { + + int size = columns[1].getPositionCount(); + for (int i = 0; i < size; i++) { + if (bitMap != null && !bitMap.isMarked(i)) { + continue; + } + if (columns[1].isNull(i)) { + continue; + } + update(getDoubleValue(columns[1], i)); + } + } + + private double getDoubleValue(Column column, int position) { + switch (seriesDataType) { + case INT32: + case DATE: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new UnsupportedOperationException( + "Unsupported data type in CentralMoment Aggregation: " + seriesDataType); + } + } + + private void update(double value) { + long n1 = count; + count++; + + double delta = value - mean; + double delta_n = delta / count; + double delta_n2 = delta_n * delta_n; + double term1 = delta * delta_n * n1; + + mean += delta_n; + + m4 += term1 * delta_n2 * (count * count - 3 * count + 3) + 6 * delta_n2 * m2 - 4 * delta_n * m3; + + m3 += term1 * delta_n * (count - 2) - 3 * delta_n * m2; + + m2 += term1; + } + + @Override + public void addIntermediate(Column[] partialResult) { + checkArgument(partialResult.length == 1, "partialResult of CentralMoment should be 1"); + if (partialResult[0].isNull(0)) { + return; + } + byte[] bytes = partialResult[0].getBinary(0).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMean = buffer.getDouble(); + double otherM2 = buffer.getDouble(); + double otherM3 = buffer.getDouble(); + double otherM4 = buffer.getDouble(); + + merge(otherCount, otherMean, otherM2, otherM3, otherM4); + } + + private void merge(long nB, double meanB, double m2B, double m3B, double m4B) { + if (nB == 0) return; + if (count == 0) { + count = nB; + mean = meanB; + m2 = m2B; + m3 = m3B; + m4 = m4B; + } else { + long nA = count; + long nTotal = nA + nB; + double delta = meanB - mean; + double delta2 = delta * delta; + double delta3 = delta * delta2; + double delta4 = delta2 * delta2; + + m4 += + m4B + + delta4 * nA * nB * (nA * nA - nA * nB + nB * nB) / (nTotal * nTotal * nTotal) + + 6.0 * delta2 * (nA * nA * m2B + nB * nB * m2) / (nTotal * nTotal) + + 4.0 * delta * (nA * m3B - nB * m3) / nTotal; + + m3 += + m3B + + delta3 * nA * nB * (nA - nB) / (nTotal * nTotal) + + 3.0 * delta * (nA * m2B - nB * m2) / nTotal; + + m2 += m2B + delta2 * nA * nB / nTotal; + + mean += delta * nB / nTotal; + count = nTotal; + } + } + + @Override + public void outputIntermediate(ColumnBuilder[] columnBuilders) { + checkArgument(columnBuilders.length == 1, "partialResult should be 1"); + if (count == 0) { + columnBuilders[0].appendNull(); + } else { + + byte[] bytes = new byte[40]; + ByteBuffer buffer = ByteBuffer.wrap(bytes); + buffer.putLong(count); + buffer.putDouble(mean); + buffer.putDouble(m2); + buffer.putDouble(m3); + buffer.putDouble(m4); + columnBuilders[0].writeBinary(new Binary(bytes)); + } + } + + @Override + public void outputFinal(ColumnBuilder columnBuilder) { + if (count == 0 || m2 == 0) { + columnBuilder.appendNull(); + return; + } + + if (momentType == MomentType.SKEWNESS) { + if (count < 3) { + columnBuilder.appendNull(); + } else { + + double variance = m2 / (count - 1); + double stdev = Math.sqrt(variance); + double result = (count * m3) / ((count - 1) * (count - 2) * stdev * stdev * stdev); + columnBuilder.writeDouble(result); + } + } else { + if (count < 4) { + columnBuilder.appendNull(); + } else { + + double variance = m2 / (count - 1); + double term1 = + (count * (count + 1) * m4) + / ((count - 1) * (count - 2) * (count - 3) * variance * variance); + double term2 = (3 * Math.pow(count - 1, 2)) / ((count - 2) * (count - 3)); + columnBuilder.writeDouble(term1 - term2); + } + } + } + + @Override + public void removeIntermediate(Column[] input) { + throw new UnsupportedOperationException(); + } + + @Override + public void addStatistics(Statistics statistics) { + throw new UnsupportedOperationException(); + } + + @Override + public void setFinal(Column finalResult) {} + + @Override + public void reset() { + count = 0; + mean = 0; + m2 = 0; + m3 = 0; + m4 = 0; + } + + @Override + public boolean hasFinalResult() { + return false; + } + + @Override + public TSDataType[] getIntermediateType() { + return new TSDataType[] {TSDataType.TEXT}; + } + + @Override + public TSDataType getFinalType() { + return TSDataType.DOUBLE; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CorrelationAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CorrelationAccumulator.java new file mode 100644 index 0000000000000..c304bd2bd673b --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CorrelationAccumulator.java @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.aggregation; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.statistics.Statistics; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.BitMap; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class CorrelationAccumulator implements Accumulator { + + public enum CorrelationType { + CORR, + COVAR_POP, + COVAR_SAMP + } + + private final TSDataType[] seriesDataTypes; + private final CorrelationType correlationType; + + private long count; + private double meanX; + private double meanY; + private double m2X; + private double m2Y; + private double c2; + + public CorrelationAccumulator(TSDataType[] seriesDataTypes, CorrelationType correlationType) { + this.seriesDataTypes = seriesDataTypes; + this.correlationType = correlationType; + } + + @Override + public void addInput(Column[] columns, BitMap bitMap) { + + int size = columns[0].getPositionCount(); + for (int i = 0; i < size; i++) { + if (bitMap != null && !bitMap.isMarked(i)) { + continue; + } + if (columns[1].isNull(i) || columns[2].isNull(i)) { + continue; + } + + double x = getDoubleValue(columns[1], i, seriesDataTypes[0]); + double y = getDoubleValue(columns[2], i, seriesDataTypes[1]); + + update(x, y); + } + } + + private double getDoubleValue(Column column, int position, TSDataType dataType) { + switch (dataType) { + case INT32: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new IllegalArgumentException("Unsupported data type: " + dataType); + } + } + + private void update(double x, double y) { + long newCount = count + 1; + double deltaX = x - meanX; + double deltaY = y - meanY; + + meanX += deltaX / newCount; + meanY += deltaY / newCount; + + c2 += deltaX * (y - meanY); + m2X += deltaX * (x - meanX); + m2Y += deltaY * (y - meanY); + + count = newCount; + } + + @Override + public void addIntermediate(Column[] partialResult) { + checkArgument(partialResult.length == 1, "partialResult of Correlation should be 1"); + if (partialResult[0].isNull(0)) { + return; + } + byte[] bytes = partialResult[0].getBinary(0).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMeanX = buffer.getDouble(); + double otherMeanY = buffer.getDouble(); + double otherM2X = buffer.getDouble(); + double otherM2Y = buffer.getDouble(); + double otherC2 = buffer.getDouble(); + + merge(otherCount, otherMeanX, otherMeanY, otherM2X, otherM2Y, otherC2); + } + + private void merge( + long otherCount, + double otherMeanX, + double otherMeanY, + double otherM2X, + double otherM2Y, + double otherC2) { + if (otherCount == 0) { + return; + } + if (count == 0) { + count = otherCount; + meanX = otherMeanX; + meanY = otherMeanY; + m2X = otherM2X; + m2Y = otherM2Y; + c2 = otherC2; + } else { + long newCount = count + otherCount; + double deltaX = otherMeanX - meanX; + double deltaY = otherMeanY - meanY; + + c2 += otherC2 + deltaX * deltaY * count * otherCount / newCount; + m2X += otherM2X + deltaX * deltaX * count * otherCount / newCount; + m2Y += otherM2Y + deltaY * deltaY * count * otherCount / newCount; + + meanX += deltaX * otherCount / newCount; + meanY += deltaY * otherCount / newCount; + count = newCount; + } + } + + @Override + public void outputIntermediate(ColumnBuilder[] columnBuilders) { + checkArgument(columnBuilders.length == 1, "partialResult of Correlation should be 1"); + if (count == 0) { + columnBuilders[0].appendNull(); + } else { + ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES + Double.BYTES * 5); + buffer.putLong(count); + buffer.putDouble(meanX); + buffer.putDouble(meanY); + buffer.putDouble(m2X); + buffer.putDouble(m2Y); + buffer.putDouble(c2); + columnBuilders[0].writeBinary(new Binary(buffer.array())); + } + } + + @Override + public void outputFinal(ColumnBuilder columnBuilder) { + switch (correlationType) { + case CORR: + if (count < 2) { + + columnBuilder.appendNull(); + } else if (m2X == 0 || m2Y == 0) { + + columnBuilder.writeDouble(0.0); + } else { + columnBuilder.writeDouble(c2 / Math.sqrt(m2X * m2Y)); + } + break; + case COVAR_POP: + if (count == 0) { + columnBuilder.appendNull(); + } else { + columnBuilder.writeDouble(c2 / count); + } + break; + case COVAR_SAMP: + if (count < 2) { + columnBuilder.appendNull(); + } else { + columnBuilder.writeDouble(c2 / (count - 1)); + } + break; + default: + throw new UnsupportedOperationException("Unknown type: " + correlationType); + } + } + + @Override + public void removeIntermediate(Column[] input) { + + throw new UnsupportedOperationException("Remove not implemented for Correlation"); + } + + @Override + public void addStatistics(Statistics statistics) { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void setFinal(Column finalResult) {} + + @Override + public void reset() { + count = 0; + meanX = 0; + meanY = 0; + m2X = 0; + m2Y = 0; + c2 = 0; + } + + @Override + public boolean hasFinalResult() { + return false; + } + + @Override + public TSDataType[] getIntermediateType() { + return new TSDataType[] {TSDataType.TEXT}; + } + + @Override + public TSDataType getFinalType() { + return TSDataType.DOUBLE; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/RegressionAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/RegressionAccumulator.java new file mode 100644 index 0000000000000..c6e39746fc954 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/RegressionAccumulator.java @@ -0,0 +1,220 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.aggregation; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.statistics.Statistics; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.BitMap; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class RegressionAccumulator implements Accumulator { + + public enum RegressionType { + REGR_SLOPE, + REGR_INTERCEPT + } + + private final TSDataType[] seriesDataTypes; + private final RegressionType regressionType; + + private long count; + private double meanX; + private double meanY; + private double m2X; + private double c2; + + public RegressionAccumulator(TSDataType[] seriesDataTypes, RegressionType regressionType) { + this.seriesDataTypes = seriesDataTypes; + this.regressionType = regressionType; + } + + @Override + public void addInput(Column[] columns, BitMap bitMap) { + + int size = columns[1].getPositionCount(); + for (int i = 0; i < size; i++) { + if (bitMap != null && !bitMap.isMarked(i)) { + continue; + } + if (columns[1].isNull(i) || columns[2].isNull(i)) { + continue; + } + + double y = getDoubleValue(columns[1], i, seriesDataTypes[0]); + double x = getDoubleValue(columns[2], i, seriesDataTypes[1]); + + update(x, y); + } + } + + private double getDoubleValue(Column column, int position, TSDataType dataType) { + switch (dataType) { + case INT32: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new IllegalArgumentException("Unsupported data type: " + dataType); + } + } + + private void update(double x, double y) { + long newCount = count + 1; + double deltaX = x - meanX; + double deltaY = y - meanY; + + meanX += deltaX / newCount; + meanY += deltaY / newCount; + + c2 += deltaX * (y - meanY); + m2X += deltaX * (x - meanX); + + count = newCount; + } + + @Override + public void addIntermediate(Column[] partialResult) { + checkArgument(partialResult.length == 1, "partialResult of Regression should be 1"); + if (partialResult[0].isNull(0)) { + return; + } + byte[] bytes = partialResult[0].getBinary(0).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMeanX = buffer.getDouble(); + double otherMeanY = buffer.getDouble(); + double otherM2X = buffer.getDouble(); + double otherC2 = buffer.getDouble(); + + merge(otherCount, otherMeanX, otherMeanY, otherM2X, otherC2); + } + + private void merge( + long otherCount, double otherMeanX, double otherMeanY, double otherM2X, double otherC2) { + if (otherCount == 0) { + return; + } + if (count == 0) { + count = otherCount; + meanX = otherMeanX; + meanY = otherMeanY; + m2X = otherM2X; + c2 = otherC2; + } else { + long newCount = count + otherCount; + double deltaX = otherMeanX - meanX; + double deltaY = otherMeanY - meanY; + + c2 += otherC2 + deltaX * deltaY * count * otherCount / newCount; + m2X += otherM2X + deltaX * deltaX * count * otherCount / newCount; + + meanX += deltaX * otherCount / newCount; + meanY += deltaY * otherCount / newCount; + count = newCount; + } + } + + @Override + public void outputIntermediate(ColumnBuilder[] columnBuilders) { + checkArgument(columnBuilders.length == 1, "partialResult of Regression should be 1"); + if (count == 0) { + columnBuilders[0].appendNull(); + } else { + byte[] bytes = new byte[40]; + ByteBuffer buffer = ByteBuffer.wrap(bytes); + buffer.putLong(count); + buffer.putDouble(meanX); + buffer.putDouble(meanY); + buffer.putDouble(m2X); + buffer.putDouble(c2); + columnBuilders[0].writeBinary(new Binary(bytes)); + } + } + + @Override + public void outputFinal(ColumnBuilder columnBuilder) { + if (count == 0) { + columnBuilder.appendNull(); + return; + } + + if (m2X == 0) { + columnBuilder.appendNull(); + return; + } + + double slope = c2 / m2X; + + switch (regressionType) { + case REGR_SLOPE: + columnBuilder.writeDouble(slope); + break; + case REGR_INTERCEPT: + columnBuilder.writeDouble(meanY - slope * meanX); + break; + default: + throw new UnsupportedOperationException("Unknown type: " + regressionType); + } + } + + @Override + public void removeIntermediate(Column[] input) { + throw new UnsupportedOperationException(); + } + + @Override + public void addStatistics(Statistics statistics) { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void setFinal(Column finalResult) {} + + @Override + public void reset() { + count = 0; + meanX = 0; + meanY = 0; + m2X = 0; + c2 = 0; + } + + @Override + public boolean hasFinalResult() { + return false; + } + + @Override + public TSDataType[] getIntermediateType() { + return new TSDataType[] {TSDataType.TEXT}; + } + + @Override + public TSDataType getFinalType() { + return TSDataType.DOUBLE; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/slidingwindow/SlidingWindowAggregatorFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/slidingwindow/SlidingWindowAggregatorFactory.java index 572d41d518486..a3ca7212c8bbc 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/slidingwindow/SlidingWindowAggregatorFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/slidingwindow/SlidingWindowAggregatorFactory.java @@ -200,6 +200,13 @@ public static SlidingWindowAggregator createSlidingWindowAggregator( case VARIANCE: case VAR_POP: case VAR_SAMP: + case CORR: + case COVAR_POP: + case COVAR_SAMP: + case REGR_SLOPE: + case REGR_INTERCEPT: + case SKEWNESS: + case KURTOSIS: case UDAF: // Currently UDAF belongs to SmoothQueueSlidingWindowAggregator return new SmoothQueueSlidingWindowAggregator(accumulator, inputLocationList, step); case MAX_VALUE: diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java index 3ff20974168be..b84f948f0ccc3 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java @@ -21,6 +21,9 @@ import org.apache.iotdb.common.rpc.thrift.TAggregationType; import org.apache.iotdb.commons.udf.utils.UDFDataTypeTransformer; +import org.apache.iotdb.db.queryengine.execution.aggregation.CentralMomentAccumulator; +import org.apache.iotdb.db.queryengine.execution.aggregation.CorrelationAccumulator; +import org.apache.iotdb.db.queryengine.execution.aggregation.RegressionAccumulator; import org.apache.iotdb.db.queryengine.execution.aggregation.VarianceAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.BinaryGroupedApproxMostFrequentAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.BlobGroupedApproxMostFrequentAccumulator; @@ -30,6 +33,8 @@ import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedApproxCountDistinctAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedAvgAccumulator; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedCentralMomentAccumulator; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedCorrelationAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedCountAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedCountAllAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedCountIfAccumulator; @@ -43,6 +48,7 @@ import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedMinAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedMinByAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedModeAccumulator; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedRegressionAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedSumAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedUserDefinedAggregateAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedVarianceAccumulator; @@ -256,6 +262,37 @@ private static GroupedAccumulator createBuiltinGroupedAccumulator( return new GroupedApproxCountDistinctAccumulator(inputDataTypes.get(0)); case APPROX_MOST_FREQUENT: return getGroupedApproxMostFrequentAccumulator(inputDataTypes.get(0)); + case CORR: + return new GroupedCorrelationAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + CorrelationAccumulator.CorrelationType.CORR); + case COVAR_POP: + return new GroupedCorrelationAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + CorrelationAccumulator.CorrelationType.COVAR_POP); + case COVAR_SAMP: + return new GroupedCorrelationAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + CorrelationAccumulator.CorrelationType.COVAR_SAMP); + case REGR_SLOPE: + return new GroupedRegressionAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + RegressionAccumulator.RegressionType.REGR_SLOPE); + case REGR_INTERCEPT: + return new GroupedRegressionAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + RegressionAccumulator.RegressionType.REGR_INTERCEPT); + case SKEWNESS: + return new GroupedCentralMomentAccumulator( + inputDataTypes.get(0), CentralMomentAccumulator.MomentType.SKEWNESS); + case KURTOSIS: + return new GroupedCentralMomentAccumulator( + inputDataTypes.get(0), CentralMomentAccumulator.MomentType.KURTOSIS); default: throw new IllegalArgumentException("Invalid Aggregation function: " + aggregationType); } @@ -325,6 +362,37 @@ public static TableAccumulator createBuiltinAccumulator( return new ApproxCountDistinctAccumulator(inputDataTypes.get(0)); case APPROX_MOST_FREQUENT: return getApproxMostFrequentAccumulator(inputDataTypes.get(0)); + case CORR: + return new TableCorrelationAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + CorrelationAccumulator.CorrelationType.CORR); + case COVAR_POP: + return new TableCorrelationAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + CorrelationAccumulator.CorrelationType.COVAR_POP); + case COVAR_SAMP: + return new TableCorrelationAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + CorrelationAccumulator.CorrelationType.COVAR_SAMP); + case REGR_SLOPE: + return new TableRegressionAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + RegressionAccumulator.RegressionType.REGR_SLOPE); + case REGR_INTERCEPT: + return new TableRegressionAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + RegressionAccumulator.RegressionType.REGR_INTERCEPT); + case SKEWNESS: + return new TableCentralMomentAccumulator( + inputDataTypes.get(0), CentralMomentAccumulator.MomentType.SKEWNESS); + case KURTOSIS: + return new TableCentralMomentAccumulator( + inputDataTypes.get(0), CentralMomentAccumulator.MomentType.KURTOSIS); default: throw new IllegalArgumentException("Invalid Aggregation function: " + aggregationType); } @@ -385,6 +453,12 @@ public static boolean isMultiInputAggregation(TAggregationType aggregationType) case MAX_BY: case MIN_BY: return true; + case CORR: + case COVAR_POP: + case COVAR_SAMP: + case REGR_SLOPE: + case REGR_INTERCEPT: + return true; default: return false; } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCentralMomentAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCentralMomentAccumulator.java new file mode 100644 index 0000000000000..bddef8ee41e6e --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCentralMomentAccumulator.java @@ -0,0 +1,253 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation; + +import org.apache.iotdb.db.queryengine.execution.aggregation.CentralMomentAccumulator; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.statistics.Statistics; +import org.apache.tsfile.read.common.block.column.BinaryColumn; +import org.apache.tsfile.read.common.block.column.BinaryColumnBuilder; +import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.RamUsageEstimator; +import org.apache.tsfile.write.UnSupportedDataTypeException; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class TableCentralMomentAccumulator implements TableAccumulator { + + private static final long INSTANCE_SIZE = + RamUsageEstimator.shallowSizeOfInstance(TableCentralMomentAccumulator.class); + + private final TSDataType seriesDataType; + private final CentralMomentAccumulator.MomentType momentType; + + private long count; + private double mean; + private double m2; + private double m3; + private double m4; + + public TableCentralMomentAccumulator( + TSDataType seriesDataType, CentralMomentAccumulator.MomentType momentType) { + this.seriesDataType = seriesDataType; + this.momentType = momentType; + } + + @Override + public void addInput(Column[] arguments, AggregationMask mask) { + int positionCount = mask.getSelectedPositionCount(); + if (mask.isSelectAll()) { + for (int i = 0; i < positionCount; i++) { + if (!arguments[0].isNull(i)) { + update(getDoubleValue(arguments[0], i)); + } + } + } else { + int[] selectedPositions = mask.getSelectedPositions(); + for (int i = 0; i < positionCount; i++) { + int position = selectedPositions[i]; + if (!arguments[0].isNull(position)) { + update(getDoubleValue(arguments[0], position)); + } + } + } + } + + private double getDoubleValue(Column column, int position) { + switch (seriesDataType) { + case INT32: + case DATE: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new UnSupportedDataTypeException( + String.format( + "Unsupported data type in CentralMoment Aggregation: %s", seriesDataType)); + } + } + + private void update(double value) { + long n1 = count; + count++; + double delta = value - mean; + double delta_n = delta / count; + double delta_n2 = delta_n * delta_n; + double term1 = delta * delta_n * n1; + mean += delta_n; + m4 += term1 * delta_n2 * (count * count - 3 * count + 3) + 6 * delta_n2 * m2 - 4 * delta_n * m3; + m3 += term1 * delta_n * (count - 2) - 3 * delta_n * m2; + m2 += term1; + } + + @Override + public void addIntermediate(Column argument) { + checkArgument( + argument instanceof BinaryColumn + || (argument instanceof RunLengthEncodedColumn + && ((RunLengthEncodedColumn) argument).getValue() instanceof BinaryColumn), + "intermediate input and output should be BinaryColumn"); + + for (int i = 0; i < argument.getPositionCount(); i++) { + if (argument.isNull(i)) { + continue; + } + byte[] bytes = argument.getBinary(i).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMean = buffer.getDouble(); + double otherM2 = buffer.getDouble(); + double otherM3 = buffer.getDouble(); + double otherM4 = buffer.getDouble(); + + merge(otherCount, otherMean, otherM2, otherM3, otherM4); + } + } + + private void merge(long nB, double meanB, double m2B, double m3B, double m4B) { + if (nB == 0) return; + if (count == 0) { + count = nB; + mean = meanB; + m2 = m2B; + m3 = m3B; + m4 = m4B; + } else { + long nA = count; + long nTotal = nA + nB; + double delta = meanB - mean; + double delta2 = delta * delta; + double delta3 = delta * delta2; + double delta4 = delta2 * delta2; + + m4 += + m4B + + delta4 * nA * nB * (nA * nA - nA * nB + nB * nB) / (nTotal * nTotal * nTotal) + + 6.0 * delta2 * (nA * nA * m2B + nB * nB * m2) / (nTotal * nTotal) + + 4.0 * delta * (nA * m3B - nB * m3) / nTotal; + + m3 += + m3B + + delta3 * nA * nB * (nA - nB) / (nTotal * nTotal) + + 3.0 * delta * (nA * m2B - nB * m2) / nTotal; + + m2 += m2B + delta2 * nA * nB / nTotal; + + mean += delta * nB / nTotal; + count = nTotal; + } + } + + @Override + public void evaluateIntermediate(ColumnBuilder columnBuilder) { + checkArgument( + columnBuilder instanceof BinaryColumnBuilder, + "intermediate input and output should be BinaryColumn"); + + if (count == 0) { + columnBuilder.appendNull(); + } else { + + ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES + Double.BYTES * 4); + buffer.putLong(count); + buffer.putDouble(mean); + buffer.putDouble(m2); + buffer.putDouble(m3); + buffer.putDouble(m4); + columnBuilder.writeBinary(new Binary(buffer.array())); + } + } + + @Override + public void evaluateFinal(ColumnBuilder columnBuilder) { + if (count == 0 || m2 == 0) { + columnBuilder.appendNull(); + return; + } + + if (momentType == CentralMomentAccumulator.MomentType.SKEWNESS) { + if (count < 3) { + columnBuilder.appendNull(); + } else { + double variance = m2 / (count - 1); + double stdev = Math.sqrt(variance); + double result = (count * m3) / ((count - 1) * (count - 2) * stdev * stdev * stdev); + columnBuilder.writeDouble(result); + } + } else { + if (count < 4) { + columnBuilder.appendNull(); + } else { + double variance = m2 / (count - 1); + double term1 = + (count * (count + 1) * m4) + / ((count - 1) * (count - 2) * (count - 3) * variance * variance); + double term2 = (3 * Math.pow(count - 1, 2)) / ((count - 2) * (count - 3)); + columnBuilder.writeDouble(term1 - term2); + } + } + } + + @Override + public long getEstimatedSize() { + return INSTANCE_SIZE; + } + + @Override + public TableAccumulator copy() { + return new TableCentralMomentAccumulator(seriesDataType, momentType); + } + + @Override + public void removeInput(Column[] arguments) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean hasFinalResult() { + return false; + } + + @Override + public void addStatistics(Statistics[] statistics) { + throw new UnsupportedOperationException(); + } + + @Override + public void reset() { + count = 0; + mean = 0; + m2 = 0; + m3 = 0; + m4 = 0; + } + + @Override + public boolean removable() { + return false; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCorrelationAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCorrelationAccumulator.java new file mode 100644 index 0000000000000..ad8407e492b3c --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCorrelationAccumulator.java @@ -0,0 +1,267 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation; + +import org.apache.iotdb.db.queryengine.execution.aggregation.CorrelationAccumulator; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.statistics.Statistics; +import org.apache.tsfile.read.common.block.column.BinaryColumn; +import org.apache.tsfile.read.common.block.column.BinaryColumnBuilder; +import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.RamUsageEstimator; +import org.apache.tsfile.write.UnSupportedDataTypeException; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class TableCorrelationAccumulator implements TableAccumulator { + + private static final long INSTANCE_SIZE = + RamUsageEstimator.shallowSizeOfInstance(TableCorrelationAccumulator.class); + private final TSDataType xDataType; + private final TSDataType yDataType; + private final CorrelationAccumulator.CorrelationType correlationType; + + private long count; + private double meanX; + private double meanY; + private double m2X; + private double m2Y; + private double c2; + + public TableCorrelationAccumulator( + TSDataType xDataType, + TSDataType yDataType, + CorrelationAccumulator.CorrelationType correlationType) { + this.xDataType = xDataType; + this.yDataType = yDataType; + this.correlationType = correlationType; + } + + @Override + public long getEstimatedSize() { + return INSTANCE_SIZE; + } + + @Override + public TableAccumulator copy() { + return new TableCorrelationAccumulator(xDataType, yDataType, correlationType); + } + + @Override + public void addInput(Column[] arguments, AggregationMask mask) { + int positionCount = mask.getSelectedPositionCount(); + + if (mask.isSelectAll()) { + for (int i = 0; i < positionCount; i++) { + if (arguments[0].isNull(i) || arguments[1].isNull(i)) { + continue; + } + double x = getDoubleValue(arguments[0], i, xDataType); + double y = getDoubleValue(arguments[1], i, yDataType); + update(x, y); + } + } else { + int[] selectedPositions = mask.getSelectedPositions(); + for (int i = 0; i < positionCount; i++) { + int position = selectedPositions[i]; + if (arguments[0].isNull(position) || arguments[1].isNull(position)) { + continue; + } + double x = getDoubleValue(arguments[0], position, xDataType); + double y = getDoubleValue(arguments[1], position, yDataType); + update(x, y); + } + } + } + + private double getDoubleValue(Column column, int position, TSDataType dataType) { + switch (dataType) { + case INT32: + case DATE: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new UnSupportedDataTypeException( + String.format("Unsupported data type in Correlation Aggregation: %s", dataType)); + } + } + + private void update(double x, double y) { + long newCount = count + 1; + double deltaX = x - meanX; + double deltaY = y - meanY; + + meanX += deltaX / newCount; + meanY += deltaY / newCount; + + c2 += deltaX * (y - meanY); + m2X += deltaX * (x - meanX); + m2Y += deltaY * (y - meanY); + + count = newCount; + } + + @Override + public void removeInput(Column[] arguments) { + throw new UnsupportedOperationException("Remove not implemented for Correlation Accumulator"); + } + + @Override + public void addIntermediate(Column argument) { + checkArgument( + argument instanceof BinaryColumn + || (argument instanceof RunLengthEncodedColumn + && ((RunLengthEncodedColumn) argument).getValue() instanceof BinaryColumn), + "intermediate input and output should be BinaryColumn"); + + for (int i = 0; i < argument.getPositionCount(); i++) { + if (argument.isNull(i)) { + continue; + } + + byte[] bytes = argument.getBinary(i).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMeanX = buffer.getDouble(); + double otherMeanY = buffer.getDouble(); + double otherM2X = buffer.getDouble(); + double otherM2Y = buffer.getDouble(); + double otherC2 = buffer.getDouble(); + + merge(otherCount, otherMeanX, otherMeanY, otherM2X, otherM2Y, otherC2); + } + } + + private void merge( + long otherCount, + double otherMeanX, + double otherMeanY, + double otherM2X, + double otherM2Y, + double otherC2) { + if (otherCount == 0) { + return; + } + if (count == 0) { + count = otherCount; + meanX = otherMeanX; + meanY = otherMeanY; + m2X = otherM2X; + m2Y = otherM2Y; + c2 = otherC2; + } else { + long newCount = count + otherCount; + double deltaX = otherMeanX - meanX; + double deltaY = otherMeanY - meanY; + + c2 += otherC2 + deltaX * deltaY * count * otherCount / newCount; + m2X += otherM2X + deltaX * deltaX * count * otherCount / newCount; + m2Y += otherM2Y + deltaY * deltaY * count * otherCount / newCount; + + meanX += deltaX * otherCount / newCount; + meanY += deltaY * otherCount / newCount; + count = newCount; + } + } + + @Override + public void evaluateIntermediate(ColumnBuilder columnBuilder) { + checkArgument( + columnBuilder instanceof BinaryColumnBuilder, + "intermediate input and output should be BinaryColumn"); + + if (count == 0) { + columnBuilder.appendNull(); + } else { + ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES + Double.BYTES * 5); + buffer.putLong(count); + buffer.putDouble(meanX); + buffer.putDouble(meanY); + buffer.putDouble(m2X); + buffer.putDouble(m2Y); + buffer.putDouble(c2); + columnBuilder.writeBinary(new Binary(buffer.array())); + } + } + + @Override + public void evaluateFinal(ColumnBuilder columnBuilder) { + switch (correlationType) { + case CORR: + if (count < 2) { + columnBuilder.appendNull(); + } else if (m2X == 0 || m2Y == 0) { + columnBuilder.writeDouble(0.0); + } else { + columnBuilder.writeDouble(c2 / Math.sqrt(m2X * m2Y)); + } + break; + case COVAR_POP: + if (count == 0) { + columnBuilder.appendNull(); + } else { + columnBuilder.writeDouble(c2 / count); + } + break; + case COVAR_SAMP: + if (count < 2) { + columnBuilder.appendNull(); + } else { + columnBuilder.writeDouble(c2 / (count - 1)); + } + break; + default: + throw new UnsupportedOperationException("Unknown type: " + correlationType); + } + } + + @Override + public boolean hasFinalResult() { + return false; + } + + @Override + public void addStatistics(Statistics[] statistics) { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void reset() { + count = 0; + meanX = 0; + meanY = 0; + m2X = 0; + m2Y = 0; + c2 = 0; + } + + @Override + public boolean removable() { + return false; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableRegressionAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableRegressionAccumulator.java new file mode 100644 index 0000000000000..5b655a67eb23c --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableRegressionAccumulator.java @@ -0,0 +1,230 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation; + +import org.apache.iotdb.db.queryengine.execution.aggregation.RegressionAccumulator; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.statistics.Statistics; +import org.apache.tsfile.read.common.block.column.BinaryColumn; +import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.RamUsageEstimator; +import org.apache.tsfile.write.UnSupportedDataTypeException; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class TableRegressionAccumulator implements TableAccumulator { + + private static final long INSTANCE_SIZE = + RamUsageEstimator.shallowSizeOfInstance(TableRegressionAccumulator.class); + + private final TSDataType yDataType; + private final TSDataType xDataType; + private final RegressionAccumulator.RegressionType regressionType; + + private long count; + private double meanX; + private double meanY; + private double m2X; + private double c2; + + public TableRegressionAccumulator( + TSDataType yDataType, + TSDataType xDataType, + RegressionAccumulator.RegressionType regressionType) { + this.yDataType = yDataType; + this.xDataType = xDataType; + this.regressionType = regressionType; + } + + @Override + public long getEstimatedSize() { + return INSTANCE_SIZE; + } + + @Override + public TableAccumulator copy() { + return new TableRegressionAccumulator(yDataType, xDataType, regressionType); + } + + @Override + public void addInput(Column[] arguments, AggregationMask mask) { + + int positionCount = mask.getSelectedPositionCount(); + + if (mask.isSelectAll()) { + for (int i = 0; i < positionCount; i++) { + if (arguments[0].isNull(i) || arguments[1].isNull(i)) { + continue; + } + double y = getDoubleValue(arguments[0], i, yDataType); + double x = getDoubleValue(arguments[1], i, xDataType); + update(x, y); + } + } else { + int[] selectedPositions = mask.getSelectedPositions(); + for (int i = 0; i < positionCount; i++) { + int position = selectedPositions[i]; + if (arguments[0].isNull(position) || arguments[1].isNull(position)) { + continue; + } + double y = getDoubleValue(arguments[0], position, yDataType); + double x = getDoubleValue(arguments[1], position, xDataType); + update(x, y); + } + } + } + + private double getDoubleValue(Column column, int position, TSDataType dataType) { + switch (dataType) { + case INT32: + case DATE: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new UnSupportedDataTypeException( + String.format("Unsupported data type in Regression Aggregation: %s", dataType)); + } + } + + private void update(double x, double y) { + long newCount = count + 1; + double deltaX = x - meanX; + double deltaY = y - meanY; + meanX += deltaX / newCount; + meanY += deltaY / newCount; + c2 += deltaX * (y - meanY); + m2X += deltaX * (x - meanX); + count = newCount; + } + + @Override + public void addIntermediate(Column argument) { + checkArgument( + argument instanceof BinaryColumn + || (argument instanceof RunLengthEncodedColumn + && ((RunLengthEncodedColumn) argument).getValue() instanceof BinaryColumn)); + + for (int i = 0; i < argument.getPositionCount(); i++) { + if (argument.isNull(i)) continue; + byte[] bytes = argument.getBinary(i).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMeanX = buffer.getDouble(); + double otherMeanY = buffer.getDouble(); + double otherM2X = buffer.getDouble(); + double otherC2 = buffer.getDouble(); + + merge(otherCount, otherMeanX, otherMeanY, otherM2X, otherC2); + } + } + + private void merge( + long otherCount, double otherMeanX, double otherMeanY, double otherM2X, double otherC2) { + if (otherCount == 0) return; + if (count == 0) { + count = otherCount; + meanX = otherMeanX; + meanY = otherMeanY; + m2X = otherM2X; + c2 = otherC2; + } else { + long newCount = count + otherCount; + double deltaX = otherMeanX - meanX; + double deltaY = otherMeanY - meanY; + c2 += otherC2 + deltaX * deltaY * count * otherCount / newCount; + m2X += otherM2X + deltaX * deltaX * count * otherCount / newCount; + meanX += deltaX * otherCount / newCount; + meanY += deltaY * otherCount / newCount; + count = newCount; + } + } + + @Override + public void evaluateIntermediate(ColumnBuilder columnBuilder) { + if (count == 0) { + columnBuilder.appendNull(); + } else { + byte[] bytes = new byte[40]; + ByteBuffer buffer = ByteBuffer.wrap(bytes); + buffer.putLong(count); + buffer.putDouble(meanX); + buffer.putDouble(meanY); + buffer.putDouble(m2X); + buffer.putDouble(c2); + columnBuilder.writeBinary(new Binary(bytes)); + } + } + + @Override + public void evaluateFinal(ColumnBuilder columnBuilder) { + if (count == 0 || m2X == 0) { + columnBuilder.appendNull(); + return; + } + double slope = c2 / m2X; + switch (regressionType) { + case REGR_SLOPE: + columnBuilder.writeDouble(slope); + break; + case REGR_INTERCEPT: + columnBuilder.writeDouble(meanY - slope * meanX); + break; + default: + throw new UnsupportedOperationException("Unknown type: " + regressionType); + } + } + + @Override + public void removeInput(Column[] arguments) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean hasFinalResult() { + return false; + } + + @Override + public void addStatistics(Statistics[] statistics) { + throw new UnsupportedOperationException(); + } + + @Override + public void reset() { + count = 0; + meanX = 0; + meanY = 0; + m2X = 0; + c2 = 0; + } + + @Override + public boolean removable() { + return false; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCentralMomentAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCentralMomentAccumulator.java new file mode 100644 index 0000000000000..6e84edd9b94be --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCentralMomentAccumulator.java @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped; + +import org.apache.iotdb.db.queryengine.execution.aggregation.CentralMomentAccumulator; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.AggregationMask; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.DoubleBigArray; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.LongBigArray; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.read.common.block.column.BinaryColumn; +import org.apache.tsfile.read.common.block.column.BinaryColumnBuilder; +import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.RamUsageEstimator; +import org.apache.tsfile.write.UnSupportedDataTypeException; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class GroupedCentralMomentAccumulator implements GroupedAccumulator { + + private static final long INSTANCE_SIZE = + RamUsageEstimator.shallowSizeOfInstance(GroupedCentralMomentAccumulator.class); + + private final TSDataType seriesDataType; + private final CentralMomentAccumulator.MomentType momentType; + + private final LongBigArray counts = new LongBigArray(); + private final DoubleBigArray means = new DoubleBigArray(); + private final DoubleBigArray m2s = new DoubleBigArray(); + private final DoubleBigArray m3s = new DoubleBigArray(); + private final DoubleBigArray m4s = new DoubleBigArray(); + + public GroupedCentralMomentAccumulator( + TSDataType seriesDataType, CentralMomentAccumulator.MomentType momentType) { + this.seriesDataType = seriesDataType; + this.momentType = momentType; + } + + @Override + public long getEstimatedSize() { + return INSTANCE_SIZE + + counts.sizeOf() + + means.sizeOf() + + m2s.sizeOf() + + m3s.sizeOf() + + m4s.sizeOf(); + } + + @Override + public void setGroupCount(long groupCount) { + counts.ensureCapacity(groupCount); + means.ensureCapacity(groupCount); + m2s.ensureCapacity(groupCount); + m3s.ensureCapacity(groupCount); + m4s.ensureCapacity(groupCount); + } + + @Override + public void addInput(int[] groupIds, Column[] arguments, AggregationMask mask) { + int positionCount = mask.getSelectedPositionCount(); + if (mask.isSelectAll()) { + for (int i = 0; i < positionCount; i++) { + if (!arguments[0].isNull(i)) { + update(groupIds[i], getDoubleValue(arguments[0], i)); + } + } + } else { + int[] selectedPositions = mask.getSelectedPositions(); + for (int i = 0; i < positionCount; i++) { + int position = selectedPositions[i]; + if (!arguments[0].isNull(position)) { + update(groupIds[position], getDoubleValue(arguments[0], position)); + } + } + } + } + + private double getDoubleValue(Column column, int position) { + switch (seriesDataType) { + case INT32: + case DATE: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new UnSupportedDataTypeException( + String.format( + "Unsupported data type in CentralMoment Aggregation: %s", seriesDataType)); + } + } + + private void update(int groupId, double value) { + long n1 = counts.get(groupId); + long newCount = n1 + 1; + double mean = means.get(groupId); + double m2 = m2s.get(groupId); + double m3 = m3s.get(groupId); + double m4 = m4s.get(groupId); + + double delta = value - mean; + double delta_n = delta / newCount; + double delta_n2 = delta_n * delta_n; + double term1 = delta * delta_n * n1; + + mean += delta_n; + m4 += + term1 * delta_n2 * (newCount * newCount - 3 * newCount + 3) + + 6 * delta_n2 * m2 + - 4 * delta_n * m3; + m3 += term1 * delta_n * (newCount - 2) - 3 * delta_n * m2; + m2 += term1; + + counts.set(groupId, newCount); + means.set(groupId, mean); + m2s.set(groupId, m2); + m3s.set(groupId, m3); + m4s.set(groupId, m4); + } + + @Override + public void addIntermediate(int[] groupIds, Column argument) { + checkArgument( + argument instanceof BinaryColumn + || (argument instanceof RunLengthEncodedColumn + && ((RunLengthEncodedColumn) argument).getValue() instanceof BinaryColumn), + "intermediate input and output should be BinaryColumn"); + + for (int i = 0; i < argument.getPositionCount(); i++) { + if (argument.isNull(i)) { + continue; + } + byte[] bytes = argument.getBinary(i).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMean = buffer.getDouble(); + double otherM2 = buffer.getDouble(); + double otherM3 = buffer.getDouble(); + double otherM4 = buffer.getDouble(); + + merge(groupIds[i], otherCount, otherMean, otherM2, otherM3, otherM4); + } + } + + private void merge(int groupId, long nB, double meanB, double m2B, double m3B, double m4B) { + if (nB == 0) return; + long nA = counts.get(groupId); + if (nA == 0) { + counts.set(groupId, nB); + means.set(groupId, meanB); + m2s.set(groupId, m2B); + m3s.set(groupId, m3B); + m4s.set(groupId, m4B); + } else { + long nTotal = nA + nB; + double delta = meanB - means.get(groupId); + double delta2 = delta * delta; + double delta3 = delta * delta2; + double delta4 = delta2 * delta2; + + double m2 = m2s.get(groupId); + double m3 = m3s.get(groupId); + double m4 = m4s.get(groupId); + + m4 += + m4B + + delta4 * nA * nB * (nA * nA - nA * nB + nB * nB) / (nTotal * nTotal * nTotal) + + 6.0 * delta2 * (nA * nA * m2B + nB * nB * m2) / (nTotal * nTotal) + + 4.0 * delta * (nA * m3B - nB * m3) / nTotal; + + m3 += + m3B + + delta3 * nA * nB * (nA - nB) / (nTotal * nTotal) + + 3.0 * delta * (nA * m2B - nB * m2) / nTotal; + + m2 += m2B + delta2 * nA * nB / nTotal; + + means.add(groupId, delta * nB / nTotal); + counts.set(groupId, nTotal); + m2s.set(groupId, m2); + m3s.set(groupId, m3); + m4s.set(groupId, m4); + } + } + + @Override + public void evaluateIntermediate(int groupId, ColumnBuilder columnBuilder) { + checkArgument( + columnBuilder instanceof BinaryColumnBuilder, + "intermediate input and output should be BinaryColumn"); + + if (counts.get(groupId) == 0) { + columnBuilder.appendNull(); + } else { + ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES + Double.BYTES * 4); + buffer.putLong(counts.get(groupId)); + buffer.putDouble(means.get(groupId)); + buffer.putDouble(m2s.get(groupId)); + buffer.putDouble(m3s.get(groupId)); + buffer.putDouble(m4s.get(groupId)); + columnBuilder.writeBinary(new Binary(buffer.array())); + } + } + + @Override + public void evaluateFinal(int groupId, ColumnBuilder columnBuilder) { + long count = counts.get(groupId); + double m2 = m2s.get(groupId); + + if (count == 0 || m2 == 0) { + columnBuilder.appendNull(); + return; + } + + if (momentType == CentralMomentAccumulator.MomentType.SKEWNESS) { + if (count < 3) { + columnBuilder.appendNull(); + } else { + double m3 = m3s.get(groupId); + double variance = m2 / (count - 1); + double stdev = Math.sqrt(variance); + double result = (count * m3) / ((count - 1) * (count - 2) * stdev * stdev * stdev); + columnBuilder.writeDouble(result); + } + } else { + if (count < 4) { + columnBuilder.appendNull(); + } else { + double m4 = m4s.get(groupId); + double variance = m2 / (count - 1); + double term1 = + (count * (count + 1) * m4) + / ((count - 1) * (count - 2) * (count - 3) * variance * variance); + double term2 = (3 * Math.pow(count - 1, 2)) / ((count - 2) * (count - 3)); + columnBuilder.writeDouble(term1 - term2); + } + } + } + + @Override + public void prepareFinal() {} + + @Override + public void reset() { + counts.reset(); + means.reset(); + m2s.reset(); + m3s.reset(); + m4s.reset(); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCorrelationAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCorrelationAccumulator.java new file mode 100644 index 0000000000000..6b6a21f00844c --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCorrelationAccumulator.java @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped; + +import org.apache.iotdb.db.queryengine.execution.aggregation.CorrelationAccumulator; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.AggregationMask; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.DoubleBigArray; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.LongBigArray; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.read.common.block.column.BinaryColumn; +import org.apache.tsfile.read.common.block.column.BinaryColumnBuilder; +import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.RamUsageEstimator; +import org.apache.tsfile.write.UnSupportedDataTypeException; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class GroupedCorrelationAccumulator implements GroupedAccumulator { + + private static final long INSTANCE_SIZE = + RamUsageEstimator.shallowSizeOfInstance(GroupedCorrelationAccumulator.class); + private final TSDataType xDataType; + private final TSDataType yDataType; + private final CorrelationAccumulator.CorrelationType correlationType; + + private final LongBigArray counts = new LongBigArray(); + private final DoubleBigArray meanXs = new DoubleBigArray(); + private final DoubleBigArray meanYs = new DoubleBigArray(); + private final DoubleBigArray m2Xs = new DoubleBigArray(); + private final DoubleBigArray m2Ys = new DoubleBigArray(); + private final DoubleBigArray c2s = new DoubleBigArray(); + + public GroupedCorrelationAccumulator( + TSDataType xDataType, + TSDataType yDataType, + CorrelationAccumulator.CorrelationType correlationType) { + this.xDataType = xDataType; + this.yDataType = yDataType; + this.correlationType = correlationType; + } + + @Override + public long getEstimatedSize() { + return INSTANCE_SIZE + + counts.sizeOf() + + meanXs.sizeOf() + + meanYs.sizeOf() + + m2Xs.sizeOf() + + m2Ys.sizeOf() + + c2s.sizeOf(); + } + + @Override + public void setGroupCount(long groupCount) { + counts.ensureCapacity(groupCount); + meanXs.ensureCapacity(groupCount); + meanYs.ensureCapacity(groupCount); + m2Xs.ensureCapacity(groupCount); + m2Ys.ensureCapacity(groupCount); + c2s.ensureCapacity(groupCount); + } + + @Override + public void addInput(int[] groupIds, Column[] arguments, AggregationMask mask) { + int positionCount = mask.getSelectedPositionCount(); + + if (mask.isSelectAll()) { + for (int i = 0; i < positionCount; i++) { + if (arguments[0].isNull(i) || arguments[1].isNull(i)) { + continue; + } + double x = getDoubleValue(arguments[0], i, xDataType); + double y = getDoubleValue(arguments[1], i, yDataType); + update(groupIds[i], x, y); + } + } else { + int[] selectedPositions = mask.getSelectedPositions(); + for (int i = 0; i < positionCount; i++) { + int position = selectedPositions[i]; + if (arguments[0].isNull(position) || arguments[1].isNull(position)) { + continue; + } + double x = getDoubleValue(arguments[0], position, xDataType); + double y = getDoubleValue(arguments[1], position, yDataType); + update(groupIds[position], x, y); + } + } + } + + private double getDoubleValue(Column column, int position, TSDataType dataType) { + switch (dataType) { + case INT32: + case DATE: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new UnSupportedDataTypeException( + String.format("Unsupported data type in Correlation Aggregation: %s", dataType)); + } + } + + private void update(int groupId, double x, double y) { + long newCount = counts.get(groupId) + 1; + double deltaX = x - meanXs.get(groupId); + double deltaY = y - meanYs.get(groupId); + + meanXs.add(groupId, deltaX / newCount); + meanYs.add(groupId, deltaY / newCount); + + c2s.add(groupId, deltaX * (y - meanYs.get(groupId))); + m2Xs.add(groupId, deltaX * (x - meanXs.get(groupId))); + m2Ys.add(groupId, deltaY * (y - meanYs.get(groupId))); + + counts.set(groupId, newCount); + } + + @Override + public void addIntermediate(int[] groupIds, Column argument) { + checkArgument( + argument instanceof BinaryColumn + || (argument instanceof RunLengthEncodedColumn + && ((RunLengthEncodedColumn) argument).getValue() instanceof BinaryColumn), + "intermediate input and output should be BinaryColumn"); + + for (int i = 0; i < argument.getPositionCount(); i++) { + if (argument.isNull(i)) { + continue; + } + + byte[] bytes = argument.getBinary(i).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMeanX = buffer.getDouble(); + double otherMeanY = buffer.getDouble(); + double otherM2X = buffer.getDouble(); + double otherM2Y = buffer.getDouble(); + double otherC2 = buffer.getDouble(); + + merge(groupIds[i], otherCount, otherMeanX, otherMeanY, otherM2X, otherM2Y, otherC2); + } + } + + private void merge( + int groupId, + long otherCount, + double otherMeanX, + double otherMeanY, + double otherM2X, + double otherM2Y, + double otherC2) { + if (otherCount == 0) { + return; + } + if (counts.get(groupId) == 0) { + counts.set(groupId, otherCount); + meanXs.set(groupId, otherMeanX); + meanYs.set(groupId, otherMeanY); + m2Xs.set(groupId, otherM2X); + m2Ys.set(groupId, otherM2Y); + c2s.set(groupId, otherC2); + } else { + long newCount = counts.get(groupId) + otherCount; + double deltaX = otherMeanX - meanXs.get(groupId); + double deltaY = otherMeanY - meanYs.get(groupId); + + c2s.add(groupId, otherC2 + deltaX * deltaY * counts.get(groupId) * otherCount / newCount); + m2Xs.add(groupId, otherM2X + deltaX * deltaX * counts.get(groupId) * otherCount / newCount); + m2Ys.add(groupId, otherM2Y + deltaY * deltaY * counts.get(groupId) * otherCount / newCount); + + meanXs.add(groupId, deltaX * otherCount / newCount); + meanYs.add(groupId, deltaY * otherCount / newCount); + counts.set(groupId, newCount); + } + } + + @Override + public void evaluateIntermediate(int groupId, ColumnBuilder columnBuilder) { + checkArgument( + columnBuilder instanceof BinaryColumnBuilder, + "intermediate input and output should be BinaryColumn"); + + if (counts.get(groupId) == 0) { + columnBuilder.appendNull(); + } else { + ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES + Double.BYTES * 5); + buffer.putLong(counts.get(groupId)); + buffer.putDouble(meanXs.get(groupId)); + buffer.putDouble(meanYs.get(groupId)); + buffer.putDouble(m2Xs.get(groupId)); + buffer.putDouble(m2Ys.get(groupId)); + buffer.putDouble(c2s.get(groupId)); + columnBuilder.writeBinary(new Binary(buffer.array())); + } + } + + @Override + public void evaluateFinal(int groupId, ColumnBuilder columnBuilder) { + switch (correlationType) { + case CORR: + if (counts.get(groupId) < 2) { + columnBuilder.appendNull(); + } else if (m2Xs.get(groupId) == 0 || m2Ys.get(groupId) == 0) { + columnBuilder.writeDouble(0.0); + } else { + columnBuilder.writeDouble( + c2s.get(groupId) / Math.sqrt(m2Xs.get(groupId) * m2Ys.get(groupId))); + } + break; + case COVAR_POP: + if (counts.get(groupId) == 0) { + columnBuilder.appendNull(); + } else { + columnBuilder.writeDouble(c2s.get(groupId) / counts.get(groupId)); + } + break; + case COVAR_SAMP: + if (counts.get(groupId) < 2) { + columnBuilder.appendNull(); + } else { + columnBuilder.writeDouble(c2s.get(groupId) / (counts.get(groupId) - 1)); + } + break; + default: + throw new UnsupportedOperationException("Unknown type: " + correlationType); + } + } + + @Override + public void prepareFinal() {} + + @Override + public void reset() { + counts.reset(); + meanXs.reset(); + meanYs.reset(); + m2Xs.reset(); + m2Ys.reset(); + c2s.reset(); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedRegressionAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedRegressionAccumulator.java new file mode 100644 index 0000000000000..97aabf8c96a19 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedRegressionAccumulator.java @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped; + +import org.apache.iotdb.db.queryengine.execution.aggregation.RegressionAccumulator; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.AggregationMask; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.DoubleBigArray; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.LongBigArray; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.read.common.block.column.BinaryColumn; +import org.apache.tsfile.read.common.block.column.BinaryColumnBuilder; +import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.RamUsageEstimator; +import org.apache.tsfile.write.UnSupportedDataTypeException; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class GroupedRegressionAccumulator implements GroupedAccumulator { + + private static final long INSTANCE_SIZE = + RamUsageEstimator.shallowSizeOfInstance(GroupedRegressionAccumulator.class); + + private final TSDataType yDataType; + private final TSDataType xDataType; + private final RegressionAccumulator.RegressionType regressionType; + + private final LongBigArray counts = new LongBigArray(); + private final DoubleBigArray meanXs = new DoubleBigArray(); + private final DoubleBigArray meanYs = new DoubleBigArray(); + private final DoubleBigArray m2Xs = new DoubleBigArray(); + private final DoubleBigArray c2s = new DoubleBigArray(); + + public GroupedRegressionAccumulator( + TSDataType yDataType, + TSDataType xDataType, + RegressionAccumulator.RegressionType regressionType) { + this.yDataType = yDataType; + this.xDataType = xDataType; + this.regressionType = regressionType; + } + + @Override + public long getEstimatedSize() { + return INSTANCE_SIZE + + counts.sizeOf() + + meanXs.sizeOf() + + meanYs.sizeOf() + + m2Xs.sizeOf() + + c2s.sizeOf(); + } + + @Override + public void setGroupCount(long groupCount) { + counts.ensureCapacity(groupCount); + meanXs.ensureCapacity(groupCount); + meanYs.ensureCapacity(groupCount); + m2Xs.ensureCapacity(groupCount); + c2s.ensureCapacity(groupCount); + } + + @Override + public void addInput(int[] groupIds, Column[] arguments, AggregationMask mask) { + + int positionCount = mask.getSelectedPositionCount(); + + if (mask.isSelectAll()) { + for (int i = 0; i < positionCount; i++) { + if (arguments[0].isNull(i) || arguments[1].isNull(i)) { + continue; + } + double y = getDoubleValue(arguments[0], i, yDataType); + double x = getDoubleValue(arguments[1], i, xDataType); + update(groupIds[i], x, y); + } + } else { + int[] selectedPositions = mask.getSelectedPositions(); + for (int i = 0; i < positionCount; i++) { + int position = selectedPositions[i]; + if (arguments[0].isNull(position) || arguments[1].isNull(position)) { + continue; + } + double y = getDoubleValue(arguments[0], position, yDataType); + double x = getDoubleValue(arguments[1], position, xDataType); + update(groupIds[position], x, y); + } + } + } + + private double getDoubleValue(Column column, int position, TSDataType dataType) { + switch (dataType) { + case INT32: + case DATE: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new UnSupportedDataTypeException( + String.format("Unsupported data type in Regression Aggregation: %s", dataType)); + } + } + + private void update(int groupId, double x, double y) { + long newCount = counts.get(groupId) + 1; + double deltaX = x - meanXs.get(groupId); + double deltaY = y - meanYs.get(groupId); + + meanXs.add(groupId, deltaX / newCount); + meanYs.add(groupId, deltaY / newCount); + + c2s.add(groupId, deltaX * (y - meanYs.get(groupId))); + m2Xs.add(groupId, deltaX * (x - meanXs.get(groupId))); + + counts.set(groupId, newCount); + } + + @Override + public void addIntermediate(int[] groupIds, Column argument) { + checkArgument( + argument instanceof BinaryColumn + || (argument instanceof RunLengthEncodedColumn + && ((RunLengthEncodedColumn) argument).getValue() instanceof BinaryColumn), + "intermediate input and output should be BinaryColumn"); + + for (int i = 0; i < argument.getPositionCount(); i++) { + if (argument.isNull(i)) { + continue; + } + + byte[] bytes = argument.getBinary(i).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMeanX = buffer.getDouble(); + double otherMeanY = buffer.getDouble(); + double otherM2X = buffer.getDouble(); + double otherC2 = buffer.getDouble(); + + merge(groupIds[i], otherCount, otherMeanX, otherMeanY, otherM2X, otherC2); + } + } + + private void merge( + int groupId, + long otherCount, + double otherMeanX, + double otherMeanY, + double otherM2X, + double otherC2) { + if (otherCount == 0) { + return; + } + if (counts.get(groupId) == 0) { + counts.set(groupId, otherCount); + meanXs.set(groupId, otherMeanX); + meanYs.set(groupId, otherMeanY); + m2Xs.set(groupId, otherM2X); + c2s.set(groupId, otherC2); + } else { + long newCount = counts.get(groupId) + otherCount; + double deltaX = otherMeanX - meanXs.get(groupId); + double deltaY = otherMeanY - meanYs.get(groupId); + + c2s.add(groupId, otherC2 + deltaX * deltaY * counts.get(groupId) * otherCount / newCount); + m2Xs.add(groupId, otherM2X + deltaX * deltaX * counts.get(groupId) * otherCount / newCount); + + meanXs.add(groupId, deltaX * otherCount / newCount); + meanYs.add(groupId, deltaY * otherCount / newCount); + counts.set(groupId, newCount); + } + } + + @Override + public void evaluateIntermediate(int groupId, ColumnBuilder columnBuilder) { + checkArgument( + columnBuilder instanceof BinaryColumnBuilder, + "intermediate input and output should be BinaryColumn"); + + if (counts.get(groupId) == 0) { + columnBuilder.appendNull(); + } else { + ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES + Double.BYTES * 4); + buffer.putLong(counts.get(groupId)); + buffer.putDouble(meanXs.get(groupId)); + buffer.putDouble(meanYs.get(groupId)); + buffer.putDouble(m2Xs.get(groupId)); + buffer.putDouble(c2s.get(groupId)); + columnBuilder.writeBinary(new Binary(buffer.array())); + } + } + + @Override + public void evaluateFinal(int groupId, ColumnBuilder columnBuilder) { + if (counts.get(groupId) == 0 || m2Xs.get(groupId) == 0) { + columnBuilder.appendNull(); + return; + } + double slope = c2s.get(groupId) / m2Xs.get(groupId); + switch (regressionType) { + case REGR_SLOPE: + columnBuilder.writeDouble(slope); + break; + case REGR_INTERCEPT: + columnBuilder.writeDouble(meanYs.get(groupId) - slope * meanXs.get(groupId)); + break; + default: + throw new UnsupportedOperationException("Unknown type: " + regressionType); + } + } + + @Override + public void prepareFinal() {} + + @Override + public void reset() { + counts.reset(); + meanXs.reset(); + meanYs.reset(); + m2Xs.reset(); + c2s.reset(); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java index adc80c7bb1522..7bc866fdd2bf5 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java @@ -361,6 +361,7 @@ public TSDataType visitFunctionExpression( } if (functionExpression.isBuiltInAggregationFunctionExpression()) { + return setExpressionType( functionExpression, TypeInferenceUtils.getBuiltinAggregationDataType( @@ -542,9 +543,21 @@ private TSDataType getInputExpressionTypeForAggregation( case SqlConstant.VARIANCE: case SqlConstant.VAR_POP: case SqlConstant.VAR_SAMP: + case SqlConstant.SKEWNESS: + case SqlConstant.KURTOSIS: case SqlConstant.MAX_BY: case SqlConstant.MIN_BY: return expressionTypes.get(NodeRef.of(inputExpressions.get(0))); + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: + TypeInferenceUtils.verifyIsAggregationDataTypeMatchedForBothInputs( + aggregateFunctionName, + expressionTypes.get(NodeRef.of(inputExpressions.get(0))), + expressionTypes.get(NodeRef.of(inputExpressions.get(1)))); + return expressionTypes.get(NodeRef.of(inputExpressions.get(0))); default: throw new IllegalArgumentException( "Invalid Aggregation function: " + aggregateFunctionName); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java index 706d14f052cd2..b8b2117e976a0 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java @@ -3186,6 +3186,8 @@ private void checkAggregationFunctionInput(FunctionExpression functionExpression case SqlConstant.VARIANCE: case SqlConstant.VAR_POP: case SqlConstant.VAR_SAMP: + case SqlConstant.SKEWNESS: + case SqlConstant.KURTOSIS: checkFunctionExpressionInputSize( functionExpression.getExpressionString(), functionExpression.getExpressions().size(), @@ -3194,6 +3196,11 @@ private void checkAggregationFunctionInput(FunctionExpression functionExpression case SqlConstant.COUNT_IF: case SqlConstant.MAX_BY: case SqlConstant.MIN_BY: + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: checkFunctionExpressionInputSize( functionExpression.getExpressionString(), functionExpression.getExpressions().size(), diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/AggregationDescriptor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/AggregationDescriptor.java index ac30dcf505afd..36c6470078853 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/AggregationDescriptor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/AggregationDescriptor.java @@ -187,6 +187,27 @@ public List getActualAggregationNames(boolean isPartial) { case VAR_SAMP: outputAggregationNames.add(addPartialSuffix(SqlConstant.VAR_SAMP)); break; + case CORR: + outputAggregationNames.add(addPartialSuffix(SqlConstant.CORR)); + break; + case COVAR_POP: + outputAggregationNames.add(addPartialSuffix(SqlConstant.COVAR_POP)); + break; + case COVAR_SAMP: + outputAggregationNames.add(addPartialSuffix(SqlConstant.COVAR_SAMP)); + break; + case REGR_SLOPE: + outputAggregationNames.add(addPartialSuffix(SqlConstant.REGR_SLOPE)); + break; + case REGR_INTERCEPT: + outputAggregationNames.add(addPartialSuffix(SqlConstant.REGR_INTERCEPT)); + break; + case SKEWNESS: + outputAggregationNames.add(addPartialSuffix(SqlConstant.SKEWNESS)); + break; + case KURTOSIS: + outputAggregationNames.add(addPartialSuffix(SqlConstant.KURTOSIS)); + break; case MAX_BY: outputAggregationNames.add(addPartialSuffix(SqlConstant.MAX_BY)); break; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java index 8934de172e9c5..02e6a1bae440d 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java @@ -597,6 +597,45 @@ && isIntegerNumber(argumentTypes.get(2)))) { functionName)); } break; + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: + if (argumentTypes.size() != 2) { + throw new SemanticException( + String.format( + "Error size of input expressions. expression: %s, actual size: %s, expected size: [2].", + functionName.toUpperCase(), argumentTypes.size())); + } + if (!isSupportedMathNumericType(argumentTypes.get(0))) { + throw new SemanticException( + String.format( + "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", + functionName.toUpperCase())); + } + if (!isSupportedMathNumericType(argumentTypes.get(1))) { + throw new SemanticException( + String.format( + "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", + functionName.toUpperCase())); + } + break; + case SqlConstant.SKEWNESS: + case SqlConstant.KURTOSIS: + if (argumentTypes.size() != 1) { + throw new SemanticException( + String.format( + "Error size of input expressions. expression: %s, actual size: %s, expected size: [1].", + functionName.toUpperCase(), argumentTypes.size())); + } + if (!isSupportedMathNumericType(argumentTypes.get(0))) { + throw new SemanticException( + String.format( + "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", + functionName.toUpperCase())); + } + break; case SqlConstant.MIN: case SqlConstant.MAX: case SqlConstant.MODE: @@ -701,6 +740,13 @@ && isIntegerNumber(argumentTypes.get(2)))) { case SqlConstant.VARIANCE: case SqlConstant.VAR_POP: case SqlConstant.VAR_SAMP: + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: + case SqlConstant.SKEWNESS: + case SqlConstant.KURTOSIS: return DOUBLE; case SqlConstant.APPROX_MOST_FREQUENT: return STRING; @@ -971,7 +1017,11 @@ public static boolean isBool(Type type) { } public static boolean isSupportedMathNumericType(Type type) { - return DOUBLE.equals(type) || FLOAT.equals(type) || INT32.equals(type) || INT64.equals(type); + return DOUBLE.equals(type) + || FLOAT.equals(type) + || INT32.equals(type) + || INT64.equals(type) + || TIMESTAMP.equals(type); } public static boolean isNumericType(Type type) { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/SchemaUtils.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/SchemaUtils.java index 773de36a067fd..71a24bc23678f 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/SchemaUtils.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/SchemaUtils.java @@ -88,6 +88,13 @@ public static TSDataType getBuiltinAggregationTypeByFuncName(String aggregation) case SqlConstant.VARIANCE: case SqlConstant.VAR_POP: case SqlConstant.VAR_SAMP: + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: + case SqlConstant.SKEWNESS: + case SqlConstant.KURTOSIS: return TSDataType.DOUBLE; // Partial aggregation names case SqlConstant.STDDEV + "_partial": @@ -96,6 +103,13 @@ public static TSDataType getBuiltinAggregationTypeByFuncName(String aggregation) case SqlConstant.VARIANCE + "_partial": case SqlConstant.VAR_POP + "_partial": case SqlConstant.VAR_SAMP + "_partial": + case SqlConstant.CORR + "_partial": + case SqlConstant.COVAR_POP + "_partial": + case SqlConstant.COVAR_SAMP + "_partial": + case SqlConstant.REGR_SLOPE + "_partial": + case SqlConstant.REGR_INTERCEPT + "_partial": + case SqlConstant.SKEWNESS + "_partial": + case SqlConstant.KURTOSIS + "_partial": case SqlConstant.MAX_BY + "_partial": case SqlConstant.MIN_BY + "_partial": return TSDataType.TEXT; @@ -163,6 +177,20 @@ public static String getBuiltinAggregationName(TAggregationType aggregationType) return SqlConstant.VAR_POP; case VAR_SAMP: return SqlConstant.VAR_SAMP; + case CORR: + return SqlConstant.CORR; + case COVAR_POP: + return SqlConstant.COVAR_POP; + case COVAR_SAMP: + return SqlConstant.COVAR_SAMP; + case REGR_SLOPE: + return SqlConstant.REGR_SLOPE; + case REGR_INTERCEPT: + return SqlConstant.REGR_INTERCEPT; + case SKEWNESS: + return SqlConstant.SKEWNESS; + case KURTOSIS: + return SqlConstant.KURTOSIS; default: return null; } @@ -198,6 +226,13 @@ public static boolean isConsistentWithScanOrder( case VAR_SAMP: case MAX_BY: case MIN_BY: + case CORR: + case COVAR_POP: + case COVAR_SAMP: + case REGR_SLOPE: + case REGR_INTERCEPT: + case SKEWNESS: + case KURTOSIS: case UDAF: return true; default: @@ -232,6 +267,20 @@ public static List splitPartialBuiltinAggregation(TAggregationType aggre return Collections.singletonList(addPartialSuffix(SqlConstant.VAR_POP)); case VAR_SAMP: return Collections.singletonList(addPartialSuffix(SqlConstant.VAR_SAMP)); + case CORR: + return Collections.singletonList(addPartialSuffix(SqlConstant.CORR)); + case COVAR_POP: + return Collections.singletonList(addPartialSuffix(SqlConstant.COVAR_POP)); + case COVAR_SAMP: + return Collections.singletonList(addPartialSuffix(SqlConstant.COVAR_SAMP)); + case REGR_SLOPE: + return Collections.singletonList(addPartialSuffix(SqlConstant.REGR_SLOPE)); + case REGR_INTERCEPT: + return Collections.singletonList(addPartialSuffix(SqlConstant.REGR_INTERCEPT)); + case SKEWNESS: + return Collections.singletonList(addPartialSuffix(SqlConstant.SKEWNESS)); + case KURTOSIS: + return Collections.singletonList(addPartialSuffix(SqlConstant.KURTOSIS)); case MAX_BY: return Collections.singletonList(addPartialSuffix(SqlConstant.MAX_BY)); case MIN_BY: diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java index 8fc1d647dc36e..10413c6f2537e 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java @@ -154,6 +154,13 @@ public static TSDataType getBuiltinAggregationDataType( case SqlConstant.VARIANCE: case SqlConstant.VAR_POP: case SqlConstant.VAR_SAMP: + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: + case SqlConstant.SKEWNESS: + case SqlConstant.KURTOSIS: return TSDataType.DOUBLE; default: throw new IllegalArgumentException( @@ -190,7 +197,22 @@ private static void verifyIsAggregationDataTypeMatched(String aggrFuncName, TSDa return; } throw new SemanticException( - "Aggregate functions [AVG, SUM, EXTREME, STDDEV, STDDEV_POP, STDDEV_SAMP, VARIANCE, VAR_POP, VAR_SAMP] only support numeric data types [INT32, INT64, FLOAT, DOUBLE]"); + "Aggregate functions [AVG, SUM, EXTREME, STDDEV, STDDEV_POP, STDDEV_SAMP, " + + "VARIANCE, VAR_POP, VAR_SAMP] only support " + + "numeric data types [INT32, INT64, FLOAT, DOUBLE]"); + case SqlConstant.SKEWNESS: + case SqlConstant.KURTOSIS: + if (dataType.isNumeric() || TSDataType.TIMESTAMP.equals(dataType)) { + return; + } + throw new SemanticException( + "Aggregate functions [SKEWNESS, KURTOSIS] only support " + + "numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]"); + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: case SqlConstant.COUNT: case SqlConstant.COUNT_TIME: case SqlConstant.MIN_TIME: @@ -215,6 +237,30 @@ private static void verifyIsAggregationDataTypeMatched(String aggrFuncName, TSDa } } + public static void verifyIsAggregationDataTypeMatchedForBothInputs( + String aggrFuncName, TSDataType firstDataType, TSDataType secondDataType) { + switch (aggrFuncName.toLowerCase()) { + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: + if ((firstDataType != null + && !firstDataType.isNumeric() + && !TSDataType.TIMESTAMP.equals(firstDataType)) + || (secondDataType != null + && !secondDataType.isNumeric() + && !TSDataType.TIMESTAMP.equals(secondDataType))) { + throw new SemanticException( + "Aggregate functions [CORR, COVAR_POP, COVAR_SAMP, REGR_SLOPE, REGR_INTERCEPT] only support " + + "numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]"); + } + return; + default: + break; + } + } + /** * Bind Type for non-series input Expressions of AggregationFunction and check Semantic * @@ -245,6 +291,13 @@ public static void bindTypeForBuiltinAggregationNonSeriesInputExpressions( case SqlConstant.VARIANCE: case SqlConstant.VAR_POP: case SqlConstant.VAR_SAMP: + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: + case SqlConstant.SKEWNESS: + case SqlConstant.KURTOSIS: case SqlConstant.MAX_BY: case SqlConstant.MIN_BY: return; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/SqlConstant.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/SqlConstant.java index 8120aff6059ba..32d43cedd5da3 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/SqlConstant.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/SqlConstant.java @@ -74,6 +74,13 @@ protected SqlConstant() { public static final String VARIANCE = "variance"; public static final String VAR_POP = "var_pop"; public static final String VAR_SAMP = "var_samp"; + public static final String CORR = "corr"; + public static final String COVAR_POP = "covar_pop"; + public static final String COVAR_SAMP = "covar_samp"; + public static final String REGR_SLOPE = "regr_slope"; + public static final String REGR_INTERCEPT = "regr_intercept"; + public static final String SKEWNESS = "skewness"; + public static final String KURTOSIS = "kurtosis"; public static final String COUNT_TIME = "count_time"; public static final String COUNT_TIME_HEADER = "count_time(*)"; diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinAggregationFunction.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinAggregationFunction.java index 1c6b25ef53aaf..c8274b94d1c44 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinAggregationFunction.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinAggregationFunction.java @@ -46,7 +46,14 @@ public enum BuiltinAggregationFunction { VAR_POP("var_pop"), VAR_SAMP("var_samp"), MAX_BY("max_by"), - MIN_BY("min_by"); + MIN_BY("min_by"), + CORR("corr"), + COVAR_POP("covar_pop"), + COVAR_SAMP("covar_samp"), + REGR_SLOPE("regr_slope"), + REGR_INTERCEPT("regr_intercept"), + SKEWNESS("skewness"), + KURTOSIS("kurtosis"); private final String functionName; @@ -97,6 +104,13 @@ public static boolean canUseStatistics(String name) { case "var_samp": case "max_by": case "min_by": + case "corr": + case "covar_pop": + case "covar_samp": + case "regr_slope": + case "regr_intercept": + case "skewness": + case "kurtosis": return false; default: throw new IllegalArgumentException("Invalid Aggregation function: " + name); @@ -131,6 +145,13 @@ public static boolean canSplitToMultiPhases(String name) { case "var_samp": case "max_by": case "min_by": + case "corr": + case "covar_pop": + case "covar_samp": + case "regr_slope": + case "regr_intercept": + case "skewness": + case "kurtosis": return true; case "count_if": case "count_time": diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java index 39f7cde84c490..151df10988d65 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java @@ -58,7 +58,14 @@ public enum TableBuiltinAggregationFunction { VAR_POP("var_pop"), VAR_SAMP("var_samp"), APPROX_COUNT_DISTINCT("approx_count_distinct"), - APPROX_MOST_FREQUENT("approx_most_frequent"); + APPROX_MOST_FREQUENT("approx_most_frequent"), + CORR("corr"), + COVAR_POP("covar_pop"), + COVAR_SAMP("covar_samp"), + REGR_SLOPE("regr_slope"), + REGR_INTERCEPT("regr_intercept"), + SKEWNESS("skewness"), + KURTOSIS("kurtosis"); private final String functionName; @@ -103,6 +110,13 @@ public static Type getIntermediateType(String name, List originalArgumentT case "variance": case "var_pop": case "var_samp": + case "corr": + case "covar_pop": + case "covar_samp": + case "regr_slope": + case "regr_intercept": + case "skewness": + case "kurtosis": case "approx_count_distinct": return RowType.anonymous(Collections.emptyList()); case "extreme": diff --git a/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift b/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift index 3287f35b9bb92..fa4ea97fef110 100644 --- a/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift +++ b/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift @@ -293,7 +293,14 @@ enum TAggregationType { MAX, COUNT_ALL, APPROX_COUNT_DISTINCT, - APPROX_MOST_FREQUENT + APPROX_MOST_FREQUENT, + CORR, + COVAR_POP, + COVAR_SAMP, + REGR_SLOPE, + REGR_INTERCEPT, + SKEWNESS, + KURTOSIS } struct TShowConfigurationTemplateResp {