From e4fb8bf3f4074c7267a1be44d79fe002f1b487c1 Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Thu, 11 Dec 2025 21:46:48 -0500 Subject: [PATCH] [Draft] Flink 2.0 support --- gradle.properties | 2 +- runners/flink/2.0/build.gradle | 43 + .../2.0/job-server-container/build.gradle | 26 + runners/flink/2.0/job-server/build.gradle | 31 + .../flink/FlinkBatchTranslationContext.java | 116 ++ .../flink/FlinkExecutionEnvironments.java | 499 +++++ .../flink/FlinkMiniClusterEntryPoint.java | 97 + .../FlinkPipelineExecutionEnvironment.java | 217 ++ .../runners/flink/FlinkPipelineOptions.java | 382 ++++ .../runners/flink/FlinkPipelineRunner.java | 213 ++ .../flink/FlinkPipelineTranslator.java | 52 + ...nkStreamingPortablePipelineTranslator.java | 1151 ++++++++++ .../FlinkStreamingTransformTranslators.java | 1440 +++++++++++++ .../flink/FlinkTransformOverrides.java | 63 + .../functions/FlinkDoFnFunction.java | 264 +++ .../FlinkExecutableStageContextFactory.java | 63 + .../FlinkExecutableStageFunction.java | 416 ++++ .../FlinkExecutableStagePruningFunction.java | 61 + .../FlinkMergingNonShuffleReduceFunction.java | 100 + .../FlinkMultiOutputPruningFunction.java | 64 + .../functions/FlinkPartialReduceFunction.java | 116 ++ .../functions/FlinkReduceFunction.java | 116 ++ .../functions/FlinkStatefulDoFnFunction.java | 276 +++ .../functions/ImpulseSourceFunction.java | 117 ++ .../types/CoderTypeInformation.java | 142 ++ .../types/EncodedValueSerializer.java | 124 ++ .../types/EncodedValueTypeInformation.java | 98 + .../UnversionedTypeSerializerSnapshot.java | 85 + .../wrappers/streaming/DoFnOperator.java | 1785 ++++++++++++++++ .../streaming/io/StreamingImpulseSource.java | 85 + .../streaming/io/TestStreamSource.java | 82 + .../streaming/io/UnboundedSourceWrapper.java | 556 +++++ .../state/FlinkBroadcastStateInternals.java | 697 +++++++ .../streaming/state/FlinkStateInternals.java | 1851 +++++++++++++++++ .../flink/EncodedValueComparatorTest.java | 69 + .../flink/FlinkExecutionEnvironmentsTest.java | 582 ++++++ ...FlinkPipelineExecutionEnvironmentTest.java | 421 ++++ .../flink/FlinkPipelineOptionsTest.java | 191 ++ .../flink/FlinkRequiresStableInputTest.java | 288 +++ .../beam/runners/flink/FlinkRunnerTest.java | 94 + .../runners/flink/FlinkSavepointTest.java | 432 ++++ .../runners/flink/FlinkSubmissionTest.java | 251 +++ .../beam/runners/flink/ReadSourceTest.java | 92 + .../BeamFlinkDataStreamAdapterTest.java | 217 ++ .../streaming/BoundedSourceRestoreTest.java | 199 ++ .../streaming/FlinkStateInternalsTest.java | 218 ++ .../streaming/MemoryStateBackendWrapper.java | 80 + .../flink/streaming/StreamSources.java | 61 + .../functions/FlinkDoFnFunctionTest.java | 109 + .../FlinkExecutableStageFunctionTest.java | 347 +++ .../FlinkStatefulDoFnFunctionTest.java | 109 + .../functions/ImpulseSourceFunctionTest.java | 208 ++ .../io/UnboundedSourceWrapperTest.java | 1027 +++++++++ .../stableinput/BufferingDoFnRunnerTest.java | 179 ++ .../src/test/resources/flink-test-config.yaml | 27 + runners/flink/flink_runner.gradle | 53 +- .../flink_job_server_container.gradle | 10 +- .../flink/job-server/flink_job_server.gradle | 7 +- .../flink/FlinkBatchPipelineTranslator.java | 2 +- .../runners/flink/FlinkPipelineOptions.java | 4 +- .../runners/flink/FlinkPipelineRunner.java | 4 +- .../FlinkStreamingTransformTranslators.java | 2 +- .../metrics/FlinkMetricContainerBase.java | 4 +- .../wrappers/streaming/DoFnOperator.java | 0 .../streaming/io/UnboundedSourceWrapper.java | 2 +- .../streaming/MemoryStateBackendWrapper.java | 0 .../flink/streaming/StreamSources.java | 0 settings.gradle.kts | 12 +- .../flink_java_pipeline_options.html | 19 +- .../flink_python_pipeline_options.html | 19 +- 70 files changed, 16704 insertions(+), 65 deletions(-) create mode 100644 runners/flink/2.0/build.gradle create mode 100644 runners/flink/2.0/job-server-container/build.gradle create mode 100644 runners/flink/2.0/job-server/build.gradle create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkExecutionEnvironments.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkMiniClusterEntryPoint.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironment.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkPipelineOptions.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkPipelineRunner.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkPipelineTranslator.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkTransformOverrides.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageContextFactory.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStagePruningFunction.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMergingNonShuffleReduceFunction.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputPruningFunction.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkPartialReduceFunction.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkReduceFunction.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunction.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/ImpulseSourceFunction.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeInformation.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/types/EncodedValueSerializer.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/types/EncodedValueTypeInformation.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/types/UnversionedTypeSerializerSnapshot.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/StreamingImpulseSource.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/TestStreamSource.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java create mode 100644 runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java create mode 100644 runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/EncodedValueComparatorTest.java create mode 100644 runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/FlinkExecutionEnvironmentsTest.java create mode 100644 runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironmentTest.java create mode 100644 runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/FlinkPipelineOptionsTest.java create mode 100644 runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/FlinkRequiresStableInputTest.java create mode 100644 runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/FlinkRunnerTest.java create mode 100644 runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/FlinkSavepointTest.java create mode 100644 runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/FlinkSubmissionTest.java create mode 100644 runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/ReadSourceTest.java create mode 100644 runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/adapter/BeamFlinkDataStreamAdapterTest.java create mode 100644 runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/streaming/BoundedSourceRestoreTest.java create mode 100644 runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java create mode 100644 runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java create mode 100644 runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java create mode 100644 runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunctionTest.java create mode 100644 runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunctionTest.java create mode 100644 runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunctionTest.java create mode 100644 runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/translation/functions/ImpulseSourceFunctionTest.java create mode 100644 runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapperTest.java create mode 100644 runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/stableinput/BufferingDoFnRunnerTest.java create mode 100644 runners/flink/2.0/src/test/resources/flink-test-config.yaml rename runners/flink/{1.17 => }/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java (100%) rename runners/flink/{1.17 => }/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java (100%) rename runners/flink/{1.17 => }/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java (100%) diff --git a/gradle.properties b/gradle.properties index f2ea56da8b19..404953c9c31f 100644 --- a/gradle.properties +++ b/gradle.properties @@ -39,6 +39,6 @@ docker_image_default_repo_root=apache docker_image_default_repo_prefix=beam_ # supported flink versions -flink_versions=1.17,1.18,1.19,1.20 +flink_versions=1.17,1.18,1.19,1.20,2.0 # supported python versions python_versions=3.10,3.11,3.12,3.13 diff --git a/runners/flink/2.0/build.gradle b/runners/flink/2.0/build.gradle new file mode 100644 index 000000000000..490bc593f40c --- /dev/null +++ b/runners/flink/2.0/build.gradle @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +project.ext { + flink_major = '2.0' + flink_version = '2.0.1' + excluded_files = [ + 'main': [ + // Used by DataSet API only + "org/apache/beam/runners/flink/adapter/BeamFlinkDataSetAdapter.java", + "org/apache/beam/runners/flink/FlinkBatchPipelineTranslator.java", + "org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java", + "org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java", + "org/apache/beam/runners/flink/translation/functions/FlinkNonMergingReduceFunction.java", + // Moved to org.apache.flink.runtime.state.StateBackendFactory + "org/apache/beam/runners/flink/FlinkStateBackendFactory.java", + ], + 'test': [ + // Used by DataSet API only + "org/apache/beam/runners/flink/adapter/BeamFlinkDataSetAdapterTest.java", + "org/apache/beam/runners/flink/batch/NonMergingGroupByKeyTest.java", + "org/apache/beam/runners/flink/batch/ReshuffleTest.java", + ] + ] +} + +// Load the main build script which contains all build logic. +apply from: "../flink_runner.gradle" diff --git a/runners/flink/2.0/job-server-container/build.gradle b/runners/flink/2.0/job-server-container/build.gradle new file mode 100644 index 000000000000..afdb68a0fc91 --- /dev/null +++ b/runners/flink/2.0/job-server-container/build.gradle @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +def basePath = '../../job-server-container' + +project.ext { + resource_path = basePath +} + +// Load the main build script which contains all build logic. +apply from: "$basePath/flink_job_server_container.gradle" diff --git a/runners/flink/2.0/job-server/build.gradle b/runners/flink/2.0/job-server/build.gradle new file mode 100644 index 000000000000..6d068f839491 --- /dev/null +++ b/runners/flink/2.0/job-server/build.gradle @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +def basePath = '../../job-server' + +project.ext { + // Look for the source code in the parent module + main_source_dirs = ["$basePath/src/main/java"] + test_source_dirs = ["$basePath/src/test/java"] + main_resources_dirs = ["$basePath/src/main/resources"] + test_resources_dirs = ["$basePath/src/test/resources"] + archives_base_name = 'beam-runners-flink-2.0-job-server' +} + +// Load the main build script which contains all build logic. +apply from: "$basePath/flink_job_server.gradle" diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java new file mode 100644 index 000000000000..0bfe06a38329 --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import java.util.Map; +import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; +import org.apache.beam.runners.flink.translation.utils.CountingPipelineVisitor; +import org.apache.beam.runners.flink.translation.utils.LookupPipelineVisitor; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.flink.api.common.typeinfo.TypeInformation; + +/** + * Helper for {@link FlinkBatchPipelineTranslator} and translators in {@link + * FlinkBatchTransformTranslators}. + */ +@SuppressWarnings({ + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +class FlinkBatchTranslationContext { + private final PipelineOptions options; + + private AppliedPTransform currentTransform; + + private final CountingPipelineVisitor countingPipelineVisitor = new CountingPipelineVisitor(); + private final LookupPipelineVisitor lookupPipelineVisitor = new LookupPipelineVisitor(); + + // ------------------------------------------------------------------------ + + FlinkBatchTranslationContext(PipelineOptions options) { + this.options = options; + } + + void init(Pipeline pipeline) { + pipeline.traverseTopologically(countingPipelineVisitor); + pipeline.traverseTopologically(lookupPipelineVisitor); + } + + public PipelineOptions getPipelineOptions() { + return options; + } + + /** + * Sets the AppliedPTransform which carries input/output. + * + * @param currentTransform Current transformation. + */ + void setCurrentTransform(AppliedPTransform currentTransform) { + this.currentTransform = currentTransform; + } + + AppliedPTransform getCurrentTransform() { + return currentTransform; + } + + Map, Coder> getOutputCoders(PTransform transform) { + return lookupPipelineVisitor.getOutputCoders(transform); + } + + TypeInformation> getTypeInfo(PCollection collection) { + return getTypeInfo(collection.getCoder(), collection.getWindowingStrategy()); + } + + TypeInformation> getTypeInfo( + Coder coder, WindowingStrategy windowingStrategy) { + WindowedValues.FullWindowedValueCoder windowedValueCoder = + WindowedValues.getFullCoder(coder, windowingStrategy.getWindowFn().windowCoder()); + + return new CoderTypeInformation<>(windowedValueCoder, options); + } + + Map, PCollection> getInputs(PTransform transform) { + return lookupPipelineVisitor.getInputs(transform); + } + + T getInput(PTransform transform) { + return lookupPipelineVisitor.getInput(transform); + } + + Map, PCollection> getOutputs(PTransform transform) { + return lookupPipelineVisitor.getOutputs(transform); + } + + T getOutput(PTransform transform) { + return lookupPipelineVisitor.getOutput(transform); + } + + /** {@link CountingPipelineVisitor#getNumConsumers(PValue)}. */ + int getNumConsumers(PValue value) { + return countingPipelineVisitor.getNumConsumers(value); + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkExecutionEnvironments.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkExecutionEnvironments.java new file mode 100644 index 000000000000..8b3b2ed9c960 --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkExecutionEnvironments.java @@ -0,0 +1,499 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import static org.apache.flink.streaming.api.environment.StreamExecutionEnvironment.getDefaultLocalParallelism; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Streams; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.RuntimeExecutionMode; +import org.apache.flink.configuration.CheckpointingOptions; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.CoreOptions; +import org.apache.flink.configuration.DeploymentOptions; +import org.apache.flink.configuration.ExternalizedCheckpointRetention; +import org.apache.flink.configuration.GlobalConfiguration; +import org.apache.flink.configuration.RestOptions; +import org.apache.flink.configuration.RestartStrategyOptions; +import org.apache.flink.configuration.StateBackendOptions; +import org.apache.flink.configuration.TaskManagerOptions; +import org.apache.flink.runtime.jobgraph.SavepointRestoreSettings; +import org.apache.flink.runtime.state.StateBackend; +import org.apache.flink.runtime.util.EnvironmentInformation; +import org.apache.flink.streaming.api.CheckpointingMode; +import org.apache.flink.streaming.api.environment.LocalStreamEnvironment; +import org.apache.flink.streaming.api.environment.RemoteStreamEnvironment; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Utilities for Flink execution environments. */ +@SuppressWarnings({ + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +public class FlinkExecutionEnvironments { + + private static final Logger LOG = LoggerFactory.getLogger(FlinkExecutionEnvironments.class); + + private static final ObjectMapper mapper = new ObjectMapper(); + + /** + * If the submitted job is a batch processing job, this method creates the adequate Flink {@link + * org.apache.flink.streaming.api.environment.StreamExecutionEnvironment} depending on the + * user-specified options. + */ + public static StreamExecutionEnvironment createBatchExecutionEnvironment( + FlinkPipelineOptions options) { + return createBatchExecutionEnvironment( + options, + MoreObjects.firstNonNull(options.getFilesToStage(), Collections.emptyList()), + options.getFlinkConfDir()); + } + + static StreamExecutionEnvironment createBatchExecutionEnvironment( + FlinkPipelineOptions options, List filesToStage, @Nullable String confDir) { + + LOG.info("Creating a Batch Execution Environment."); + + // Although Flink uses Rest, it expects the address not to contain a http scheme + String flinkMasterHostPort = stripHttpSchema(options.getFlinkMaster()); + Configuration flinkConfiguration = getFlinkConfiguration(confDir); + StreamExecutionEnvironment flinkBatchEnv; + + // depending on the master, create the right environment. + if ("[local]".equals(flinkMasterHostPort)) { + setManagedMemoryByFraction(flinkConfiguration); + disableClassLoaderLeakCheck(flinkConfiguration); + flinkBatchEnv = StreamExecutionEnvironment.createLocalEnvironment(flinkConfiguration); + if (!options.getAttachedMode()) { + LOG.warn("Detached mode is only supported in RemoteStreamEnvironment"); + } + } else if ("[collection]".equals(flinkMasterHostPort)) { + throw new UnsupportedOperationException( + "CollectionEnvironment has been removed in Flink 2. Use [local] instead."); + } else if ("[auto]".equals(flinkMasterHostPort)) { + flinkBatchEnv = StreamExecutionEnvironment.getExecutionEnvironment(); + if (flinkBatchEnv instanceof LocalStreamEnvironment) { + disableClassLoaderLeakCheck(flinkConfiguration); + flinkBatchEnv = StreamExecutionEnvironment.createLocalEnvironment(flinkConfiguration); + flinkBatchEnv.setParallelism(getDefaultLocalParallelism()); + } + if (!options.getAttachedMode()) { + LOG.warn("Detached mode is not supported in [auto]."); + } + } else { + int defaultPort = flinkConfiguration.get(RestOptions.PORT); + HostAndPort hostAndPort = + HostAndPort.fromString(flinkMasterHostPort).withDefaultPort(defaultPort); + flinkConfiguration.set(RestOptions.PORT, hostAndPort.getPort()); + if (!options.getAttachedMode()) { + flinkConfiguration.set(DeploymentOptions.ATTACHED, options.getAttachedMode()); + } + flinkBatchEnv = + StreamExecutionEnvironment.createRemoteEnvironment( + hostAndPort.getHost(), + hostAndPort.getPort(), + flinkConfiguration, + filesToStage.toArray(new String[filesToStage.size()])); + LOG.info("Using Flink Master URL {}:{}.", hostAndPort.getHost(), hostAndPort.getPort()); + } + + // Set the execution mode for data exchange. + flinkBatchEnv.setRuntimeMode(RuntimeExecutionMode.BATCH); + + // set the correct parallelism. + if (options.getParallelism() != -1) { + flinkBatchEnv.setParallelism(options.getParallelism()); + } + + // Set the correct parallelism, required by UnboundedSourceWrapper to generate consistent + // splits. + final int parallelism = + determineParallelism( + options.getParallelism(), flinkBatchEnv.getParallelism(), flinkConfiguration); + + flinkBatchEnv.setParallelism(parallelism); + // set parallelism in the options (required by some execution code) + options.setParallelism(parallelism); + + if (options.getObjectReuse()) { + flinkBatchEnv.getConfig().enableObjectReuse(); + } else { + flinkBatchEnv.getConfig().disableObjectReuse(); + } + + applyLatencyTrackingInterval(flinkBatchEnv.getConfig(), options); + + configureWebUIOptions(flinkBatchEnv.getConfig(), options.as(PipelineOptions.class)); + + return flinkBatchEnv; + } + + @VisibleForTesting + static StreamExecutionEnvironment createStreamExecutionEnvironment(FlinkPipelineOptions options) { + return createStreamExecutionEnvironment( + options, + MoreObjects.firstNonNull(options.getFilesToStage(), Collections.emptyList()), + options.getFlinkConfDir()); + } + + /** + * If the submitted job is a stream processing job, this method creates the adequate Flink {@link + * org.apache.flink.streaming.api.environment.StreamExecutionEnvironment} depending on the + * user-specified options. + */ + public static StreamExecutionEnvironment createStreamExecutionEnvironment( + FlinkPipelineOptions options, List filesToStage, @Nullable String confDir) { + + LOG.info("Creating a Streaming Environment."); + + // Although Flink uses Rest, it expects the address not to contain a http scheme + String masterUrl = stripHttpSchema(options.getFlinkMaster()); + Configuration flinkConfiguration = getFlinkConfiguration(confDir); + configureRestartStrategy(options, flinkConfiguration); + configureStateBackend(options, flinkConfiguration); + StreamExecutionEnvironment flinkStreamEnv; + + // depending on the master, create the right environment. + if ("[local]".equals(masterUrl)) { + setManagedMemoryByFraction(flinkConfiguration); + disableClassLoaderLeakCheck(flinkConfiguration); + flinkStreamEnv = + StreamExecutionEnvironment.createLocalEnvironment( + getDefaultLocalParallelism(), flinkConfiguration); + if (!options.getAttachedMode()) { + LOG.warn("Detached mode is only supported in RemoteStreamEnvironment"); + } + } else if ("[auto]".equals(masterUrl)) { + + flinkStreamEnv = StreamExecutionEnvironment.getExecutionEnvironment(flinkConfiguration); + if (flinkStreamEnv instanceof LocalStreamEnvironment) { + disableClassLoaderLeakCheck(flinkConfiguration); + flinkStreamEnv = + StreamExecutionEnvironment.createLocalEnvironment( + getDefaultLocalParallelism(), flinkConfiguration); + } + if (!options.getAttachedMode()) { + LOG.warn("Detached mode is not only supported in [auto]"); + } + } else { + int defaultPort = flinkConfiguration.get(RestOptions.PORT); + HostAndPort hostAndPort = HostAndPort.fromString(masterUrl).withDefaultPort(defaultPort); + flinkConfiguration.set(RestOptions.PORT, hostAndPort.getPort()); + final SavepointRestoreSettings savepointRestoreSettings; + if (options.getSavepointPath() != null) { + savepointRestoreSettings = + SavepointRestoreSettings.forPath( + options.getSavepointPath(), options.getAllowNonRestoredState()); + } else { + savepointRestoreSettings = SavepointRestoreSettings.none(); + } + if (!options.getAttachedMode()) { + flinkConfiguration.set(DeploymentOptions.ATTACHED, options.getAttachedMode()); + } + flinkStreamEnv = + new RemoteStreamEnvironment( + hostAndPort.getHost(), + hostAndPort.getPort(), + flinkConfiguration, + filesToStage.toArray(new String[filesToStage.size()]), + null, + savepointRestoreSettings); + LOG.info("Using Flink Master URL {}:{}.", hostAndPort.getHost(), hostAndPort.getPort()); + } + + // Set the parallelism, required by UnboundedSourceWrapper to generate consistent splits. + final int parallelism = + determineParallelism( + options.getParallelism(), flinkStreamEnv.getParallelism(), flinkConfiguration); + flinkStreamEnv.setParallelism(parallelism); + if (options.getMaxParallelism() > 0) { + flinkStreamEnv.setMaxParallelism(options.getMaxParallelism()); + } else if (!options.isStreaming()) { + // In Flink maxParallelism defines the number of keyGroups. + // (see + // https://github.com/apache/flink/blob/e9dd4683f758b463d0b5ee18e49cecef6a70c5cf/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRangeAssignment.java#L76) + // The default value (parallelism * 1.5) + // (see + // https://github.com/apache/flink/blob/e9dd4683f758b463d0b5ee18e49cecef6a70c5cf/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRangeAssignment.java#L137-L147) + // create a lot of skew so we force maxParallelism = parallelism in Batch mode. + LOG.info("Setting maxParallelism to {}", parallelism); + flinkStreamEnv.setMaxParallelism(parallelism); + } + // set parallelism in the options (required by some execution code) + options.setParallelism(parallelism); + + if (options.getObjectReuse()) { + flinkStreamEnv.getConfig().enableObjectReuse(); + } else { + flinkStreamEnv.getConfig().disableObjectReuse(); + } + + if (!options.getOperatorChaining()) { + flinkStreamEnv.disableOperatorChaining(); + } + + configureCheckpointing(options, flinkStreamEnv); + + applyLatencyTrackingInterval(flinkStreamEnv.getConfig(), options); + + if (options.getAutoWatermarkInterval() != null) { + flinkStreamEnv.getConfig().setAutoWatermarkInterval(options.getAutoWatermarkInterval()); + } + configureWebUIOptions(flinkStreamEnv.getConfig(), options.as(PipelineOptions.class)); + + return flinkStreamEnv; + } + + private static void configureWebUIOptions( + ExecutionConfig config, org.apache.beam.sdk.options.PipelineOptions options) { + SerializablePipelineOptions serializablePipelineOptions = + new SerializablePipelineOptions(options); + String optionsAsString = serializablePipelineOptions.toString(); + + try { + JsonNode node = mapper.readTree(optionsAsString); + JsonNode optionsNode = node.get("options"); + Map output = + Streams.stream(optionsNode.fields()) + .filter(entry -> !entry.getValue().isNull()) + .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue().asText())); + + config.setGlobalJobParameters(new GlobalJobParametersImpl(output)); + } catch (Exception e) { + LOG.warn("Unable to configure web ui options", e); + } + } + + private static class GlobalJobParametersImpl extends ExecutionConfig.GlobalJobParameters { + private final Map jobOptions; + + private GlobalJobParametersImpl(Map jobOptions) { + this.jobOptions = jobOptions; + } + + @Override + public Map toMap() { + return jobOptions; + } + + @Override + public boolean equals(Object obj) { + if (obj == null || this.getClass() != obj.getClass()) { + return false; + } + + ExecutionConfig.GlobalJobParameters jobParams = (ExecutionConfig.GlobalJobParameters) obj; + return Maps.difference(jobParams.toMap(), this.jobOptions).areEqual(); + } + + @Override + public int hashCode() { + return Objects.hashCode(jobOptions); + } + } + + private static void configureCheckpointing( + FlinkPipelineOptions options, StreamExecutionEnvironment flinkStreamEnv) { + // A value of -1 corresponds to disabled checkpointing (see CheckpointConfig in Flink). + // If the value is not -1, then the validity checks are applied. + // By default, checkpointing is disabled. + long checkpointInterval = options.getCheckpointingInterval(); + if (checkpointInterval != -1) { + if (checkpointInterval < 1) { + throw new IllegalArgumentException("The checkpoint interval must be positive"); + } + flinkStreamEnv.enableCheckpointing( + checkpointInterval, CheckpointingMode.valueOf(options.getCheckpointingMode())); + + if (options.getShutdownSourcesAfterIdleMs() == -1) { + // If not explicitly configured, we never shutdown sources when checkpointing is enabled. + options.setShutdownSourcesAfterIdleMs(Long.MAX_VALUE); + } + + if (options.getCheckpointTimeoutMillis() != -1) { + flinkStreamEnv + .getCheckpointConfig() + .setCheckpointTimeout(options.getCheckpointTimeoutMillis()); + } + + boolean externalizedCheckpoint = options.isExternalizedCheckpointsEnabled(); + boolean retainOnCancellation = options.getRetainExternalizedCheckpointsOnCancellation(); + if (externalizedCheckpoint) { + flinkStreamEnv + .getCheckpointConfig() + .setExternalizedCheckpointRetention( + retainOnCancellation + ? ExternalizedCheckpointRetention.RETAIN_ON_CANCELLATION + : ExternalizedCheckpointRetention.DELETE_ON_CANCELLATION); + } + + if (options.getUnalignedCheckpointEnabled()) { + flinkStreamEnv.getCheckpointConfig().enableUnalignedCheckpoints(); + } + flinkStreamEnv + .getCheckpointConfig() + .setForceUnalignedCheckpoints(options.getForceUnalignedCheckpointEnabled()); + + long minPauseBetweenCheckpoints = options.getMinPauseBetweenCheckpoints(); + if (minPauseBetweenCheckpoints != -1) { + flinkStreamEnv + .getCheckpointConfig() + .setMinPauseBetweenCheckpoints(minPauseBetweenCheckpoints); + } + if (options.getTolerableCheckpointFailureNumber() != null + && options.getTolerableCheckpointFailureNumber() > 0) { + flinkStreamEnv + .getCheckpointConfig() + .setTolerableCheckpointFailureNumber(options.getTolerableCheckpointFailureNumber()); + } + + flinkStreamEnv + .getCheckpointConfig() + .setMaxConcurrentCheckpoints(options.getNumConcurrentCheckpoints()); + } else { + if (options.getShutdownSourcesAfterIdleMs() == -1) { + // If not explicitly configured, we never shutdown sources when checkpointing is enabled. + options.setShutdownSourcesAfterIdleMs(0L); + } + } + } + + private static void configureStateBackend(FlinkPipelineOptions options, Configuration config) { + final StateBackend stateBackend; + if (options.getStateBackend() != null) { + final String storagePath = options.getStateBackendStoragePath(); + Preconditions.checkArgument( + storagePath != null, + "State backend was set to '%s' but no storage path was provided.", + options.getStateBackend()); + + if (options.getStateBackend().equalsIgnoreCase("rocksdb")) { + config.set(StateBackendOptions.STATE_BACKEND, "rocksdb"); + } else if (options.getStateBackend().equalsIgnoreCase("filesystem") + || options.getStateBackend().equalsIgnoreCase("hashmap")) { + config.set(StateBackendOptions.STATE_BACKEND, "hashmap"); + } else { + throw new IllegalArgumentException( + String.format( + "Unknown state backend '%s'. Use 'rocksdb' or 'filesystem' or configure via Flink config file.", + options.getStateBackend())); + } + config.set(CheckpointingOptions.CHECKPOINTS_DIRECTORY, storagePath); + } else if (options.getStateBackendFactory() != null) { + // Legacy way of setting the state backend + config.set(StateBackendOptions.STATE_BACKEND, options.getStateBackendFactory().getName()); + } + } + + private static void configureRestartStrategy(FlinkPipelineOptions options, Configuration config) { + // for the following 2 parameters, a value of -1 means that Flink will use + // the default values as specified in the configuration. + int numRetries = options.getNumberOfExecutionRetries(); + if (numRetries != -1) { + // setNumberOfExecutionRetries + config.set(RestartStrategyOptions.RESTART_STRATEGY, "fixed-delay"); + config.set(RestartStrategyOptions.RESTART_STRATEGY_FIXED_DELAY_ATTEMPTS, numRetries); + } + long retryDelay = options.getExecutionRetryDelay(); + if (retryDelay != -1) { + config.set( + RestartStrategyOptions.RESTART_STRATEGY_FIXED_DELAY_DELAY, + java.time.Duration.ofMillis(retryDelay)); + } + } + + /** + * Removes the http:// or https:// schema from a url string. This is commonly used with the + * flink_master address which is expected to be of form host:port but users may specify a URL; + * Python code also assumes a URL which may be passed here. + */ + private static String stripHttpSchema(String url) { + return url.trim().replaceFirst("^http[s]?://", ""); + } + + private static int determineParallelism( + final int pipelineOptionsParallelism, + final int envParallelism, + final Configuration configuration) { + if (pipelineOptionsParallelism > 0) { + return pipelineOptionsParallelism; + } + if (envParallelism > 0) { + // If the user supplies a parallelism on the command-line, this is set on the execution + // environment during creation + return envParallelism; + } + + final int flinkConfigParallelism = + configuration.getOptional(CoreOptions.DEFAULT_PARALLELISM).orElse(-1); + if (flinkConfigParallelism > 0) { + return flinkConfigParallelism; + } + LOG.warn( + "No default parallelism could be found. Defaulting to parallelism 1. " + + "Please set an explicit parallelism with --parallelism"); + return 1; + } + + private static Configuration getFlinkConfiguration(@Nullable String flinkConfDir) { + return flinkConfDir == null || flinkConfDir.isEmpty() + ? GlobalConfiguration.loadConfiguration() + : GlobalConfiguration.loadConfiguration(flinkConfDir); + } + + private static void applyLatencyTrackingInterval( + ExecutionConfig config, FlinkPipelineOptions options) { + long latencyTrackingInterval = options.getLatencyTrackingInterval(); + config.setLatencyTrackingInterval(latencyTrackingInterval); + } + + private static void setManagedMemoryByFraction(final Configuration config) { + if (!config.containsKey("taskmanager.memory.managed.size")) { + float managedMemoryFraction = config.get(TaskManagerOptions.MANAGED_MEMORY_FRACTION); + long freeHeapMemory = EnvironmentInformation.getSizeOfFreeHeapMemoryWithDefrag(); + long managedMemorySize = (long) (freeHeapMemory * managedMemoryFraction); + config.setString("taskmanager.memory.managed.size", String.valueOf(managedMemorySize)); + } + } + + /** + * Disables classloader.check-leaked-classloader unless set by the user. See + * https://github.com/apache/beam/issues/20783. + */ + private static void disableClassLoaderLeakCheck(final Configuration config) { + if (!config.containsKey(CoreOptions.CHECK_LEAKED_CLASSLOADER.key())) { + config.set(CoreOptions.CHECK_LEAKED_CLASSLOADER, false); + } + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkMiniClusterEntryPoint.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkMiniClusterEntryPoint.java new file mode 100644 index 000000000000..ead10741be5b --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkMiniClusterEntryPoint.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.RestOptions; +import org.apache.flink.runtime.minicluster.MiniCluster; +import org.apache.flink.runtime.minicluster.MiniClusterConfiguration; +import org.kohsuke.args4j.CmdLineException; +import org.kohsuke.args4j.CmdLineParser; +import org.kohsuke.args4j.Option; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Entry point for starting an embedded Flink cluster. */ +public class FlinkMiniClusterEntryPoint { + + private static final Logger LOG = LoggerFactory.getLogger(FlinkMiniClusterEntryPoint.class); + + static class MiniClusterArgs { + @Option(name = "--rest-port") + int restPort = 0; + + @Option(name = "--rest-bind-address") + String restBindAddress = ""; + + @Option(name = "--num-task-managers") + int numTaskManagers = 1; + + @Option(name = "--num-task-slots-per-taskmanager") + int numSlotsPerTaskManager = 1; + } + + public static void main(String[] args) throws Exception { + MiniClusterArgs miniClusterArgs = parseArgs(args); + + Configuration flinkConfig = new Configuration(); + flinkConfig.set(RestOptions.PORT, miniClusterArgs.restPort); + if (!miniClusterArgs.restBindAddress.isEmpty()) { + flinkConfig.set(RestOptions.BIND_ADDRESS, miniClusterArgs.restBindAddress); + } + + MiniClusterConfiguration clusterConfig = + new MiniClusterConfiguration.Builder() + .setConfiguration(flinkConfig) + .setNumTaskManagers(miniClusterArgs.numTaskManagers) + .setNumSlotsPerTaskManager(miniClusterArgs.numSlotsPerTaskManager) + .build(); + + try (MiniCluster miniCluster = new MiniCluster(clusterConfig)) { + miniCluster.start(); + System.out.println( + String.format( + "Started Flink mini cluster (%s TaskManagers with %s task slots) with Rest API at %s", + miniClusterArgs.numTaskManagers, + miniClusterArgs.numSlotsPerTaskManager, + miniCluster.getRestAddress())); + Thread.sleep(Long.MAX_VALUE); + } + } + + private static MiniClusterArgs parseArgs(String[] args) { + MiniClusterArgs configuration = new MiniClusterArgs(); + CmdLineParser parser = new CmdLineParser(configuration); + try { + parser.parseArgument(args); + } catch (CmdLineException e) { + LOG.error("Unable to parse command line arguments.", e); + printUsage(parser); + throw new IllegalArgumentException("Unable to parse command line arguments.", e); + } + return configuration; + } + + private static void printUsage(CmdLineParser parser) { + System.err.println( + String.format( + "Usage: java %s arguments...", FlinkMiniClusterEntryPoint.class.getSimpleName())); + parser.printUsage(System.err); + System.err.println(); + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironment.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironment.java new file mode 100644 index 000000000000..758ded42aff5 --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironment.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.beam.runners.core.metrics.MetricsPusher; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.sdk.metrics.MetricsOptions; +import org.apache.beam.sdk.util.construction.resources.PipelineResources; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.flink.api.common.JobExecutionResult; +import org.apache.flink.api.common.RuntimeExecutionMode; +import org.apache.flink.core.execution.JobClient; +import org.apache.flink.runtime.jobgraph.JobGraph; +import org.apache.flink.streaming.api.environment.LocalStreamEnvironment; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.graph.StreamGraph; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * The class that instantiates and manages the execution of a given job. Depending on if the job is + * a Streaming or Batch processing one, it creates a {@link StreamExecutionEnvironment}), the + * necessary {@link FlinkPipelineTranslator} or {@link FlinkStreamingPipelineTranslator}) to + * transform the Beam job into a Flink one, and executes the (translated) job. + */ +@SuppressWarnings({ + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +class FlinkPipelineExecutionEnvironment { + + private static final Logger LOG = + LoggerFactory.getLogger(FlinkPipelineExecutionEnvironment.class); + + private static final Set protectedThreadGroups = ConcurrentHashMap.newKeySet(); + + private final FlinkPipelineOptions options; + + /** + * The Flink DataStream execution environment. This is instantiated to either a {@link + * org.apache.flink.streaming.api.environment.LocalStreamEnvironment} or a {@link + * org.apache.flink.streaming.api.environment.RemoteStreamEnvironment}, depending on the + * configuration options, and more specifically, the url of the master. + */ + private StreamExecutionEnvironment flinkStreamEnv; + + /** + * Creates a {@link FlinkPipelineExecutionEnvironment} with the user-specified parameters in the + * provided {@link FlinkPipelineOptions}. + * + * @param options the user-defined pipeline options. + */ + FlinkPipelineExecutionEnvironment(FlinkPipelineOptions options) { + this.options = Preconditions.checkNotNull(options); + } + + /** + * Depending on if the job is a Streaming or a Batch one, this method creates the necessary + * execution environment and pipeline translator, and translates the {@link + * org.apache.beam.sdk.values.PCollection} program into a + * org.apache.flink.streaming.api.datastream.DataStream}. + */ + public void translate(Pipeline pipeline) { + this.flinkStreamEnv = null; + + final boolean hasUnboundedOutput = + PipelineTranslationModeOptimizer.hasUnboundedOutput(pipeline); + if (hasUnboundedOutput) { + LOG.info("Found unbounded PCollection. Switching to streaming execution."); + options.setStreaming(true); + } + + // Staged files need to be set before initializing the execution environments + prepareFilesToStageForRemoteClusterExecution(options); + + FlinkPipelineTranslator translator; + this.flinkStreamEnv = FlinkExecutionEnvironments.createStreamExecutionEnvironment(options); + if (hasUnboundedOutput && !flinkStreamEnv.getCheckpointConfig().isCheckpointingEnabled()) { + LOG.warn( + "UnboundedSources present which rely on checkpointing, but checkpointing is disabled."); + } + translator = + new FlinkStreamingPipelineTranslator(flinkStreamEnv, options, options.isStreaming()); + if (!options.isStreaming()) { + flinkStreamEnv.setRuntimeMode(RuntimeExecutionMode.BATCH); + } + + // Transform replacements need to receive the finalized PipelineOptions + // including execution mode (batch/streaming) and parallelism. + pipeline.replaceAll(FlinkTransformOverrides.getDefaultOverrides(options)); + + translator.translate(pipeline); + } + + /** + * Local configurations work in the same JVM and have no problems with improperly formatted files + * on classpath (eg. directories with .class files or empty directories). Prepare files for + * staging only when using remote cluster (passing the master address explicitly). + */ + private static void prepareFilesToStageForRemoteClusterExecution(FlinkPipelineOptions options) { + if (!options.getFlinkMaster().matches("\\[auto\\]|\\[collection\\]|\\[local\\]")) { + PipelineResources.prepareFilesForStaging(options); + } + } + + /** Launches the program execution. */ + public PipelineResult executePipeline() throws Exception { + final String jobName = options.getJobName(); + Preconditions.checkNotNull(flinkStreamEnv, "The Pipeline has not yet been translated."); + if (options.getAttachedMode()) { + JobExecutionResult jobExecutionResult = flinkStreamEnv.execute(jobName); + ensureFlinkCleanupComplete(flinkStreamEnv); + return createAttachedPipelineResult(jobExecutionResult); + } else { + JobClient jobClient = flinkStreamEnv.executeAsync(jobName); + return createDetachedPipelineResult(jobClient, options); + } + } + + /** Prevents ThreadGroup destruction while Flink cleanup threads are still running. */ + private void ensureFlinkCleanupComplete(Object executionEnv) { + String javaVersion = System.getProperty("java.version"); + if (javaVersion == null || !javaVersion.startsWith("1.8")) { + return; + } + + if (!(executionEnv instanceof LocalStreamEnvironment)) { + return; + } + + ThreadGroup currentThreadGroup = Thread.currentThread().getThreadGroup(); + if (currentThreadGroup == null) { + return; + } + + protectedThreadGroups.add(currentThreadGroup); + + Thread cleanupReleaser = + new Thread( + () -> { + try { + Thread.sleep(2000); // 2 seconds should be enough for Flink cleanup + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } finally { + protectedThreadGroups.remove(currentThreadGroup); + } + }, + "FlinkCleanupReleaser"); + cleanupReleaser.setDaemon(true); + cleanupReleaser.start(); + } + + private FlinkDetachedRunnerResult createDetachedPipelineResult( + JobClient jobClient, FlinkPipelineOptions options) { + LOG.info("Pipeline submitted in detached mode"); + return new FlinkDetachedRunnerResult(jobClient, options.getJobCheckIntervalInSecs()); + } + + private FlinkRunnerResult createAttachedPipelineResult(JobExecutionResult result) { + LOG.info("Execution finished in {} msecs", result.getNetRuntime()); + Map accumulators = result.getAllAccumulatorResults(); + if (accumulators != null && !accumulators.isEmpty()) { + LOG.info("Final accumulator values:"); + for (Map.Entry entry : result.getAllAccumulatorResults().entrySet()) { + LOG.info("{} : {}", entry.getKey(), entry.getValue()); + } + } + FlinkRunnerResult flinkRunnerResult = + new FlinkRunnerResult(accumulators, result.getNetRuntime()); + MetricsPusher metricsPusher = + new MetricsPusher( + flinkRunnerResult.getMetricsContainerStepMap(), + options.as(MetricsOptions.class), + flinkRunnerResult); + metricsPusher.start(); + return flinkRunnerResult; + } + + /** + * Retrieves the generated JobGraph which can be submitted against the cluster. For testing + * purposes. + */ + @VisibleForTesting + JobGraph getJobGraph(Pipeline p) { + translate(p); + StreamGraph streamGraph = flinkStreamEnv.getStreamGraph(); + // Normally the job name is set when we execute the job, and JobGraph is immutable, so we need + // to set the job name here. + streamGraph.setJobName(p.getOptions().getJobName()); + return streamGraph.getJobGraph(); + } + + @VisibleForTesting + StreamExecutionEnvironment getStreamExecutionEnvironment() { + return flinkStreamEnv; + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkPipelineOptions.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkPipelineOptions.java new file mode 100644 index 000000000000..6a4bd77611fe --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkPipelineOptions.java @@ -0,0 +1,382 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import org.apache.beam.sdk.options.ApplicationNameOptions; +import org.apache.beam.sdk.options.Default; +import org.apache.beam.sdk.options.DefaultValueFactory; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.FileStagingOptions; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.options.StreamingOptions; +import org.apache.flink.runtime.state.StateBackendFactory; + +/** + * Options which can be used to configure the Flink Runner. + * + *

Avoid using `org.apache.flink.*` members below. This allows including the flink runner without + * requiring flink on the classpath (e.g. to use with the direct runner). + */ +public interface FlinkPipelineOptions + extends PipelineOptions, + ApplicationNameOptions, + StreamingOptions, + FileStagingOptions, + VersionDependentFlinkPipelineOptions { + + String AUTO = "[auto]"; + String PIPELINED = "PIPELINED"; + String EXACTLY_ONCE = "EXACTLY_ONCE"; + + /** + * The url of the Flink JobManager on which to execute pipelines. This can either be the address + * of a cluster JobManager, in the form "host:port" or one of the special Strings "[local]", or + * "[auto]". "[local]" will start a local Flink Cluster in the JVM, while "[auto]" will let the + * system decide where to execute the pipeline based on the environment. + */ + @Description( + "Address of the Flink Master where the Pipeline should be executed. Can" + + " either be of the form \"host:port\" or one of the special values [local], " + + "[collection] or [auto].") + @Default.String(AUTO) + String getFlinkMaster(); + + void setFlinkMaster(String value); + + @Description( + "The degree of parallelism to be used when distributing operations onto workers. " + + "If the parallelism is not set, the configured Flink default is used, or 1 if none can be found.") + @Default.Integer(-1) + Integer getParallelism(); + + void setParallelism(Integer value); + + @Description( + "The pipeline wide maximum degree of parallelism to be used. The maximum parallelism specifies the upper limit " + + "for dynamic scaling and the number of key groups used for partitioned state.") + @Default.Integer(-1) + Integer getMaxParallelism(); + + void setMaxParallelism(Integer value); + + @Description( + "The interval in milliseconds at which to trigger checkpoints of the running pipeline. " + + "Default: No checkpointing.") + @Default.Long(-1L) + Long getCheckpointingInterval(); + + void setCheckpointingInterval(Long interval); + + @Description("The checkpointing mode that defines consistency guarantee.") + @Default.String(EXACTLY_ONCE) + String getCheckpointingMode(); + + void setCheckpointingMode(String mode); + + @Description( + "The maximum time in milliseconds that a checkpoint may take before being discarded.") + @Default.Long(-1L) + Long getCheckpointTimeoutMillis(); + + void setCheckpointTimeoutMillis(Long checkpointTimeoutMillis); + + @Description("The minimal pause in milliseconds before the next checkpoint is triggered.") + @Default.Long(-1L) + Long getMinPauseBetweenCheckpoints(); + + void setMinPauseBetweenCheckpoints(Long minPauseInterval); + + @Description( + "The maximum number of concurrent checkpoints. Defaults to 1 (=no concurrent checkpoints).") + @Default.Integer(1) + int getNumConcurrentCheckpoints(); + + void setNumConcurrentCheckpoints(int maxConcurrentCheckpoints); + + @Description( + "Sets the expected behaviour for tasks in case that they encounter an error in their " + + "checkpointing procedure. To tolerate a specific number of failures, set it to a positive number.") + @Default.Integer(0) + Integer getTolerableCheckpointFailureNumber(); + + void setTolerableCheckpointFailureNumber(Integer tolerableCheckpointFailureNumber); + + @Description( + "If set, finishes the current bundle and flushes all output before checkpointing the state of the operators. " + + "By default, starts checkpointing immediately and buffers any remaining bundle output as part of the checkpoint. " + + "The setting may affect the checkpoint alignment.") + @Default.Boolean(false) + boolean getFinishBundleBeforeCheckpointing(); + + void setFinishBundleBeforeCheckpointing(boolean finishBundleBeforeCheckpointing); + + @Description( + "If set, Unaligned checkpoints contain in-flight data (i.e., data stored in buffers) as part of the " + + "checkpoint state, allowing checkpoint barriers to overtake these buffers. Thus, the checkpoint duration " + + "becomes independent of the current throughput as checkpoint barriers are effectively not embedded into the " + + "stream of data anymore") + @Default.Boolean(false) + boolean getUnalignedCheckpointEnabled(); + + void setUnalignedCheckpointEnabled(boolean unalignedCheckpointEnabled); + + @Description("Forces unaligned checkpoints, particularly allowing them for iterative jobs.") + @Default.Boolean(false) + boolean getForceUnalignedCheckpointEnabled(); + + void setForceUnalignedCheckpointEnabled(boolean forceUnalignedCheckpointEnabled); + + @Description( + "Shuts down sources which have been idle for the configured time of milliseconds. Once a source has been " + + "shut down, checkpointing is not possible anymore. Shutting down the sources eventually leads to pipeline " + + "shutdown (=Flink job finishes) once all input has been processed. Unless explicitly set, this will " + + "default to Long.MAX_VALUE when checkpointing is enabled and to 0 when checkpointing is disabled. " + + "See https://issues.apache.org/jira/browse/FLINK-2491 for progress on this issue.") + @Default.Long(-1L) + Long getShutdownSourcesAfterIdleMs(); + + void setShutdownSourcesAfterIdleMs(Long timeoutMs); + + @Description( + "Sets the number of times that failed tasks are re-executed. " + + "A value of zero effectively disables fault tolerance. A value of -1 indicates " + + "that the system default value (as defined in the configuration) should be used.") + @Default.Integer(-1) + Integer getNumberOfExecutionRetries(); + + void setNumberOfExecutionRetries(Integer retries); + + @Description( + "Set job check interval in seconds under detached mode in method waitUntilFinish, " + + "by default it is 5 seconds") + @Default.Integer(5) + int getJobCheckIntervalInSecs(); + + void setJobCheckIntervalInSecs(int seconds); + + @Description("Specifies if the pipeline is submitted in attached or detached mode") + @Default.Boolean(true) + boolean getAttachedMode(); + + void setAttachedMode(boolean attachedMode); + + @Description( + "Sets the delay in milliseconds between executions. A value of {@code -1} " + + "indicates that the default value should be used.") + @Default.Long(-1L) + Long getExecutionRetryDelay(); + + void setExecutionRetryDelay(Long delay); + + @Description("Sets the behavior of reusing objects.") + @Default.Boolean(false) + Boolean getObjectReuse(); + + void setObjectReuse(Boolean reuse); + + @Description("Sets the behavior of operator chaining.") + @Default.Boolean(true) + Boolean getOperatorChaining(); + + void setOperatorChaining(Boolean chaining); + + /** State backend to store Beam's state during computation. */ + @Description( + "Sets the state backend factory to use in streaming mode. " + + "Defaults to the flink cluster's state.backend configuration.") + Class> getStateBackendFactory(); + + void setStateBackendFactory(Class> stateBackendFactory); + + void setStateBackend(String stateBackend); + + @Description( + "State backend to store Beam's state. Use 'rocksdb' or 'hashmap' (same as 'filesystem').") + String getStateBackend(); + + void setStateBackendStoragePath(String path); + + @Description( + "State backend path to persist state backend data. Used to initialize state backend.") + String getStateBackendStoragePath(); + + @Description("Disable Beam metrics in Flink Runner") + @Default.Boolean(false) + Boolean getDisableMetrics(); + + void setDisableMetrics(Boolean disableMetrics); + + /** Enables or disables externalized checkpoints. */ + @Description( + "Enables or disables externalized checkpoints. " + + "Works in conjunction with CheckpointingInterval") + @Default.Boolean(false) + Boolean isExternalizedCheckpointsEnabled(); + + void setExternalizedCheckpointsEnabled(Boolean externalCheckpoints); + + @Description("Sets the behavior of externalized checkpoints on cancellation.") + @Default.Boolean(false) + Boolean getRetainExternalizedCheckpointsOnCancellation(); + + void setRetainExternalizedCheckpointsOnCancellation(Boolean retainOnCancellation); + + @Description( + "The maximum number of elements in a bundle. Default values are 1000 for a streaming job and 1,000,000 for batch") + @Default.InstanceFactory(MaxBundleSizeFactory.class) + Long getMaxBundleSize(); + + void setMaxBundleSize(Long size); + + /** + * Maximum bundle size factory. For a streaming job it's desireable to keep bundle size small to + * optimize latency. In batch, we optimize for throughput and hence bundle size is kept large. + */ + class MaxBundleSizeFactory implements DefaultValueFactory { + @Override + public Long create(PipelineOptions options) { + if (options.as(StreamingOptions.class).isStreaming()) { + return 1000L; + } else { + return 5000L; + } + } + } + + @Description( + "The maximum time to wait before finalising a bundle (in milliseconds). Default values are 1000 for streaming and 10,000 for batch.") + @Default.InstanceFactory(MaxBundleTimeFactory.class) + Long getMaxBundleTimeMills(); + + void setMaxBundleTimeMills(Long time); + + /** + * Maximum bundle time factory. For a streaming job it's desireable to keep the value small to + * optimize latency. In batch, we optimize for throughput and hence bundle time size is kept + * larger. + */ + class MaxBundleTimeFactory implements DefaultValueFactory { + @Override + public Long create(PipelineOptions options) { + if (options.as(StreamingOptions.class).isStreaming()) { + return 1000L; + } else { + return 10000L; + } + } + } + + @Description( + "Interval in milliseconds for sending latency tracking marks from the sources to the sinks. " + + "Interval value <= 0 disables the feature.") + @Default.Long(0) + Long getLatencyTrackingInterval(); + + void setLatencyTrackingInterval(Long interval); + + @Description("The interval in milliseconds for automatic watermark emission.") + Long getAutoWatermarkInterval(); + + void setAutoWatermarkInterval(Long interval); + + /** ExecutionMode is only effective for DataSet API and has been removed in Flink 2.0. */ + @Deprecated() + @Description( + "Flink mode for data exchange of batch pipelines. " + + "Reference {@link org.apache.flink.api.common.ExecutionMode}. " + + "Set this to BATCH_FORCED if pipelines get blocked, see " + + "https://issues.apache.org/jira/browse/FLINK-10672.") + @Default.String(PIPELINED) + String getExecutionModeForBatch(); + + void setExecutionModeForBatch(String executionMode); + + @Description( + "Savepoint restore path. If specified, restores the streaming pipeline from the provided path.") + String getSavepointPath(); + + void setSavepointPath(String path); + + @Description( + "Flag indicating whether non restored state is allowed if the savepoint " + + "contains state for an operator that is no longer part of the pipeline.") + @Default.Boolean(false) + Boolean getAllowNonRestoredState(); + + void setAllowNonRestoredState(Boolean allowNonRestoredState); + + @Description( + "Flag indicating whether auto-balance sharding for WriteFiles transform should be enabled. " + + "This might prove useful in streaming use-case, where pipeline needs to write quite many events " + + "into files, typically divided into N shards. Default behavior on Flink would be, that some workers " + + "will receive more shards to take care of than others. This cause workers to go out of balance in " + + "terms of processing backlog and memory usage. Enabling this feature will make shards to be spread " + + "evenly among available workers in improve throughput and memory usage stability.") + @Default.Boolean(false) + Boolean isAutoBalanceWriteFilesShardingEnabled(); + + void setAutoBalanceWriteFilesShardingEnabled(Boolean autoBalanceWriteFilesShardingEnabled); + + @Description( + "If not null, reports the checkpoint duration of each ParDo stage in the provided metric namespace.") + String getReportCheckpointDuration(); + + void setReportCheckpointDuration(String metricNamespace); + + @Description( + "Remove unneeded deep copy between operators. See https://issues.apache.org/jira/browse/BEAM-11146") + @Default.Boolean(false) + Boolean getFasterCopy(); + + void setFasterCopy(Boolean fasterCopy); + + @Description( + "Directory containing Flink YAML configuration files. " + + "These properties will be set to all jobs submitted to Flink and take precedence " + + "over configurations in FLINK_CONF_DIR.") + String getFlinkConfDir(); + + void setFlinkConfDir(String confDir); + + @Description( + "Set the maximum size of input split when data is read from a filesystem. 0 implies no max size.") + @Default.Long(0) + Long getFileInputSplitMaxSizeMB(); + + void setFileInputSplitMaxSizeMB(Long fileInputSplitMaxSizeMB); + + @Description( + "Allow drain operation for flink pipelines that contain RequiresStableInput operator. Note that at time of draining," + + "the RequiresStableInput contract might be violated if there any processing related failures in the DoFn operator.") + @Default.Boolean(false) + Boolean getEnableStableInputDrain(); + + void setEnableStableInputDrain(Boolean enableStableInputDrain); + + @Description( + "Set a slot sharing group for all bounded sources. This is required when using Datastream to have the same scheduling behaviour as the Dataset API.") + @Default.Boolean(true) + Boolean getForceSlotSharingGroup(); + + void setForceSlotSharingGroup(Boolean enableStableInputDrain); + + static FlinkPipelineOptions defaults() { + return PipelineOptionsFactory.as(FlinkPipelineOptions.class); + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkPipelineRunner.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkPipelineRunner.java new file mode 100644 index 000000000000..460b8f3604c4 --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkPipelineRunner.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import static org.apache.beam.sdk.util.construction.resources.PipelineResources.detectClassPathResourcesToStage; + +import java.util.List; +import java.util.Map; +import java.util.UUID; +import org.apache.beam.model.jobmanagement.v1.ArtifactApi; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.model.pipeline.v1.RunnerApi.Pipeline; +import org.apache.beam.runners.core.metrics.MetricsPusher; +import org.apache.beam.runners.fnexecution.provisioning.JobInfo; +import org.apache.beam.runners.jobsubmission.PortablePipelineJarUtils; +import org.apache.beam.runners.jobsubmission.PortablePipelineResult; +import org.apache.beam.runners.jobsubmission.PortablePipelineRunner; +import org.apache.beam.sdk.io.FileSystems; +import org.apache.beam.sdk.metrics.MetricsEnvironment; +import org.apache.beam.sdk.metrics.MetricsOptions; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.options.SdkHarnessOptions; +import org.apache.beam.sdk.util.construction.PipelineOptionsTranslation; +import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.Struct; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.flink.api.common.JobExecutionResult; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.kohsuke.args4j.CmdLineException; +import org.kohsuke.args4j.CmdLineParser; +import org.kohsuke.args4j.Option; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Runs a Pipeline on Flink via {@link FlinkRunner}. */ +@SuppressWarnings({ + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +public class FlinkPipelineRunner implements PortablePipelineRunner { + private static final Logger LOG = LoggerFactory.getLogger(FlinkPipelineRunner.class); + + private final FlinkPipelineOptions pipelineOptions; + private final String confDir; + private final List filesToStage; + + /** + * Setup a flink pipeline runner. + * + * @param pipelineOptions pipeline options configuring the flink pipeline runner. + * @param confDir flink configuration directory. Note that pipeline option's flinkConfDir, If not + * null, takes precedence against this parameter. + * @param filesToStage a list of file names to stage. + */ + public FlinkPipelineRunner( + FlinkPipelineOptions pipelineOptions, @Nullable String confDir, List filesToStage) { + this.pipelineOptions = pipelineOptions; + // pipelineOptions.getFlinkConfDir takes precedence than confDir + this.confDir = + pipelineOptions.getFlinkConfDir() != null ? pipelineOptions.getFlinkConfDir() : confDir; + this.filesToStage = filesToStage; + } + + @Override + public PortablePipelineResult run(final Pipeline pipeline, JobInfo jobInfo) throws Exception { + MetricsEnvironment.setMetricsSupported(false); + + // Apply log levels settings at the beginning of pipeline run + SdkHarnessOptions.getConfiguredLoggerFromOptions(pipelineOptions.as(SdkHarnessOptions.class)); + + FlinkPortablePipelineTranslator translator = new FlinkStreamingPortablePipelineTranslator(); + return runPipelineWithTranslator(pipeline, jobInfo, translator); + } + + private + PortablePipelineResult runPipelineWithTranslator( + final Pipeline pipeline, JobInfo jobInfo, FlinkPortablePipelineTranslator translator) + throws Exception { + LOG.info("Translating pipeline to Flink program."); + + FlinkPortablePipelineTranslator.Executor executor = + translator.translate( + translator.createTranslationContext(jobInfo, pipelineOptions, confDir, filesToStage), + translator.prepareForTranslation(pipeline)); + final JobExecutionResult result = executor.execute(pipelineOptions.getJobName()); + + return createPortablePipelineResult(result, pipelineOptions); + } + + private PortablePipelineResult createPortablePipelineResult( + JobExecutionResult result, PipelineOptions options) { + String resultClassName = result.getClass().getCanonicalName(); + if (resultClassName.equals("org.apache.flink.core.execution.DetachedJobExecutionResult")) { + LOG.info("Pipeline submitted in Detached mode"); + // no metricsPusher because metrics are not supported in detached mode + return new FlinkPortableRunnerResult.Detached(); + } else { + LOG.info("Execution finished in {} msecs", result.getNetRuntime()); + Map accumulators = result.getAllAccumulatorResults(); + if (accumulators != null && !accumulators.isEmpty()) { + LOG.info("Final accumulator values:"); + for (Map.Entry entry : result.getAllAccumulatorResults().entrySet()) { + LOG.info("{} : {}", entry.getKey(), entry.getValue()); + } + } + FlinkPortableRunnerResult flinkRunnerResult = + new FlinkPortableRunnerResult(accumulators, result.getNetRuntime()); + MetricsPusher metricsPusher = + new MetricsPusher( + flinkRunnerResult.getMetricsContainerStepMap(), + options.as(MetricsOptions.class), + flinkRunnerResult); + metricsPusher.start(); + return flinkRunnerResult; + } + } + + /** + * Main method to be called only as the entry point to an executable jar with structure as defined + * in {@link PortablePipelineJarUtils}. + */ + public static void main(String[] args) throws Exception { + // Register standard file systems. + FileSystems.setDefaultPipelineOptions(PipelineOptionsFactory.create()); + + FlinkPipelineRunnerConfiguration configuration = parseArgs(args); + String baseJobName = + configuration.baseJobName == null + ? PortablePipelineJarUtils.getDefaultJobName() + : configuration.baseJobName; + Preconditions.checkArgument( + baseJobName != null, + "No default job name found. Job name must be set using --base-job-name."); + Pipeline pipeline = PortablePipelineJarUtils.getPipelineFromClasspath(baseJobName); + Struct originalOptions = PortablePipelineJarUtils.getPipelineOptionsFromClasspath(baseJobName); + + // The retrieval token is only required by the legacy artifact service, which the Flink runner + // no longer uses. + String retrievalToken = + ArtifactApi.CommitManifestResponse.Constants.NO_ARTIFACTS_STAGED_TOKEN + .getValueDescriptor() + .getOptions() + .getExtension(RunnerApi.beamConstant); + + FlinkPipelineOptions flinkOptions = + PipelineOptionsTranslation.fromProto(originalOptions).as(FlinkPipelineOptions.class); + String invocationId = + String.format("%s_%s", flinkOptions.getJobName(), UUID.randomUUID().toString()); + + FlinkPipelineRunner runner = + new FlinkPipelineRunner( + flinkOptions, + configuration.flinkConfDir, + detectClassPathResourcesToStage( + FlinkPipelineRunner.class.getClassLoader(), flinkOptions)); + JobInfo jobInfo = + JobInfo.create( + invocationId, + flinkOptions.getJobName(), + retrievalToken, + PipelineOptionsTranslation.toProto(flinkOptions)); + try { + runner.run(pipeline, jobInfo); + } catch (Exception e) { + throw new RuntimeException(String.format("Job %s failed.", invocationId), e); + } + LOG.info("Job {} finished successfully.", invocationId); + } + + private static class FlinkPipelineRunnerConfiguration { + @Option( + name = "--flink-conf-dir", + usage = + "Directory containing Flink YAML configuration files. " + + "These properties will be set to all jobs submitted to Flink and take precedence " + + "over configurations in FLINK_CONF_DIR.") + private String flinkConfDir = null; + + @Option( + name = "--base-job-name", + usage = + "The job to run. This must correspond to a subdirectory of the jar's BEAM-PIPELINE " + + "directory. *Only needs to be specified if the jar contains multiple pipelines.*") + private String baseJobName = null; + } + + private static FlinkPipelineRunnerConfiguration parseArgs(String[] args) { + FlinkPipelineRunnerConfiguration configuration = new FlinkPipelineRunnerConfiguration(); + CmdLineParser parser = new CmdLineParser(configuration); + try { + parser.parseArgument(args); + } catch (CmdLineException e) { + LOG.error("Unable to parse command line arguments.", e); + parser.printUsage(System.err); + throw new IllegalArgumentException("Unable to parse command line arguments.", e); + } + return configuration; + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkPipelineTranslator.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkPipelineTranslator.java new file mode 100644 index 000000000000..13d36e6f8150 --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkPipelineTranslator.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import org.apache.beam.sdk.Pipeline; + +/** + * The role of this class is to translate the Beam operators to their Flink counterparts---a {@link + * FlinkStreamingPipelineTranslator}. The {@link org.apache.beam.sdk.values.PCollection}-based + * user-provided job is translated into a {@link + * org.apache.flink.streaming.api.datastream.DataStream} (for batch) one. + */ +abstract class FlinkPipelineTranslator extends Pipeline.PipelineVisitor.Defaults { + + /** + * Translates the pipeline by passing this class as a visitor. + * + * @param pipeline The pipeline to be translated + */ + public void translate(Pipeline pipeline) { + pipeline.traverseTopologically(this); + } + + /** + * Utility formatting method. + * + * @param n number of spaces to generate + * @return String with "|" followed by n spaces + */ + protected static String genSpaces(int n) { + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < n; i++) { + builder.append("| "); + } + return builder.toString(); + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java new file mode 100644 index 000000000000..ae918083256a --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java @@ -0,0 +1,1151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import static java.lang.String.format; +import static org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils.createOutputMap; +import static org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils.getWindowingStrategy; +import static org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils.instantiateCoder; +import static org.apache.beam.sdk.util.construction.ExecutableStageTranslation.generateNameFromStagePayload; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.auto.service.AutoService; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.runners.core.KeyedWorkItem; +import org.apache.beam.runners.core.SystemReduceFn; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.flink.adapter.FlinkKey; +import org.apache.beam.runners.flink.translation.functions.FlinkExecutableStageContextFactory; +import org.apache.beam.runners.flink.translation.functions.ImpulseSourceFunction; +import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; +import org.apache.beam.runners.flink.translation.wrappers.SourceInputFormat; +import org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperator; +import org.apache.beam.runners.flink.translation.wrappers.streaming.ExecutableStageDoFnOperator; +import org.apache.beam.runners.flink.translation.wrappers.streaming.KvToFlinkKeyKeySelector; +import org.apache.beam.runners.flink.translation.wrappers.streaming.SdfFlinkKeyKeySelector; +import org.apache.beam.runners.flink.translation.wrappers.streaming.SingletonKeyedWorkItemCoder; +import org.apache.beam.runners.flink.translation.wrappers.streaming.WindowDoFnOperator; +import org.apache.beam.runners.flink.translation.wrappers.streaming.WorkItemKeySelector; +import org.apache.beam.runners.flink.translation.wrappers.streaming.io.DedupingOperator; +import org.apache.beam.runners.flink.translation.wrappers.streaming.io.StreamingImpulseSource; +import org.apache.beam.runners.flink.translation.wrappers.streaming.io.TestStreamSource; +import org.apache.beam.runners.flink.translation.wrappers.streaming.io.UnboundedSourceWrapper; +import org.apache.beam.runners.fnexecution.control.SdkHarnessClient; +import org.apache.beam.runners.fnexecution.provisioning.JobInfo; +import org.apache.beam.runners.fnexecution.wire.WireCoders; +import org.apache.beam.sdk.coders.ByteArrayCoder; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.coders.IterableCoder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.VoidCoder; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.io.FileSystems; +import org.apache.beam.sdk.io.UnboundedSource; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.testing.TestStream; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.ViewFn; +import org.apache.beam.sdk.transforms.join.RawUnionValue; +import org.apache.beam.sdk.transforms.join.UnionCoder; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.util.CoderUtils; +import org.apache.beam.sdk.util.construction.ModelCoders; +import org.apache.beam.sdk.util.construction.NativeTransforms; +import org.apache.beam.sdk.util.construction.PTransformTranslation; +import org.apache.beam.sdk.util.construction.ReadTranslation; +import org.apache.beam.sdk.util.construction.RehydratedComponents; +import org.apache.beam.sdk.util.construction.RunnerPCollectionView; +import org.apache.beam.sdk.util.construction.TestStreamTranslation; +import org.apache.beam.sdk.util.construction.WindowingStrategyTranslation; +import org.apache.beam.sdk.util.construction.graph.ExecutableStage; +import org.apache.beam.sdk.util.construction.graph.PipelineNode; +import org.apache.beam.sdk.util.construction.graph.QueryablePipeline; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.PCollectionViews; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.TypeDescriptors; +import org.apache.beam.sdk.values.ValueWithRecordId; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.sdk.values.WindowedValues.WindowedValueCoder; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.InvalidProtocolBufferException; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.BiMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.HashMultiset; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; +import org.apache.flink.api.common.JobExecutionResult; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.OpenContext; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.DataStreamSource; +import org.apache.flink.streaming.api.datastream.KeyedStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.transformations.TwoInputTransformation; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; + +/** Translate an unbounded portable pipeline representation into a Flink pipeline representation. */ +@SuppressWarnings({ + "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) + "keyfor", + "nullness" +}) // TODO(https://github.com/apache/beam/issues/20497) +public class FlinkStreamingPortablePipelineTranslator + implements FlinkPortablePipelineTranslator< + FlinkStreamingPortablePipelineTranslator.StreamingTranslationContext> { + + /** + * Creates a streaming translation context. The resulting Flink execution dag will live in a new + * {@link StreamExecutionEnvironment}. + */ + @Override + public StreamingTranslationContext createTranslationContext( + JobInfo jobInfo, + FlinkPipelineOptions pipelineOptions, + String confDir, + List filesToStage) { + StreamExecutionEnvironment executionEnvironment = + FlinkExecutionEnvironments.createStreamExecutionEnvironment( + pipelineOptions, filesToStage, confDir); + return createTranslationContext(jobInfo, pipelineOptions, executionEnvironment); + } + + /** + * Creates a streaming translation context. The resulting Flink execution dag will live in the + * given {@link StreamExecutionEnvironment}. + */ + public StreamingTranslationContext createTranslationContext( + JobInfo jobInfo, + FlinkPipelineOptions pipelineOptions, + StreamExecutionEnvironment executionEnvironment) { + return new StreamingTranslationContext(jobInfo, pipelineOptions, executionEnvironment); + } + + /** + * Streaming translation context. Stores metadata about known PCollections/DataStreams and holds + * the Flink {@link StreamExecutionEnvironment} that the execution plan will be applied to. + */ + public static class StreamingTranslationContext + implements FlinkPortablePipelineTranslator.TranslationContext, + FlinkPortablePipelineTranslator.Executor { + + private final JobInfo jobInfo; + private final FlinkPipelineOptions options; + private final StreamExecutionEnvironment executionEnvironment; + private final Map> dataStreams; + + private StreamingTranslationContext( + JobInfo jobInfo, + FlinkPipelineOptions options, + StreamExecutionEnvironment executionEnvironment) { + this.jobInfo = jobInfo; + this.options = options; + this.executionEnvironment = executionEnvironment; + dataStreams = new HashMap<>(); + } + + @Override + public JobInfo getJobInfo() { + return jobInfo; + } + + @Override + public FlinkPipelineOptions getPipelineOptions() { + return options; + } + + @Override + public JobExecutionResult execute(String jobName) throws Exception { + return getExecutionEnvironment().execute(jobName); + } + + public StreamExecutionEnvironment getExecutionEnvironment() { + return executionEnvironment; + } + + public void addDataStream(String pCollectionId, DataStream dataStream) { + dataStreams.put(pCollectionId, dataStream); + } + + public DataStream getDataStreamOrThrow(String pCollectionId) { + DataStream dataSet = (DataStream) dataStreams.get(pCollectionId); + if (dataSet == null) { + throw new IllegalArgumentException( + String.format("Unknown datastream for id %s.", pCollectionId)); + } + return dataSet; + } + } + + public interface PTransformTranslator { + void translate(String id, RunnerApi.Pipeline pipeline, T t); + } + + /** @deprecated Legacy non-portable source which can be replaced by a DoFn with timers. */ + @Deprecated + private static final String STREAMING_IMPULSE_TRANSFORM_URN = + "flink:transform:streaming_impulse:v1"; + + private final Map> + urnToTransformTranslator; + + public FlinkStreamingPortablePipelineTranslator() { + this(ImmutableMap.of()); + } + + public FlinkStreamingPortablePipelineTranslator( + Map> extraTranslations) { + ImmutableMap.Builder> translatorMap = + ImmutableMap.builder(); + translatorMap.put(PTransformTranslation.FLATTEN_TRANSFORM_URN, this::translateFlatten); + translatorMap.put(PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN, this::translateGroupByKey); + translatorMap.put(PTransformTranslation.IMPULSE_TRANSFORM_URN, this::translateImpulse); + translatorMap.put(ExecutableStage.URN, this::translateExecutableStage); + translatorMap.put(PTransformTranslation.RESHUFFLE_URN, this::translateReshuffle); + + // TODO Legacy transforms which need to be removed + // Consider removing now that timers are supported + translatorMap.put(STREAMING_IMPULSE_TRANSFORM_URN, this::translateStreamingImpulse); + // Remove once unbounded Reads can be wrapped in SDFs + translatorMap.put(PTransformTranslation.READ_TRANSFORM_URN, this::translateRead); + + // For testing only + translatorMap.put(PTransformTranslation.TEST_STREAM_TRANSFORM_URN, this::translateTestStream); + + translatorMap.putAll(extraTranslations); + + this.urnToTransformTranslator = translatorMap.build(); + } + + @Override + public Set knownUrns() { + // Do not expose Read as a known URN because TrivialNativeTransformExpander otherwise removes + // the subtransforms which are added in case of bounded reads. We only have a + // translator here for unbounded Reads which are native transforms which do not + // have subtransforms. Unbounded Reads are used by cross-language transforms, e.g. + // KafkaIO. + return Sets.difference( + urnToTransformTranslator.keySet(), + ImmutableSet.of(PTransformTranslation.READ_TRANSFORM_URN)); + } + + @Override + public FlinkPortablePipelineTranslator.Executor translate( + StreamingTranslationContext context, RunnerApi.Pipeline pipeline) { + QueryablePipeline p = + QueryablePipeline.forTransforms( + pipeline.getRootTransformIdsList(), pipeline.getComponents()); + for (PipelineNode.PTransformNode transform : p.getTopologicallyOrderedTransforms()) { + urnToTransformTranslator + .getOrDefault(transform.getTransform().getSpec().getUrn(), this::urnNotFound) + .translate(transform.getId(), pipeline, context); + } + + return context; + } + + private void urnNotFound( + String id, + RunnerApi.Pipeline pipeline, + FlinkStreamingPortablePipelineTranslator.TranslationContext context) { + throw new IllegalArgumentException( + String.format( + "Unknown type of URN %s for PTransform with id %s.", + pipeline.getComponents().getTransformsOrThrow(id).getSpec().getUrn(), id)); + } + + private void translateReshuffle( + String id, RunnerApi.Pipeline pipeline, StreamingTranslationContext context) { + RunnerApi.PTransform transform = pipeline.getComponents().getTransformsOrThrow(id); + DataStream>> inputDataStream = + context.getDataStreamOrThrow(Iterables.getOnlyElement(transform.getInputsMap().values())); + context.addDataStream( + Iterables.getOnlyElement(transform.getOutputsMap().values()), inputDataStream.rebalance()); + } + + private void translateFlatten( + String id, RunnerApi.Pipeline pipeline, StreamingTranslationContext context) { + RunnerApi.PTransform transform = pipeline.getComponents().getTransformsOrThrow(id); + Map allInputs = transform.getInputsMap(); + + if (allInputs.isEmpty()) { + + // create an empty dummy source to satisfy downstream operations + // we cannot create an empty source in Flink, therefore we have to + // add the flatMap that simply never forwards the single element + long shutdownAfterIdleSourcesMs = + context.getPipelineOptions().getShutdownSourcesAfterIdleMs(); + DataStreamSource> dummySource = + context + .getExecutionEnvironment() + .addSource(new ImpulseSourceFunction(shutdownAfterIdleSourcesMs)); + + DataStream> result = + dummySource + .>flatMap( + (s, collector) -> { + // never return anything + }) + .returns( + new CoderTypeInformation<>( + WindowedValues.getFullCoder( + (Coder) VoidCoder.of(), GlobalWindow.Coder.INSTANCE), + context.getPipelineOptions())); + context.addDataStream(Iterables.getOnlyElement(transform.getOutputsMap().values()), result); + } else { + DataStream result = null; + + // Determine DataStreams that we use as input several times. For those, we need to uniquify + // input streams because Flink seems to swallow watermarks when we have a union of one and + // the same stream. + HashMultiset> inputCounts = HashMultiset.create(); + for (String input : allInputs.values()) { + DataStream current = context.getDataStreamOrThrow(input); + inputCounts.add(current, 1); + } + + for (String input : allInputs.values()) { + DataStream current = context.getDataStreamOrThrow(input); + final int timesRequired = inputCounts.count(current); + if (timesRequired > 1) { + current = + current.flatMap( + new FlatMapFunction() { + private static final long serialVersionUID = 1L; + + @Override + public void flatMap(T t, Collector collector) { + collector.collect(t); + } + }); + } + result = (result == null) ? current : result.union(current); + } + + context.addDataStream(Iterables.getOnlyElement(transform.getOutputsMap().values()), result); + } + } + + private void translateGroupByKey( + String id, RunnerApi.Pipeline pipeline, StreamingTranslationContext context) { + + RunnerApi.PTransform pTransform = pipeline.getComponents().getTransformsOrThrow(id); + String inputPCollectionId = Iterables.getOnlyElement(pTransform.getInputsMap().values()); + + RehydratedComponents rehydratedComponents = + RehydratedComponents.forComponents(pipeline.getComponents()); + + RunnerApi.WindowingStrategy windowingStrategyProto = + pipeline + .getComponents() + .getWindowingStrategiesOrThrow( + pipeline + .getComponents() + .getPcollectionsOrThrow(inputPCollectionId) + .getWindowingStrategyId()); + + WindowingStrategy windowingStrategy; + try { + windowingStrategy = + WindowingStrategyTranslation.fromProto(windowingStrategyProto, rehydratedComponents); + } catch (InvalidProtocolBufferException e) { + throw new IllegalStateException( + String.format( + "Unable to hydrate GroupByKey windowing strategy %s.", windowingStrategyProto), + e); + } + + WindowedValueCoder> windowedInputCoder = + (WindowedValueCoder) instantiateCoder(inputPCollectionId, pipeline.getComponents()); + + DataStream>> inputDataStream = + context.getDataStreamOrThrow(inputPCollectionId); + + SingleOutputStreamOperator>>> outputDataStream = + addGBK( + inputDataStream, + windowingStrategy, + windowedInputCoder, + pTransform.getUniqueName(), + context); + // Assign a unique but consistent id to re-map operator state + outputDataStream.uid(pTransform.getUniqueName()); + + context.addDataStream( + Iterables.getOnlyElement(pTransform.getOutputsMap().values()), outputDataStream); + } + + private SingleOutputStreamOperator>>> addGBK( + DataStream>> inputDataStream, + WindowingStrategy windowingStrategy, + WindowedValueCoder> windowedInputCoder, + String operatorName, + StreamingTranslationContext context) { + KvCoder inputElementCoder = (KvCoder) windowedInputCoder.getValueCoder(); + + SingletonKeyedWorkItemCoder workItemCoder = + SingletonKeyedWorkItemCoder.of( + inputElementCoder.getKeyCoder(), + inputElementCoder.getValueCoder(), + windowingStrategy.getWindowFn().windowCoder()); + + WindowedValues.FullWindowedValueCoder> windowedWorkItemCoder = + WindowedValues.getFullCoder(workItemCoder, windowingStrategy.getWindowFn().windowCoder()); + + WorkItemKeySelector keySelector = + new WorkItemKeySelector<>(inputElementCoder.getKeyCoder()); + + KeyedStream>, FlinkKey> keyedWorkItemStream = + inputDataStream.keyBy(new KvToFlinkKeyKeySelector(inputElementCoder.getKeyCoder())); + + SystemReduceFn, Iterable, BoundedWindow> reduceFn = + SystemReduceFn.buffering(inputElementCoder.getValueCoder()); + + Coder> accumulatorCoder = IterableCoder.of(inputElementCoder.getValueCoder()); + + Coder>>> outputCoder = + WindowedValues.getFullCoder( + KvCoder.of(inputElementCoder.getKeyCoder(), accumulatorCoder), + windowingStrategy.getWindowFn().windowCoder()); + + TypeInformation>>> outputTypeInfo = + new CoderTypeInformation<>(outputCoder, context.getPipelineOptions()); + + TupleTag>> mainTag = new TupleTag<>("main output"); + + WindowDoFnOperator> doFnOperator = + new WindowDoFnOperator<>( + reduceFn, + operatorName, + windowedWorkItemCoder, + mainTag, + Collections.emptyList(), + new DoFnOperator.MultiOutputOutputManagerFactory<>( + mainTag, + outputCoder, + new SerializablePipelineOptions(context.getPipelineOptions())), + windowingStrategy, + new HashMap<>(), /* side-input mapping */ + Collections.emptyList(), /* side inputs */ + context.getPipelineOptions(), + inputElementCoder.getKeyCoder(), + keySelector /* key selector */); + + return keyedWorkItemStream.transform(operatorName, outputTypeInfo, doFnOperator); + } + + private void translateRead( + String id, RunnerApi.Pipeline pipeline, StreamingTranslationContext context) { + RunnerApi.PTransform transform = pipeline.getComponents().getTransformsOrThrow(id); + String outputCollectionId = Iterables.getOnlyElement(transform.getOutputsMap().values()); + + RunnerApi.ReadPayload payload; + try { + payload = RunnerApi.ReadPayload.parseFrom(transform.getSpec().getPayload()); + } catch (IOException e) { + throw new RuntimeException("Failed to parse ReadPayload from transform", e); + } + + final DataStream> source; + if (payload.getIsBounded() == RunnerApi.IsBounded.Enum.BOUNDED) { + source = + translateBoundedSource( + transform.getUniqueName(), + outputCollectionId, + payload, + pipeline, + context.getPipelineOptions(), + context.getExecutionEnvironment()); + } else { + source = + translateUnboundedSource( + transform.getUniqueName(), + outputCollectionId, + payload, + pipeline, + context.getPipelineOptions(), + context.getExecutionEnvironment()); + } + context.addDataStream(outputCollectionId, source); + } + + private DataStream> translateBoundedSource( + String transformName, + String outputCollectionId, + RunnerApi.ReadPayload payload, + RunnerApi.Pipeline pipeline, + FlinkPipelineOptions pipelineOptions, + StreamExecutionEnvironment env) { + + try { + @SuppressWarnings("unchecked") + BoundedSource boundedSource = + (BoundedSource) ReadTranslation.boundedSourceFromProto(payload); + @SuppressWarnings("unchecked") + WindowedValues.FullWindowedValueCoder wireCoder = + (WindowedValues.FullWindowedValueCoder) + instantiateCoder(outputCollectionId, pipeline.getComponents()); + + WindowedValues.FullWindowedValueCoder sdkCoder = + getSdkCoder(outputCollectionId, pipeline.getComponents()); + + CoderTypeInformation> outputTypeInfo = + new CoderTypeInformation<>(wireCoder, pipelineOptions); + + CoderTypeInformation> sdkTypeInfo = + new CoderTypeInformation<>(sdkCoder, pipelineOptions); + + return env.createInput(new SourceInputFormat<>(transformName, boundedSource, pipelineOptions)) + .name(transformName) + .uid(transformName) + .returns(sdkTypeInfo) + .map(value -> intoWireTypes(sdkCoder, wireCoder, value)) + .returns(outputTypeInfo); + } catch (Exception e) { + throw new RuntimeException("Error while translating UnboundedSource: " + transformName, e); + } + } + + private static DataStream> translateUnboundedSource( + String transformName, + String outputCollectionId, + RunnerApi.ReadPayload payload, + RunnerApi.Pipeline pipeline, + PipelineOptions pipelineOptions, + StreamExecutionEnvironment env) { + + final DataStream> source; + final DataStream>> nonDedupSource; + + @SuppressWarnings("unchecked") + UnboundedSource unboundedSource = + (UnboundedSource) ReadTranslation.unboundedSourceFromProto(payload); + + @SuppressWarnings("unchecked") + WindowingStrategy windowStrategy = + getWindowingStrategy(outputCollectionId, pipeline.getComponents()); + + try { + + @SuppressWarnings("unchecked") + WindowedValues.FullWindowedValueCoder wireCoder = + (WindowedValues.FullWindowedValueCoder) + instantiateCoder(outputCollectionId, pipeline.getComponents()); + + WindowedValues.FullWindowedValueCoder sdkCoder = + getSdkCoder(outputCollectionId, pipeline.getComponents()); + + CoderTypeInformation> outputTypeInfo = + new CoderTypeInformation<>(wireCoder, pipelineOptions); + + CoderTypeInformation> sdkTypeInformation = + new CoderTypeInformation<>(sdkCoder, pipelineOptions); + + TypeInformation>> withIdTypeInfo = + new CoderTypeInformation<>( + WindowedValues.getFullCoder( + ValueWithRecordId.ValueWithRecordIdCoder.of(sdkCoder.getValueCoder()), + windowStrategy.getWindowFn().windowCoder()), + pipelineOptions); + + int parallelism = + env.getMaxParallelism() > 0 ? env.getMaxParallelism() : env.getParallelism(); + UnboundedSourceWrapper sourceWrapper = + new UnboundedSourceWrapper<>( + transformName, pipelineOptions, unboundedSource, parallelism); + nonDedupSource = + env.addSource(sourceWrapper) + .name(transformName) + .uid(transformName) + .returns(withIdTypeInfo); + + if (unboundedSource.requiresDeduping()) { + source = + nonDedupSource + .keyBy(new FlinkStreamingTransformTranslators.ValueWithRecordIdKeySelector<>()) + .transform("deduping", sdkTypeInformation, new DedupingOperator<>(pipelineOptions)) + .uid(format("%s/__deduplicated__", transformName)) + .returns(sdkTypeInformation); + } else { + source = + nonDedupSource + .flatMap(new FlinkStreamingTransformTranslators.StripIdsMap<>(pipelineOptions)) + .returns(sdkTypeInformation); + } + + return source.map(value -> intoWireTypes(sdkCoder, wireCoder, value)).returns(outputTypeInfo); + } catch (Exception e) { + throw new RuntimeException("Error while translating UnboundedSource: " + unboundedSource, e); + } + } + + /** + * Get SDK coder for given PCollection. The SDK coder is the coder that the SDK-harness would have + * used to encode data before passing it to the runner over {@link SdkHarnessClient}. + * + * @param pCollectionId ID of PCollection in components + * @param components the Pipeline components (proto) + * @return SDK-side coder for the PCollection + */ + private static WindowedValues.FullWindowedValueCoder getSdkCoder( + String pCollectionId, RunnerApi.Components components) { + + PipelineNode.PCollectionNode pCollectionNode = + PipelineNode.pCollection(pCollectionId, components.getPcollectionsOrThrow(pCollectionId)); + RunnerApi.Components.Builder componentsBuilder = components.toBuilder(); + String coderId = + WireCoders.addSdkWireCoder( + pCollectionNode, + componentsBuilder, + RunnerApi.ExecutableStagePayload.WireCoderSetting.getDefaultInstance()); + RehydratedComponents rehydratedComponents = + RehydratedComponents.forComponents(componentsBuilder.build()); + try { + @SuppressWarnings("unchecked") + WindowedValues.FullWindowedValueCoder res = + (WindowedValues.FullWindowedValueCoder) rehydratedComponents.getCoder(coderId); + return res; + } catch (IOException ex) { + throw new IllegalStateException("Could not get SDK coder.", ex); + } + } + + /** + * Transform types from SDK types to runner types. The runner uses byte array representation for + * non {@link ModelCoders} coders. + * + * @param inCoder the input coder (SDK-side) + * @param outCoder the output coder (runner-side) + * @param value encoded value + * @param SDK-side type + * @param runer-side type + * @return re-encoded {@link WindowedValue} + */ + private static WindowedValue intoWireTypes( + Coder> inCoder, + Coder> outCoder, + WindowedValue value) { + + try { + return CoderUtils.decodeFromByteArray(outCoder, CoderUtils.encodeToByteArray(inCoder, value)); + } catch (CoderException ex) { + throw new IllegalStateException("Could not transform element into wire types", ex); + } + } + + private void translateImpulse( + String id, RunnerApi.Pipeline pipeline, StreamingTranslationContext context) { + RunnerApi.PTransform pTransform = pipeline.getComponents().getTransformsOrThrow(id); + + TypeInformation> typeInfo = + new CoderTypeInformation<>( + WindowedValues.getFullCoder(ByteArrayCoder.of(), GlobalWindow.Coder.INSTANCE), + context.getPipelineOptions()); + + long shutdownAfterIdleSourcesMs = context.getPipelineOptions().getShutdownSourcesAfterIdleMs(); + SingleOutputStreamOperator> source = + context + .getExecutionEnvironment() + .addSource(new ImpulseSourceFunction(shutdownAfterIdleSourcesMs), "Impulse") + .returns(typeInfo); + + context.addDataStream(Iterables.getOnlyElement(pTransform.getOutputsMap().values()), source); + } + + /** Predicate to determine whether a URN is a Flink native transform. */ + @AutoService(NativeTransforms.IsNativeTransform.class) + public static class IsFlinkNativeTransform implements NativeTransforms.IsNativeTransform { + @Override + public boolean test(RunnerApi.PTransform pTransform) { + return STREAMING_IMPULSE_TRANSFORM_URN.equals( + PTransformTranslation.urnForTransformOrNull(pTransform)); + } + } + + private void translateStreamingImpulse( + String id, RunnerApi.Pipeline pipeline, StreamingTranslationContext context) { + RunnerApi.PTransform pTransform = pipeline.getComponents().getTransformsOrThrow(id); + + TypeInformation> typeInfo = + new CoderTypeInformation<>( + WindowedValues.getFullCoder(ByteArrayCoder.of(), GlobalWindow.Coder.INSTANCE), + context.getPipelineOptions()); + + ObjectMapper objectMapper = new ObjectMapper(); + final int intervalMillis; + final int messageCount; + try { + JsonNode config = objectMapper.readTree(pTransform.getSpec().getPayload().toByteArray()); + intervalMillis = config.path("interval_ms").asInt(100); + messageCount = config.path("message_count").asInt(0); + } catch (IOException e) { + throw new RuntimeException("Failed to parse configuration for streaming impulse", e); + } + + SingleOutputStreamOperator> source = + context + .getExecutionEnvironment() + .addSource( + new StreamingImpulseSource(intervalMillis, messageCount), + StreamingImpulseSource.class.getSimpleName()) + .returns(typeInfo); + + context.addDataStream(Iterables.getOnlyElement(pTransform.getOutputsMap().values()), source); + } + + private void translateExecutableStage( + String id, RunnerApi.Pipeline pipeline, StreamingTranslationContext context) { + // TODO: Fail on splittable DoFns. + // TODO: Special-case single outputs to avoid multiplexing PCollections. + RunnerApi.Components components = pipeline.getComponents(); + RunnerApi.PTransform transform = components.getTransformsOrThrow(id); + Map outputs = transform.getOutputsMap(); + + final RunnerApi.ExecutableStagePayload stagePayload; + try { + stagePayload = RunnerApi.ExecutableStagePayload.parseFrom(transform.getSpec().getPayload()); + } catch (IOException e) { + throw new RuntimeException(e); + } + + String inputPCollectionId = stagePayload.getInput(); + final TransformedSideInputs transformedSideInputs; + + if (stagePayload.getSideInputsCount() > 0) { + transformedSideInputs = transformSideInputs(stagePayload, components, context); + } else { + transformedSideInputs = new TransformedSideInputs(Collections.emptyMap(), null); + } + + Map, OutputTag>> tagsToOutputTags = Maps.newLinkedHashMap(); + Map, Coder>> tagsToCoders = Maps.newLinkedHashMap(); + // TODO: does it matter which output we designate as "main" + final TupleTag mainOutputTag = + outputs.isEmpty() ? null : new TupleTag(outputs.keySet().iterator().next()); + + // associate output tags with ids, output manager uses these Integer ids to serialize state + BiMap outputIndexMap = createOutputMap(outputs.keySet()); + Map>> outputCoders = Maps.newHashMap(); + Map, Integer> tagsToIds = Maps.newHashMap(); + Map> collectionIdToTupleTag = Maps.newHashMap(); + // order output names for deterministic mapping + for (String localOutputName : new TreeMap<>(outputIndexMap).keySet()) { + String collectionId = outputs.get(localOutputName); + Coder> windowCoder = (Coder) instantiateCoder(collectionId, components); + outputCoders.put(localOutputName, windowCoder); + TupleTag tupleTag = new TupleTag<>(localOutputName); + CoderTypeInformation> typeInformation = + new CoderTypeInformation(windowCoder, context.getPipelineOptions()); + tagsToOutputTags.put(tupleTag, new OutputTag<>(localOutputName, typeInformation)); + tagsToCoders.put(tupleTag, windowCoder); + tagsToIds.put(tupleTag, outputIndexMap.get(localOutputName)); + collectionIdToTupleTag.put(collectionId, tupleTag); + } + + final SingleOutputStreamOperator> outputStream; + DataStream> inputDataStream = + context.getDataStreamOrThrow(inputPCollectionId); + + CoderTypeInformation> outputTypeInformation = + !outputs.isEmpty() + ? new CoderTypeInformation( + outputCoders.get(mainOutputTag.getId()), context.getPipelineOptions()) + : null; + + ArrayList> additionalOutputTags = Lists.newArrayList(); + for (TupleTag tupleTag : tagsToCoders.keySet()) { + if (!mainOutputTag.getId().equals(tupleTag.getId())) { + additionalOutputTags.add(tupleTag); + } + } + + final Coder> windowedInputCoder = + instantiateCoder(inputPCollectionId, components); + + final boolean stateful = + stagePayload.getUserStatesCount() > 0 || stagePayload.getTimersCount() > 0; + final boolean hasSdfProcessFn = + stagePayload.getComponents().getTransformsMap().values().stream() + .anyMatch( + pTransform -> + pTransform + .getSpec() + .getUrn() + .equals( + PTransformTranslation + .SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN)); + Coder keyCoder = null; + KeySelector, ?> keySelector = null; + if (stateful || hasSdfProcessFn) { + // Stateful/SDF stages are only allowed of KV input. + Coder valueCoder = + ((WindowedValues.FullWindowedValueCoder) windowedInputCoder).getValueCoder(); + if (!(valueCoder instanceof KvCoder)) { + throw new IllegalStateException( + String.format( + Locale.ENGLISH, + "The element coder for stateful DoFn '%s' must be KvCoder but is: %s", + inputPCollectionId, + valueCoder.getClass().getSimpleName())); + } + if (stateful) { + keyCoder = ((KvCoder) valueCoder).getKeyCoder(); + keySelector = new KvToFlinkKeyKeySelector(keyCoder); + } else { + // For an SDF, we know that the input element should be + // KV>, size>. We are going to use the element + // as the key. + if (!(((KvCoder) valueCoder).getKeyCoder() instanceof KvCoder)) { + throw new IllegalStateException( + String.format( + Locale.ENGLISH, + "The element coder for splittable DoFn '%s' must be KVCoder(KvCoder, DoubleCoder) but is: %s", + inputPCollectionId, + valueCoder.getClass().getSimpleName())); + } + keyCoder = ((KvCoder) ((KvCoder) valueCoder).getKeyCoder()).getKeyCoder(); + keySelector = new SdfFlinkKeyKeySelector(keyCoder); + } + inputDataStream = inputDataStream.keyBy(keySelector); + } + + DoFnOperator.MultiOutputOutputManagerFactory outputManagerFactory = + new DoFnOperator.MultiOutputOutputManagerFactory<>( + mainOutputTag, + tagsToOutputTags, + tagsToCoders, + tagsToIds, + new SerializablePipelineOptions(context.getPipelineOptions())); + + DoFnOperator doFnOperator = + new ExecutableStageDoFnOperator<>( + transform.getUniqueName(), + windowedInputCoder, + Collections.emptyMap(), + mainOutputTag, + additionalOutputTags, + outputManagerFactory, + transformedSideInputs.unionTagToView, + new ArrayList<>(transformedSideInputs.unionTagToView.values()), + getSideInputIdToPCollectionViewMap(stagePayload, components), + context.getPipelineOptions(), + stagePayload, + context.getJobInfo(), + FlinkExecutableStageContextFactory.getInstance(), + collectionIdToTupleTag, + getWindowingStrategy(inputPCollectionId, components), + keyCoder, + keySelector); + + final String operatorName = generateNameFromStagePayload(stagePayload); + + if (transformedSideInputs.unionTagToView.isEmpty()) { + outputStream = inputDataStream.transform(operatorName, outputTypeInformation, doFnOperator); + } else { + DataStream sideInputStream = + transformedSideInputs.unionedSideInputs.broadcast(); + if (stateful || hasSdfProcessFn) { + // We have to manually construct the two-input transform because we're not + // allowed to have only one input keyed, normally. Since Flink 1.5.0 it's + // possible to use the Broadcast State Pattern which provides a more elegant + // way to process keyed main input with broadcast state, but it's not feasible + // here because it breaks the DoFnOperator abstraction. + TwoInputTransformation>, RawUnionValue, WindowedValue> + rawFlinkTransform = + new TwoInputTransformation( + inputDataStream.getTransformation(), + sideInputStream.getTransformation(), + transform.getUniqueName(), + doFnOperator, + outputTypeInformation, + inputDataStream.getParallelism()); + + rawFlinkTransform.setStateKeyType(((KeyedStream) inputDataStream).getKeyType()); + rawFlinkTransform.setStateKeySelectors( + ((KeyedStream) inputDataStream).getKeySelector(), null); + + outputStream = + new SingleOutputStreamOperator( + inputDataStream.getExecutionEnvironment(), + rawFlinkTransform) {}; // we have to cheat around the ctor being protected + } else { + outputStream = + inputDataStream + .connect(sideInputStream) + .transform(operatorName, outputTypeInformation, doFnOperator); + } + } + // Assign a unique but consistent id to re-map operator state + outputStream.uid(transform.getUniqueName()); + + if (mainOutputTag != null) { + context.addDataStream(outputs.get(mainOutputTag.getId()), outputStream); + } + + for (TupleTag tupleTag : additionalOutputTags) { + context.addDataStream( + outputs.get(tupleTag.getId()), + outputStream.getSideOutput(tagsToOutputTags.get(tupleTag))); + } + } + + private void translateTestStream( + String id, RunnerApi.Pipeline pipeline, StreamingTranslationContext context) { + RunnerApi.Components components = pipeline.getComponents(); + + SerializableFunction> testStreamDecoder = + bytes -> { + try { + RunnerApi.TestStreamPayload testStreamPayload = + RunnerApi.TestStreamPayload.parseFrom(bytes); + @SuppressWarnings("unchecked") + TestStream testStream = + (TestStream) + TestStreamTranslation.testStreamFromProtoPayload( + testStreamPayload, RehydratedComponents.forComponents(components)); + return testStream; + } catch (Exception e) { + throw new RuntimeException("Can't decode TestStream payload.", e); + } + }; + + RunnerApi.PTransform transform = components.getTransformsOrThrow(id); + String outputPCollectionId = Iterables.getOnlyElement(transform.getOutputsMap().values()); + Coder> coder = instantiateCoder(outputPCollectionId, components); + + DataStream> source = + context + .getExecutionEnvironment() + .addSource( + new TestStreamSource<>( + testStreamDecoder, transform.getSpec().getPayload().toByteArray()), + new CoderTypeInformation<>(coder, context.getPipelineOptions())); + + context.addDataStream(outputPCollectionId, source); + } + + private static LinkedHashMap> + getSideInputIdToPCollectionViewMap( + RunnerApi.ExecutableStagePayload stagePayload, RunnerApi.Components components) { + + RehydratedComponents rehydratedComponents = RehydratedComponents.forComponents(components); + + LinkedHashMap> sideInputs = + new LinkedHashMap<>(); + // for PCollectionView compatibility, not used to transform materialization + ViewFn>, ?> viewFn = + (ViewFn) + new PCollectionViews.MultimapViewFn<>( + (PCollectionViews.TypeDescriptorSupplier>>) + () -> TypeDescriptors.iterables(new TypeDescriptor>() {}), + (PCollectionViews.TypeDescriptorSupplier) TypeDescriptors::voids); + + for (RunnerApi.ExecutableStagePayload.SideInputId sideInputId : + stagePayload.getSideInputsList()) { + + // TODO: local name is unique as long as only one transform with side input can be within a + // stage + String sideInputTag = sideInputId.getLocalName(); + String collectionId = + components + .getTransformsOrThrow(sideInputId.getTransformId()) + .getInputsOrThrow(sideInputId.getLocalName()); + RunnerApi.WindowingStrategy windowingStrategyProto = + components.getWindowingStrategiesOrThrow( + components.getPcollectionsOrThrow(collectionId).getWindowingStrategyId()); + + final WindowingStrategy windowingStrategy; + try { + windowingStrategy = + WindowingStrategyTranslation.fromProto(windowingStrategyProto, rehydratedComponents); + } catch (InvalidProtocolBufferException e) { + throw new IllegalStateException( + String.format( + "Unable to hydrate side input windowing strategy %s.", windowingStrategyProto), + e); + } + + Coder> coder = instantiateCoder(collectionId, components); + // side input materialization via GBK (T -> Iterable) + WindowedValueCoder wvCoder = (WindowedValueCoder) coder; + coder = wvCoder.withValueCoder(IterableCoder.of(wvCoder.getValueCoder())); + + sideInputs.put( + sideInputId, + new RunnerPCollectionView<>( + null, + new TupleTag<>(sideInputTag), + viewFn, + // TODO: support custom mapping fn + windowingStrategy.getWindowFn().getDefaultWindowMappingFn(), + windowingStrategy, + coder)); + } + return sideInputs; + } + + private TransformedSideInputs transformSideInputs( + RunnerApi.ExecutableStagePayload stagePayload, + RunnerApi.Components components, + StreamingTranslationContext context) { + + LinkedHashMap> sideInputs = + getSideInputIdToPCollectionViewMap(stagePayload, components); + + Map, Integer> tagToIntMapping = new HashMap<>(); + Map> intToViewMapping = new HashMap<>(); + List>> kvCoders = new ArrayList<>(); + List> viewCoders = new ArrayList<>(); + + int count = 0; + for (Map.Entry> sideInput : + sideInputs.entrySet()) { + TupleTag tag = sideInput.getValue().getTagInternal(); + intToViewMapping.put(count, sideInput.getValue()); + tagToIntMapping.put(tag, count); + count++; + String collectionId = + components + .getTransformsOrThrow(sideInput.getKey().getTransformId()) + .getInputsOrThrow(sideInput.getKey().getLocalName()); + DataStream sideInputStream = context.getDataStreamOrThrow(collectionId); + TypeInformation tpe = sideInputStream.getType(); + if (!(tpe instanceof CoderTypeInformation)) { + throw new IllegalStateException("Input Stream TypeInformation is no CoderTypeInformation."); + } + + WindowedValueCoder coder = + (WindowedValueCoder) ((CoderTypeInformation) tpe).getCoder(); + Coder> kvCoder = KvCoder.of(VoidCoder.of(), coder.getValueCoder()); + kvCoders.add(coder.withValueCoder(kvCoder)); + // coder for materialized view matching GBK below + WindowedValueCoder>> viewCoder = + coder.withValueCoder(KvCoder.of(VoidCoder.of(), IterableCoder.of(coder.getValueCoder()))); + viewCoders.add(viewCoder); + } + + // second pass, now that we gathered the input coders + UnionCoder unionCoder = UnionCoder.of(viewCoders); + + CoderTypeInformation unionTypeInformation = + new CoderTypeInformation<>(unionCoder, context.getPipelineOptions()); + + // transform each side input to RawUnionValue and union them + DataStream sideInputUnion = null; + + for (Map.Entry> sideInput : + sideInputs.entrySet()) { + TupleTag tag = sideInput.getValue().getTagInternal(); + final int intTag = tagToIntMapping.get(tag); + RunnerApi.PTransform pTransform = + components.getTransformsOrThrow(sideInput.getKey().getTransformId()); + String collectionId = pTransform.getInputsOrThrow(sideInput.getKey().getLocalName()); + DataStream> sideInputStream = context.getDataStreamOrThrow(collectionId); + + // insert GBK to materialize side input view + String viewName = + sideInput.getKey().getTransformId() + "-" + sideInput.getKey().getLocalName(); + WindowedValueCoder> kvCoder = kvCoders.get(intTag); + DataStream>> keyedSideInputStream = + sideInputStream.map(new ToVoidKeyValue(context.getPipelineOptions())); + + SingleOutputStreamOperator>>> viewStream = + addGBK( + keyedSideInputStream, + sideInput.getValue().getWindowingStrategyInternal(), + kvCoder, + viewName, + context); + // Assign a unique but consistent id to re-map operator state + viewStream.uid(pTransform.getUniqueName() + "-" + sideInput.getKey().getLocalName()); + + DataStream unionValueStream = + viewStream + .map( + new FlinkStreamingTransformTranslators.ToRawUnion<>( + intTag, context.getPipelineOptions())) + .returns(unionTypeInformation); + + if (sideInputUnion == null) { + sideInputUnion = unionValueStream; + } else { + sideInputUnion = sideInputUnion.union(unionValueStream); + } + } + + return new TransformedSideInputs(intToViewMapping, sideInputUnion); + } + + private static class TransformedSideInputs { + final Map> unionTagToView; + final DataStream unionedSideInputs; + + TransformedSideInputs( + Map> unionTagToView, + DataStream unionedSideInputs) { + this.unionTagToView = unionTagToView; + this.unionedSideInputs = unionedSideInputs; + } + } + + private static class ToVoidKeyValue + extends RichMapFunction, WindowedValue>> { + + private final SerializablePipelineOptions options; + + public ToVoidKeyValue(PipelineOptions pipelineOptions) { + this.options = new SerializablePipelineOptions(pipelineOptions); + } + + @Override + public void open(OpenContext openContext) { + // Initialize FileSystems for any coders which may want to use the FileSystem, + // see https://issues.apache.org/jira/browse/BEAM-8303 + FileSystems.setDefaultPipelineOptions(options.get()); + } + + @Override + public WindowedValue> map(WindowedValue value) { + return value.withValue(KV.of(null, value.getValue())); + } + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java new file mode 100644 index 000000000000..abeb9daaf044 --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java @@ -0,0 +1,1440 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import static java.lang.String.format; +import static org.apache.beam.sdk.util.construction.SplittableParDo.SPLITTABLE_PROCESS_URN; + +import com.google.auto.service.AutoService; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.beam.runners.core.KeyedWorkItem; +import org.apache.beam.runners.core.SplittableParDoViaKeyedWorkItems; +import org.apache.beam.runners.core.SystemReduceFn; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.flink.adapter.FlinkKey; +import org.apache.beam.runners.flink.translation.functions.FlinkAssignWindows; +import org.apache.beam.runners.flink.translation.functions.ImpulseSourceFunction; +import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; +import org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperator; +import org.apache.beam.runners.flink.translation.wrappers.streaming.KvToFlinkKeyKeySelector; +import org.apache.beam.runners.flink.translation.wrappers.streaming.SingletonKeyedWorkItem; +import org.apache.beam.runners.flink.translation.wrappers.streaming.SingletonKeyedWorkItemCoder; +import org.apache.beam.runners.flink.translation.wrappers.streaming.SplittableDoFnOperator; +import org.apache.beam.runners.flink.translation.wrappers.streaming.WindowDoFnOperator; +import org.apache.beam.runners.flink.translation.wrappers.streaming.WorkItemKeySelector; +import org.apache.beam.runners.flink.translation.wrappers.streaming.io.BeamStoppableFunction; +import org.apache.beam.runners.flink.translation.wrappers.streaming.io.DedupingOperator; +import org.apache.beam.runners.flink.translation.wrappers.streaming.io.TestStreamSource; +import org.apache.beam.runners.flink.translation.wrappers.streaming.io.UnboundedSourceWrapper; +import org.apache.beam.runners.flink.translation.wrappers.streaming.io.source.FlinkSource; +import org.apache.beam.runners.flink.translation.wrappers.streaming.io.source.bounded.FlinkBoundedSource; +import org.apache.beam.runners.flink.translation.wrappers.streaming.io.source.unbounded.FlinkUnboundedSource; +import org.apache.beam.sdk.coders.ByteArrayCoder; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.coders.IterableCoder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.VoidCoder; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.io.FileSystems; +import org.apache.beam.sdk.io.UnboundedSource; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.testing.TestStream; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.CombineFnBase.GlobalCombineFn; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFnSchemaInformation; +import org.apache.beam.sdk.transforms.Impulse; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.join.RawUnionValue; +import org.apache.beam.sdk.transforms.join.UnionCoder; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature; +import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.util.CoderUtils; +import org.apache.beam.sdk.util.construction.PTransformTranslation; +import org.apache.beam.sdk.util.construction.ParDoTranslation; +import org.apache.beam.sdk.util.construction.ReadTranslation; +import org.apache.beam.sdk.util.construction.SplittableParDo; +import org.apache.beam.sdk.util.construction.TransformPayloadTranslatorRegistrar; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; +import org.apache.beam.sdk.values.ValueWithRecordId; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.OpenContext; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.operators.ProcessingTimeService.ProcessingTimeCallback; +import org.apache.flink.api.common.state.CheckpointListener; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.ResultTypeQueryable; +import org.apache.flink.api.java.typeutils.ValueTypeInfo; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.DataStreamSource; +import org.apache.flink.streaming.api.datastream.DataStreamUtils; +import org.apache.flink.streaming.api.datastream.KeyedStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.functions.source.legacy.RichParallelSourceFunction; +import org.apache.flink.streaming.api.transformations.TwoInputTransformation; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * This class contains all the mappings between Beam and Flink streaming transformations. The + * {@link FlinkStreamingPipelineTranslator} traverses the Beam job and comes here to translate the + * encountered Beam transformations into Flink one, based on the mapping available in this class. + */ +@SuppressWarnings({ + "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +class FlinkStreamingTransformTranslators { + + // -------------------------------------------------------------------------------------------- + // Transform Translator Registry + // -------------------------------------------------------------------------------------------- + + /** A map from a Transform URN to the translator. */ + @SuppressWarnings("rawtypes") + private static final Map + TRANSLATORS = new HashMap<>(); + + // here you can find all the available translators. + static { + TRANSLATORS.put(PTransformTranslation.IMPULSE_TRANSFORM_URN, new ImpulseTranslator()); + TRANSLATORS.put(PTransformTranslation.READ_TRANSFORM_URN, new ReadSourceTranslator()); + + TRANSLATORS.put(PTransformTranslation.PAR_DO_TRANSFORM_URN, new ParDoStreamingTranslator()); + TRANSLATORS.put(SPLITTABLE_PROCESS_URN, new SplittableProcessElementsStreamingTranslator()); + TRANSLATORS.put(SplittableParDo.SPLITTABLE_GBKIKWI_URN, new GBKIntoKeyedWorkItemsTranslator()); + + TRANSLATORS.put( + PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN, new WindowAssignTranslator()); + TRANSLATORS.put( + PTransformTranslation.FLATTEN_TRANSFORM_URN, new FlattenPCollectionTranslator()); + TRANSLATORS.put( + CreateStreamingFlinkView.CREATE_STREAMING_FLINK_VIEW_URN, + new CreateViewStreamingTranslator()); + + TRANSLATORS.put(PTransformTranslation.RESHUFFLE_URN, new ReshuffleTranslatorStreaming()); + TRANSLATORS.put(PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN, new GroupByKeyTranslator()); + TRANSLATORS.put( + PTransformTranslation.COMBINE_PER_KEY_TRANSFORM_URN, new CombinePerKeyTranslator()); + + TRANSLATORS.put(PTransformTranslation.TEST_STREAM_TRANSFORM_URN, new TestStreamTranslator()); + } + + private static final String FORCED_SLOT_GROUP = "beam"; + + public static FlinkStreamingPipelineTranslator.StreamTransformTranslator getTranslator( + PTransform transform) { + @Nullable String urn = PTransformTranslation.urnForTransformOrNull(transform); + return urn == null ? null : TRANSLATORS.get(urn); + } + + @SuppressWarnings("unchecked") + public static String getCurrentTransformName(FlinkStreamingTranslationContext context) { + return context.getCurrentTransform().getFullName(); + } + + // -------------------------------------------------------------------------------------------- + // Transformation Implementations + // -------------------------------------------------------------------------------------------- + + private static class UnboundedReadSourceTranslator + extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< + PTransform>> { + + @Override + public void translateNode( + PTransform> transform, FlinkStreamingTranslationContext context) { + PCollection output = context.getOutput(transform); + + DataStream> source; + DataStream>> nonDedupSource; + TypeInformation> outputTypeInfo = + context.getTypeInfo(context.getOutput(transform)); + + Coder coder = context.getOutput(transform).getCoder(); + + TypeInformation>> withIdTypeInfo = + new CoderTypeInformation<>( + WindowedValues.getFullCoder( + ValueWithRecordId.ValueWithRecordIdCoder.of(coder), + output.getWindowingStrategy().getWindowFn().windowCoder()), + context.getPipelineOptions()); + + UnboundedSource rawSource; + try { + rawSource = + ReadTranslation.unboundedSourceFromTransform( + (AppliedPTransform, PTransform>>) + context.getCurrentTransform()); + } catch (IOException e) { + throw new RuntimeException(e); + } + + String fullName = getCurrentTransformName(context); + try { + int parallelism = + context.getExecutionEnvironment().getMaxParallelism() > 0 + ? context.getExecutionEnvironment().getMaxParallelism() + : context.getExecutionEnvironment().getParallelism(); + + FlinkUnboundedSource unboundedSource = + FlinkSource.unbounded( + transform.getName(), + rawSource, + new SerializablePipelineOptions(context.getPipelineOptions()), + parallelism); + nonDedupSource = + context + .getExecutionEnvironment() + .fromSource( + unboundedSource, WatermarkStrategy.noWatermarks(), fullName, withIdTypeInfo) + .uid(fullName); + + if (rawSource.requiresDeduping()) { + source = + nonDedupSource + .keyBy(new ValueWithRecordIdKeySelector<>()) + .transform( + "deduping", + outputTypeInfo, + new DedupingOperator<>(context.getPipelineOptions())) + .uid(format("%s/__deduplicated__", fullName)); + } else { + source = + nonDedupSource + .flatMap(new StripIdsMap<>(context.getPipelineOptions())) + .returns(outputTypeInfo); + } + } catch (Exception e) { + throw new RuntimeException("Error while translating UnboundedSource: " + rawSource, e); + } + + context.setOutputDataStream(output, source); + } + } + + static class ValueWithRecordIdKeySelector + implements KeySelector>, FlinkKey>, + ResultTypeQueryable { + + @Override + public FlinkKey getKey(WindowedValue> value) throws Exception { + return FlinkKey.of(ByteBuffer.wrap(value.getValue().getId())); + } + + @Override + public TypeInformation getProducedType() { + return ValueTypeInfo.of(FlinkKey.class); + } + } + + public static class StripIdsMap + extends RichFlatMapFunction>, WindowedValue> { + + private final SerializablePipelineOptions options; + + StripIdsMap(PipelineOptions options) { + this.options = new SerializablePipelineOptions(options); + } + + @Override + public void open(OpenContext openContext) { + // Initialize FileSystems for any coders which may want to use the FileSystem, + // see https://issues.apache.org/jira/browse/BEAM-8303 + FileSystems.setDefaultPipelineOptions(options.get()); + } + + @Override + public void flatMap( + WindowedValue> value, Collector> collector) + throws Exception { + collector.collect(value.withValue(value.getValue().getValue())); + } + } + + private static class ImpulseTranslator + extends FlinkStreamingPipelineTranslator.StreamTransformTranslator { + @Override + void translateNode(Impulse transform, FlinkStreamingTranslationContext context) { + + TypeInformation> typeInfo = + new CoderTypeInformation<>( + WindowedValues.getFullCoder(ByteArrayCoder.of(), GlobalWindow.Coder.INSTANCE), + context.getPipelineOptions()); + + SingleOutputStreamOperator> impulseOperator; + if (context.isStreaming()) { + long shutdownAfterIdleSourcesMs = + context + .getPipelineOptions() + .as(FlinkPipelineOptions.class) + .getShutdownSourcesAfterIdleMs(); + impulseOperator = + context + .getExecutionEnvironment() + .addSource(new ImpulseSourceFunction(shutdownAfterIdleSourcesMs), "Impulse") + .returns(typeInfo); + } else { + FlinkBoundedSource impulseSource = FlinkSource.boundedImpulse(); + impulseOperator = + context + .getExecutionEnvironment() + .fromSource(impulseSource, WatermarkStrategy.noWatermarks(), "Impulse") + .returns(typeInfo); + + if (!context.isStreaming() + && context + .getPipelineOptions() + .as(FlinkPipelineOptions.class) + .getForceSlotSharingGroup()) { + impulseOperator = impulseOperator.slotSharingGroup(FORCED_SLOT_GROUP); + } + } + context.setOutputDataStream(context.getOutput(transform), impulseOperator); + } + } + + private static class ReadSourceTranslator + extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< + PTransform>> { + + private final BoundedReadSourceTranslator boundedTranslator = + new BoundedReadSourceTranslator<>(); + private final UnboundedReadSourceTranslator unboundedTranslator = + new UnboundedReadSourceTranslator<>(); + + @Override + void translateNode( + PTransform> transform, FlinkStreamingTranslationContext context) { + if (ReadTranslation.sourceIsBounded(context.getCurrentTransform()) + == PCollection.IsBounded.BOUNDED) { + boundedTranslator.translateNode(transform, context); + } else { + unboundedTranslator.translateNode(transform, context); + } + } + } + + private static class BoundedReadSourceTranslator + extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< + PTransform>> { + + @Override + public void translateNode( + PTransform> transform, FlinkStreamingTranslationContext context) { + PCollection output = context.getOutput(transform); + + TypeInformation> outputTypeInfo = + context.getTypeInfo(context.getOutput(transform)); + + BoundedSource rawSource; + try { + rawSource = + ReadTranslation.boundedSourceFromTransform( + (AppliedPTransform, PTransform>>) + context.getCurrentTransform()); + } catch (IOException e) { + throw new RuntimeException(e); + } + + String fullName = getCurrentTransformName(context); + int parallelism = + context.getExecutionEnvironment().getMaxParallelism() > 0 + ? context.getExecutionEnvironment().getMaxParallelism() + : context.getExecutionEnvironment().getParallelism(); + + FlinkBoundedSource flinkBoundedSource = + FlinkSource.bounded( + transform.getName(), + rawSource, + new SerializablePipelineOptions(context.getPipelineOptions()), + parallelism); + + TypeInformation> typeInfo = context.getTypeInfo(output); + + SingleOutputStreamOperator> source; + try { + source = + context + .getExecutionEnvironment() + .fromSource( + flinkBoundedSource, WatermarkStrategy.noWatermarks(), fullName, outputTypeInfo) + .uid(fullName) + .returns(typeInfo); + + if (!context.isStreaming() + && context + .getPipelineOptions() + .as(FlinkPipelineOptions.class) + .getForceSlotSharingGroup()) { + source = source.slotSharingGroup(FORCED_SLOT_GROUP); + } + } catch (Exception e) { + throw new RuntimeException("Error while translating BoundedSource: " + rawSource, e); + } + context.setOutputDataStream(output, source); + } + } + + /** Wraps each element in a {@link RawUnionValue} with the given tag id. */ + public static class ToRawUnion extends RichMapFunction { + private final int intTag; + private final SerializablePipelineOptions options; + + ToRawUnion(int intTag, PipelineOptions pipelineOptions) { + this.intTag = intTag; + this.options = new SerializablePipelineOptions(pipelineOptions); + } + + @Override + public void open(OpenContext openContext) { + // Initialize FileSystems for any coders which may want to use the FileSystem, + // see https://issues.apache.org/jira/browse/BEAM-8303 + FileSystems.setDefaultPipelineOptions(options.get()); + } + + @Override + public RawUnionValue map(T o) throws Exception { + return new RawUnionValue(intTag, o); + } + } + + public static Tuple2>, DataStream> + transformSideInputs( + Collection> sideInputs, FlinkStreamingTranslationContext context) { + + // collect all side inputs + Map, Integer> tagToIntMapping = new HashMap<>(); + Map> intToViewMapping = new HashMap<>(); + int count = 0; + for (PCollectionView sideInput : sideInputs) { + TupleTag tag = sideInput.getTagInternal(); + intToViewMapping.put(count, sideInput); + tagToIntMapping.put(tag, count); + count++; + } + + List> inputCoders = new ArrayList<>(); + for (PCollectionView sideInput : sideInputs) { + DataStream sideInputStream = context.getInputDataStream(sideInput); + TypeInformation tpe = sideInputStream.getType(); + if (!(tpe instanceof CoderTypeInformation)) { + throw new IllegalStateException("Input Stream TypeInformation is no CoderTypeInformation."); + } + + Coder coder = ((CoderTypeInformation) tpe).getCoder(); + inputCoders.add(coder); + } + + UnionCoder unionCoder = UnionCoder.of(inputCoders); + + CoderTypeInformation unionTypeInformation = + new CoderTypeInformation<>(unionCoder, context.getPipelineOptions()); + + // transform each side input to RawUnionValue and union them + DataStream sideInputUnion = null; + + for (PCollectionView sideInput : sideInputs) { + TupleTag tag = sideInput.getTagInternal(); + final int intTag = tagToIntMapping.get(tag); + DataStream sideInputStream = context.getInputDataStream(sideInput); + DataStream unionValueStream = + sideInputStream + .map(new ToRawUnion<>(intTag, context.getPipelineOptions())) + .returns(unionTypeInformation); + + if (sideInputUnion == null) { + sideInputUnion = unionValueStream; + } else { + sideInputUnion = sideInputUnion.union(unionValueStream); + } + } + + if (sideInputUnion == null) { + throw new IllegalStateException("No unioned side inputs, this indicates a bug."); + } + + return new Tuple2<>(intToViewMapping, sideInputUnion); + } + + /** + * Helper for translating {@code ParDo.MultiOutput} and {@link + * SplittableParDoViaKeyedWorkItems.ProcessElements}. + */ + static class ParDoTranslationHelper { + + interface DoFnOperatorFactory { + DoFnOperator createDoFnOperator( + DoFn doFn, + String stepName, + List> sideInputs, + TupleTag mainOutputTag, + List> additionalOutputTags, + FlinkStreamingTranslationContext context, + WindowingStrategy windowingStrategy, + Map, OutputTag>> tagsToOutputTags, + Map, Coder>> tagsToCoders, + Map, Integer> tagsToIds, + Coder> windowedInputCoder, + Map, Coder> outputCoders, + Coder keyCoder, + KeySelector, ?> keySelector, + Map> transformedSideInputs, + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping); + } + + static void translateParDo( + String transformName, + DoFn doFn, + PCollection input, + List> sideInputs, + Map, PCollection> outputs, + TupleTag mainOutputTag, + List> additionalOutputTags, + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping, + FlinkStreamingTranslationContext context, + DoFnOperatorFactory doFnOperatorFactory) { + + // we assume that the transformation does not change the windowing strategy. + WindowingStrategy windowingStrategy = input.getWindowingStrategy(); + + Map, OutputTag>> tagsToOutputTags = Maps.newHashMap(); + Map, Coder>> tagsToCoders = Maps.newHashMap(); + + // We associate output tags with ids, the Integer is easier to serialize than TupleTag. + // The return map of AppliedPTransform.getOutputs() is an ImmutableMap, its implementation is + // RegularImmutableMap, its entrySet order is the same with the order of insertion. + // So we can use the original AppliedPTransform.getOutputs() to produce deterministic ids. + Map, Integer> tagsToIds = Maps.newHashMap(); + int idCount = 0; + tagsToIds.put(mainOutputTag, idCount++); + for (Map.Entry, PCollection> entry : outputs.entrySet()) { + if (!tagsToOutputTags.containsKey(entry.getKey())) { + tagsToOutputTags.put( + entry.getKey(), + new OutputTag>( + entry.getKey().getId(), + (TypeInformation) context.getTypeInfo((PCollection) entry.getValue()))); + tagsToCoders.put( + entry.getKey(), + (Coder) context.getWindowedInputCoder((PCollection) entry.getValue())); + tagsToIds.put(entry.getKey(), idCount++); + } + } + + SingleOutputStreamOperator> outputStream; + + Coder> windowedInputCoder = context.getWindowedInputCoder(input); + Map, Coder> outputCoders = context.getOutputCoders(); + + DataStream> inputDataStream = context.getInputDataStream(input); + + Coder keyCoder = null; + KeySelector, ?> keySelector = null; + boolean stateful = false; + DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); + if (!signature.stateDeclarations().isEmpty() + || !signature.timerDeclarations().isEmpty() + || !signature.timerFamilyDeclarations().isEmpty()) { + // Based on the fact that the signature is stateful, DoFnSignatures ensures + // that it is also keyed + keyCoder = ((KvCoder) input.getCoder()).getKeyCoder(); + keySelector = new KvToFlinkKeyKeySelector<>(keyCoder); + final PTransform> producer = context.getProducer(input); + final String previousUrn = + producer != null + ? PTransformTranslation.urnForTransformOrNull(context.getProducer(input)) + : null; + // We can skip reshuffle in case previous transform was CPK or GBK + if (PTransformTranslation.COMBINE_PER_KEY_TRANSFORM_URN.equals(previousUrn) + || PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN.equals(previousUrn)) { + inputDataStream = DataStreamUtils.reinterpretAsKeyedStream(inputDataStream, keySelector); + } else { + inputDataStream = inputDataStream.keyBy(keySelector); + } + stateful = true; + } else if (doFn instanceof SplittableParDoViaKeyedWorkItems.ProcessFn) { + // we know that it is keyed on byte[] + keyCoder = ByteArrayCoder.of(); + keySelector = new WorkItemKeySelector<>(keyCoder); + stateful = true; + } + + CoderTypeInformation> outputTypeInformation = + new CoderTypeInformation<>( + context.getWindowedInputCoder((PCollection) outputs.get(mainOutputTag)), + context.getPipelineOptions()); + + if (sideInputs.isEmpty()) { + DoFnOperator doFnOperator = + doFnOperatorFactory.createDoFnOperator( + doFn, + getCurrentTransformName(context), + sideInputs, + mainOutputTag, + additionalOutputTags, + context, + windowingStrategy, + tagsToOutputTags, + tagsToCoders, + tagsToIds, + windowedInputCoder, + outputCoders, + keyCoder, + keySelector, + new HashMap<>() /* side-input mapping */, + doFnSchemaInformation, + sideInputMapping); + + outputStream = + inputDataStream.transform(transformName, outputTypeInformation, doFnOperator); + + } else { + Tuple2>, DataStream> transformedSideInputs = + transformSideInputs(sideInputs, context); + + DoFnOperator doFnOperator = + doFnOperatorFactory.createDoFnOperator( + doFn, + getCurrentTransformName(context), + sideInputs, + mainOutputTag, + additionalOutputTags, + context, + windowingStrategy, + tagsToOutputTags, + tagsToCoders, + tagsToIds, + windowedInputCoder, + outputCoders, + keyCoder, + keySelector, + transformedSideInputs.f0, + doFnSchemaInformation, + sideInputMapping); + + if (stateful) { + // we have to manually construct the two-input transform because we're not + // allowed to have only one input keyed, normally. + KeyedStream keyedStream = (KeyedStream) inputDataStream; + TwoInputTransformation< + WindowedValue>, RawUnionValue, WindowedValue> + rawFlinkTransform = + new TwoInputTransformation( + keyedStream.getTransformation(), + transformedSideInputs.f1.broadcast().getTransformation(), + transformName, + doFnOperator, + outputTypeInformation, + keyedStream.getParallelism()); + + rawFlinkTransform.setStateKeyType(keyedStream.getKeyType()); + rawFlinkTransform.setStateKeySelectors(keyedStream.getKeySelector(), null); + + outputStream = + new SingleOutputStreamOperator( + keyedStream.getExecutionEnvironment(), + rawFlinkTransform) {}; // we have to cheat around the ctor being protected + + keyedStream.getExecutionEnvironment().addOperator(rawFlinkTransform); + + } else { + outputStream = + inputDataStream + .connect(transformedSideInputs.f1.broadcast()) + .transform(transformName, outputTypeInformation, doFnOperator); + } + } + + outputStream.uid(transformName); + context.setOutputDataStream(outputs.get(mainOutputTag), outputStream); + + for (Map.Entry, PCollection> entry : outputs.entrySet()) { + if (!entry.getKey().equals(mainOutputTag)) { + context.setOutputDataStream( + entry.getValue(), outputStream.getSideOutput(tagsToOutputTags.get(entry.getKey()))); + } + } + } + } + + private static class ParDoStreamingTranslator + extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< + PTransform, PCollectionTuple>> { + + @Override + public void translateNode( + PTransform, PCollectionTuple> transform, + FlinkStreamingTranslationContext context) { + + DoFn doFn; + try { + doFn = (DoFn) ParDoTranslation.getDoFn(context.getCurrentTransform()); + } catch (IOException e) { + throw new RuntimeException(e); + } + + TupleTag mainOutputTag; + try { + mainOutputTag = + (TupleTag) ParDoTranslation.getMainOutputTag(context.getCurrentTransform()); + } catch (IOException e) { + throw new RuntimeException(e); + } + + List> sideInputs; + try { + sideInputs = ParDoTranslation.getSideInputs(context.getCurrentTransform()); + } catch (IOException e) { + throw new RuntimeException(e); + } + + Map> sideInputMapping = + ParDoTranslation.getSideInputMapping(context.getCurrentTransform()); + + TupleTagList additionalOutputTags; + try { + additionalOutputTags = + ParDoTranslation.getAdditionalOutputTags(context.getCurrentTransform()); + } catch (IOException e) { + throw new RuntimeException(e); + } + + DoFnSchemaInformation doFnSchemaInformation; + doFnSchemaInformation = ParDoTranslation.getSchemaInformation(context.getCurrentTransform()); + + ParDoTranslationHelper.translateParDo( + getCurrentTransformName(context), + doFn, + context.getInput(transform), + sideInputs, + context.getOutputs(transform), + mainOutputTag, + additionalOutputTags.getAll(), + doFnSchemaInformation, + sideInputMapping, + context, + (doFn1, + stepName, + sideInputs1, + mainOutputTag1, + additionalOutputTags1, + context1, + windowingStrategy, + tagsToOutputTags, + tagsToCoders, + tagsToIds, + windowedInputCoder, + outputCoders1, + keyCoder, + keySelector, + transformedSideInputs, + doFnSchemaInformation1, + sideInputMapping1) -> + new DoFnOperator<>( + doFn1, + stepName, + windowedInputCoder, + outputCoders1, + mainOutputTag1, + additionalOutputTags1, + new DoFnOperator.MultiOutputOutputManagerFactory<>( + mainOutputTag1, + tagsToOutputTags, + tagsToCoders, + tagsToIds, + new SerializablePipelineOptions(context.getPipelineOptions())), + windowingStrategy, + transformedSideInputs, + sideInputs1, + context1.getPipelineOptions(), + keyCoder, + keySelector, + doFnSchemaInformation1, + sideInputMapping1)); + } + } + + private static class SplittableProcessElementsStreamingTranslator< + InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT> + extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< + SplittableParDoViaKeyedWorkItems.ProcessElements< + InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>> { + + @Override + public void translateNode( + SplittableParDoViaKeyedWorkItems.ProcessElements< + InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT> + transform, + FlinkStreamingTranslationContext context) { + + ParDoTranslationHelper.translateParDo( + getCurrentTransformName(context), + transform.newProcessFn(transform.getFn()), + context.getInput(transform), + transform.getSideInputs(), + context.getOutputs(transform), + transform.getMainOutputTag(), + transform.getAdditionalOutputTags().getAll(), + DoFnSchemaInformation.create(), + Collections.emptyMap(), + context, + (doFn, + stepName, + sideInputs, + mainOutputTag, + additionalOutputTags, + context1, + windowingStrategy, + tagsToOutputTags, + tagsToCoders, + tagsToIds, + windowedInputCoder, + outputCoders1, + keyCoder, + keySelector, + transformedSideInputs, + doFnSchemaInformation, + sideInputMapping) -> + new SplittableDoFnOperator<>( + doFn, + stepName, + windowedInputCoder, + outputCoders1, + mainOutputTag, + additionalOutputTags, + new DoFnOperator.MultiOutputOutputManagerFactory<>( + mainOutputTag, + tagsToOutputTags, + tagsToCoders, + tagsToIds, + new SerializablePipelineOptions(context.getPipelineOptions())), + windowingStrategy, + transformedSideInputs, + sideInputs, + context1.getPipelineOptions(), + keyCoder, + keySelector)); + } + } + + private static class CreateViewStreamingTranslator + extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< + CreateStreamingFlinkView.CreateFlinkPCollectionView> { + + @Override + public void translateNode( + CreateStreamingFlinkView.CreateFlinkPCollectionView transform, + FlinkStreamingTranslationContext context) { + // just forward + DataStream>> inputDataSet = + context.getInputDataStream(context.getInput(transform)); + + PCollectionView view = transform.getView(); + + context.setOutputDataStream(view, inputDataSet); + } + } + + private static class WindowAssignTranslator + extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< + PTransform, PCollection>> { + + @Override + public void translateNode( + PTransform, PCollection> transform, + FlinkStreamingTranslationContext context) { + + @SuppressWarnings("unchecked") + WindowingStrategy windowingStrategy = + (WindowingStrategy) context.getOutput(transform).getWindowingStrategy(); + + TypeInformation> typeInfo = + context.getTypeInfo(context.getOutput(transform)); + + DataStream> inputDataStream = + context.getInputDataStream(context.getInput(transform)); + + WindowFn windowFn = windowingStrategy.getWindowFn(); + + FlinkAssignWindows assignWindowsFunction = + new FlinkAssignWindows<>(windowFn); + + String fullName = context.getOutput(transform).getName(); + SingleOutputStreamOperator> outputDataStream = + inputDataStream + .flatMap(assignWindowsFunction) + .name(fullName) + .uid(fullName) + .returns(typeInfo); + + context.setOutputDataStream(context.getOutput(transform), outputDataStream); + } + } + + private static class ReshuffleTranslatorStreaming + extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< + PTransform>, PCollection>>> { + + @Override + public void translateNode( + PTransform>, PCollection>> transform, + FlinkStreamingTranslationContext context) { + + DataStream>> inputDataSet = + context.getInputDataStream(context.getInput(transform)); + + context.setOutputDataStream(context.getOutput(transform), inputDataSet.rebalance()); + } + } + + private static class GroupByKeyTranslator + extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< + PTransform>, PCollection>>>> { + + @Override + public void translateNode( + PTransform>, PCollection>>> transform, + FlinkStreamingTranslationContext context) { + + PCollection> input = context.getInput(transform); + @SuppressWarnings("unchecked") + WindowingStrategy windowingStrategy = + (WindowingStrategy) input.getWindowingStrategy(); + KvCoder inputKvCoder = (KvCoder) input.getCoder(); + DataStream>> inputDataStream = context.getInputDataStream(input); + String fullName = getCurrentTransformName(context); + + SingleOutputStreamOperator>>> outDataStream; + // Pre-aggregate before shuffle similar to group combine + if (!context.isStreaming()) { + outDataStream = FlinkStreamingAggregationsTranslators.batchGroupByKey(context, transform); + } else { + // No pre-aggregation in Streaming mode. + KvToFlinkKeyKeySelector keySelector = + new KvToFlinkKeyKeySelector<>(inputKvCoder.getKeyCoder()); + + Coder>>> outputCoder = + WindowedValues.getFullCoder( + KvCoder.of( + inputKvCoder.getKeyCoder(), IterableCoder.of(inputKvCoder.getValueCoder())), + windowingStrategy.getWindowFn().windowCoder()); + + TypeInformation>>> outputTypeInfo = + new CoderTypeInformation<>(outputCoder, context.getPipelineOptions()); + + WindowDoFnOperator> doFnOperator = + FlinkStreamingAggregationsTranslators.getWindowedAggregateDoFnOperator( + context, + transform, + inputKvCoder, + outputCoder, + SystemReduceFn.buffering(inputKvCoder.getValueCoder()), + new HashMap<>(), + Collections.emptyList()); + + outDataStream = + inputDataStream + .keyBy(keySelector) + .transform(fullName, outputTypeInfo, doFnOperator) + .uid(fullName); + } + context.setOutputDataStream(context.getOutput(transform), outDataStream); + } + } + + private static class CombinePerKeyTranslator + extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< + PTransform>, PCollection>>> { + + @Override + boolean canTranslate( + PTransform>, PCollection>> transform, + FlinkStreamingTranslationContext context) { + // if we have a merging window strategy and side inputs we cannot + // translate as a proper combine. We have to group and then run the combine + // over the final grouped values. + PCollection> input = context.getInput(transform); + + @SuppressWarnings("unchecked") + WindowingStrategy windowingStrategy = + (WindowingStrategy) input.getWindowingStrategy(); + + return !windowingStrategy.needsMerge() + || ((Combine.PerKey) transform).getSideInputs().isEmpty(); + } + + @Override + public void translateNode( + PTransform>, PCollection>> transform, + FlinkStreamingTranslationContext context) { + String fullName = getCurrentTransformName(context); + + PCollection> input = context.getInput(transform); + + KvCoder inputKvCoder = (KvCoder) input.getCoder(); + Coder keyCoder = inputKvCoder.getKeyCoder(); + Coder>> outputCoder = + context.getWindowedInputCoder(context.getOutput(transform)); + + DataStream>> inputDataStream = context.getInputDataStream(input); + + @SuppressWarnings("unchecked") + GlobalCombineFn combineFn = ((Combine.PerKey) transform).getFn(); + + TypeInformation>> outputTypeInfo = + context.getTypeInfo(context.getOutput(transform)); + + @SuppressWarnings("unchecked") + List> sideInputs = ((Combine.PerKey) transform).getSideInputs(); + + KeyedStream>, FlinkKey> keyedStream = + inputDataStream.keyBy(new KvToFlinkKeyKeySelector<>(keyCoder)); + + if (sideInputs.isEmpty()) { + SingleOutputStreamOperator>> outDataStream; + + if (!context.isStreaming()) { + outDataStream = + FlinkStreamingAggregationsTranslators.batchCombinePerKeyNoSideInputs( + context, transform, combineFn); + } else { + WindowDoFnOperator doFnOperator = + FlinkStreamingAggregationsTranslators.getWindowedAggregateDoFnOperator( + context, + transform, + inputKvCoder, + outputCoder, + combineFn, + new HashMap<>(), + Collections.emptyList()); + + outDataStream = + keyedStream.transform(fullName, outputTypeInfo, doFnOperator).uid(fullName); + } + + context.setOutputDataStream(context.getOutput(transform), outDataStream); + } else { + Tuple2>, DataStream> transformSideInputs = + transformSideInputs(sideInputs, context); + SingleOutputStreamOperator>> outDataStream; + + if (!context.isStreaming()) { + outDataStream = + FlinkStreamingAggregationsTranslators.batchCombinePerKey( + context, transform, combineFn, transformSideInputs.f0, sideInputs); + } else { + WindowDoFnOperator doFnOperator = + FlinkStreamingAggregationsTranslators.getWindowedAggregateDoFnOperator( + context, + transform, + inputKvCoder, + outputCoder, + combineFn, + transformSideInputs.f0, + sideInputs); + + outDataStream = + FlinkStreamingAggregationsTranslators.buildTwoInputStream( + keyedStream, + transformSideInputs.f1, + transform.getName(), + doFnOperator, + outputTypeInfo); + } + + context.setOutputDataStream(context.getOutput(transform), outDataStream); + } + } + } + + private static class GBKIntoKeyedWorkItemsTranslator + extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< + PTransform>, PCollection>>> { + + @Override + boolean canTranslate( + PTransform>, PCollection>> transform, + FlinkStreamingTranslationContext context) { + return true; + } + + @Override + public void translateNode( + PTransform>, PCollection>> transform, + FlinkStreamingTranslationContext context) { + + PCollection> input = context.getInput(transform); + + KvCoder inputKvCoder = (KvCoder) input.getCoder(); + + SingletonKeyedWorkItemCoder workItemCoder = + SingletonKeyedWorkItemCoder.of( + inputKvCoder.getKeyCoder(), + inputKvCoder.getValueCoder(), + input.getWindowingStrategy().getWindowFn().windowCoder()); + + WindowedValues.ValueOnlyWindowedValueCoder> windowedWorkItemCoder = + WindowedValues.getValueOnlyCoder(workItemCoder); + + CoderTypeInformation>> workItemTypeInfo = + new CoderTypeInformation<>(windowedWorkItemCoder, context.getPipelineOptions()); + + DataStream>> inputDataStream = context.getInputDataStream(input); + + DataStream>> workItemStream = + inputDataStream + .flatMap(new ToKeyedWorkItemInGlobalWindow<>(context.getPipelineOptions())) + .returns(workItemTypeInfo) + .name("ToKeyedWorkItem"); + + KeyedStream>, FlinkKey> keyedWorkItemStream = + workItemStream.keyBy(new WorkItemKeySelector<>(inputKvCoder.getKeyCoder())); + + context.setOutputDataStream(context.getOutput(transform), keyedWorkItemStream); + } + } + + private static class ToKeyedWorkItemInGlobalWindow + extends RichFlatMapFunction< + WindowedValue>, WindowedValue>> { + + private final SerializablePipelineOptions options; + + ToKeyedWorkItemInGlobalWindow(PipelineOptions options) { + this.options = new SerializablePipelineOptions(options); + } + + @Override + public void open(OpenContext openContext) { + // Initialize FileSystems for any coders which may want to use the FileSystem, + // see https://issues.apache.org/jira/browse/BEAM-8303 + FileSystems.setDefaultPipelineOptions(options.get()); + } + + @Override + public void flatMap( + WindowedValue> inWithMultipleWindows, + Collector>> out) + throws Exception { + + // we need to wrap each one work item per window for now + // since otherwise the PushbackSideInputRunner will not correctly + // determine whether side inputs are ready + // + // this is tracked as https://github.com/apache/beam/issues/18358 + for (WindowedValue> in : inWithMultipleWindows.explodeWindows()) { + SingletonKeyedWorkItem workItem = + new SingletonKeyedWorkItem<>( + in.getValue().getKey(), in.withValue(in.getValue().getValue())); + + out.collect(WindowedValues.valueInGlobalWindow(workItem)); + } + } + } + + private static class FlattenPCollectionTranslator + extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< + PTransform, PCollection>> { + + @Override + public void translateNode( + PTransform, PCollection> transform, + FlinkStreamingTranslationContext context) { + Map, PCollection> allInputs = context.getInputs(transform); + + if (allInputs.isEmpty()) { + + // create an empty dummy source to satisfy downstream operations + // we cannot create an empty source in Flink, therefore we have to + // add the flatMap that simply never forwards the single element + DataStreamSource dummySource = + context.getExecutionEnvironment().fromElements("dummy"); + + DataStream> result = + dummySource + .>flatMap( + (s, collector) -> { + // never return anything + }) + .returns( + new CoderTypeInformation<>( + WindowedValues.getFullCoder( + (Coder) VoidCoder.of(), GlobalWindow.Coder.INSTANCE), + context.getPipelineOptions())); + context.setOutputDataStream(context.getOutput(transform), result); + + } else { + DataStream result = null; + + // Determine DataStreams that we use as input several times. For those, we need to uniquify + // input streams because Flink seems to swallow watermarks when we have a union of one and + // the same stream. + Map, Integer> duplicates = new HashMap<>(); + for (PValue input : allInputs.values()) { + DataStream current = context.getInputDataStream(input); + Integer oldValue = duplicates.put(current, 1); + if (oldValue != null) { + duplicates.put(current, oldValue + 1); + } + } + + for (PValue input : allInputs.values()) { + DataStream current = context.getInputDataStream(input); + + final Integer timesRequired = duplicates.get(current); + if (timesRequired > 1) { + current = + current.flatMap( + new FlatMapFunction() { + private static final long serialVersionUID = 1L; + + @Override + public void flatMap(T t, Collector collector) throws Exception { + collector.collect(t); + } + }); + } + result = (result == null) ? current : result.union(current); + } + + context.setOutputDataStream(context.getOutput(transform), result); + } + } + } + + /** Registers classes specialized to the Flink runner. */ + @AutoService(TransformPayloadTranslatorRegistrar.class) + public static class FlinkTransformsRegistrar implements TransformPayloadTranslatorRegistrar { + @Override + public Map< + ? extends Class, + ? extends PTransformTranslation.TransformPayloadTranslator> + getTransformPayloadTranslators() { + return ImmutableMap + ., PTransformTranslation.TransformPayloadTranslator>builder() + .put( + CreateStreamingFlinkView.CreateFlinkPCollectionView.class, + new CreateStreamingFlinkViewPayloadTranslator()) + .build(); + } + } + + /** A translator just to vend the URN. */ + private static class CreateStreamingFlinkViewPayloadTranslator + extends PTransformTranslation.TransformPayloadTranslator.NotSerializable< + CreateStreamingFlinkView.CreateFlinkPCollectionView> { + + private CreateStreamingFlinkViewPayloadTranslator() {} + + @Override + public String getUrn() { + return CreateStreamingFlinkView.CREATE_STREAMING_FLINK_VIEW_URN; + } + } + + /** A translator to support {@link TestStream} with Flink. */ + private static class TestStreamTranslator + extends FlinkStreamingPipelineTranslator.StreamTransformTranslator> { + + @Override + void translateNode(TestStream testStream, FlinkStreamingTranslationContext context) { + Coder valueCoder = testStream.getValueCoder(); + + // Coder for the Elements in the TestStream + TestStream.TestStreamCoder testStreamCoder = TestStream.TestStreamCoder.of(valueCoder); + final byte[] payload; + try { + payload = CoderUtils.encodeToByteArray(testStreamCoder, testStream); + } catch (CoderException e) { + throw new RuntimeException("Could not encode TestStream.", e); + } + + SerializableFunction> testStreamDecoder = + bytes -> { + try { + return CoderUtils.decodeFromByteArray( + TestStream.TestStreamCoder.of(valueCoder), bytes); + } catch (CoderException e) { + throw new RuntimeException("Can't decode TestStream payload.", e); + } + }; + + WindowedValues.FullWindowedValueCoder elementCoder = + WindowedValues.getFullCoder(valueCoder, GlobalWindow.Coder.INSTANCE); + + DataStreamSource> source = + context + .getExecutionEnvironment() + .addSource( + new TestStreamSource<>(testStreamDecoder, payload), + new CoderTypeInformation<>(elementCoder, context.getPipelineOptions())); + + context.setOutputDataStream(context.getOutput(testStream), source); + } + } + + // TODO(https://github.com/apache/beam/issues/37114) migrate off RichParallelSourceFunction + /** + * Wrapper for {@link UnboundedSourceWrapper}, which simplifies output type, namely, removes + * {@link ValueWithRecordId}. + */ + static class UnboundedSourceWrapperNoValueWithRecordId< + OutputT, CheckpointMarkT extends UnboundedSource.CheckpointMark> + extends RichParallelSourceFunction> + implements BeamStoppableFunction, + CheckpointListener, + CheckpointedFunction, + ProcessingTimeCallback { + + private final UnboundedSourceWrapper unboundedSourceWrapper; + + @VisibleForTesting + UnboundedSourceWrapper getUnderlyingSource() { + return unboundedSourceWrapper; + } + + UnboundedSourceWrapperNoValueWithRecordId( + UnboundedSourceWrapper unboundedSourceWrapper) { + this.unboundedSourceWrapper = unboundedSourceWrapper; + } + + @Override + public void open(OpenContext openContext) throws Exception { + unboundedSourceWrapper.setRuntimeContext(getRuntimeContext()); + unboundedSourceWrapper.open(openContext); + } + + @Override + public void run(SourceContext> ctx) throws Exception { + unboundedSourceWrapper.run(new SourceContextWrapper(ctx)); + } + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + unboundedSourceWrapper.initializeState(context); + } + + @Override + public void snapshotState(FunctionSnapshotContext context) throws Exception { + unboundedSourceWrapper.snapshotState(context); + } + + @Override + public void notifyCheckpointComplete(long checkpointId) throws Exception { + unboundedSourceWrapper.notifyCheckpointComplete(checkpointId); + } + + @Override + public void stop() { + unboundedSourceWrapper.stop(); + } + + @Override + public void cancel() { + unboundedSourceWrapper.cancel(); + } + + @Override + public void onProcessingTime(long timestamp) throws Exception { + unboundedSourceWrapper.onProcessingTime(timestamp); + } + + private final class SourceContextWrapper + implements SourceContext>> { + + private final SourceContext> ctx; + + private SourceContextWrapper(SourceContext> ctx) { + this.ctx = ctx; + } + + @Override + public void collect(WindowedValue> element) { + OutputT originalValue = element.getValue().getValue(); + WindowedValues.builder(element).withValue(originalValue).setReceiver(ctx::collect).output(); + } + + @Override + public void collectWithTimestamp( + WindowedValue> element, long timestamp) { + OutputT originalValue = element.getValue().getValue(); + WindowedValues.builder(element) + .withValue(originalValue) + .setReceiver(wv -> ctx.collectWithTimestamp(wv, timestamp)); + } + + @Override + public void emitWatermark(Watermark mark) { + ctx.emitWatermark(mark); + } + + @Override + public void markAsTemporarilyIdle() { + ctx.markAsTemporarilyIdle(); + } + + @Override + public Object getCheckpointLock() { + return ctx.getCheckpointLock(); + } + + @Override + public void close() { + ctx.close(); + } + } + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkTransformOverrides.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkTransformOverrides.java new file mode 100644 index 000000000000..2cf5f743ca03 --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/FlinkTransformOverrides.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; + +import java.util.List; +import org.apache.beam.runners.core.SplittableParDoViaKeyedWorkItems; +import org.apache.beam.sdk.runners.PTransformOverride; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.util.construction.PTransformMatchers; +import org.apache.beam.sdk.util.construction.PTransformTranslation; +import org.apache.beam.sdk.util.construction.SplittableParDo; +import org.apache.beam.sdk.util.construction.SplittableParDoNaiveBounded; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; + +/** {@link PTransform} overrides for Flink runner. */ +@SuppressWarnings({ + "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) +}) +class FlinkTransformOverrides { + static List getDefaultOverrides(FlinkPipelineOptions options) { + ImmutableList.Builder builder = ImmutableList.builder(); + if (options.isStreaming()) { + builder.add( + PTransformOverride.of( + FlinkStreamingPipelineTranslator.StreamingShardedWriteFactory + .writeFilesNeedsOverrides(), + new FlinkStreamingPipelineTranslator.StreamingShardedWriteFactory( + checkNotNull(options)))); + } + builder.add( + PTransformOverride.of( + PTransformMatchers.urnEqualTo(PTransformTranslation.CREATE_VIEW_TRANSFORM_URN), + CreateStreamingFlinkView.Factory.INSTANCE)); + builder + .add( + PTransformOverride.of( + PTransformMatchers.splittableParDo(), new SplittableParDo.OverrideFactory())) + .add( + PTransformOverride.of( + PTransformMatchers.urnEqualTo(PTransformTranslation.SPLITTABLE_PROCESS_KEYED_URN), + options.isStreaming() + ? new SplittableParDoViaKeyedWorkItems.OverrideFactory() + : new SplittableParDoNaiveBounded.OverrideFactory())); + return builder.build(); + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java new file mode 100644 index 000000000000..31ef5ee54711 --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import org.apache.beam.runners.core.DoFnRunner; +import org.apache.beam.runners.core.DoFnRunners; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.flink.FlinkPipelineOptions; +import org.apache.beam.runners.flink.metrics.DoFnRunnerWithMetricsUpdate; +import org.apache.beam.runners.flink.metrics.FlinkMetricContainer; +import org.apache.beam.runners.flink.translation.utils.Workarounds; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.io.FileSystems; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFnSchemaInformation; +import org.apache.beam.sdk.transforms.join.RawUnionValue; +import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; +import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; +import org.apache.beam.sdk.util.WindowedValueMultiReceiver; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.OpenContext; +import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.util.Collector; +import org.checkerframework.checker.nullness.qual.MonotonicNonNull; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * Encapsulates a {@link DoFn} inside a Flink {@link + * org.apache.flink.api.common.functions.RichMapPartitionFunction}. + * + *

We get a mapping from {@link org.apache.beam.sdk.values.TupleTag} to output index and must tag + * all outputs with the output number. Afterwards a filter will filter out those elements that are + * not to be in a specific output. + */ +@SuppressWarnings({ + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +public class FlinkDoFnFunction extends AbstractRichFunction + implements FlatMapFunction, WindowedValue> { + + private final SerializablePipelineOptions serializedOptions; + + private final DoFn doFn; + private final String stepName; + private final Map, WindowingStrategy> sideInputs; + + private final WindowingStrategy windowingStrategy; + + private final Map, Integer> outputMap; + private final TupleTag mainOutputTag; + private final Coder inputCoder; + private final Map, Coder> outputCoderMap; + private final DoFnSchemaInformation doFnSchemaInformation; + private final Map> sideInputMapping; + + private transient CollectorAware collectorAware; + private transient DoFnInvoker doFnInvoker; + private transient DoFnRunner doFnRunner; + private transient FlinkMetricContainer metricContainer; + + private boolean bundleStarted = false; + private boolean exceptionThrownInFlatMap = false; + + public FlinkDoFnFunction( + DoFn doFn, + String stepName, + WindowingStrategy windowingStrategy, + Map, WindowingStrategy> sideInputs, + PipelineOptions options, + Map, Integer> outputMap, + TupleTag mainOutputTag, + Coder inputCoder, + Map, Coder> outputCoderMap, + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping) { + + this.doFn = doFn; + this.stepName = stepName; + this.sideInputs = sideInputs; + this.serializedOptions = new SerializablePipelineOptions(options); + this.windowingStrategy = windowingStrategy; + this.outputMap = outputMap; + this.mainOutputTag = mainOutputTag; + this.inputCoder = inputCoder; + this.outputCoderMap = outputCoderMap; + this.doFnSchemaInformation = doFnSchemaInformation; + this.sideInputMapping = sideInputMapping; + } + + @Override + public void flatMap(WindowedValue value, Collector> out) { + try { + if (!bundleStarted) { + bundleStarted = true; + doFnRunner.startBundle(); + } + collectorAware.setCollector(out); + doFnRunner.processElement(value); + } catch (Exception e) { + exceptionThrownInFlatMap = true; + throw e; + } + } + + @Override + public void open(OpenContext parameters) { + // Note that the SerializablePipelineOptions already initialize FileSystems in the readObject() + // deserialization method. However, this is a hack, and we want to properly initialize the + // options where they are needed. + PipelineOptions options = serializedOptions.get(); + FileSystems.setDefaultPipelineOptions(options); + doFnInvoker = DoFnInvokers.tryInvokeSetupFor(doFn, options); + metricContainer = new FlinkMetricContainer(getRuntimeContext()); + + // setup DoFnRunner + final RuntimeContext runtimeContext = getRuntimeContext(); + final WindowedValueMultiReceiver outputManager; + if (outputMap.size() == 1) { + outputManager = new DoFnOutputManager(); + } else { + // it has some additional outputs + outputManager = new MultiDoFnOutputManagerWindowed(outputMap); + } + + final List> additionalOutputTags = Lists.newArrayList(outputMap.keySet()); + + DoFnRunner doFnRunner = + DoFnRunners.simpleRunner( + options, + doFn, + new FlinkSideInputReader(sideInputs, runtimeContext), + outputManager, + mainOutputTag, + additionalOutputTags, + new FlinkNoOpStepContext(), + inputCoder, + outputCoderMap, + windowingStrategy, + doFnSchemaInformation, + sideInputMapping); + + if (!serializedOptions.get().as(FlinkPipelineOptions.class).getDisableMetrics()) { + doFnRunner = new DoFnRunnerWithMetricsUpdate<>(stepName, doFnRunner, metricContainer); + } + + this.collectorAware = (CollectorAware) outputManager; + this.doFnRunner = doFnRunner; + } + + @Override + public void close() throws Exception { + Exception suppressed = null; + try { + if (bundleStarted && !exceptionThrownInFlatMap) { + doFnRunner.finishBundle(); + } + } catch (Exception e) { + // Suppress exception, so we can properly teardown DoFn. + suppressed = e; + } + try { + metricContainer.registerMetricsForPipelineResult(); + Optional.ofNullable(doFnInvoker).ifPresent(DoFnInvoker::invokeTeardown); + if (suppressed != null) { + throw suppressed; + } + } finally { + Workarounds.deleteStaticCaches(); + } + } + + interface CollectorAware { + + void setCollector(Collector> collector); + } + + static class DoFnOutputManager implements WindowedValueMultiReceiver, CollectorAware { + + private @MonotonicNonNull Collector> collector; + + DoFnOutputManager() { + this(null); + } + + DoFnOutputManager(@Nullable Collector> collector) { + this.collector = collector; + } + + @Override + public void setCollector(Collector> collector) { + this.collector = Objects.requireNonNull(collector); + } + + @Override + public void output(TupleTag tag, WindowedValue output) { + checkStateNotNull(collector); + WindowedValues.builder(output) + .withValue(new RawUnionValue(0 /* single output */, output.getValue())) + .setReceiver(collector::collect) + .output(); + } + } + + static class MultiDoFnOutputManagerWindowed + implements WindowedValueMultiReceiver, CollectorAware { + + private @MonotonicNonNull Collector> collector; + private final Map, Integer> outputMap; + + MultiDoFnOutputManagerWindowed(Map, Integer> outputMap) { + this.outputMap = outputMap; + } + + MultiDoFnOutputManagerWindowed( + @Nullable Collector> collector, + Map, Integer> outputMap) { + this.collector = collector; + this.outputMap = outputMap; + } + + @Override + public void setCollector(Collector> collector) { + this.collector = Objects.requireNonNull(collector); + } + + @Override + public void output(TupleTag tag, WindowedValue output) { + checkStateNotNull(collector); + WindowedValues.builder(output) + .withValue(new RawUnionValue(outputMap.get(tag), output.getValue())) + .setReceiver(collector::collect) + .output(); + } + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageContextFactory.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageContextFactory.java new file mode 100644 index 000000000000..3f42eb93e4e6 --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageContextFactory.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import org.apache.beam.runners.fnexecution.control.DefaultExecutableStageContext; +import org.apache.beam.runners.fnexecution.control.ExecutableStageContext; +import org.apache.beam.runners.fnexecution.control.ReferenceCountingExecutableStageContextFactory; +import org.apache.beam.runners.fnexecution.provisioning.JobInfo; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; + +/** Singleton class that contains one {@link ExecutableStageContext.Factory} per job. */ +public class FlinkExecutableStageContextFactory implements ExecutableStageContext.Factory { + + private static final FlinkExecutableStageContextFactory instance = + new FlinkExecutableStageContextFactory(); + // This map should only ever have a single element, as each job will have its own + // classloader and therefore its own instance of FlinkExecutableStageContextFactory. This + // code supports multiple JobInfos in order to provide a sensible implementation of + // Factory.get(JobInfo), which in theory could be called with different JobInfos. + private static final ConcurrentMap jobFactories = + new ConcurrentHashMap<>(); + + private FlinkExecutableStageContextFactory() {} + + public static FlinkExecutableStageContextFactory getInstance() { + return instance; + } + + @Override + public ExecutableStageContext get(JobInfo jobInfo) { + ExecutableStageContext.Factory jobFactory = + jobFactories.computeIfAbsent( + jobInfo.jobId(), + k -> { + return ReferenceCountingExecutableStageContextFactory.create( + DefaultExecutableStageContext::create, + // Clean up context immediately if its class is not loaded on Flink parent + // classloader. + (caller) -> + caller.getClass().getClassLoader() + != StreamExecutionEnvironment.class.getClassLoader()); + }); + + return jobFactory.get(jobInfo); + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java new file mode 100644 index 000000000000..1298fd3105aa --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java @@ -0,0 +1,416 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.EnumMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import javax.annotation.concurrent.GuardedBy; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleProgressResponse; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.runners.core.InMemoryStateInternals; +import org.apache.beam.runners.core.InMemoryTimerInternals; +import org.apache.beam.runners.core.StateInternals; +import org.apache.beam.runners.core.StateTags; +import org.apache.beam.runners.core.TimerInternals; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.flink.FlinkPipelineOptions; +import org.apache.beam.runners.flink.metrics.FlinkMetricContainer; +import org.apache.beam.runners.fnexecution.control.BundleCheckpointHandler; +import org.apache.beam.runners.fnexecution.control.BundleCheckpointHandlers; +import org.apache.beam.runners.fnexecution.control.BundleFinalizationHandler; +import org.apache.beam.runners.fnexecution.control.BundleProgressHandler; +import org.apache.beam.runners.fnexecution.control.ExecutableStageContext; +import org.apache.beam.runners.fnexecution.control.OutputReceiverFactory; +import org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors; +import org.apache.beam.runners.fnexecution.control.RemoteBundle; +import org.apache.beam.runners.fnexecution.control.StageBundleFactory; +import org.apache.beam.runners.fnexecution.control.TimerReceiverFactory; +import org.apache.beam.runners.fnexecution.provisioning.JobInfo; +import org.apache.beam.runners.fnexecution.state.InMemoryBagUserStateFactory; +import org.apache.beam.runners.fnexecution.state.StateRequestHandler; +import org.apache.beam.runners.fnexecution.state.StateRequestHandlers; +import org.apache.beam.runners.fnexecution.translation.BatchSideInputHandlerFactory; +import org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.fn.data.FnDataReceiver; +import org.apache.beam.sdk.io.FileSystems; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.join.RawUnionValue; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.construction.PTransformTranslation; +import org.apache.beam.sdk.util.construction.Timer; +import org.apache.beam.sdk.util.construction.graph.ExecutableStage; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.common.functions.GroupReduceFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.OpenContext; +import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Flink operator that passes its input DataSet through an SDK-executed {@link + * org.apache.beam.sdk.util.construction.graph.ExecutableStage}. + * + *

The output of this operation is a multiplexed DataSet whose elements are tagged with a union + * coder. The coder's tags are determined by the output coder map. The resulting data set should be + * further processed by a {@link FlinkExecutableStagePruningFunction}. + */ +@SuppressWarnings({ + "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +public class FlinkExecutableStageFunction extends AbstractRichFunction + implements MapPartitionFunction, RawUnionValue>, + GroupReduceFunction, RawUnionValue> { + private static final Logger LOG = LoggerFactory.getLogger(FlinkExecutableStageFunction.class); + + // Main constructor fields. All must be Serializable because Flink distributes Functions to + // task managers via java serialization. + + // Pipeline options for initializing the FileSystems + private final SerializablePipelineOptions pipelineOptions; + // The executable stage this function will run. + private final RunnerApi.ExecutableStagePayload stagePayload; + // Pipeline options. Used for provisioning api. + private final JobInfo jobInfo; + // Map from PCollection id to the union tag used to represent this PCollection in the output. + private final Map outputMap; + private final FlinkExecutableStageContextFactory contextFactory; + private final Coder windowCoder; + private final Coder> inputCoder; + // Unique name for namespacing metrics + private final String stepName; + + // Worker-local fields. These should only be constructed and consumed on Flink TaskManagers. + private transient RuntimeContext runtimeContext; + private transient FlinkMetricContainer metricContainer; + private transient StateRequestHandler stateRequestHandler; + private transient ExecutableStageContext stageContext; + private transient StageBundleFactory stageBundleFactory; + private transient BundleProgressHandler progressHandler; + private transient BundleFinalizationHandler finalizationHandler; + private transient BundleCheckpointHandler bundleCheckpointHandler; + private transient InMemoryTimerInternals sdfTimerInternals; + private transient StateInternals sdfStateInternals; + // Only initialized when the ExecutableStage is stateful + private transient InMemoryBagUserStateFactory bagUserStateHandlerFactory; + private transient ExecutableStage executableStage; + // In state + private transient Object currentTimerKey; + + public FlinkExecutableStageFunction( + String stepName, + PipelineOptions pipelineOptions, + RunnerApi.ExecutableStagePayload stagePayload, + JobInfo jobInfo, + Map outputMap, + FlinkExecutableStageContextFactory contextFactory, + Coder windowCoder, + Coder> inputCoder) { + this.stepName = stepName; + this.pipelineOptions = new SerializablePipelineOptions(pipelineOptions); + this.stagePayload = stagePayload; + this.jobInfo = jobInfo; + this.outputMap = outputMap; + this.contextFactory = contextFactory; + this.windowCoder = windowCoder; + this.inputCoder = inputCoder; + } + + @Override + public void open(OpenContext openContext) { + FlinkPipelineOptions options = pipelineOptions.get().as(FlinkPipelineOptions.class); + // Register standard file systems. + FileSystems.setDefaultPipelineOptions(options); + executableStage = ExecutableStage.fromPayload(stagePayload); + runtimeContext = getRuntimeContext(); + metricContainer = new FlinkMetricContainer(runtimeContext); + // TODO: Wire this into the distributed cache and make it pluggable. + stageContext = contextFactory.get(jobInfo); + stageBundleFactory = stageContext.getStageBundleFactory(executableStage); + // NOTE: It's safe to reuse the state handler between partitions because each partition uses the + // same backing runtime context and broadcast variables. We use checkState below to catch errors + // in backward-incompatible Flink changes. + stateRequestHandler = + getStateRequestHandler( + executableStage, stageBundleFactory.getProcessBundleDescriptor(), runtimeContext); + progressHandler = + new BundleProgressHandler() { + @Override + public void onProgress(ProcessBundleProgressResponse progress) { + metricContainer.updateMetrics(stepName, progress.getMonitoringInfosList()); + } + + @Override + public void onCompleted(ProcessBundleResponse response) { + metricContainer.updateMetrics(stepName, response.getMonitoringInfosList()); + } + }; + // TODO(https://github.com/apache/beam/issues/19526): Support bundle finalization in portable + // batch. + finalizationHandler = + bundleId -> { + throw new UnsupportedOperationException( + "Portable Flink runner doesn't support bundle finalization in batch mode. For more details, please refer to https://github.com/apache/beam/issues/19526."); + }; + bundleCheckpointHandler = getBundleCheckpointHandler(executableStage); + } + + private boolean hasSDF(ExecutableStage executableStage) { + return executableStage.getTransforms().stream() + .anyMatch( + pTransformNode -> + pTransformNode + .getTransform() + .getSpec() + .getUrn() + .equals( + PTransformTranslation + .SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN)); + } + + private BundleCheckpointHandler getBundleCheckpointHandler(ExecutableStage executableStage) { + if (!hasSDF(executableStage)) { + sdfStateInternals = null; + sdfStateInternals = null; + return response -> { + throw new UnsupportedOperationException( + "Self-checkpoint is only supported on splittable DoFn."); + }; + } + sdfTimerInternals = new InMemoryTimerInternals(); + sdfStateInternals = InMemoryStateInternals.forKey("sdf_state"); + return new BundleCheckpointHandlers.StateAndTimerBundleCheckpointHandler( + key -> sdfTimerInternals, key -> sdfStateInternals, inputCoder, windowCoder); + } + + private StateRequestHandler getStateRequestHandler( + ExecutableStage executableStage, + ProcessBundleDescriptors.ExecutableProcessBundleDescriptor processBundleDescriptor, + RuntimeContext runtimeContext) { + final StateRequestHandler sideInputHandler; + StateRequestHandlers.SideInputHandlerFactory sideInputHandlerFactory = + BatchSideInputHandlerFactory.forStage( + executableStage, runtimeContext::getBroadcastVariable); + try { + sideInputHandler = + StateRequestHandlers.forSideInputHandlerFactory( + ProcessBundleDescriptors.getSideInputs(executableStage), sideInputHandlerFactory); + } catch (IOException e) { + throw new RuntimeException("Failed to setup state handler", e); + } + + final StateRequestHandler userStateHandler; + if (executableStage.getUserStates().size() > 0) { + bagUserStateHandlerFactory = new InMemoryBagUserStateFactory<>(); + userStateHandler = + StateRequestHandlers.forBagUserStateHandlerFactory( + processBundleDescriptor, bagUserStateHandlerFactory); + } else { + userStateHandler = StateRequestHandler.unsupported(); + } + + EnumMap handlerMap = + new EnumMap<>(StateKey.TypeCase.class); + handlerMap.put(StateKey.TypeCase.ITERABLE_SIDE_INPUT, sideInputHandler); + handlerMap.put(StateKey.TypeCase.MULTIMAP_SIDE_INPUT, sideInputHandler); + handlerMap.put(StateKey.TypeCase.MULTIMAP_KEYS_SIDE_INPUT, sideInputHandler); + handlerMap.put(StateKey.TypeCase.BAG_USER_STATE, userStateHandler); + + return StateRequestHandlers.delegateBasedUponType(handlerMap); + } + + /** For non-stateful processing via a simple MapPartitionFunction. */ + @Override + public void mapPartition( + Iterable> iterable, Collector collector) + throws Exception { + + ReceiverFactory receiverFactory = new ReceiverFactory(collector, outputMap); + if (sdfStateInternals != null) { + sdfTimerInternals.advanceProcessingTime(Instant.now()); + sdfTimerInternals.advanceSynchronizedProcessingTime(Instant.now()); + } + try (RemoteBundle bundle = + stageBundleFactory.getBundle( + receiverFactory, + stateRequestHandler, + progressHandler, + finalizationHandler, + bundleCheckpointHandler)) { + processElements(iterable, bundle); + } + if (sdfTimerInternals != null) { + // Finally, advance the processing time to infinity to fire any timers. + sdfTimerInternals.advanceProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE); + sdfTimerInternals.advanceSynchronizedProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE); + + // Now we fire the SDF timers and process elements generated by timers. + while (sdfTimerInternals.hasPendingTimers()) { + try (RemoteBundle bundle = + stageBundleFactory.getBundle( + receiverFactory, + stateRequestHandler, + progressHandler, + finalizationHandler, + bundleCheckpointHandler)) { + List> residuals = new ArrayList<>(); + TimerInternals.TimerData timer; + while ((timer = sdfTimerInternals.removeNextProcessingTimer()) != null) { + WindowedValue stateValue = + sdfStateInternals + .state(timer.getNamespace(), StateTags.value(timer.getTimerId(), inputCoder)) + .read(); + + residuals.add(stateValue); + } + processElements(residuals, bundle); + } + } + } + } + + /** For stateful and timer processing via a GroupReduceFunction. */ + @Override + public void reduce(Iterable> iterable, Collector collector) + throws Exception { + + // Need to discard the old key's state + if (bagUserStateHandlerFactory != null) { + bagUserStateHandlerFactory.resetForNewKey(); + } + + // Used with Batch, we know that all the data is available for this key. We can't use the + // timer manager from the context because it doesn't exist. So we create one and advance + // time to the end after processing all elements. + final InMemoryTimerInternals timerInternals = new InMemoryTimerInternals(); + timerInternals.advanceProcessingTime(Instant.now()); + timerInternals.advanceSynchronizedProcessingTime(Instant.now()); + + ReceiverFactory receiverFactory = new ReceiverFactory(collector, outputMap); + + TimerReceiverFactory timerReceiverFactory = + new TimerReceiverFactory( + stageBundleFactory, + (Timer timer, TimerInternals.TimerData timerData) -> { + currentTimerKey = timer.getUserKey(); + if (timer.getClearBit()) { + timerInternals.deleteTimer(timerData); + } else { + timerInternals.setTimer(timerData); + } + }, + windowCoder); + + // First process all elements and make sure no more elements can arrive + try (RemoteBundle bundle = + stageBundleFactory.getBundle( + receiverFactory, timerReceiverFactory, stateRequestHandler, progressHandler)) { + processElements(iterable, bundle); + } + + // Finish any pending windows by advancing the input watermark to infinity. + timerInternals.advanceInputWatermark(BoundedWindow.TIMESTAMP_MAX_VALUE); + // Finally, advance the processing time to infinity to fire any timers. + timerInternals.advanceProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE); + timerInternals.advanceSynchronizedProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE); + + // Now we fire the timers and process elements generated by timers (which may be timers itself) + while (timerInternals.hasPendingTimers()) { + try (RemoteBundle bundle = + stageBundleFactory.getBundle( + receiverFactory, timerReceiverFactory, stateRequestHandler, progressHandler)) { + PipelineTranslatorUtils.fireEligibleTimers( + timerInternals, bundle.getTimerReceivers(), currentTimerKey); + } + } + } + + private void processElements(Iterable> iterable, RemoteBundle bundle) + throws Exception { + Preconditions.checkArgument(bundle != null, "RemoteBundle must not be null"); + + FnDataReceiver> mainReceiver = + Iterables.getOnlyElement(bundle.getInputReceivers().values()); + for (WindowedValue input : iterable) { + mainReceiver.accept(input); + } + } + + @Override + public void close() throws Exception { + metricContainer.registerMetricsForPipelineResult(); + // close may be called multiple times when an exception is thrown + if (stageContext != null) { + try (AutoCloseable bundleFactoryCloser = stageBundleFactory; + AutoCloseable closable = stageContext) { + } catch (Exception e) { + LOG.error("Error in close: ", e); + throw e; + } + } + stageContext = null; + } + + /** + * Receiver factory that wraps outgoing elements with the corresponding union tag for a + * multiplexed PCollection and optionally handles timer items. + */ + private static class ReceiverFactory implements OutputReceiverFactory { + + private final Object collectorLock = new Object(); + + @GuardedBy("collectorLock") + private final Collector collector; + + private final Map outputMap; + + ReceiverFactory(Collector collector, Map outputMap) { + this.collector = collector; + this.outputMap = outputMap; + } + + @Override + public FnDataReceiver create(String collectionId) { + Integer unionTag = outputMap.get(collectionId); + if (unionTag != null) { + int tagInt = unionTag; + return receivedElement -> { + synchronized (collectorLock) { + collector.collect(new RawUnionValue(tagInt, receivedElement)); + } + }; + } else { + throw new IllegalStateException( + String.format(Locale.ENGLISH, "Unknown PCollectionId %s", collectionId)); + } + } + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStagePruningFunction.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStagePruningFunction.java new file mode 100644 index 000000000000..9079d347772f --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStagePruningFunction.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.sdk.io.FileSystems; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.join.RawUnionValue; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.flink.api.common.functions.OpenContext; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.util.Collector; + +/** A Flink function that demultiplexes output from a {@link FlinkExecutableStageFunction}. */ +@SuppressWarnings({ + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +public class FlinkExecutableStagePruningFunction + extends RichFlatMapFunction> { + + private final int unionTag; + private final SerializablePipelineOptions options; + + /** + * Creates a {@link FlinkExecutableStagePruningFunction} that extracts elements of the given union + * tag. + */ + public FlinkExecutableStagePruningFunction(int unionTag, PipelineOptions pipelineOptions) { + this.unionTag = unionTag; + this.options = new SerializablePipelineOptions(pipelineOptions); + } + + @Override + public void open(OpenContext parameters) { + // Initialize FileSystems for any coders which may want to use the FileSystem, + // see https://issues.apache.org/jira/browse/BEAM-8303 + FileSystems.setDefaultPipelineOptions(options.get()); + } + + @Override + public void flatMap(RawUnionValue rawUnionValue, Collector> collector) { + if (rawUnionValue.getUnionTag() == unionTag) { + collector.collect((WindowedValue) rawUnionValue.getValue()); + } + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMergingNonShuffleReduceFunction.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMergingNonShuffleReduceFunction.java new file mode 100644 index 000000000000..15080c053d46 --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMergingNonShuffleReduceFunction.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import java.util.Map; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.sdk.io.FileSystems; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.CombineFnBase; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.Sessions; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.flink.api.common.functions.OpenContext; +import org.apache.flink.api.common.functions.RichGroupReduceFunction; +import org.apache.flink.util.Collector; + +/** + * Special version of {@link FlinkReduceFunction} that supports merging windows. + * + *

This is different from the pair of function for the non-merging windows case in that we cannot + * do combining before the shuffle because elements would not yet be in their correct windows for + * side-input access. + */ +public class FlinkMergingNonShuffleReduceFunction< + K, InputT, AccumT, OutputT, W extends BoundedWindow> + extends RichGroupReduceFunction>, WindowedValue>> { + + private final CombineFnBase.GlobalCombineFn combineFn; + + private final WindowingStrategy windowingStrategy; + + private final Map, WindowingStrategy> sideInputs; + + private final SerializablePipelineOptions serializedOptions; + + public FlinkMergingNonShuffleReduceFunction( + CombineFnBase.GlobalCombineFn combineFn, + WindowingStrategy windowingStrategy, + Map, WindowingStrategy> sideInputs, + PipelineOptions pipelineOptions) { + + this.combineFn = combineFn; + + this.windowingStrategy = windowingStrategy; + this.sideInputs = sideInputs; + + this.serializedOptions = new SerializablePipelineOptions(pipelineOptions); + } + + @Override + public void open(OpenContext parameters) { + // Initialize FileSystems for any coders which may want to use the FileSystem, + // see https://issues.apache.org/jira/browse/BEAM-8303 + FileSystems.setDefaultPipelineOptions(serializedOptions.get()); + } + + @Override + public void reduce( + Iterable>> elements, Collector>> out) + throws Exception { + + PipelineOptions options = serializedOptions.get(); + + FlinkSideInputReader sideInputReader = + new FlinkSideInputReader(sideInputs, getRuntimeContext()); + + AbstractFlinkCombineRunner reduceRunner; + if (windowingStrategy.getWindowFn() instanceof Sessions) { + reduceRunner = new SortingFlinkCombineRunner<>(); + } else { + reduceRunner = new HashingFlinkCombineRunner<>(); + } + + reduceRunner.combine( + new AbstractFlinkCombineRunner.CompleteFlinkCombiner<>(combineFn), + windowingStrategy, + sideInputReader, + options, + elements, + out); + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputPruningFunction.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputPruningFunction.java new file mode 100644 index 000000000000..379dcce6b1e7 --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputPruningFunction.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.sdk.io.FileSystems; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.join.RawUnionValue; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.OpenContext; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.util.Collector; + +/** + * A {@link FlatMapFunction} function that filters out those elements that don't belong in this + * output. We need this to implement MultiOutput ParDo functions in combination with {@link + * FlinkDoFnFunction}. + */ +public class FlinkMultiOutputPruningFunction + extends RichFlatMapFunction, WindowedValue> { + + private final int ourOutputTag; + private final SerializablePipelineOptions options; + + public FlinkMultiOutputPruningFunction(int ourOutputTag, PipelineOptions options) { + this.ourOutputTag = ourOutputTag; + this.options = new SerializablePipelineOptions(options); + } + + @Override + public void open(OpenContext parameters) { + // Initialize FileSystems for any coders which may want to use the FileSystem, + // see https://issues.apache.org/jira/browse/BEAM-8303 + FileSystems.setDefaultPipelineOptions(options.get()); + } + + @Override + @SuppressWarnings("unchecked") + public void flatMap( + WindowedValue windowedValue, Collector> collector) + throws Exception { + int unionTag = windowedValue.getValue().getUnionTag(); + if (unionTag == ourOutputTag) { + collector.collect( + (WindowedValue) windowedValue.withValue(windowedValue.getValue().getValue())); + } + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkPartialReduceFunction.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkPartialReduceFunction.java new file mode 100644 index 000000000000..f277cef058f9 --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkPartialReduceFunction.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import java.util.Map; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.sdk.io.FileSystems; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.CombineFnBase; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.Sessions; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.flink.api.common.functions.OpenContext; +import org.apache.flink.api.common.functions.RichGroupCombineFunction; +import org.apache.flink.util.Collector; + +/** + * This is the first step for executing a {@link org.apache.beam.sdk.transforms.Combine.PerKey} on + * Flink. The second part is {@link FlinkReduceFunction}. This function performs a local combine + * step before shuffling while the latter does the final combination after a shuffle. + * + *

The input to {@link #combine(Iterable, Collector)} are elements of the same key but for + * different windows. We have to ensure that we only combine elements of matching windows. + */ +public class FlinkPartialReduceFunction + extends RichGroupCombineFunction>, WindowedValue>> { + + protected final CombineFnBase.GlobalCombineFn combineFn; + + protected final WindowingStrategy windowingStrategy; + + protected final SerializablePipelineOptions serializedOptions; + + // TODO: Remove side input functionality since liftable Combines no longer have side inputs. + protected final Map, WindowingStrategy> sideInputs; + + /** WindowedValues has been exploded and pre-grouped by window. */ + private final boolean groupedByWindow; + + public FlinkPartialReduceFunction( + CombineFnBase.GlobalCombineFn combineFn, + WindowingStrategy windowingStrategy, + Map, WindowingStrategy> sideInputs, + PipelineOptions pipelineOptions) { + this(combineFn, windowingStrategy, sideInputs, pipelineOptions, false); + } + + public FlinkPartialReduceFunction( + CombineFnBase.GlobalCombineFn combineFn, + WindowingStrategy windowingStrategy, + Map, WindowingStrategy> sideInputs, + PipelineOptions pipelineOptions, + boolean groupedByWindow) { + this.combineFn = combineFn; + this.windowingStrategy = windowingStrategy; + this.sideInputs = sideInputs; + this.serializedOptions = new SerializablePipelineOptions(pipelineOptions); + this.groupedByWindow = groupedByWindow; + } + + @Override + public void open(OpenContext parameters) { + // Initialize FileSystems for any coders which may want to use the FileSystem, + // see https://issues.apache.org/jira/browse/BEAM-8303 + FileSystems.setDefaultPipelineOptions(serializedOptions.get()); + } + + @Override + public void combine( + Iterable>> elements, Collector>> out) + throws Exception { + + PipelineOptions options = serializedOptions.get(); + + FlinkSideInputReader sideInputReader = + new FlinkSideInputReader(sideInputs, getRuntimeContext()); + + AbstractFlinkCombineRunner reduceRunner; + + if (groupedByWindow) { + reduceRunner = new SingleWindowFlinkCombineRunner<>(); + } else { + if (windowingStrategy.needsMerge() && windowingStrategy.getWindowFn() instanceof Sessions) { + reduceRunner = new SortingFlinkCombineRunner<>(); + } else { + reduceRunner = new HashingFlinkCombineRunner<>(); + } + } + + reduceRunner.combine( + new AbstractFlinkCombineRunner.PartialFlinkCombiner<>(combineFn), + windowingStrategy, + sideInputReader, + options, + elements, + out); + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkReduceFunction.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkReduceFunction.java new file mode 100644 index 000000000000..72e99bb4151f --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkReduceFunction.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import java.util.Map; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.sdk.io.FileSystems; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.CombineFnBase; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.Sessions; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.flink.api.common.functions.OpenContext; +import org.apache.flink.api.common.functions.RichGroupReduceFunction; +import org.apache.flink.util.Collector; + +/** + * This is the second part for executing a {@link org.apache.beam.sdk.transforms.Combine.PerKey} on + * Flink, the second part is {@link FlinkReduceFunction}. This function performs the final + * combination of the pre-combined values after a shuffle. + * + *

The input to {@link #reduce(Iterable, Collector)} are elements of the same key but for + * different windows. We have to ensure that we only combine elements of matching windows. + */ +public class FlinkReduceFunction + extends RichGroupReduceFunction>, WindowedValue>> { + + protected final CombineFnBase.GlobalCombineFn combineFn; + + protected final WindowingStrategy windowingStrategy; + + // TODO: Remove side input functionality since liftable Combines no longer have side inputs. + protected final Map, WindowingStrategy> sideInputs; + + protected final SerializablePipelineOptions serializedOptions; + + /** WindowedValues has been exploded and pre-grouped by window. */ + private final boolean groupedByWindow; + + public FlinkReduceFunction( + CombineFnBase.GlobalCombineFn combineFn, + WindowingStrategy windowingStrategy, + Map, WindowingStrategy> sideInputs, + PipelineOptions pipelineOptions) { + this(combineFn, windowingStrategy, sideInputs, pipelineOptions, false); + } + + public FlinkReduceFunction( + CombineFnBase.GlobalCombineFn combineFn, + WindowingStrategy windowingStrategy, + Map, WindowingStrategy> sideInputs, + PipelineOptions pipelineOptions, + boolean groupedByWindow) { + this.combineFn = combineFn; + this.windowingStrategy = windowingStrategy; + this.sideInputs = sideInputs; + this.serializedOptions = new SerializablePipelineOptions(pipelineOptions); + this.groupedByWindow = groupedByWindow; + } + + @Override + public void open(OpenContext parameters) { + // Initialize FileSystems for any coders which may want to use the FileSystem, + // see https://issues.apache.org/jira/browse/BEAM-8303 + FileSystems.setDefaultPipelineOptions(serializedOptions.get()); + } + + @Override + public void reduce( + Iterable>> elements, Collector>> out) + throws Exception { + + PipelineOptions options = serializedOptions.get(); + + FlinkSideInputReader sideInputReader = + new FlinkSideInputReader(sideInputs, getRuntimeContext()); + + AbstractFlinkCombineRunner reduceRunner; + + if (groupedByWindow) { + reduceRunner = new SingleWindowFlinkCombineRunner<>(); + } else { + if (windowingStrategy.needsMerge() && windowingStrategy.getWindowFn() instanceof Sessions) { + reduceRunner = new SortingFlinkCombineRunner<>(); + } else { + reduceRunner = new HashingFlinkCombineRunner<>(); + } + } + + reduceRunner.combine( + new AbstractFlinkCombineRunner.FinalFlinkCombiner<>(combineFn), + windowingStrategy, + sideInputReader, + options, + elements, + out); + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunction.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunction.java new file mode 100644 index 000000000000..2a208d30a87e --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunction.java @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import static org.apache.flink.util.Preconditions.checkArgument; + +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import org.apache.beam.runners.core.DoFnRunner; +import org.apache.beam.runners.core.DoFnRunners; +import org.apache.beam.runners.core.InMemoryStateInternals; +import org.apache.beam.runners.core.InMemoryTimerInternals; +import org.apache.beam.runners.core.StateInternals; +import org.apache.beam.runners.core.StateNamespace; +import org.apache.beam.runners.core.StateNamespaces; +import org.apache.beam.runners.core.TimerInternals; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.flink.FlinkPipelineOptions; +import org.apache.beam.runners.flink.metrics.DoFnRunnerWithMetricsUpdate; +import org.apache.beam.runners.flink.metrics.FlinkMetricContainer; +import org.apache.beam.runners.flink.translation.utils.Workarounds; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.io.FileSystems; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFnSchemaInformation; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.join.RawUnionValue; +import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; +import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; +import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.WindowedValueMultiReceiver; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.apache.flink.api.common.functions.OpenContext; +import org.apache.flink.api.common.functions.RichGroupReduceFunction; +import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.util.Collector; +import org.joda.time.Duration; +import org.joda.time.Instant; + +/** A {@link RichGroupReduceFunction} for stateful {@link ParDo} in Flink Batch Runner. */ +@SuppressWarnings({ + "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +public class FlinkStatefulDoFnFunction + extends RichGroupReduceFunction>, WindowedValue> { + + private final DoFn, OutputT> dofn; + private final boolean usesOnWindowExpiration; + private String stepName; + private final WindowingStrategy windowingStrategy; + private final Map, WindowingStrategy> sideInputs; + private final SerializablePipelineOptions serializedOptions; + private final Map, Integer> outputMap; + private final TupleTag mainOutputTag; + private final Coder> inputCoder; + private final Map, Coder> outputCoderMap; + private final DoFnSchemaInformation doFnSchemaInformation; + private final Map> sideInputMapping; + + private transient DoFnInvoker doFnInvoker; + private transient FlinkMetricContainer metricContainer; + + public FlinkStatefulDoFnFunction( + DoFn, OutputT> dofn, + String stepName, + WindowingStrategy windowingStrategy, + Map, WindowingStrategy> sideInputs, + PipelineOptions pipelineOptions, + Map, Integer> outputMap, + TupleTag mainOutputTag, + Coder> inputCoder, + Map, Coder> outputCoderMap, + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping) { + + this.dofn = dofn; + this.usesOnWindowExpiration = + DoFnSignatures.signatureForDoFn(dofn).onWindowExpiration() != null; + this.stepName = stepName; + this.windowingStrategy = windowingStrategy; + this.sideInputs = sideInputs; + this.serializedOptions = new SerializablePipelineOptions(pipelineOptions); + this.outputMap = outputMap; + this.mainOutputTag = mainOutputTag; + this.inputCoder = inputCoder; + this.outputCoderMap = outputCoderMap; + this.doFnSchemaInformation = doFnSchemaInformation; + this.sideInputMapping = sideInputMapping; + } + + @Override + public void reduce( + Iterable>> values, Collector> out) + throws Exception { + RuntimeContext runtimeContext = getRuntimeContext(); + + WindowedValueMultiReceiver outputManager; + if (outputMap.size() == 1) { + outputManager = new FlinkDoFnFunction.DoFnOutputManager(out); + } else { + // it has some additional Outputs + outputManager = new FlinkDoFnFunction.MultiDoFnOutputManagerWindowed(out, outputMap); + } + + final Iterator>> iterator = values.iterator(); + + // get the first value, we need this for initializing the state internals with the key. + // we are guaranteed to have a first value, otherwise reduce() would not have been called. + WindowedValue> currentValue = iterator.next(); + final K key = currentValue.getValue().getKey(); + + final InMemoryStateInternals stateInternals = InMemoryStateInternals.forKey(key); + + // Used with Batch, we know that all the data is available for this key. We can't use the + // timer manager from the context because it doesn't exist. So we create one and advance + // time to the end after processing all elements. + final InMemoryTimerInternals timerInternals = new InMemoryTimerInternals(); + timerInternals.advanceProcessingTime(Instant.now()); + timerInternals.advanceSynchronizedProcessingTime(Instant.now()); + + final Set windowsSeen = new HashSet<>(); + + List> additionalOutputTags = Lists.newArrayList(outputMap.keySet()); + + DoFnRunner, OutputT> doFnRunner = + DoFnRunners.simpleRunner( + serializedOptions.get(), + dofn, + new FlinkSideInputReader(sideInputs, runtimeContext), + outputManager, + mainOutputTag, + additionalOutputTags, + new FlinkNoOpStepContext() { + @Override + public StateInternals stateInternals() { + return stateInternals; + } + + @Override + public TimerInternals timerInternals() { + return timerInternals; + } + }, + inputCoder, + outputCoderMap, + windowingStrategy, + doFnSchemaInformation, + sideInputMapping); + + FlinkPipelineOptions pipelineOptions = serializedOptions.get().as(FlinkPipelineOptions.class); + if (!pipelineOptions.getDisableMetrics()) { + doFnRunner = new DoFnRunnerWithMetricsUpdate<>(stepName, doFnRunner, metricContainer); + } + + doFnRunner.startBundle(); + + doFnRunner.processElement(currentValue); + if (usesOnWindowExpiration) { + windowsSeen.addAll(currentValue.getWindows()); + } + while (iterator.hasNext()) { + currentValue = iterator.next(); + if (usesOnWindowExpiration) { + windowsSeen.addAll(currentValue.getWindows()); + } + doFnRunner.processElement(currentValue); + } + + // Finish any pending windows by advancing the input watermark to infinity. + timerInternals.advanceInputWatermark(BoundedWindow.TIMESTAMP_MAX_VALUE); + + // Finally, advance the processing time to infinity to fire any timers. + timerInternals.advanceProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE); + timerInternals.advanceSynchronizedProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE); + + fireEligibleTimers(key, timerInternals, doFnRunner); + + if (usesOnWindowExpiration) { + for (BoundedWindow window : windowsSeen) { + doFnRunner.onWindowExpiration(window, window.maxTimestamp().minus(Duration.millis(1)), key); + } + } + + doFnRunner.finishBundle(); + } + + private void fireEligibleTimers( + final K key, InMemoryTimerInternals timerInternals, DoFnRunner, OutputT> runner) + throws Exception { + + while (true) { + + TimerInternals.TimerData timer; + boolean hasFired = false; + + while ((timer = timerInternals.removeNextEventTimer()) != null) { + hasFired = true; + fireTimer(key, timer, runner); + } + while ((timer = timerInternals.removeNextProcessingTimer()) != null) { + hasFired = true; + fireTimer(key, timer, runner); + } + while ((timer = timerInternals.removeNextSynchronizedProcessingTimer()) != null) { + hasFired = true; + fireTimer(key, timer, runner); + } + if (!hasFired) { + break; + } + } + } + + private void fireTimer( + final K key, TimerInternals.TimerData timer, DoFnRunner, OutputT> doFnRunner) { + StateNamespace namespace = timer.getNamespace(); + checkArgument(namespace instanceof StateNamespaces.WindowNamespace); + BoundedWindow window = ((StateNamespaces.WindowNamespace) namespace).getWindow(); + doFnRunner.onTimer( + timer.getTimerId(), + timer.getTimerFamilyId(), + key, + window, + timer.getTimestamp(), + timer.getOutputTimestamp(), + timer.getDomain()); + } + + @Override + public void open(OpenContext parameters) { + // Note that the SerializablePipelineOptions already initialize FileSystems in the readObject() + // deserialization method. However, this is a hack, and we want to properly initialize the + // options where they are needed. + PipelineOptions options = serializedOptions.get(); + FileSystems.setDefaultPipelineOptions(options); + metricContainer = new FlinkMetricContainer(getRuntimeContext()); + doFnInvoker = DoFnInvokers.tryInvokeSetupFor(dofn, options); + } + + @Override + public void close() throws Exception { + try { + metricContainer.registerMetricsForPipelineResult(); + Optional.ofNullable(doFnInvoker).ifPresent(DoFnInvoker::invokeTeardown); + } finally { + Workarounds.deleteStaticCaches(); + } + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/ImpulseSourceFunction.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/ImpulseSourceFunction.java new file mode 100644 index 000000000000..1c8edf8b0c59 --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/functions/ImpulseSourceFunction.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeutils.base.BooleanSerializer; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; +import org.apache.flink.streaming.api.functions.source.legacy.SourceFunction; +import org.apache.flink.streaming.api.watermark.Watermark; + +/** + * Source function which sends a single global impulse to a downstream operator. It may keep the + * source alive although its work is already done. It will only shutdown when the streaming job is + * cancelled. + */ +@SuppressWarnings({ + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +public class ImpulseSourceFunction + implements SourceFunction>, CheckpointedFunction { + + /** The idle time before the source shuts down. */ + private final long idleTimeoutMs; + + /** Indicates the streaming job is running and the source can produce elements. */ + private volatile boolean running; + + /** Checkpointed state which indicates whether the impulse has finished. */ + private transient ListState impulseEmitted; + + public ImpulseSourceFunction(long idleTimeoutMs) { + this.idleTimeoutMs = idleTimeoutMs; + this.running = true; + } + + @Override + public void run(SourceContext> sourceContext) throws Exception { + if (Iterables.isEmpty(impulseEmitted.get())) { + synchronized (sourceContext.getCheckpointLock()) { + // emit single impulse element + sourceContext.collect(WindowedValues.valueInGlobalWindow(new byte[0])); + impulseEmitted.add(true); + } + } + // Always emit a final watermark. + // (1) In case we didn't restore the pipeline, this is important to close the global window; + // if no operator holds back this watermark. + // (2) In case we are restoring the pipeline, this is needed to initialize the operators with + // the current watermark and trigger execution of any pending timers. + sourceContext.emitWatermark(Watermark.MAX_WATERMARK); + // Wait to allow checkpoints of the pipeline + waitToEnsureCheckpointingWorksCorrectly(); + } + + private void waitToEnsureCheckpointingWorksCorrectly() { + // Do nothing, but still look busy ... + // we can't return here since Flink requires that all operators stay up, + // otherwise checkpointing would not work correctly anymore + // + // See https://issues.apache.org/jira/browse/FLINK-2491 for progress on this issue + long idleStart = System.currentTimeMillis(); + // wait until this is canceled + final Object waitLock = new Object(); + while (running && (System.currentTimeMillis() - idleStart < idleTimeoutMs)) { + try { + // Flink will interrupt us at some point + //noinspection SynchronizationOnLocalVariableOrMethodParameter + synchronized (waitLock) { + // don't wait indefinitely, in case something goes horribly wrong + waitLock.wait(1000); + } + } catch (InterruptedException e) { + if (!running) { + // restore the interrupted state, and fall through the loop + Thread.currentThread().interrupt(); + } + } + } + } + + @Override + public void cancel() { + this.running = false; + } + + @Override + public void snapshotState(FunctionSnapshotContext context) {} + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + impulseEmitted = + context + .getOperatorStateStore() + .getListState(new ListStateDescriptor<>("impulse-emitted", BooleanSerializer.INSTANCE)); + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeInformation.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeInformation.java new file mode 100644 index 000000000000..12e74a64faa7 --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeInformation.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.types; + +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; + +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.serialization.SerializerConfig; +import org.apache.flink.api.common.typeinfo.AtomicType; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * Flink {@link org.apache.flink.api.common.typeinfo.TypeInformation} for Beam {@link + * org.apache.beam.sdk.coders.Coder}s. + */ +@SuppressWarnings({ + "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) +}) +public class CoderTypeInformation extends TypeInformation implements AtomicType { + + private final Coder coder; + private final SerializablePipelineOptions pipelineOptions; + + public CoderTypeInformation(Coder coder, PipelineOptions pipelineOptions) { + this(coder, new SerializablePipelineOptions(pipelineOptions)); + } + + public CoderTypeInformation(Coder coder, SerializablePipelineOptions pipelineOptions) { + checkNotNull(coder); + checkNotNull(pipelineOptions); + this.coder = coder; + this.pipelineOptions = pipelineOptions; + } + + public Coder getCoder() { + return coder; + } + + @Override + public boolean isBasicType() { + return false; + } + + @Override + public boolean isTupleType() { + return false; + } + + @Override + public int getArity() { + return 1; + } + + @Override + @SuppressWarnings("unchecked") + public Class getTypeClass() { + return (Class) coder.getEncodedTypeDescriptor().getRawType(); + } + + @Override + public boolean isKeyType() { + return true; + } + + @Override + public TypeSerializer createSerializer(SerializerConfig config) { + return new CoderTypeSerializer<>(coder, pipelineOptions); + } + + @Override + public int getTotalFields() { + return 2; + } + + /** + * Creates a new {@link CoderTypeInformation} with {@link PipelineOptions}, that can be used for + * {@link org.apache.beam.sdk.io.FileSystems} registration. + * + * @see Jira issue. + * @param pipelineOptions Options of current pipeline. + * @return New type information. + */ + public CoderTypeInformation withPipelineOptions(PipelineOptions pipelineOptions) { + return new CoderTypeInformation<>(getCoder(), pipelineOptions); + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + CoderTypeInformation that = (CoderTypeInformation) o; + + return coder.equals(that.coder); + } + + @Override + public int hashCode() { + return coder.hashCode(); + } + + @Override + public boolean canEqual(Object obj) { + return obj instanceof CoderTypeInformation; + } + + @Override + public String toString() { + return "CoderTypeInformation{coder=" + coder + '}'; + } + + @Override + public TypeComparator createComparator( + boolean sortOrderAscending, ExecutionConfig executionConfig) { + throw new UnsupportedOperationException("Non-encoded values cannot be compared directly."); + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/types/EncodedValueSerializer.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/types/EncodedValueSerializer.java new file mode 100644 index 000000000000..1703a7dca0e9 --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/types/EncodedValueSerializer.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.types; + +import java.io.IOException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerSchemaCompatibility; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; + +/** {@link TypeSerializer} for values that were encoded using a {@link Coder}. */ +public final class EncodedValueSerializer extends TypeSerializerSingleton { + + private static final long serialVersionUID = 1L; + + private static final byte[] EMPTY = new byte[0]; + + @Override + public boolean isImmutableType() { + return true; + } + + @Override + public byte[] createInstance() { + return EMPTY; + } + + @Override + public byte[] copy(byte[] from) { + return from; + } + + @Override + public byte[] copy(byte[] from, byte[] reuse) { + return copy(from); + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(byte[] record, DataOutputView target) throws IOException { + if (record == null) { + throw new IllegalArgumentException("The record must not be null."); + } + + final int len = record.length; + target.writeInt(len); + target.write(record); + } + + @Override + public byte[] deserialize(DataInputView source) throws IOException { + final int len = source.readInt(); + byte[] result = new byte[len]; + source.readFully(result); + return result; + } + + @Override + public byte[] deserialize(byte[] reuse, DataInputView source) throws IOException { + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + final int len = source.readInt(); + target.writeInt(len); + target.write(source, len); + } + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new TypeSerializerSnapshot() { + @Override + public int getCurrentVersion() { + return 2; + } + + @Override + public void writeSnapshot(DataOutputView out) throws IOException {} + + @Override + public void readSnapshot(int readVersion, DataInputView in, ClassLoader userCodeClassLoader) + throws IOException {} + + @Override + public TypeSerializer restoreSerializer() { + return new EncodedValueSerializer(); + } + + @Override + public TypeSerializerSchemaCompatibility resolveSchemaCompatibility( + TypeSerializerSnapshot oldSerializerSnapshot) { + // For maintainer: handle future incompatible change here + if (oldSerializerSnapshot.restoreSerializer() instanceof EncodedValueSerializer) { + return TypeSerializerSchemaCompatibility.compatibleAsIs(); + } else { + return TypeSerializerSchemaCompatibility.compatibleAfterMigration(); + } + } + }; + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/types/EncodedValueTypeInformation.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/types/EncodedValueTypeInformation.java new file mode 100644 index 000000000000..075ef0ef453e --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/types/EncodedValueTypeInformation.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.types; + +import org.apache.beam.sdk.coders.Coder; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.serialization.SerializerConfig; +import org.apache.flink.api.common.typeinfo.AtomicType; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * Flink {@link TypeInformation} for Beam values that have been encoded to byte data by a {@link + * Coder}. + */ +public class EncodedValueTypeInformation extends TypeInformation + implements AtomicType { + + private static final long serialVersionUID = 1L; + + @Override + public boolean isBasicType() { + return false; + } + + @Override + public boolean isTupleType() { + return false; + } + + @Override + public int getArity() { + return 0; + } + + @Override + public int getTotalFields() { + return 0; + } + + @Override + public Class getTypeClass() { + return byte[].class; + } + + @Override + public boolean isKeyType() { + return true; + } + + @Override + public TypeSerializer createSerializer(SerializerConfig executionConfig) { + return new EncodedValueSerializer(); + } + + @Override + public boolean equals(@Nullable Object other) { + return other instanceof EncodedValueTypeInformation; + } + + @Override + public int hashCode() { + return this.getClass().hashCode(); + } + + @Override + public boolean canEqual(Object obj) { + return obj instanceof EncodedValueTypeInformation; + } + + @Override + public String toString() { + return "EncodedValueTypeInformation"; + } + + @Override + public TypeComparator createComparator( + boolean sortOrderAscending, ExecutionConfig executionConfig) { + return new EncodedValueComparator(sortOrderAscending); + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/types/UnversionedTypeSerializerSnapshot.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/types/UnversionedTypeSerializerSnapshot.java new file mode 100644 index 000000000000..4f94fb631554 --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/types/UnversionedTypeSerializerSnapshot.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.types; + +import java.io.IOException; +import javax.annotation.Nullable; +import org.apache.beam.sdk.util.SerializableUtils; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerSchemaCompatibility; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.core.io.VersionedIOReadableWritable; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.util.TemporaryClassLoaderContext; + +/** A legacy snapshot which does not care about schema compatibility. */ +@SuppressWarnings("allcheckers") +public class UnversionedTypeSerializerSnapshot implements TypeSerializerSnapshot { + + private @Nullable CoderTypeSerializer serializer; + + /** Needs to be public to work with {@link VersionedIOReadableWritable}. */ + public UnversionedTypeSerializerSnapshot() { + this(null); + } + + @SuppressWarnings("initialization") + public UnversionedTypeSerializerSnapshot(CoderTypeSerializer serializer) { + this.serializer = serializer; + } + + @Override + public int getCurrentVersion() { + return 1; + } + + @Override + public void writeSnapshot(DataOutputView dataOutputView) throws IOException { + byte[] bytes = SerializableUtils.serializeToByteArray(serializer); + dataOutputView.writeInt(bytes.length); + dataOutputView.write(bytes); + } + + @SuppressWarnings("unchecked") + @Override + public void readSnapshot(int version, DataInputView dataInputView, ClassLoader classLoader) + throws IOException { + + try (TemporaryClassLoaderContext context = TemporaryClassLoaderContext.of(classLoader)) { + int length = dataInputView.readInt(); + byte[] bytes = new byte[length]; + dataInputView.readFully(bytes); + this.serializer = + (CoderTypeSerializer) + SerializableUtils.deserializeFromByteArray( + bytes, CoderTypeSerializer.class.getName()); + } + } + + @Override + public TypeSerializer restoreSerializer() { + return serializer; + } + + @Override + public TypeSerializerSchemaCompatibility resolveSchemaCompatibility( + TypeSerializerSnapshot oldSerializerSnapshot) { + return TypeSerializerSchemaCompatibility.compatibleAsIs(); + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java new file mode 100644 index 000000000000..f5ce658de4fd --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java @@ -0,0 +1,1785 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming; + +import static org.apache.flink.util.Preconditions.checkArgument; + +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.locks.Lock; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.beam.runners.core.DoFnRunner; +import org.apache.beam.runners.core.DoFnRunners; +import org.apache.beam.runners.core.InMemoryBundleFinalizer; +import org.apache.beam.runners.core.NullSideInputReader; +import org.apache.beam.runners.core.ProcessFnRunner; +import org.apache.beam.runners.core.PushbackSideInputDoFnRunner; +import org.apache.beam.runners.core.SideInputHandler; +import org.apache.beam.runners.core.SideInputReader; +import org.apache.beam.runners.core.SimplePushbackSideInputDoFnRunner; +import org.apache.beam.runners.core.SplittableParDoViaKeyedWorkItems; +import org.apache.beam.runners.core.StateInternals; +import org.apache.beam.runners.core.StateNamespace; +import org.apache.beam.runners.core.StateNamespaces.WindowNamespace; +import org.apache.beam.runners.core.StatefulDoFnRunner; +import org.apache.beam.runners.core.StepContext; +import org.apache.beam.runners.core.TimerInternals; +import org.apache.beam.runners.core.TimerInternals.TimerData; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.flink.FlinkPipelineOptions; +import org.apache.beam.runners.flink.adapter.FlinkKey; +import org.apache.beam.runners.flink.metrics.DoFnRunnerWithMetricsUpdate; +import org.apache.beam.runners.flink.metrics.FlinkMetricContainer; +import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer; +import org.apache.beam.runners.flink.translation.utils.CheckpointStats; +import org.apache.beam.runners.flink.translation.utils.Workarounds; +import org.apache.beam.runners.flink.translation.wrappers.streaming.stableinput.BufferingDoFnRunner; +import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkBroadcastStateInternals; +import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkStateInternals; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.StructuredCoder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.io.FileSystems; +import org.apache.beam.sdk.metrics.MetricName; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; +import org.apache.beam.sdk.transforms.DoFnSchemaInformation; +import org.apache.beam.sdk.transforms.join.RawUnionValue; +import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; +import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature; +import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.util.NoopLock; +import org.apache.beam.sdk.util.WindowedValueMultiReceiver; +import org.apache.beam.sdk.util.WindowedValueReceiver; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Joiner; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.flink.api.common.operators.ProcessingTimeService.ProcessingTimeCallback; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapState; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.runtime.state.InternalPriorityQueue; +import org.apache.flink.runtime.state.KeyedStateBackend; +import org.apache.flink.runtime.state.OperatorStateBackend; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.CheckpointingMode; +import org.apache.flink.streaming.api.graph.StreamConfig; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.InternalTimeServiceManagerImpl; +import org.apache.flink.streaming.api.operators.InternalTimer; +import org.apache.flink.streaming.api.operators.InternalTimerService; +import org.apache.flink.streaming.api.operators.InternalTimerServiceImpl; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.Output; +import org.apache.flink.streaming.api.operators.Triggerable; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.api.operators.sorted.state.BatchExecutionInternalTimeService; +import org.apache.flink.streaming.api.operators.sorted.state.BatchExecutionInternalTimeServiceManager; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService; +import org.apache.flink.streaming.runtime.tasks.StreamTask; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.function.BiConsumerWithException; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Flink operator for executing {@link DoFn DoFns}. + * + * @param the input type of the {@link DoFn} + * @param the output type of the {@link DoFn} + */ +// We use Flink's lifecycle methods to initialize transient fields +@SuppressFBWarnings("SE_TRANSIENT_FIELD_NOT_RESTORED") +@SuppressWarnings({ + "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) + "keyfor", + "nullness" +}) // TODO(https://github.com/apache/beam/issues/20497) +public class DoFnOperator + extends AbstractStreamOperator> + implements OneInputStreamOperator, WindowedValue>, + TwoInputStreamOperator, RawUnionValue, WindowedValue>, + Triggerable { + + private static final Logger LOG = LoggerFactory.getLogger(DoFnOperator.class); + private final boolean isStreaming; + + protected DoFn doFn; + + protected final SerializablePipelineOptions serializedOptions; + + protected final TupleTag mainOutputTag; + protected final List> additionalOutputTags; + + protected final Collection> sideInputs; + protected final Map> sideInputTagMapping; + + protected final WindowingStrategy windowingStrategy; + + protected final OutputManagerFactory outputManagerFactory; + + protected transient DoFnRunner doFnRunner; + protected transient PushbackSideInputDoFnRunner pushbackDoFnRunner; + protected transient BufferingDoFnRunner bufferingDoFnRunner; + + protected transient SideInputHandler sideInputHandler; + + protected transient SideInputReader sideInputReader; + + protected transient BufferedOutputManager outputManager; + + private transient DoFnInvoker doFnInvoker; + + protected transient FlinkStateInternals keyedStateInternals; + protected transient FlinkTimerInternals timerInternals; + + protected final String stepName; + + final Coder> windowedInputCoder; + + final Map, Coder> outputCoders; + + final Coder keyCoder; + + final KeySelector, ?> keySelector; + + final TimerInternals.TimerDataCoderV2 timerCoder; + + /** Max number of elements to include in a bundle. */ + private final long maxBundleSize; + /** Max duration of a bundle. */ + private final long maxBundleTimeMills; + + private final DoFnSchemaInformation doFnSchemaInformation; + + private final Map> sideInputMapping; + + /** If true, we must process elements only after a checkpoint is finished. */ + final boolean requiresStableInput; + + /** + * If both requiresStableInput and this parameter are true, we must flush the buffer during drain + * operation. + */ + final boolean enableStableInputDrain; + + final int numConcurrentCheckpoints; + + private final boolean usesOnWindowExpiration; + + private final boolean finishBundleBeforeCheckpointing; + + /** Stores new finalizations being gathered. */ + private transient InMemoryBundleFinalizer bundleFinalizer; + /** Pending bundle finalizations which have not been acknowledged yet. */ + private transient LinkedHashMap> + pendingFinalizations; + /** + * Keep a maximum of 32 bundle finalizations for {@link + * BundleFinalizer.Callback#onBundleSuccess()}. + */ + private static final int MAX_NUMBER_PENDING_BUNDLE_FINALIZATIONS = 32; + + protected transient InternalTimerService timerService; + // Flink 1.20 moved timeServiceManager to protected scope. No longer need delegate + // private transient InternalTimeServiceManager timeServiceManager; + + private transient PushedBackElementsHandler> pushedBackElementsHandler; + + /** Metrics container for reporting Beam metrics to Flink (null if metrics are disabled). */ + transient @Nullable FlinkMetricContainer flinkMetricContainer; + + /** Helper class to report the checkpoint duration. */ + private transient @Nullable CheckpointStats checkpointStats; + + /** A timer that finishes the current bundle after a fixed amount of time. */ + private transient ScheduledFuture checkFinishBundleTimer; + + /** + * This and the below fields need to be volatile because we use multiple threads to access these. + * (a) the main processing thread (b) a timer thread to finish bundles by a timeout instead of the + * number of element However, we do not need a lock because Flink makes sure to acquire the + * "checkpointing" lock for the main processing but also for timer set via its {@code + * timerService}. + * + *

The volatile flag can be removed once https://issues.apache.org/jira/browse/FLINK-12481 has + * been addressed. + */ + private transient volatile boolean bundleStarted; + /** Number of processed elements in the current bundle. */ + private transient volatile long elementCount; + /** Time that the last bundle was finished (to set the timer). */ + private transient volatile long lastFinishBundleTime; + /** Callback to be executed before the current bundle is started. */ + private transient volatile Runnable preBundleCallback; + /** Callback to be executed after the current bundle was finished. */ + private transient volatile Runnable bundleFinishedCallback; + + // Watermark state. + // Volatile because these can be set in two mutually exclusive threads (see above). + private transient volatile long currentInputWatermark; + private transient volatile long currentSideInputWatermark; + private transient volatile long currentOutputWatermark; + private transient volatile long pushedBackWatermark; + + /** Constructor for DoFnOperator. */ + public DoFnOperator( + @Nullable DoFn doFn, + String stepName, + Coder> inputWindowedCoder, + Map, Coder> outputCoders, + TupleTag mainOutputTag, + List> additionalOutputTags, + OutputManagerFactory outputManagerFactory, + WindowingStrategy windowingStrategy, + Map> sideInputTagMapping, + Collection> sideInputs, + PipelineOptions options, + @Nullable Coder keyCoder, + @Nullable KeySelector, ?> keySelector, + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping) { + this.doFn = doFn; + this.stepName = stepName; + this.windowedInputCoder = inputWindowedCoder; + this.outputCoders = outputCoders; + this.mainOutputTag = mainOutputTag; + this.additionalOutputTags = additionalOutputTags; + this.sideInputTagMapping = sideInputTagMapping; + this.sideInputs = sideInputs; + this.serializedOptions = new SerializablePipelineOptions(options); + this.isStreaming = serializedOptions.get().as(FlinkPipelineOptions.class).isStreaming(); + this.windowingStrategy = windowingStrategy; + this.outputManagerFactory = outputManagerFactory; + + // API removed in Flink 2.0. setChainingStrategy is now set internally. + // setChainingStrategy(ChainingStrategy.ALWAYS); + + this.keyCoder = keyCoder; + this.keySelector = keySelector; + + this.timerCoder = + TimerInternals.TimerDataCoderV2.of(windowingStrategy.getWindowFn().windowCoder()); + + FlinkPipelineOptions flinkOptions = options.as(FlinkPipelineOptions.class); + + this.maxBundleSize = flinkOptions.getMaxBundleSize(); + Preconditions.checkArgument(maxBundleSize > 0, "Bundle size must be at least 1"); + this.maxBundleTimeMills = flinkOptions.getMaxBundleTimeMills(); + Preconditions.checkArgument(maxBundleTimeMills > 0, "Bundle time must be at least 1"); + this.doFnSchemaInformation = doFnSchemaInformation; + this.sideInputMapping = sideInputMapping; + + this.requiresStableInput = isRequiresStableInput(doFn); + + this.usesOnWindowExpiration = + doFn != null && DoFnSignatures.getSignature(doFn.getClass()).onWindowExpiration() != null; + + if (requiresStableInput) { + Preconditions.checkState( + CheckpointingMode.valueOf(flinkOptions.getCheckpointingMode()) + == CheckpointingMode.EXACTLY_ONCE, + "Checkpointing mode is not set to exactly once but @RequiresStableInput is used."); + Preconditions.checkState( + flinkOptions.getCheckpointingInterval() > 0, + "No checkpointing configured but pipeline uses @RequiresStableInput"); + LOG.warn( + "Enabling stable input for transform {}. Will only process elements at most every {} milliseconds.", + stepName, + flinkOptions.getCheckpointingInterval() + + Math.max(0, flinkOptions.getMinPauseBetweenCheckpoints())); + } + + this.enableStableInputDrain = flinkOptions.getEnableStableInputDrain(); + + this.numConcurrentCheckpoints = flinkOptions.getNumConcurrentCheckpoints(); + + this.finishBundleBeforeCheckpointing = flinkOptions.getFinishBundleBeforeCheckpointing(); + } + + private boolean isRequiresStableInput(DoFn doFn) { + // WindowDoFnOperator does not use a DoFn + return doFn != null + && DoFnSignatures.getSignature(doFn.getClass()).processElement().requiresStableInput(); + } + + @VisibleForTesting + boolean getRequiresStableInput() { + return requiresStableInput; + } + + // allow overriding this in WindowDoFnOperator because this one dynamically creates + // the DoFn + protected DoFn getDoFn() { + return doFn; + } + + protected Iterable> preProcess(WindowedValue input) { + // Assume Input is PreInputT + return Collections.singletonList((WindowedValue) input); + } + + // allow overriding this, for example SplittableDoFnOperator will not create a + // stateful DoFn runner because ProcessFn, which is used for executing a Splittable DoFn + // doesn't play by the normal DoFn rules and WindowDoFnOperator uses LateDataDroppingDoFnRunner + protected DoFnRunner createWrappingDoFnRunner( + DoFnRunner wrappedRunner, StepContext stepContext) { + + if (keyCoder != null) { + StatefulDoFnRunner.CleanupTimer cleanupTimer = + new StatefulDoFnRunner.TimeInternalsCleanupTimer( + timerInternals, windowingStrategy) { + @Override + public void setForWindow(InputT input, BoundedWindow window) { + if (!window.equals(GlobalWindow.INSTANCE) || usesOnWindowExpiration) { + // Skip setting a cleanup timer for the global window as these timers + // lead to potentially unbounded state growth in the runner, depending on key + // cardinality. Cleanup for global window will be performed upon arrival of the + // final watermark. + // In the case of OnWindowExpiration, we still set the timer. + super.setForWindow(input, window); + } + } + }; + + // we don't know the window type + // @SuppressWarnings({"unchecked", "rawtypes"}) + Coder windowCoder = windowingStrategy.getWindowFn().windowCoder(); + + @SuppressWarnings({"unchecked"}) + StatefulDoFnRunner.StateCleaner stateCleaner = + new StatefulDoFnRunner.StateInternalsStateCleaner<>( + doFn, keyedStateInternals, windowCoder); + + return DoFnRunners.defaultStatefulDoFnRunner( + doFn, + getInputCoder(), + wrappedRunner, + stepContext, + windowingStrategy, + cleanupTimer, + stateCleaner, + true /* requiresTimeSortedInput is supported */); + + } else { + return doFnRunner; + } + } + + @Override + public void setup( + StreamTask containingTask, + StreamConfig config, + Output>> output) { + + // make sure that FileSystems is initialized correctly + FileSystems.setDefaultPipelineOptions(serializedOptions.get()); + + super.setup(containingTask, config, output); + } + + protected boolean shoudBundleElements() { + return isStreaming; + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + + ListStateDescriptor> pushedBackStateDescriptor = + new ListStateDescriptor<>( + "pushed-back-elements", + new CoderTypeSerializer<>(windowedInputCoder, serializedOptions)); + + if (keySelector != null) { + pushedBackElementsHandler = + KeyedPushedBackElementsHandler.create( + keySelector, getKeyedStateBackend(), pushedBackStateDescriptor); + } else { + ListState> listState = + getOperatorStateBackend().getListState(pushedBackStateDescriptor); + pushedBackElementsHandler = NonKeyedPushedBackElementsHandler.create(listState); + } + + currentInputWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE.getMillis(); + currentSideInputWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE.getMillis(); + currentOutputWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE.getMillis(); + + sideInputReader = NullSideInputReader.of(sideInputs); + + if (!sideInputs.isEmpty()) { + + FlinkBroadcastStateInternals sideInputStateInternals = + new FlinkBroadcastStateInternals<>( + getContainingTask().getIndexInSubtaskGroup(), + getOperatorStateBackend(), + serializedOptions); + + sideInputHandler = new SideInputHandler(sideInputs, sideInputStateInternals); + sideInputReader = sideInputHandler; + + Stream> pushedBack = pushedBackElementsHandler.getElements(); + long min = + pushedBack.map(v -> v.getTimestamp().getMillis()).reduce(Long.MAX_VALUE, Math::min); + pushedBackWatermark = min; + } else { + pushedBackWatermark = Long.MAX_VALUE; + } + + // StatefulPardo or WindowDoFn + if (keyCoder != null) { + keyedStateInternals = + new FlinkStateInternals<>( + (KeyedStateBackend) getKeyedStateBackend(), + keyCoder, + windowingStrategy.getWindowFn().windowCoder(), + serializedOptions); + + if (timerService == null) { + timerService = + getInternalTimerService( + "beam-timer", new CoderTypeSerializer<>(timerCoder, serializedOptions), this); + } + + timerInternals = new FlinkTimerInternals(timerService); + Preconditions.checkNotNull(getTimeServiceManager(), "Time service manager is not set."); + } + + outputManager = + outputManagerFactory.create( + output, getLockToAcquireForStateAccessDuringBundles(), getOperatorStateBackend()); + } + + /** + * Subclasses may provide a lock to ensure that the state backend is not accessed concurrently + * during bundle execution. + */ + protected Lock getLockToAcquireForStateAccessDuringBundles() { + return NoopLock.get(); + } + + @Override + public void open() throws Exception { + // WindowDoFnOperator need use state and timer to get DoFn. + // So must wait StateInternals and TimerInternals ready. + // This will be called after initializeState() + this.doFn = getDoFn(); + + FlinkPipelineOptions options = serializedOptions.get().as(FlinkPipelineOptions.class); + doFnInvoker = DoFnInvokers.tryInvokeSetupFor(doFn, options); + + StepContext stepContext = new FlinkStepContext(); + doFnRunner = + DoFnRunners.simpleRunner( + options, + doFn, + sideInputReader, + outputManager, + mainOutputTag, + additionalOutputTags, + stepContext, + getInputCoder(), + outputCoders, + windowingStrategy, + doFnSchemaInformation, + sideInputMapping); + + doFnRunner = + createBufferingDoFnRunnerIfNeeded(createWrappingDoFnRunner(doFnRunner, stepContext)); + earlyBindStateIfNeeded(); + + if (!options.getDisableMetrics()) { + flinkMetricContainer = new FlinkMetricContainer(getRuntimeContext()); + doFnRunner = new DoFnRunnerWithMetricsUpdate<>(stepName, doFnRunner, flinkMetricContainer); + String checkpointMetricNamespace = options.getReportCheckpointDuration(); + if (checkpointMetricNamespace != null) { + MetricName checkpointMetric = + MetricName.named(checkpointMetricNamespace, "checkpoint_duration"); + checkpointStats = + new CheckpointStats( + () -> + flinkMetricContainer + .getMetricsContainer(stepName) + .getDistribution(checkpointMetric)); + } + } + + elementCount = 0L; + lastFinishBundleTime = getProcessingTimeService().getCurrentProcessingTime(); + + // Schedule timer to check timeout of finish bundle. + long bundleCheckPeriod = Math.max(maxBundleTimeMills / 2, 1); + checkFinishBundleTimer = + getProcessingTimeService() + .scheduleAtFixedRate( + timestamp -> checkInvokeFinishBundleByTime(), bundleCheckPeriod, bundleCheckPeriod); + + if (doFn instanceof SplittableParDoViaKeyedWorkItems.ProcessFn) { + pushbackDoFnRunner = + new ProcessFnRunner<>((DoFnRunner) doFnRunner, sideInputs, sideInputHandler); + } else { + pushbackDoFnRunner = + SimplePushbackSideInputDoFnRunner.create(doFnRunner, sideInputs, sideInputHandler); + } + + bundleFinalizer = new InMemoryBundleFinalizer(); + pendingFinalizations = new LinkedHashMap<>(); + } + + DoFnRunner createBufferingDoFnRunnerIfNeeded( + DoFnRunner wrappedRunner) throws Exception { + + if (requiresStableInput) { + // put this in front of the root FnRunner before any additional wrappers + return this.bufferingDoFnRunner = + BufferingDoFnRunner.create( + wrappedRunner, + "stable-input-buffer", + windowedInputCoder, + windowingStrategy.getWindowFn().windowCoder(), + getOperatorStateBackend(), + getBufferingKeyedStateBackend(), + numConcurrentCheckpoints, + serializedOptions); + } + return wrappedRunner; + } + + /** + * Retrieve a keyed state backend that should be used to buffer elements for {@link @{code @} + * RequiresStableInput} functionality. By default this is the default keyed backend, but can be + * override in @{link ExecutableStageDoFnOperator}. + * + * @return the keyed backend to use for element buffering + */ + @Nullable KeyedStateBackend getBufferingKeyedStateBackend() { + return getKeyedStateBackend(); + } + + private void earlyBindStateIfNeeded() throws IllegalArgumentException, IllegalAccessException { + if (keyCoder != null) { + if (doFn != null) { + DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); + FlinkStateInternals.EarlyBinder earlyBinder = + new FlinkStateInternals.EarlyBinder( + getKeyedStateBackend(), + serializedOptions, + windowingStrategy.getWindowFn().windowCoder()); + for (DoFnSignature.StateDeclaration value : signature.stateDeclarations().values()) { + StateSpec spec = + (StateSpec) signature.stateDeclarations().get(value.id()).field().get(doFn); + spec.bind(value.id(), earlyBinder); + } + if (doFnRunner instanceof StatefulDoFnRunner) { + ((StatefulDoFnRunner) doFnRunner) + .getSystemStateTags() + .forEach(tag -> tag.getSpec().bind(tag.getId(), earlyBinder)); + } + } + } + } + + void cleanUp() throws Exception { + Optional.ofNullable(flinkMetricContainer) + .ifPresent(FlinkMetricContainer::registerMetricsForPipelineResult); + Optional.ofNullable(checkFinishBundleTimer).ifPresent(timer -> timer.cancel(true)); + Workarounds.deleteStaticCaches(); + Optional.ofNullable(doFnInvoker).ifPresent(DoFnInvoker::invokeTeardown); + } + + void flushData() throws Exception { + // This is our last change to block shutdown of this operator while + // there are still remaining processing-time timers. Flink will ignore pending + // processing-time timers when upstream operators have shut down and will also + // shut down this operator with pending processing-time timers. + if (numProcessingTimeTimers() > 0) { + timerInternals.processPendingProcessingTimeTimers(); + } + if (numProcessingTimeTimers() > 0) { + throw new RuntimeException( + "There are still " + + numProcessingTimeTimers() + + " processing-time timers left, this indicates a bug"); + } + // make sure we send a +Inf watermark downstream. It can happen that we receive +Inf + // in processWatermark*() but have holds, so we have to re-evaluate here. + processWatermark(new Watermark(Long.MAX_VALUE)); + // Make sure to finish the current bundle + while (bundleStarted) { + invokeFinishBundle(); + } + if (requiresStableInput && enableStableInputDrain) { + // Flush any buffered events here before draining the pipeline. Note that this is best-effort + // and requiresStableInput contract might be violated in cases where buffer processing fails. + bufferingDoFnRunner.checkpointCompleted(Long.MAX_VALUE); + updateOutputWatermark(); + } + if (currentOutputWatermark < Long.MAX_VALUE) { + throw new RuntimeException( + String.format( + "There are still watermark holds left when terminating operator %s Watermark held %d", + getOperatorName(), currentOutputWatermark)); + } + + // sanity check: these should have been flushed out by +Inf watermarks + if (!sideInputs.isEmpty()) { + + List> pushedBackElements = + pushedBackElementsHandler.getElements().collect(Collectors.toList()); + + if (pushedBackElements.size() > 0) { + String pushedBackString = Joiner.on(",").join(pushedBackElements); + throw new RuntimeException( + "Leftover pushed-back data: " + pushedBackString + ". This indicates a bug."); + } + } + } + + @Override + public void finish() throws Exception { + try { + flushData(); + } finally { + super.finish(); + } + } + + @Override + public void close() throws Exception { + try { + cleanUp(); + } finally { + super.close(); + } + } + + protected int numProcessingTimeTimers() { + return getTimeServiceManager() + .map( + manager -> { + if (timeServiceManager instanceof InternalTimeServiceManagerImpl) { + final InternalTimeServiceManagerImpl cast = + (InternalTimeServiceManagerImpl) timeServiceManager; + return cast.numProcessingTimeTimers(); + } else if (timeServiceManager instanceof BatchExecutionInternalTimeServiceManager) { + return 0; + } else { + throw new IllegalStateException( + String.format( + "Unknown implementation of InternalTimerServiceManager. %s", + timeServiceManager)); + } + }) + .orElse(0); + } + + public long getEffectiveInputWatermark() { + // hold back by the pushed back values waiting for side inputs + long combinedPushedBackWatermark = pushedBackWatermark; + if (requiresStableInput) { + combinedPushedBackWatermark = + Math.min(combinedPushedBackWatermark, bufferingDoFnRunner.getOutputWatermarkHold()); + } + return Math.min(combinedPushedBackWatermark, currentInputWatermark); + } + + public long getCurrentOutputWatermark() { + return currentOutputWatermark; + } + + protected final void setPreBundleCallback(Runnable callback) { + this.preBundleCallback = callback; + } + + protected final void setBundleFinishedCallback(Runnable callback) { + this.bundleFinishedCallback = callback; + } + + @Override + public final void processElement(StreamRecord> streamRecord) { + for (WindowedValue e : preProcess(streamRecord.getValue())) { + checkInvokeStartBundle(); + LOG.trace("Processing element {} in {}", streamRecord.getValue().getValue(), doFn.getClass()); + long oldHold = keyCoder != null ? keyedStateInternals.minWatermarkHoldMs() : -1L; + doFnRunner.processElement(e); + checkInvokeFinishBundleByCount(); + emitWatermarkIfHoldChanged(oldHold); + } + } + + @Override + public final void processElement1(StreamRecord> streamRecord) + throws Exception { + for (WindowedValue e : preProcess(streamRecord.getValue())) { + checkInvokeStartBundle(); + Iterable> justPushedBack = + pushbackDoFnRunner.processElementInReadyWindows(e); + + long min = pushedBackWatermark; + for (WindowedValue pushedBackValue : justPushedBack) { + min = Math.min(min, pushedBackValue.getTimestamp().getMillis()); + pushedBackElementsHandler.pushBack(pushedBackValue); + } + pushedBackWatermark = min; + + checkInvokeFinishBundleByCount(); + } + } + + /** + * Add the side input value. Here we are assuming that views have already been materialized and + * are sent over the wire as {@link Iterable}. Subclasses may elect to perform materialization in + * state and receive side input incrementally instead. + * + * @param streamRecord + */ + protected void addSideInputValue(StreamRecord streamRecord) { + @SuppressWarnings("unchecked") + WindowedValue> value = + (WindowedValue>) streamRecord.getValue().getValue(); + + PCollectionView sideInput = sideInputTagMapping.get(streamRecord.getValue().getUnionTag()); + sideInputHandler.addSideInputValue(sideInput, value); + } + + @Override + public final void processElement2(StreamRecord streamRecord) throws Exception { + // we finish the bundle because the newly arrived side-input might + // make a view available that was previously not ready. + // The PushbackSideInputRunner will only reset its cache of non-ready windows when + // finishing a bundle. + invokeFinishBundle(); + checkInvokeStartBundle(); + + // add the side input, which may cause pushed back elements become eligible for processing + addSideInputValue(streamRecord); + + List> newPushedBack = new ArrayList<>(); + + Iterator> it = pushedBackElementsHandler.getElements().iterator(); + + while (it.hasNext()) { + WindowedValue element = it.next(); + // we need to set the correct key in case the operator is + // a (keyed) window operator + if (keySelector != null) { + setCurrentKey(keySelector.getKey(element)); + } + + Iterable> justPushedBack = + pushbackDoFnRunner.processElementInReadyWindows(element); + Iterables.addAll(newPushedBack, justPushedBack); + } + + pushedBackElementsHandler.clear(); + long min = Long.MAX_VALUE; + for (WindowedValue pushedBackValue : newPushedBack) { + min = Math.min(min, pushedBackValue.getTimestamp().getMillis()); + pushedBackElementsHandler.pushBack(pushedBackValue); + } + pushedBackWatermark = min; + + checkInvokeFinishBundleByCount(); + + // maybe output a new watermark + processWatermark1(new Watermark(currentInputWatermark)); + } + + @Override + public final void processWatermark(Watermark mark) throws Exception { + LOG.trace("Processing watermark {} in {}", mark.getTimestamp(), doFn.getClass()); + processWatermark1(mark); + } + + @Override + public final void processWatermark1(Watermark mark) throws Exception { + // Flush any data buffered during snapshotState(). + outputManager.flushBuffer(); + + // We do the check here because we are guaranteed to at least get the +Inf watermark on the + // main input when the job finishes. + if (currentSideInputWatermark >= BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis()) { + // this means we will never see any more side input + // we also do the check here because we might have received the side-input MAX watermark + // before receiving any main-input data + emitAllPushedBackData(); + } + + currentInputWatermark = mark.getTimestamp(); + processInputWatermark(true); + } + + private void processInputWatermark(boolean advanceInputWatermark) throws Exception { + long inputWatermarkHold = applyInputWatermarkHold(getEffectiveInputWatermark()); + if (keyCoder != null && advanceInputWatermark) { + timeServiceManager.advanceWatermark(new Watermark(inputWatermarkHold)); + } + + long potentialOutputWatermark = + applyOutputWatermarkHold( + currentOutputWatermark, computeOutputWatermark(inputWatermarkHold)); + + maybeEmitWatermark(potentialOutputWatermark); + } + + /** + * Allows to apply a hold to the input watermark. By default, just passes the input watermark + * through. + */ + public long applyInputWatermarkHold(long inputWatermark) { + return inputWatermark; + } + + /** + * Allows to apply a hold to the output watermark before it is sent out. Used to apply hold on + * output watermark for delayed (asynchronous or buffered) processing. + * + * @param currentOutputWatermark the current output watermark + * @param potentialOutputWatermark The potential new output watermark which can be adjusted, if + * needed. The input watermark hold has already been applied. + * @return The new output watermark which will be emitted. + */ + public long applyOutputWatermarkHold(long currentOutputWatermark, long potentialOutputWatermark) { + return potentialOutputWatermark; + } + + private long computeOutputWatermark(long inputWatermarkHold) { + final long potentialOutputWatermark; + if (keyCoder == null) { + potentialOutputWatermark = inputWatermarkHold; + } else { + potentialOutputWatermark = + Math.min(keyedStateInternals.minWatermarkHoldMs(), inputWatermarkHold); + } + return potentialOutputWatermark; + } + + private void maybeEmitWatermark(long watermark) { + if (watermark > currentOutputWatermark) { + // Must invoke finishBatch before emit the +Inf watermark otherwise there are some late + // events. + if (watermark >= BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis()) { + invokeFinishBundle(); + } + + if (bundleStarted) { + // do not update watermark in the middle of bundle, because it might cause + // user-buffered data to be emitted past watermark + return; + } + + LOG.debug("Emitting watermark {} from {}", watermark, getOperatorName()); + currentOutputWatermark = watermark; + output.emitWatermark(new Watermark(watermark)); + + // Check if the final watermark was triggered to perform state cleanup for global window + // TODO: Do we need to do this when OnWindowExpiration is set, since in that case we have a + // cleanup timer? + if (keyedStateInternals != null + && currentOutputWatermark + > adjustTimestampForFlink(GlobalWindow.INSTANCE.maxTimestamp().getMillis())) { + keyedStateInternals.clearGlobalState(); + } + } + } + + @Override + public final void processWatermark2(Watermark mark) throws Exception { + currentSideInputWatermark = mark.getTimestamp(); + if (mark.getTimestamp() >= BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis()) { + // this means we will never see any more side input + emitAllPushedBackData(); + + // maybe output a new watermark + processWatermark1(new Watermark(currentInputWatermark)); + } + } + + /** + * Emits all pushed-back data. This should be used once we know that there will not be any future + * side input, i.e. that there is no point in waiting. + */ + private void emitAllPushedBackData() throws Exception { + + Iterator> it = pushedBackElementsHandler.getElements().iterator(); + + while (it.hasNext()) { + checkInvokeStartBundle(); + WindowedValue element = it.next(); + // we need to set the correct key in case the operator is + // a (keyed) window operator + setKeyContextElement1(new StreamRecord<>(element)); + + doFnRunner.processElement(element); + } + + pushedBackElementsHandler.clear(); + pushedBackWatermark = Long.MAX_VALUE; + } + + /** + * Check whether invoke startBundle, if it is, need to output elements that were buffered as part + * of finishing a bundle in snapshot() first. + * + *

In order to avoid having {@link DoFnRunner#processElement(WindowedValue)} or {@link + * DoFnRunner#onTimer(String, String, Object, BoundedWindow, Instant, Instant, TimeDomain)} not + * between StartBundle and FinishBundle, this method needs to be called in each processElement and + * each processWatermark and onProcessingTime. Do not need to call in onEventTime, because it has + * been guaranteed in the processWatermark. + */ + private void checkInvokeStartBundle() { + if (!bundleStarted) { + // Flush any data buffered during snapshotState(). + outputManager.flushBuffer(); + LOG.debug("Starting bundle."); + if (preBundleCallback != null) { + preBundleCallback.run(); + } + pushbackDoFnRunner.startBundle(); + bundleStarted = true; + } + } + + /** Check whether invoke finishBundle by elements count. Called in processElement. */ + @SuppressWarnings("NonAtomicVolatileUpdate") + @SuppressFBWarnings("VO_VOLATILE_INCREMENT") + private void checkInvokeFinishBundleByCount() { + if (!shoudBundleElements()) { + return; + } + // We do not access this statement concurrently, but we want to make sure that each thread + // sees the latest value, which is why we use volatile. See the class field section above + // for more information. + //noinspection NonAtomicOperationOnVolatileField + elementCount++; + if (elementCount >= maxBundleSize) { + invokeFinishBundle(); + updateOutputWatermark(); + } + } + + /** Check whether invoke finishBundle by timeout. */ + private void checkInvokeFinishBundleByTime() { + if (!shoudBundleElements()) { + return; + } + long now = getProcessingTimeService().getCurrentProcessingTime(); + if (now - lastFinishBundleTime >= maxBundleTimeMills) { + invokeFinishBundle(); + scheduleForCurrentProcessingTime(ts -> updateOutputWatermark()); + } + } + + @SuppressWarnings("FutureReturnValueIgnored") + protected void scheduleForCurrentProcessingTime(ProcessingTimeCallback callback) { + // We are scheduling a timer for advancing the watermark, to not delay finishing the bundle + // and temporarily release the checkpoint lock. Otherwise, we could potentially loop when a + // timer keeps scheduling a timer for the same timestamp. + ProcessingTimeService timeService = getProcessingTimeService(); + timeService.registerTimer(timeService.getCurrentProcessingTime(), callback); + } + + void updateOutputWatermark() { + try { + processInputWatermark(false); + } catch (Exception ex) { + failBundleFinalization(ex); + } + } + + protected final void invokeFinishBundle() { + long previousBundleFinishTime = lastFinishBundleTime; + if (bundleStarted) { + LOG.debug("Finishing bundle."); + pushbackDoFnRunner.finishBundle(); + LOG.debug("Finished bundle. Element count: {}", elementCount); + elementCount = 0L; + lastFinishBundleTime = getProcessingTimeService().getCurrentProcessingTime(); + bundleStarted = false; + // callback only after current bundle was fully finalized + // it could start a new bundle, for example resulting from timer processing + if (bundleFinishedCallback != null) { + LOG.debug("Invoking bundle finish callback."); + bundleFinishedCallback.run(); + } + } + try { + if (previousBundleFinishTime - getProcessingTimeService().getCurrentProcessingTime() + > maxBundleTimeMills) { + processInputWatermark(false); + } + } catch (Exception ex) { + LOG.warn("Failed to update downstream watermark", ex); + } + } + + @Override + public void prepareSnapshotPreBarrier(long checkpointId) { + if (finishBundleBeforeCheckpointing) { + // We finish the bundle and flush any pending data. + // This avoids buffering any data as part of snapshotState() below. + while (bundleStarted) { + invokeFinishBundle(); + } + updateOutputWatermark(); + } + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + if (checkpointStats != null) { + checkpointStats.snapshotStart(context.getCheckpointId()); + } + + if (requiresStableInput) { + // We notify the BufferingDoFnRunner to associate buffered state with this + // snapshot id and start a new buffer for elements arriving after this snapshot. + bufferingDoFnRunner.checkpoint(context.getCheckpointId()); + } + + int diff = pendingFinalizations.size() - MAX_NUMBER_PENDING_BUNDLE_FINALIZATIONS; + if (diff >= 0) { + for (Iterator iterator = pendingFinalizations.keySet().iterator(); diff >= 0; diff--) { + iterator.next(); + iterator.remove(); + } + } + pendingFinalizations.put(context.getCheckpointId(), bundleFinalizer.getAndClearFinalizations()); + + try { + outputManager.openBuffer(); + // Ensure that no new bundle gets started as part of finishing a bundle + while (bundleStarted) { + invokeFinishBundle(); + } + outputManager.closeBuffer(); + } catch (Exception e) { + failBundleFinalization(e); + } + + super.snapshotState(context); + } + + private void failBundleFinalization(Exception e) { + // https://jira.apache.org/jira/browse/FLINK-14653 + // Any regular exception during checkpointing will be tolerated by Flink because those + // typically do not affect the execution flow. We need to fail hard here because errors + // in bundle execution are application errors which are not related to checkpointing. + throw new Error("Checkpointing failed because bundle failed to finalize.", e); + } + + public BundleFinalizer getBundleFinalizer() { + return bundleFinalizer; + } + + @Override + public void notifyCheckpointComplete(long checkpointId) throws Exception { + if (checkpointStats != null) { + checkpointStats.reportCheckpointDuration(checkpointId); + } + + if (requiresStableInput) { + // We can now release all buffered data which was held back for + // @RequiresStableInput guarantees. + bufferingDoFnRunner.checkpointCompleted(checkpointId); + updateOutputWatermark(); + } + + List finalizations = + pendingFinalizations.remove(checkpointId); + if (finalizations != null) { + // confirm all finalizations that were associated with the checkpoint + for (InMemoryBundleFinalizer.Finalization finalization : finalizations) { + finalization.getCallback().onBundleSuccess(); + } + } + + super.notifyCheckpointComplete(checkpointId); + } + + @Override + public void onEventTime(InternalTimer timer) { + checkInvokeStartBundle(); + fireTimerInternal(timer.getKey(), timer.getNamespace()); + } + + @Override + public void onProcessingTime(InternalTimer timer) { + checkInvokeStartBundle(); + fireTimerInternal(timer.getKey(), timer.getNamespace()); + } + + // allow overriding this in ExecutableStageDoFnOperator to set the key context + protected void fireTimerInternal(FlinkKey key, TimerData timerData) { + long oldHold = keyCoder != null ? keyedStateInternals.minWatermarkHoldMs() : -1L; + fireTimer(timerData); + emitWatermarkIfHoldChanged(oldHold); + } + + void emitWatermarkIfHoldChanged(long currentWatermarkHold) { + if (keyCoder != null) { + long newWatermarkHold = keyedStateInternals.minWatermarkHoldMs(); + if (newWatermarkHold > currentWatermarkHold) { + try { + processInputWatermark(false); + } catch (Exception ex) { + // should not happen + throw new IllegalStateException(ex); + } + } + } + } + + // allow overriding this in WindowDoFnOperator + protected void fireTimer(TimerData timerData) { + LOG.debug( + "Firing timer: {} at {} with output time {}", + timerData.getTimerId(), + timerData.getTimestamp().getMillis(), + timerData.getOutputTimestamp().getMillis()); + StateNamespace namespace = timerData.getNamespace(); + // This is a user timer, so namespace must be WindowNamespace + checkArgument(namespace instanceof WindowNamespace); + BoundedWindow window = ((WindowNamespace) namespace).getWindow(); + timerInternals.onFiredOrDeletedTimer(timerData); + + pushbackDoFnRunner.onTimer( + timerData.getTimerId(), + timerData.getTimerFamilyId(), + keyedStateInternals.getKey(), + window, + timerData.getTimestamp(), + timerData.getOutputTimestamp(), + timerData.getDomain()); + } + + @SuppressWarnings("unchecked") + Coder getInputCoder() { + return (Coder) Iterables.getOnlyElement(windowedInputCoder.getCoderArguments()); + } + + /** Factory for creating an {@link BufferedOutputManager} from a Flink {@link Output}. */ + interface OutputManagerFactory extends Serializable { + BufferedOutputManager create( + Output>> output, + Lock bufferLock, + OperatorStateBackend operatorStateBackend) + throws Exception; + } + + /** + * A {@link WindowedValueReceiver} that can buffer its outputs. Uses {@link + * PushedBackElementsHandler} to buffer the data. Buffering data is necessary because no elements + * can be emitted during {@code snapshotState} which is called when the checkpoint barrier already + * has been sent downstream. Emitting elements would break the flow of checkpoint barrier and + * violate exactly-once semantics. + * + *

This buffering can be deactived using {@code + * FlinkPipelineOptions#setFinishBundleBeforeCheckpointing(true)}. If activated, we flush out + * bundle data before the barrier is sent downstream. This is done via {@code + * prepareSnapshotPreBarrier}. When Flink supports unaligned checkpoints, this should become the + * default and this class should be removed as in https://github.com/apache/beam/pull/9652. + */ + public static class BufferedOutputManager implements WindowedValueMultiReceiver { + + private final TupleTag mainTag; + private final Map, OutputTag>> tagsToOutputTags; + private final Map, Integer> tagsToIds; + /** + * A lock to be acquired before writing to the buffer. This lock will only be acquired during + * buffering. It will not be acquired during flushing the buffer. + */ + private final Lock bufferLock; + + private final boolean isStreaming; + + private Map> idsToTags; + /** Elements buffered during a snapshot, by output id. */ + @VisibleForTesting + final PushedBackElementsHandler>> pushedBackElementsHandler; + + protected final Output>> output; + + /** Indicates whether we are buffering data as part of snapshotState(). */ + private boolean openBuffer = false; + /** For performance, to avoid having to access the state backend when the buffer is empty. */ + private boolean bufferIsEmpty = false; + + BufferedOutputManager( + Output>> output, + TupleTag mainTag, + Map, OutputTag>> tagsToOutputTags, + Map, Integer> tagsToIds, + Lock bufferLock, + PushedBackElementsHandler>> pushedBackElementsHandler, + boolean isStreaming) { + this.output = output; + this.mainTag = mainTag; + this.tagsToOutputTags = tagsToOutputTags; + this.tagsToIds = tagsToIds; + this.bufferLock = bufferLock; + this.idsToTags = new HashMap<>(); + for (Map.Entry, Integer> entry : tagsToIds.entrySet()) { + idsToTags.put(entry.getValue(), entry.getKey()); + } + this.pushedBackElementsHandler = pushedBackElementsHandler; + this.isStreaming = isStreaming; + } + + void openBuffer() { + this.openBuffer = true; + } + + void closeBuffer() { + this.openBuffer = false; + } + + @Override + public void output(TupleTag tag, WindowedValue value) { + // Don't buffer elements in Batch mode + if (!openBuffer || !isStreaming) { + emit(tag, value); + } else { + buffer(KV.of(tagsToIds.get(tag), value)); + } + } + + private void buffer(KV> taggedValue) { + bufferLock.lock(); + try { + pushedBackElementsHandler.pushBack(taggedValue); + } catch (Exception e) { + throw new RuntimeException("Couldn't pushback element.", e); + } finally { + bufferLock.unlock(); + bufferIsEmpty = false; + } + } + + /** + * Flush elements of bufferState to Flink Output. This method should not be invoked in {@link + * #snapshotState(StateSnapshotContext)} because the checkpoint barrier has already been sent + * downstream; emitting elements at this point would violate the checkpoint barrier alignment. + * + *

The buffer should be flushed before starting a new bundle when the buffer cannot be + * concurrently accessed and thus does not need to be guarded by a lock. + */ + void flushBuffer() { + if (openBuffer || bufferIsEmpty) { + // Checkpoint currently in progress or nothing buffered, do not proceed + return; + } + try { + pushedBackElementsHandler + .getElements() + .forEach( + element -> + emit(idsToTags.get(element.getKey()), (WindowedValue) element.getValue())); + pushedBackElementsHandler.clear(); + bufferIsEmpty = true; + } catch (Exception e) { + throw new RuntimeException("Couldn't flush pushed back elements.", e); + } + } + + private void emit(TupleTag tag, WindowedValue value) { + if (tag.equals(mainTag)) { + // with tagged outputs we can't get around this because we don't + // know our own output type... + @SuppressWarnings("unchecked") + WindowedValue castValue = (WindowedValue) value; + output.collect(new StreamRecord<>(castValue)); + } else { + @SuppressWarnings("unchecked") + OutputTag> outputTag = (OutputTag) tagsToOutputTags.get(tag); + output.collect(outputTag, new StreamRecord<>(value)); + } + } + } + + /** Coder for KV of id and value. It will be serialized in Flink checkpoint. */ + private static class TaggedKvCoder extends StructuredCoder>> { + + private final Map>> idsToCoders; + + TaggedKvCoder(Map>> idsToCoders) { + this.idsToCoders = idsToCoders; + } + + @Override + public void encode(KV> kv, OutputStream out) throws IOException { + Coder> coder = idsToCoders.get(kv.getKey()); + VarIntCoder.of().encode(kv.getKey(), out); + coder.encode(kv.getValue(), out); + } + + @Override + public KV> decode(InputStream in) throws IOException { + Integer id = VarIntCoder.of().decode(in); + Coder> coder = idsToCoders.get(id); + WindowedValue value = coder.decode(in); + return KV.of(id, value); + } + + @Override + public List> getCoderArguments() { + return new ArrayList<>(idsToCoders.values()); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + for (Coder coder : idsToCoders.values()) { + verifyDeterministic(this, "Coder must be deterministic", coder); + } + } + } + + /** + * Implementation of {@link OutputManagerFactory} that creates an {@link BufferedOutputManager} + * that can write to multiple logical outputs by Flink side output. + */ + public static class MultiOutputOutputManagerFactory + implements OutputManagerFactory { + + private final TupleTag mainTag; + private final Map, Integer> tagsToIds; + private final Map, OutputTag>> tagsToOutputTags; + private final Map, Coder>> tagsToCoders; + private final SerializablePipelineOptions pipelineOptions; + private final boolean isStreaming; + + // There is no side output. + @SuppressWarnings("unchecked") + public MultiOutputOutputManagerFactory( + TupleTag mainTag, + Coder> mainCoder, + SerializablePipelineOptions pipelineOptions) { + this( + mainTag, + new HashMap<>(), + ImmutableMap., Coder>>builder() + .put(mainTag, (Coder) mainCoder) + .build(), + ImmutableMap., Integer>builder().put(mainTag, 0).build(), + pipelineOptions); + } + + public MultiOutputOutputManagerFactory( + TupleTag mainTag, + Map, OutputTag>> tagsToOutputTags, + Map, Coder>> tagsToCoders, + Map, Integer> tagsToIds, + SerializablePipelineOptions pipelineOptions) { + this.mainTag = mainTag; + this.tagsToOutputTags = tagsToOutputTags; + this.tagsToCoders = tagsToCoders; + this.tagsToIds = tagsToIds; + this.pipelineOptions = pipelineOptions; + this.isStreaming = pipelineOptions.get().as(FlinkPipelineOptions.class).isStreaming(); + } + + @Override + public BufferedOutputManager create( + Output>> output, + Lock bufferLock, + OperatorStateBackend operatorStateBackend) + throws Exception { + Preconditions.checkNotNull(output); + Preconditions.checkNotNull(bufferLock); + Preconditions.checkNotNull(operatorStateBackend); + + TaggedKvCoder taggedKvCoder = buildTaggedKvCoder(); + ListStateDescriptor>> taggedOutputPushbackStateDescriptor = + new ListStateDescriptor<>( + "bundle-buffer-tag", new CoderTypeSerializer<>(taggedKvCoder, pipelineOptions)); + ListState>> listStateBuffer = + operatorStateBackend.getListState(taggedOutputPushbackStateDescriptor); + PushedBackElementsHandler>> pushedBackElementsHandler = + NonKeyedPushedBackElementsHandler.create(listStateBuffer); + + return new BufferedOutputManager<>( + output, + mainTag, + tagsToOutputTags, + tagsToIds, + bufferLock, + pushedBackElementsHandler, + isStreaming); + } + + private TaggedKvCoder buildTaggedKvCoder() { + ImmutableMap.Builder>> idsToCodersBuilder = + ImmutableMap.builder(); + for (Map.Entry, Integer> entry : tagsToIds.entrySet()) { + idsToCodersBuilder.put(entry.getValue(), tagsToCoders.get(entry.getKey())); + } + return new TaggedKvCoder(idsToCodersBuilder.build()); + } + } + + /** + * {@link StepContext} for running {@link DoFn DoFns} on Flink. This does not allow accessing + * state or timer internals. + */ + protected class FlinkStepContext implements StepContext { + + @Override + public StateInternals stateInternals() { + return keyedStateInternals; + } + + @Override + public TimerInternals timerInternals() { + return timerInternals; + } + + @Override + public BundleFinalizer bundleFinalizer() { + return bundleFinalizer; + } + } + + class FlinkTimerInternals implements TimerInternals { + + private static final String PENDING_TIMERS_STATE_NAME = "pending-timers"; + + /** + * Pending Timers (=not been fired yet) by context id. The id is generated from the state + * namespace of the timer and the timer's id. Necessary for supporting removal of existing + * timers. In Flink removal of timers can only be done by providing id and time of the timer. + * + *

CAUTION: This map is scoped by the current active key. Do not attempt to perform any + * calculations which span across keys. + */ + @VisibleForTesting final MapState pendingTimersById; + + private final InternalTimerService timerService; + + private FlinkTimerInternals(InternalTimerService timerService) throws Exception { + MapStateDescriptor pendingTimersByIdStateDescriptor = + new MapStateDescriptor<>( + PENDING_TIMERS_STATE_NAME, + new StringSerializer(), + new CoderTypeSerializer<>(timerCoder, serializedOptions)); + + this.pendingTimersById = getKeyedStateStore().getMapState(pendingTimersByIdStateDescriptor); + this.timerService = timerService; + populateOutputTimestampQueue(timerService); + } + + /** + * Processes all pending processing timers. This is intended for use during shutdown. From Flink + * 1.10 on, processing timer execution is stopped when the operator is closed. This leads to + * problems for applications which assume all pending timers will be completed. Although Flink + * does drain the remaining timers after close(), this is not sufficient because no new timers + * are allowed to be scheduled anymore. This breaks Beam pipelines which rely on all processing + * timers to be scheduled and executed. + */ + void processPendingProcessingTimeTimers() { + final KeyedStateBackend keyedStateBackend = getKeyedStateBackend(); + final InternalPriorityQueue> processingTimeTimersQueue = + Workarounds.retrieveInternalProcessingTimerQueue(timerService); + + InternalTimer internalTimer; + while ((internalTimer = processingTimeTimersQueue.poll()) != null) { + keyedStateBackend.setCurrentKey(internalTimer.getKey()); + TimerData timer = internalTimer.getNamespace(); + checkInvokeStartBundle(); + fireTimerInternal((FlinkKey) internalTimer.getKey(), timer); + } + } + + private void populateOutputTimestampQueue(InternalTimerService timerService) + throws Exception { + + BiConsumerWithException consumer = + (timerData, stamp) -> + keyedStateInternals.addWatermarkHoldUsage(timerData.getOutputTimestamp()); + if (timerService instanceof InternalTimerServiceImpl) { + timerService.forEachEventTimeTimer(consumer); + timerService.forEachProcessingTimeTimer(consumer); + } + } + + private String constructTimerId(String timerFamilyId, String timerId) { + return timerFamilyId + "+" + timerId; + } + + @Override + public void setTimer( + StateNamespace namespace, + String timerId, + String timerFamilyId, + Instant target, + Instant outputTimestamp, + TimeDomain timeDomain) { + setTimer( + TimerData.of(timerId, timerFamilyId, namespace, target, outputTimestamp, timeDomain)); + } + + /** + * @deprecated use {@link #setTimer(StateNamespace, String, String, Instant, Instant, + * TimeDomain)}. + */ + @Deprecated + @Override + public void setTimer(TimerData timer) { + try { + LOG.debug( + "Setting timer: {} at {} with output time {}", + timer.getTimerId(), + timer.getTimestamp().getMillis(), + timer.getOutputTimestamp().getMillis()); + String contextTimerId = + getContextTimerId( + constructTimerId(timer.getTimerFamilyId(), timer.getTimerId()), + timer.getNamespace()); + @Nullable final TimerData oldTimer = pendingTimersById.get(contextTimerId); + if (!timer.equals(oldTimer)) { + // Only one timer can exist at a time for a given timer id and context. + // If a timer gets set twice in the same context, the second must + // override the first. Thus, we must cancel any pending timers + // before we set the new one. + cancelPendingTimer(oldTimer); + registerTimer(timer, contextTimerId); + } + } catch (Exception e) { + throw new RuntimeException("Failed to set timer", e); + } + } + + private void registerTimer(TimerData timer, String contextTimerId) throws Exception { + LOG.debug("Registering timer {}", timer); + pendingTimersById.put(contextTimerId, timer); + long time = timer.getTimestamp().getMillis(); + switch (timer.getDomain()) { + case EVENT_TIME: + timerService.registerEventTimeTimer(timer, adjustTimestampForFlink(time)); + break; + case PROCESSING_TIME: + case SYNCHRONIZED_PROCESSING_TIME: + timerService.registerProcessingTimeTimer(timer, adjustTimestampForFlink(time)); + break; + default: + throw new UnsupportedOperationException("Unsupported time domain: " + timer.getDomain()); + } + keyedStateInternals.addWatermarkHoldUsage(timer.getOutputTimestamp()); + } + + /** + * Looks up a timer by its id. This is necessary to support canceling existing timers with the + * same id. Flink does not provide this functionality. + * + * @param contextTimerId Timer ID o cancel. + */ + private void cancelPendingTimerById(String contextTimerId) throws Exception { + cancelPendingTimer(pendingTimersById.get(contextTimerId)); + } + + /** + * Cancels a pending timer. + * + * @param timer Timer to cancel. + */ + private void cancelPendingTimer(@Nullable TimerData timer) { + if (timer != null) { + deleteTimerInternal(timer); + } + } + + /** + * Hook which must be called when a timer is fired or deleted to perform cleanup. Note: Make + * sure that the state backend key is set correctly. It is best to run this in the fireTimer() + * method. + */ + void onFiredOrDeletedTimer(TimerData timer) { + try { + pendingTimersById.remove( + getContextTimerId( + constructTimerId(timer.getTimerFamilyId(), timer.getTimerId()), + timer.getNamespace())); + keyedStateInternals.removeWatermarkHoldUsage(timer.getOutputTimestamp()); + } catch (Exception e) { + throw new RuntimeException("Failed to cleanup pending timers state.", e); + } + } + + /** @deprecated use {@link #deleteTimer(StateNamespace, String, TimeDomain)}. */ + @Deprecated + @Override + public void deleteTimer(StateNamespace namespace, String timerId, String timerFamilyId) { + throw new UnsupportedOperationException("Canceling of a timer by ID is not yet supported."); + } + + @Override + public void deleteTimer( + StateNamespace namespace, String timerId, String timerFamilyId, TimeDomain timeDomain) { + try { + cancelPendingTimerById(getContextTimerId(timerId, namespace)); + } catch (Exception e) { + throw new RuntimeException("Failed to cancel timer", e); + } + } + + /** @deprecated use {@link #deleteTimer(StateNamespace, String, TimeDomain)}. */ + @Override + @Deprecated + public void deleteTimer(TimerData timer) { + deleteTimer( + timer.getNamespace(), + constructTimerId(timer.getTimerFamilyId(), timer.getTimerId()), + timer.getTimerFamilyId(), + timer.getDomain()); + } + + void deleteTimerInternal(TimerData timer) { + long time = timer.getTimestamp().getMillis(); + switch (timer.getDomain()) { + case EVENT_TIME: + timerService.deleteEventTimeTimer(timer, adjustTimestampForFlink(time)); + break; + case PROCESSING_TIME: + case SYNCHRONIZED_PROCESSING_TIME: + timerService.deleteProcessingTimeTimer(timer, adjustTimestampForFlink(time)); + break; + default: + throw new UnsupportedOperationException("Unsupported time domain: " + timer.getDomain()); + } + onFiredOrDeletedTimer(timer); + } + + @Override + public Instant currentProcessingTime() { + return new Instant(timerService.currentProcessingTime()); + } + + @Override + public @Nullable Instant currentSynchronizedProcessingTime() { + return new Instant(timerService.currentProcessingTime()); + } + + @Override + public Instant currentInputWatermarkTime() { + if (timerService instanceof BatchExecutionInternalTimeService) { + // In batch mode, this method will only either return BoundedWindow.TIMESTAMP_MIN_VALUE, + // or BoundedWindow.TIMESTAMP_MAX_VALUE. + // + // For batch execution mode, the currentInputWatermark variable will never be updated + // until all the records are processed. However, every time when a record with a new + // key arrives, the Flink timer service watermark will be set to + // MAX_WATERMARK(LONG.MAX_VALUE) so that all the timers associated with the current + // key can fire. After that the Flink timer service watermark will be reset to + // LONG.MIN_VALUE, so the next key will start from a fresh env as if the previous + // records of a different key never existed. So the watermark is either Long.MIN_VALUE + // or long MAX_VALUE. So we should just use the Flink time service watermark in batch mode. + // + // In Flink the watermark ranges from + // [LONG.MIN_VALUE (-9223372036854775808), LONG.MAX_VALUE (9223372036854775807)] while the + // beam + // watermark range is [BoundedWindow.TIMESTAMP_MIN_VALUE (-9223372036854775), + // BoundedWindow.TIMESTAMP_MAX_VALUE (9223372036854775)]. To ensure the timestamp visible to + // the users follow the Beam convention, we just use the Beam range instead. + return timerService.currentWatermark() == Long.MAX_VALUE + ? new Instant(Long.MAX_VALUE) + : BoundedWindow.TIMESTAMP_MIN_VALUE; + } else { + return new Instant(getEffectiveInputWatermark()); + } + } + + @Override + public @Nullable Instant currentOutputWatermarkTime() { + return new Instant(currentOutputWatermark); + } + + /** + * Check whether event time timers lower or equal to the given timestamp exist. Caution: This is + * scoped by the current key. + */ + public boolean hasPendingEventTimeTimers(long maxTimestamp) throws Exception { + for (TimerData timer : pendingTimersById.values()) { + if (timer.getDomain() == TimeDomain.EVENT_TIME + && timer.getTimestamp().getMillis() <= maxTimestamp) { + return true; + } + } + return false; + } + + /** Unique contextual id of a timer. Used to look up any existing timers in a context. */ + private String getContextTimerId(String timerId, StateNamespace namespace) { + return timerId + namespace.stringKey(); + } + } + + /** + * In Beam, a timer with timestamp {@code T} is only illegible for firing when the time has moved + * past this time stamp, i.e. {@code T < current_time}. In the case of event time, current_time is + * the watermark, in the case of processing time it is the system time. + * + *

Flink's TimerService has different semantics because it only ensures {@code T <= + * current_time}. + * + *

To make up for this, we need to add one millisecond to Flink's internal timer timestamp. + * Note that we do not modify Beam's timestamp and we are not exposing Flink's timestamp. + * + *

See also https://jira.apache.org/jira/browse/BEAM-3863 + */ + static long adjustTimestampForFlink(long beamTimerTimestamp) { + if (beamTimerTimestamp == Long.MAX_VALUE) { + // We would overflow, do not adjust timestamp + return Long.MAX_VALUE; + } + return beamTimerTimestamp + 1; + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/StreamingImpulseSource.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/StreamingImpulseSource.java new file mode 100644 index 000000000000..63c4cfb6b034 --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/StreamingImpulseSource.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.io; + +import java.nio.charset.StandardCharsets; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.flink.streaming.api.functions.source.legacy.RichParallelSourceFunction; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +// TODO(https://github.com/apache/beam/issues/37114) migrate off RichParallelSourceFunction +/** + * A streaming source that periodically produces a byte array. This is mostly useful for debugging, + * or for triggering periodic behavior in a portable pipeline. + * + * @deprecated Legacy non-portable source which can be replaced by a DoFn with timers. + * https://jira.apache.org/jira/browse/BEAM-8353 + */ +@Deprecated +public class StreamingImpulseSource extends RichParallelSourceFunction> { + private static final Logger LOG = LoggerFactory.getLogger(StreamingImpulseSource.class); + + private final int intervalMillis; + private final int messageCount; + + private volatile boolean running = true; + private long count; + + public StreamingImpulseSource(int intervalMillis, int messageCount) { + this.intervalMillis = intervalMillis; + this.messageCount = messageCount; + } + + @Override + public void run(SourceContext> ctx) { + // in order to produce messageCount messages across all parallel subtasks, we divide by + // the total number of subtasks + int subtaskCount = + messageCount / getRuntimeContext().getTaskInfo().getNumberOfParallelSubtasks(); + // if the message count is not evenly divisible by the number of subtasks, add an estra + // message to the first (messageCount % subtasksCount) subtasks + if (getRuntimeContext().getTaskInfo().getIndexOfThisSubtask() + < (messageCount % getRuntimeContext().getTaskInfo().getNumberOfParallelSubtasks())) { + subtaskCount++; + } + + while (running && (messageCount == 0 || count < subtaskCount)) { + synchronized (ctx.getCheckpointLock()) { + ctx.collect( + WindowedValues.valueInGlobalWindow( + String.valueOf(count).getBytes(StandardCharsets.UTF_8))); + count++; + } + + try { + if (intervalMillis > 0) { + Thread.sleep(intervalMillis); + } + } catch (InterruptedException e) { + LOG.warn("Interrupted while sleeping", e); + } + } + } + + @Override + public void cancel() { + this.running = false; + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/TestStreamSource.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/TestStreamSource.java new file mode 100644 index 000000000000..6f6b2d7bc3ed --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/TestStreamSource.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.io; + +import java.util.List; +import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.testing.TestStream; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.values.TimestampedValue; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.flink.streaming.api.functions.source.legacy.RichSourceFunction; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.joda.time.Instant; + +/** Flink source for executing {@link org.apache.beam.sdk.testing.TestStream}. */ +public class TestStreamSource extends RichSourceFunction> { + + private final SerializableFunction> testStreamDecoder; + private final byte[] payload; + + private volatile boolean isRunning = true; + + public TestStreamSource( + SerializableFunction> testStreamDecoder, byte[] payload) { + this.testStreamDecoder = testStreamDecoder; + this.payload = payload; + } + + @Override + public void run(SourceContext> ctx) throws CoderException { + TestStream testStream = testStreamDecoder.apply(payload); + List> events = testStream.getEvents(); + + for (int eventId = 0; isRunning && eventId < events.size(); eventId++) { + TestStream.Event event = events.get(eventId); + + synchronized (ctx.getCheckpointLock()) { + if (event instanceof TestStream.ElementEvent) { + for (TimestampedValue element : ((TestStream.ElementEvent) event).getElements()) { + Instant timestamp = element.getTimestamp(); + WindowedValue value = + WindowedValues.of( + element.getValue(), timestamp, GlobalWindow.INSTANCE, PaneInfo.NO_FIRING); + ctx.collectWithTimestamp(value, timestamp.getMillis()); + } + } else if (event instanceof TestStream.WatermarkEvent) { + long millis = ((TestStream.WatermarkEvent) event).getWatermark().getMillis(); + ctx.emitWatermark(new Watermark(millis)); + } else if (event instanceof TestStream.ProcessingTimeEvent) { + // There seems to be no clean way to implement this + throw new UnsupportedOperationException( + "Advancing Processing time is not supported by the Flink Runner."); + } else { + throw new IllegalStateException("Unknown event type " + event); + } + } + } + } + + @Override + public void cancel() { + this.isRunning = false; + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java new file mode 100644 index 000000000000..25cf9879766f --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java @@ -0,0 +1,556 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.io; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.flink.FlinkPipelineOptions; +import org.apache.beam.runners.flink.metrics.FlinkMetricContainer; +import org.apache.beam.runners.flink.metrics.ReaderInvocationUtil; +import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; +import org.apache.beam.runners.flink.translation.utils.Workarounds; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.SerializableCoder; +import org.apache.beam.sdk.io.FileSystems; +import org.apache.beam.sdk.io.UnboundedSource; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.util.construction.UnboundedReadFromBoundedSource; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.ValueWithRecordId; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.flink.api.common.functions.OpenContext; +import org.apache.flink.api.common.operators.ProcessingTimeService.ProcessingTimeCallback; +import org.apache.flink.api.common.serialization.SerializerConfigImpl; +import org.apache.flink.api.common.state.CheckpointListener; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.OperatorStateStore; +import org.apache.flink.runtime.state.DefaultOperatorStateBackend; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; +import org.apache.flink.streaming.api.functions.source.legacy.RichParallelSourceFunction; +import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +// TODO(https://github.com/apache/beam/issues/37114) migrate off RichParallelSourceFunction +/** Wrapper for executing {@link UnboundedSource UnboundedSources} as a Flink Source. */ +@SuppressWarnings({ + "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +public class UnboundedSourceWrapper + extends RichParallelSourceFunction>> + implements BeamStoppableFunction, + CheckpointListener, + CheckpointedFunction, + ProcessingTimeCallback { + + private static final Logger LOG = LoggerFactory.getLogger(UnboundedSourceWrapper.class); + + private final String stepName; + /** Keep the options so that we can initialize the localReaders. */ + private final SerializablePipelineOptions serializedOptions; + + /** + * We are processing bounded data and should read from the sources sequentially instead of reading + * round-robin from all the sources. In case of file sources this avoids having too many open + * files/connections at once. + */ + private final boolean isConvertedBoundedSource; + + /** For snapshot and restore. */ + private final KvCoder, CheckpointMarkT> + checkpointCoder; + + /** + * The split sources. We split them in the constructor to ensure that all parallel sources are + * consistent about the split sources. + */ + private final List> splitSources; + + /** The idle time before we the source shuts down. */ + private final long idleTimeoutMs; + + /** The local split sources. Assigned at runtime when the wrapper is executed in parallel. */ + private transient List> localSplitSources; + + /** + * The local split readers. Assigned at runtime when the wrapper is executed in parallel. Make it + * a field so that we can access it in {@link #onProcessingTime(long)} for emitting watermarks. + */ + private transient List> localReaders; + + /** + * Flag to indicate whether the source is running. Initialize here and not in run() to prevent + * races where we cancel a job before run() is ever called or run() is called after cancel(). + */ + private volatile boolean isRunning = true; + + /** + * Make it a field so that we can access it in {@link #onProcessingTime(long)} for registering new + * triggers. + */ + private transient StreamingRuntimeContext runtimeContext; + + /** + * Make it a field so that we can access it in {@link #onProcessingTime(long)} for emitting + * watermarks. + */ + private transient SourceContext>> context; + + /** Pending checkpoints which have not been acknowledged yet. */ + private transient LinkedHashMap> pendingCheckpoints; + /** Keep a maximum of 32 checkpoints for {@code CheckpointMark.finalizeCheckpoint()}. */ + private static final int MAX_NUMBER_PENDING_CHECKPOINTS = 32; + + private transient ListState< + KV, CheckpointMarkT>> + stateForCheckpoint; + + /** false if checkpointCoder is null or no restore state by starting first. */ + private transient boolean isRestored = false; + + /** Flag to indicate whether all readers have reached the maximum watermark. */ + private transient boolean maxWatermarkReached; + + /** Metrics container which will be reported as Flink accumulators at the end of the job. */ + private transient FlinkMetricContainer metricContainer; + + @SuppressWarnings("unchecked") + public UnboundedSourceWrapper( + String stepName, + PipelineOptions pipelineOptions, + UnboundedSource source, + int parallelism) + throws Exception { + this.stepName = stepName; + this.serializedOptions = new SerializablePipelineOptions(pipelineOptions); + this.isConvertedBoundedSource = + source instanceof UnboundedReadFromBoundedSource.BoundedToUnboundedSourceAdapter; + + if (source.requiresDeduping()) { + LOG.warn("Source {} requires deduping but Flink runner doesn't support this yet.", source); + } + + Coder checkpointMarkCoder = source.getCheckpointMarkCoder(); + if (checkpointMarkCoder == null) { + LOG.info("No CheckpointMarkCoder specified for this source. Won't create snapshots."); + checkpointCoder = null; + } else { + + Coder> sourceCoder = + (Coder) SerializableCoder.of(new TypeDescriptor() {}); + + checkpointCoder = KvCoder.of(sourceCoder, checkpointMarkCoder); + } + + // get the splits early. we assume that the generated splits are stable, + // this is necessary so that the mapping of state to source is correct + // when restoring + splitSources = source.split(parallelism, pipelineOptions); + + FlinkPipelineOptions options = pipelineOptions.as(FlinkPipelineOptions.class); + idleTimeoutMs = options.getShutdownSourcesAfterIdleMs(); + } + + /** Initialize and restore state before starting execution of the source. */ + @Override + public void open(OpenContext openContext) throws Exception { + FileSystems.setDefaultPipelineOptions(serializedOptions.get()); + runtimeContext = (StreamingRuntimeContext) getRuntimeContext(); + metricContainer = new FlinkMetricContainer(runtimeContext); + + // figure out which split sources we're responsible for + int subtaskIndex = runtimeContext.getTaskInfo().getIndexOfThisSubtask(); + int numSubtasks = runtimeContext.getTaskInfo().getNumberOfParallelSubtasks(); + + localSplitSources = new ArrayList<>(); + localReaders = new ArrayList<>(); + + pendingCheckpoints = new LinkedHashMap<>(); + + if (isRestored) { + // restore the splitSources from the checkpoint to ensure consistent ordering + for (KV, CheckpointMarkT> restored : + stateForCheckpoint.get()) { + localSplitSources.add(restored.getKey()); + localReaders.add( + restored.getKey().createReader(serializedOptions.get(), restored.getValue())); + } + } else { + // initialize localReaders and localSources from scratch + for (int i = 0; i < splitSources.size(); i++) { + if (i % numSubtasks == subtaskIndex) { + UnboundedSource source = splitSources.get(i); + UnboundedSource.UnboundedReader reader = + source.createReader(serializedOptions.get(), null); + localSplitSources.add(source); + localReaders.add(reader); + } + } + } + + LOG.info( + "Unbounded Flink Source {}/{} is reading from sources: {}", + subtaskIndex + 1, + numSubtasks, + localSplitSources); + } + + @Override + public void run(SourceContext>> ctx) throws Exception { + + context = ctx; + + ReaderInvocationUtil> readerInvoker = + new ReaderInvocationUtil<>(stepName, serializedOptions.get(), metricContainer); + + setNextWatermarkTimer(this.runtimeContext); + + if (localReaders.isEmpty()) { + // It can happen when value of parallelism is greater than number of IO readers (for example, + // parallelism is 2 and number of Kafka topic partitions is 1). In this case, we just fall + // through to idle this executor. + LOG.info("Number of readers is 0 for this task executor, idle"); + // Do nothing here but still execute the rest of the source logic + } else if (isConvertedBoundedSource) { + + // We read sequentially from all bounded sources + for (int i = 0; i < localReaders.size() && isRunning; i++) { + UnboundedSource.UnboundedReader reader = localReaders.get(i); + + synchronized (ctx.getCheckpointLock()) { + boolean dataAvailable = readerInvoker.invokeStart(reader); + if (dataAvailable) { + emitElement(ctx, reader); + } + } + + boolean dataAvailable; + do { + synchronized (ctx.getCheckpointLock()) { + dataAvailable = readerInvoker.invokeAdvance(reader); + + if (dataAvailable) { + emitElement(ctx, reader); + } + } + } while (dataAvailable && isRunning); + } + } else { + // Read from multiple unbounded sources, + // loop through them and sleep if none of them had any data + + int numReaders = localReaders.size(); + int currentReader = 0; + + // start each reader and emit data if immediately available + for (UnboundedSource.UnboundedReader reader : localReaders) { + synchronized (ctx.getCheckpointLock()) { + boolean dataAvailable = readerInvoker.invokeStart(reader); + if (dataAvailable) { + emitElement(ctx, reader); + } + } + } + + // a flag telling us whether any of the localReaders had data + // if no reader had data, sleep for bit + boolean hadData = false; + while (isRunning && !maxWatermarkReached) { + UnboundedSource.UnboundedReader reader = localReaders.get(currentReader); + + synchronized (ctx.getCheckpointLock()) { + if (readerInvoker.invokeAdvance(reader)) { + emitElement(ctx, reader); + hadData = true; + } + } + + currentReader = (currentReader + 1) % numReaders; + if (currentReader == 0 && !hadData) { + // We have visited all the readers and none had data + // Wait for a bit and check if more data is available + Thread.sleep(50); + } else if (currentReader == 0) { + // Reset the flag for another round across the readers + hadData = false; + } + } + } + + ctx.emitWatermark(new Watermark(Long.MAX_VALUE)); + finalizeSource(); + } + + private void finalizeSource() { + // do nothing, but still look busy ... + // we can't return here since Flink requires that all operators stay up, + // otherwise checkpointing would not work correctly anymore + // + // See https://issues.apache.org/jira/browse/FLINK-2491 for progress on this issue + long idleStart = System.currentTimeMillis(); + while (isRunning && System.currentTimeMillis() - idleStart < idleTimeoutMs) { + try { + // Flink will interrupt us at some point + Thread.sleep(1000); + } catch (InterruptedException e) { + if (!isRunning) { + // restore the interrupted state, and fall through the loop + Thread.currentThread().interrupt(); + } + } + } + } + + /** Emit the current element from the given Reader. The reader is guaranteed to have data. */ + private void emitElement( + SourceContext>> ctx, + UnboundedSource.UnboundedReader reader) { + // make sure that reader state update and element emission are atomic + // with respect to snapshots + OutputT item = reader.getCurrent(); + byte[] recordId = reader.getCurrentRecordId(); + Instant timestamp = reader.getCurrentTimestamp(); + + WindowedValue> windowedValue = + WindowedValues.of( + new ValueWithRecordId<>(item, recordId), + timestamp, + GlobalWindow.INSTANCE, + PaneInfo.NO_FIRING); + ctx.collect(windowedValue); + } + + @Override + public void close() throws Exception { + try { + if (metricContainer != null) { + metricContainer.registerMetricsForPipelineResult(); + } + super.close(); + if (localReaders != null) { + for (UnboundedSource.UnboundedReader reader : localReaders) { + reader.close(); + } + } + } finally { + Workarounds.deleteStaticCaches(); + } + } + + @Override + public void cancel() { + isRunning = false; + } + + @Override + public void stop() { + isRunning = false; + } + + // ------------------------------------------------------------------------ + // Checkpoint and restore + // ------------------------------------------------------------------------ + + @Override + public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception { + if (!isRunning) { + // This implies that stop/drain is invoked and final checkpoint is triggered. This method + // should not be skipped in this scenario so that the notifyCheckpointComplete method is still + // invoked and performs the finalization step after commit is complete. + LOG.debug("snapshotState() called on closed source"); + } + + if (checkpointCoder == null) { + // no checkpoint coder available in this source + return; + } + + stateForCheckpoint.clear(); + + long checkpointId = functionSnapshotContext.getCheckpointId(); + + // we checkpoint the sources along with the CheckpointMarkT to ensure + // than we have a correct mapping of checkpoints to sources when + // restoring + List checkpointMarks = new ArrayList<>(localSplitSources.size()); + + for (int i = 0; i < localSplitSources.size(); i++) { + UnboundedSource source = localSplitSources.get(i); + UnboundedSource.UnboundedReader reader = localReaders.get(i); + + @SuppressWarnings("unchecked") + CheckpointMarkT mark = (CheckpointMarkT) reader.getCheckpointMark(); + checkpointMarks.add(mark); + KV, CheckpointMarkT> kv = KV.of(source, mark); + stateForCheckpoint.add(kv); + } + + // cleanup old pending checkpoints and add new checkpoint + int diff = pendingCheckpoints.size() - MAX_NUMBER_PENDING_CHECKPOINTS; + if (diff >= 0) { + for (Iterator iterator = pendingCheckpoints.keySet().iterator(); diff >= 0; diff--) { + iterator.next(); + iterator.remove(); + } + } + pendingCheckpoints.put(checkpointId, checkpointMarks); + } + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + if (checkpointCoder == null) { + // no checkpoint coder available in this source + return; + } + + OperatorStateStore stateStore = context.getOperatorStateStore(); + @SuppressWarnings("unchecked") + CoderTypeInformation, CheckpointMarkT>> + typeInformation = + (CoderTypeInformation) new CoderTypeInformation<>(checkpointCoder, serializedOptions); + stateForCheckpoint = + stateStore.getListState( + new ListStateDescriptor<>( + DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME, + typeInformation.createSerializer(new SerializerConfigImpl()))); + + if (context.isRestored()) { + isRestored = true; + LOG.info("Restoring state in the UnboundedSourceWrapper."); + } else { + LOG.info("No restore state for UnboundedSourceWrapper."); + } + } + + @Override + public void onProcessingTime(long timestamp) { + if (this.isRunning) { + synchronized (context.getCheckpointLock()) { + // find minimum watermark over all localReaders + long watermarkMillis = Long.MAX_VALUE; + for (UnboundedSource.UnboundedReader reader : localReaders) { + Instant watermark = reader.getWatermark(); + if (watermark != null) { + watermarkMillis = Math.min(watermark.getMillis(), watermarkMillis); + } + } + context.emitWatermark(new Watermark(watermarkMillis)); + + if (watermarkMillis < BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis()) { + setNextWatermarkTimer(this.runtimeContext); + } else { + this.maxWatermarkReached = true; + } + } + } + } + + // the callback is ourselves so there is nothing meaningful we can do with the ScheduledFuture + @SuppressWarnings("FutureReturnValueIgnored") + private void setNextWatermarkTimer(StreamingRuntimeContext runtime) { + if (this.isRunning) { + java.time.Duration autoWaterMarkDuration = + runtime + .getJobConfiguration() + .get(org.apache.flink.configuration.PipelineOptions.AUTO_WATERMARK_INTERVAL); + long watermarkInterval = autoWaterMarkDuration.toMillis(); + synchronized (context.getCheckpointLock()) { + long currentProcessingTime = runtime.getProcessingTimeService().getCurrentProcessingTime(); + if (currentProcessingTime < Long.MAX_VALUE) { + long nextTriggerTime = currentProcessingTime + watermarkInterval; + if (nextTriggerTime < currentProcessingTime) { + // overflow, just trigger once for the max timestamp + nextTriggerTime = Long.MAX_VALUE; + } + runtime.getProcessingTimeService().registerTimer(nextTriggerTime, this); + } + } + } + } + + /** Visible so that we can check this in tests. Must not be used for anything else. */ + @VisibleForTesting + public List> getSplitSources() { + return splitSources; + } + + /** Visible so that we can check this in tests. Must not be used for anything else. */ + @VisibleForTesting + List> getLocalSplitSources() { + return localSplitSources; + } + + /** Visible so that we can check this in tests. Must not be used for anything else. */ + @VisibleForTesting + List> getLocalReaders() { + return localReaders; + } + + /** Visible so that we can check this in tests. Must not be used for anything else. */ + @VisibleForTesting + boolean isRunning() { + return isRunning; + } + + /** + * Visible so that we can set this in tests. This is only set in the run method which is + * inconvenient for the tests where the context is assumed to be set when run is called. Must not + * be used for anything else. + */ + @VisibleForTesting + public void setSourceContext(SourceContext>> ctx) { + context = ctx; + } + + @Override + public void notifyCheckpointComplete(long checkpointId) throws Exception { + List checkpointMarks = pendingCheckpoints.get(checkpointId); + + if (checkpointMarks != null) { + + // remove old checkpoints including the current one + Iterator iterator = pendingCheckpoints.keySet().iterator(); + long currentId; + do { + currentId = iterator.next(); + iterator.remove(); + } while (currentId != checkpointId); + + // confirm all marks + for (CheckpointMarkT mark : checkpointMarks) { + mark.finalizeCheckpoint(); + } + } + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java new file mode 100644 index 000000000000..4bec4c59f9de --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java @@ -0,0 +1,697 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.state; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import org.apache.beam.runners.core.StateInternals; +import org.apache.beam.runners.core.StateNamespace; +import org.apache.beam.runners.core.StateTag; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.ListCoder; +import org.apache.beam.sdk.coders.MapCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.CombiningState; +import org.apache.beam.sdk.state.MapState; +import org.apache.beam.sdk.state.MultimapState; +import org.apache.beam.sdk.state.OrderedListState; +import org.apache.beam.sdk.state.ReadableState; +import org.apache.beam.sdk.state.SetState; +import org.apache.beam.sdk.state.State; +import org.apache.beam.sdk.state.StateContext; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.state.WatermarkHoldState; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.CombineWithContext; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; +import org.apache.beam.sdk.util.CombineContextFactory; +import org.apache.flink.api.common.serialization.SerializerConfigImpl; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.OperatorStateStore; +import org.apache.flink.runtime.state.OperatorStateBackend; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * {@link StateInternals} that uses a Flink {@link OperatorStateBackend} to manage the broadcast + * state. The state is the same on all parallel instances of the operator. So we just need store + * state of operator-0 in OperatorStateBackend. + * + *

Note: Ignore index of key. Mainly for SideInputs. + */ +@SuppressWarnings({ + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +public class FlinkBroadcastStateInternals implements StateInternals { + + private int indexInSubtaskGroup; + private final OperatorStateBackend stateBackend; + // stateName -> + private Map> stateForNonZeroOperator; + + private final SerializablePipelineOptions pipelineOptions; + + public FlinkBroadcastStateInternals( + int indexInSubtaskGroup, + OperatorStateBackend stateBackend, + SerializablePipelineOptions pipelineOptions) { + this.stateBackend = stateBackend; + this.indexInSubtaskGroup = indexInSubtaskGroup; + this.pipelineOptions = pipelineOptions; + if (indexInSubtaskGroup != 0) { + stateForNonZeroOperator = new HashMap<>(); + } + } + + @Override + public @Nullable K getKey() { + return null; + } + + @Override + public T state( + final StateNamespace namespace, StateTag address, final StateContext context) { + + return address.bind( + new StateTag.StateBinder() { + + @Override + public ValueState bindValue(StateTag> address, Coder coder) { + + return new FlinkBroadcastValueState<>( + stateBackend, address, namespace, coder, pipelineOptions); + } + + @Override + public BagState bindBag(StateTag> address, Coder elemCoder) { + + return new FlinkBroadcastBagState<>( + stateBackend, address, namespace, elemCoder, pipelineOptions); + } + + @Override + public SetState bindSet(StateTag> address, Coder elemCoder) { + throw new UnsupportedOperationException( + String.format("%s is not supported", SetState.class.getSimpleName())); + } + + @Override + public MapState bindMap( + StateTag> spec, + Coder mapKeyCoder, + Coder mapValueCoder) { + throw new UnsupportedOperationException( + String.format("%s is not supported", MapState.class.getSimpleName())); + } + + @Override + public OrderedListState bindOrderedList( + StateTag> spec, Coder elemCoder) { + throw new UnsupportedOperationException( + String.format("%s is not supported", OrderedListState.class.getSimpleName())); + } + + @Override + public MultimapState bindMultimap( + StateTag> spec, + Coder keyCoder, + Coder valueCoder) { + throw new UnsupportedOperationException( + String.format("%s is not supported", MultimapState.class.getSimpleName())); + } + + @Override + public + CombiningState bindCombiningValue( + StateTag> address, + Coder accumCoder, + Combine.CombineFn combineFn) { + + return new FlinkCombiningState<>( + stateBackend, address, combineFn, namespace, accumCoder, pipelineOptions); + } + + @Override + public + CombiningState bindCombiningValueWithContext( + StateTag> address, + Coder accumCoder, + CombineWithContext.CombineFnWithContext combineFn) { + return new FlinkCombiningStateWithContext<>( + stateBackend, + address, + combineFn, + namespace, + accumCoder, + CombineContextFactory.createFromStateContext(context)); + } + + @Override + public WatermarkHoldState bindWatermark( + StateTag address, TimestampCombiner timestampCombiner) { + throw new UnsupportedOperationException( + String.format("%s is not supported", WatermarkHoldState.class.getSimpleName())); + } + }); + } + + /** + * 1. The way we would use it is to only checkpoint anything from the operator with subtask index + * 0 because we assume that the state is the same on all parallel instances of the operator. + * + *

2. Use map to support namespace. + */ + private abstract class AbstractBroadcastState { + + private String name; + private final StateNamespace namespace; + private final ListStateDescriptor> flinkStateDescriptor; + private final OperatorStateStore flinkStateBackend; + + AbstractBroadcastState( + OperatorStateBackend flinkStateBackend, + String name, + StateNamespace namespace, + Coder coder, + SerializablePipelineOptions pipelineOptions) { + this.name = name; + + this.namespace = namespace; + this.flinkStateBackend = flinkStateBackend; + + CoderTypeInformation> typeInfo = + new CoderTypeInformation<>(MapCoder.of(StringUtf8Coder.of(), coder), pipelineOptions); + + flinkStateDescriptor = + new ListStateDescriptor<>(name, typeInfo.createSerializer(new SerializerConfigImpl())); + } + + /** Get map(namespce->T) from index 0. */ + Map getMap() throws Exception { + if (indexInSubtaskGroup == 0) { + return getMapFromBroadcastState(); + } else { + Map result = (Map) stateForNonZeroOperator.get(name); + // maybe restore from BroadcastState of Operator-0 + if (result == null) { + result = getMapFromBroadcastState(); + if (result != null) { + stateForNonZeroOperator.put(name, result); + // we don't need it anymore, must clear it. + flinkStateBackend.getUnionListState(flinkStateDescriptor).clear(); + } + } + return result; + } + } + + Map getMapFromBroadcastState() throws Exception { + ListState> state = flinkStateBackend.getUnionListState(flinkStateDescriptor); + Iterable> iterable = state.get(); + Map ret = null; + if (iterable != null) { + // just use index 0 + Iterator> iterator = iterable.iterator(); + if (iterator.hasNext()) { + ret = iterator.next(); + } + } + return ret; + } + + /** Update map(namespce->T) from index 0. */ + void updateMap(Map map) throws Exception { + if (indexInSubtaskGroup == 0) { + ListState> state = flinkStateBackend.getUnionListState(flinkStateDescriptor); + state.clear(); + if (map.size() > 0) { + state.add(map); + } + } else { + if (map.isEmpty()) { + stateForNonZeroOperator.remove(name); + // updateMap is always behind getMap, + // getMap will clear map in BroadcastOperatorState, + // we don't need clear here. + } else { + stateForNonZeroOperator.put(name, map); + } + } + } + + void writeInternal(T input) { + try { + Map map = getMap(); + if (map == null) { + map = new HashMap<>(); + } + map.put(namespace.stringKey(), input); + updateMap(map); + } catch (Exception e) { + throw new RuntimeException("Error updating state.", e); + } + } + + T readInternal() { + try { + Map map = getMap(); + if (map == null) { + return null; + } else { + return map.get(namespace.stringKey()); + } + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + void clearInternal() { + try { + Map map = getMap(); + if (map != null) { + map.remove(namespace.stringKey()); + updateMap(map); + } + } catch (Exception e) { + throw new RuntimeException("Error clearing state.", e); + } + } + } + + private class FlinkBroadcastValueState extends AbstractBroadcastState + implements ValueState { + + private final StateNamespace namespace; + private final StateTag> address; + + FlinkBroadcastValueState( + OperatorStateBackend flinkStateBackend, + StateTag> address, + StateNamespace namespace, + Coder coder, + SerializablePipelineOptions pipelineOptions) { + super(flinkStateBackend, address.getId(), namespace, coder, pipelineOptions); + + this.namespace = namespace; + this.address = address; + } + + @Override + public void write(T input) { + writeInternal(input); + } + + @Override + public ValueState readLater() { + return this; + } + + @Override + public T read() { + return readInternal(); + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkBroadcastValueState that = (FlinkBroadcastValueState) o; + + return namespace.equals(that.namespace) && address.equals(that.address); + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + address.hashCode(); + return result; + } + + @Override + public void clear() { + clearInternal(); + } + } + + private class FlinkBroadcastBagState extends AbstractBroadcastState> + implements BagState { + + private final StateNamespace namespace; + private final StateTag> address; + + FlinkBroadcastBagState( + OperatorStateBackend flinkStateBackend, + StateTag> address, + StateNamespace namespace, + Coder coder, + SerializablePipelineOptions pipelineOptions) { + super(flinkStateBackend, address.getId(), namespace, ListCoder.of(coder), pipelineOptions); + + this.namespace = namespace; + this.address = address; + } + + @Override + public void add(T input) { + List list = readInternal(); + if (list == null) { + list = new ArrayList<>(); + } + list.add(input); + writeInternal(list); + } + + @Override + public BagState readLater() { + return this; + } + + @Override + public Iterable read() { + List result = readInternal(); + return result != null ? result : Collections.emptyList(); + } + + @Override + public ReadableState isEmpty() { + return new ReadableState() { + @Override + public Boolean read() { + try { + List result = readInternal(); + return result == null; + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public ReadableState readLater() { + return this; + } + }; + } + + @Override + public void clear() { + clearInternal(); + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkBroadcastBagState that = (FlinkBroadcastBagState) o; + + return namespace.equals(that.namespace) && address.equals(that.address); + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + address.hashCode(); + return result; + } + } + + private class FlinkCombiningState extends AbstractBroadcastState + implements CombiningState { + + private final StateNamespace namespace; + private final StateTag> address; + private final Combine.CombineFn combineFn; + + FlinkCombiningState( + OperatorStateBackend flinkStateBackend, + StateTag> address, + Combine.CombineFn combineFn, + StateNamespace namespace, + Coder accumCoder, + SerializablePipelineOptions pipelineOptions) { + super(flinkStateBackend, address.getId(), namespace, accumCoder, pipelineOptions); + + this.namespace = namespace; + this.address = address; + this.combineFn = combineFn; + } + + @Override + public CombiningState readLater() { + return this; + } + + @Override + public void add(InputT value) { + AccumT current = readInternal(); + if (current == null) { + current = combineFn.createAccumulator(); + } + current = combineFn.addInput(current, value); + writeInternal(current); + } + + @Override + public void addAccum(AccumT accum) { + AccumT current = readInternal(); + + if (current == null) { + writeInternal(accum); + } else { + current = combineFn.mergeAccumulators(Arrays.asList(current, accum)); + writeInternal(current); + } + } + + @Override + public AccumT getAccum() { + AccumT accum = readInternal(); + return accum != null ? accum : combineFn.createAccumulator(); + } + + @Override + public AccumT mergeAccumulators(Iterable accumulators) { + return combineFn.mergeAccumulators(accumulators); + } + + @Override + public OutputT read() { + AccumT accum = readInternal(); + if (accum != null) { + return combineFn.extractOutput(accum); + } else { + return combineFn.extractOutput(combineFn.createAccumulator()); + } + } + + @Override + public ReadableState isEmpty() { + return new ReadableState() { + @Override + public Boolean read() { + try { + return readInternal() == null; + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public ReadableState readLater() { + return this; + } + }; + } + + @Override + public void clear() { + clearInternal(); + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkCombiningState that = (FlinkCombiningState) o; + + return namespace.equals(that.namespace) && address.equals(that.address); + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + address.hashCode(); + return result; + } + } + + private class FlinkCombiningStateWithContext + extends AbstractBroadcastState implements CombiningState { + + private final StateNamespace namespace; + private final StateTag> address; + private final CombineWithContext.CombineFnWithContext combineFn; + private final CombineWithContext.Context context; + + FlinkCombiningStateWithContext( + OperatorStateBackend flinkStateBackend, + StateTag> address, + CombineWithContext.CombineFnWithContext combineFn, + StateNamespace namespace, + Coder accumCoder, + CombineWithContext.Context context) { + super(flinkStateBackend, address.getId(), namespace, accumCoder, pipelineOptions); + + this.namespace = namespace; + this.address = address; + this.combineFn = combineFn; + this.context = context; + } + + @Override + public CombiningState readLater() { + return this; + } + + @Override + public void add(InputT value) { + try { + AccumT current = readInternal(); + if (current == null) { + current = combineFn.createAccumulator(context); + } + current = combineFn.addInput(current, value, context); + writeInternal(current); + } catch (Exception e) { + throw new RuntimeException("Error adding to state.", e); + } + } + + @Override + public void addAccum(AccumT accum) { + try { + + AccumT current = readInternal(); + if (current == null) { + writeInternal(accum); + } else { + current = combineFn.mergeAccumulators(Arrays.asList(current, accum), context); + writeInternal(current); + } + } catch (Exception e) { + throw new RuntimeException("Error adding to state.", e); + } + } + + @Override + public AccumT getAccum() { + try { + AccumT accum = readInternal(); + return accum != null ? accum : combineFn.createAccumulator(context); + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public AccumT mergeAccumulators(Iterable accumulators) { + return combineFn.mergeAccumulators(accumulators, context); + } + + @Override + public OutputT read() { + try { + AccumT accum = readInternal(); + if (accum == null) { + accum = combineFn.createAccumulator(context); + } + return combineFn.extractOutput(accum, context); + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public ReadableState isEmpty() { + return new ReadableState() { + @Override + public Boolean read() { + try { + return readInternal() == null; + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public ReadableState readLater() { + return this; + } + }; + } + + @Override + public void clear() { + clearInternal(); + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkCombiningStateWithContext that = + (FlinkCombiningStateWithContext) o; + + return namespace.equals(that.namespace) && address.equals(that.address); + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + address.hashCode(); + return result; + } + } +} diff --git a/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java new file mode 100644 index 000000000000..501207b32e97 --- /dev/null +++ b/runners/flink/2.0/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java @@ -0,0 +1,1851 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.state; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.SortedMap; +import java.util.function.Function; +import java.util.stream.Stream; +import javax.annotation.Nonnull; +import org.apache.beam.runners.core.StateInternals; +import org.apache.beam.runners.core.StateNamespace; +import org.apache.beam.runners.core.StateNamespaces; +import org.apache.beam.runners.core.StateTag; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.flink.FlinkPipelineOptions; +import org.apache.beam.runners.flink.adapter.FlinkKey; +import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer; +import org.apache.beam.runners.flink.translation.wrappers.streaming.FlinkKeyUtils; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.InstantCoder; +import org.apache.beam.sdk.coders.VoidCoder; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.CombiningState; +import org.apache.beam.sdk.state.GroupingState; +import org.apache.beam.sdk.state.MapState; +import org.apache.beam.sdk.state.MultimapState; +import org.apache.beam.sdk.state.OrderedListState; +import org.apache.beam.sdk.state.ReadableState; +import org.apache.beam.sdk.state.ReadableStates; +import org.apache.beam.sdk.state.SetState; +import org.apache.beam.sdk.state.State; +import org.apache.beam.sdk.state.StateBinder; +import org.apache.beam.sdk.state.StateContext; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.state.WatermarkHoldState; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.CombineWithContext; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; +import org.apache.beam.sdk.util.CombineContextFactory; +import org.apache.beam.sdk.values.TimestampedValue; +import org.apache.beam.sdk.values.TimestampedValue.TimestampedValueCoder; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.HashMultimap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.TreeMultiset; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerSchemaCompatibility; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.BooleanSerializer; +import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.runtime.state.JavaSerializer; +import org.apache.flink.runtime.state.KeyedStateBackend; +import org.apache.flink.runtime.state.VoidNamespace; +import org.apache.flink.runtime.state.VoidNamespaceSerializer; +import org.checkerframework.checker.initialization.qual.Initialized; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.nullness.qual.UnknownKeyFor; +import org.joda.time.Instant; + +/** + * {@link StateInternals} that uses a Flink {@link KeyedStateBackend} to manage state. + * + *

Note: In the Flink streaming runner the key is always encoded using an {@link Coder} and + * stored in a {@link FlinkKey}. + */ +@SuppressWarnings({ + "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +public class FlinkStateInternals implements StateInternals { + + private static final StateNamespace globalWindowNamespace = + StateNamespaces.window(GlobalWindow.Coder.INSTANCE, GlobalWindow.INSTANCE); + + private final KeyedStateBackend flinkStateBackend; + private final Coder keyCoder; + FlinkStateNamespaceKeySerializer namespaceKeySerializer; + + private static class StateAndNamespaceDescriptor { + static StateAndNamespaceDescriptor of( + StateDescriptor stateDescriptor, T namespace, TypeSerializer namespaceSerializer) { + return new StateAndNamespaceDescriptor<>(stateDescriptor, namespace, namespaceSerializer); + } + + private final StateDescriptor stateDescriptor; + private final T namespace; + private final TypeSerializer namespaceSerializer; + + private StateAndNamespaceDescriptor( + StateDescriptor stateDescriptor, T namespace, TypeSerializer namespaceSerializer) { + this.stateDescriptor = stateDescriptor; + this.namespace = namespace; + this.namespaceSerializer = namespaceSerializer; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + StateAndNamespaceDescriptor other = (StateAndNamespaceDescriptor) o; + return Objects.equals(stateDescriptor, other.stateDescriptor); + } + + @Override + public int hashCode() { + return Objects.hash(stateDescriptor); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("stateDescriptor", stateDescriptor) + .add("namespace", namespace) + .add("namespaceSerializer", namespaceSerializer) + .toString(); + } + } + + /** + * A set which contains all state descriptors created in the global window. Used for cleanup on + * final watermark. + */ + private final Set> globalWindowStateDescriptors = new HashSet<>(); + + /** Watermark holds descriptors created for a specific window. */ + private final HashMultimap watermarkHoldsMap = + HashMultimap.create(); + + // Watermark holds for all keys/windows of this partition, allows efficient lookup of the minimum + private final TreeMultiset watermarkHolds = TreeMultiset.create(); + // State to persist combined watermark holds for all keys of this partition + private final MapStateDescriptor watermarkHoldStateDescriptor; + + private final boolean fasterCopy; + + public FlinkStateInternals( + KeyedStateBackend flinkStateBackend, + Coder keyCoder, + Coder windowCoder, + SerializablePipelineOptions pipelineOptions) + throws Exception { + this.flinkStateBackend = Objects.requireNonNull(flinkStateBackend); + this.keyCoder = Objects.requireNonNull(keyCoder); + this.fasterCopy = pipelineOptions.get().as(FlinkPipelineOptions.class).getFasterCopy(); + this.namespaceKeySerializer = new FlinkStateNamespaceKeySerializer(windowCoder); + + watermarkHoldStateDescriptor = + new MapStateDescriptor<>( + "watermark-holds", + StringSerializer.INSTANCE, + new CoderTypeSerializer<>(InstantCoder.of(), fasterCopy)); + restoreWatermarkHoldsView(); + } + + /** Returns the minimum over all watermark holds. */ + public Long minWatermarkHoldMs() { + if (watermarkHolds.isEmpty()) { + return Long.MAX_VALUE; + } else { + return watermarkHolds.firstEntry().getElement(); + } + } + + @Override + public K getKey() { + FlinkKey keyBytes = flinkStateBackend.getCurrentKey(); + return FlinkKeyUtils.decodeKey(keyBytes.getSerializedKey(), keyCoder); + } + + @Override + public T state( + StateNamespace namespace, StateTag address, StateContext context) { + return address.getSpec().bind(address.getId(), new FlinkStateBinder(namespace, context)); + } + + /** + * Allows to clear all state for the global watermark when the maximum watermark arrives. We do + * not clean up the global window state via timers which would lead to an unbounded number of keys + * and cleanup timers. Instead, the cleanup code below should be run when we finally receive the + * max watermark. + */ + @SuppressWarnings({"unchecked", "rawtypes"}) + public void clearGlobalState() { + try { + for (StateAndNamespaceDescriptor stateAndNamespace : globalWindowStateDescriptors) { + flinkStateBackend.applyToAllKeys( + stateAndNamespace.namespace, + stateAndNamespace.namespaceSerializer, + stateAndNamespace.stateDescriptor, + (key, state) -> state.clear()); + } + watermarkHoldsMap.values().forEach(FlinkWatermarkHoldState::clear); + // Clear set to avoid repeating the cleanup + globalWindowStateDescriptors.clear(); + watermarkHoldsMap.clear(); + } catch (Exception e) { + throw new RuntimeException("Failed to cleanup global state.", e); + } + } + + private class FlinkStateBinder implements StateBinder { + + private final StateNamespace namespace; + private final StateContext stateContext; + + private FlinkStateBinder(StateNamespace namespace, StateContext stateContext) { + this.namespace = namespace; + this.stateContext = stateContext; + } + + @Override + public ValueState bindValue( + String id, StateSpec> spec, Coder coder) { + FlinkValueState valueState = + new FlinkValueState<>( + flinkStateBackend, id, namespace, coder, namespaceKeySerializer, fasterCopy); + collectGlobalWindowStateDescriptor( + valueState.flinkStateDescriptor, valueState.namespace, namespaceKeySerializer); + return valueState; + } + + @Override + public BagState bindBag(String id, StateSpec> spec, Coder elemCoder) { + FlinkBagState bagState = + new FlinkBagState<>( + flinkStateBackend, id, namespace, elemCoder, namespaceKeySerializer, fasterCopy); + collectGlobalWindowStateDescriptor( + bagState.flinkStateDescriptor, bagState.namespace, namespaceKeySerializer); + return bagState; + } + + @Override + public SetState bindSet(String id, StateSpec> spec, Coder elemCoder) { + FlinkSetState setState = + new FlinkSetState<>( + flinkStateBackend, id, namespace, elemCoder, namespaceKeySerializer, fasterCopy); + collectGlobalWindowStateDescriptor( + setState.flinkStateDescriptor, setState.namespace, namespaceKeySerializer); + return setState; + } + + @Override + public MapState bindMap( + String id, + StateSpec> spec, + Coder mapKeyCoder, + Coder mapValueCoder) { + FlinkMapState mapState = + new FlinkMapState<>( + flinkStateBackend, + id, + namespace, + mapKeyCoder, + mapValueCoder, + namespaceKeySerializer, + fasterCopy); + collectGlobalWindowStateDescriptor( + mapState.flinkStateDescriptor, mapState.namespace, namespaceKeySerializer); + return mapState; + } + + @Override + public OrderedListState bindOrderedList( + String id, StateSpec> spec, Coder elemCoder) { + FlinkOrderedListState flinkOrderedListState = + new FlinkOrderedListState<>( + flinkStateBackend, id, namespace, elemCoder, namespaceKeySerializer, fasterCopy); + collectGlobalWindowStateDescriptor( + flinkOrderedListState.flinkStateDescriptor, + flinkOrderedListState.namespace, + namespaceKeySerializer); + return flinkOrderedListState; + } + + @Override + public MultimapState bindMultimap( + String id, + StateSpec> spec, + Coder keyCoder, + Coder valueCoder) { + throw new UnsupportedOperationException( + String.format("%s is not supported", MultimapState.class.getSimpleName())); + } + + @Override + public CombiningState bindCombining( + String id, + StateSpec> spec, + Coder accumCoder, + Combine.CombineFn combineFn) { + FlinkCombiningState combiningState = + new FlinkCombiningState<>( + flinkStateBackend, + id, + combineFn, + namespace, + accumCoder, + namespaceKeySerializer, + fasterCopy); + collectGlobalWindowStateDescriptor( + combiningState.flinkStateDescriptor, combiningState.namespace, namespaceKeySerializer); + return combiningState; + } + + @Override + public + CombiningState bindCombiningWithContext( + String id, + StateSpec> spec, + Coder accumCoder, + CombineWithContext.CombineFnWithContext combineFn) { + FlinkCombiningStateWithContext combiningStateWithContext = + new FlinkCombiningStateWithContext<>( + flinkStateBackend, + id, + combineFn, + namespace, + accumCoder, + namespaceKeySerializer, + CombineContextFactory.createFromStateContext(stateContext), + fasterCopy); + collectGlobalWindowStateDescriptor( + combiningStateWithContext.flinkStateDescriptor, + combiningStateWithContext.namespace, + namespaceKeySerializer); + return combiningStateWithContext; + } + + @Override + public WatermarkHoldState bindWatermark( + String id, StateSpec spec, TimestampCombiner timestampCombiner) { + collectGlobalWindowStateDescriptor( + watermarkHoldStateDescriptor, VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE); + FlinkWatermarkHoldState state = + new FlinkWatermarkHoldState( + flinkStateBackend, watermarkHoldStateDescriptor, id, namespace, timestampCombiner); + collectWatermarkHolds(state); + return state; + } + + private void collectWatermarkHolds(FlinkWatermarkHoldState state) { + watermarkHoldsMap.put(namespace.stringKey(), state); + } + + /** Take note of state bound to the global window for cleanup in clearGlobalState(). */ + private void collectGlobalWindowStateDescriptor( + StateDescriptor descriptor, T namespaceKey, TypeSerializer keySerializer) { + if (globalWindowNamespace.equals(namespace) || StateNamespaces.global().equals(namespace)) { + globalWindowStateDescriptors.add( + StateAndNamespaceDescriptor.of(descriptor, namespaceKey, keySerializer)); + } + } + } + + public static class FlinkStateNamespaceKeySerializer extends TypeSerializer { + + public Coder getCoder() { + return coder; + } + + private final Coder coder; + + public FlinkStateNamespaceKeySerializer(Coder coder) { + this.coder = coder; + } + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public TypeSerializer duplicate() { + return this; + } + + @Override + public StateNamespace createInstance() { + return null; + } + + @Override + public StateNamespace copy(StateNamespace from) { + return from; + } + + @Override + public StateNamespace copy(StateNamespace from, StateNamespace reuse) { + return from; + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(StateNamespace record, DataOutputView target) throws IOException { + StringSerializer.INSTANCE.serialize(record.stringKey(), target); + } + + @Override + public StateNamespace deserialize(DataInputView source) throws IOException { + return StateNamespaces.fromString(StringSerializer.INSTANCE.deserialize(source), coder); + } + + @Override + public StateNamespace deserialize(StateNamespace reuse, DataInputView source) + throws IOException { + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + throw new UnsupportedOperationException("copy is not supported for FlinkStateNamespace key"); + } + + @Override + public boolean equals(Object obj) { + return obj instanceof FlinkStateNamespaceKeySerializer; + } + + @Override + public int hashCode() { + return Objects.hashCode(getClass()); + } + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new FlinkStateNameSpaceSerializerSnapshot(this); + } + + /** Serializer configuration snapshot for compatibility and format evolution. */ + @SuppressWarnings("WeakerAccess") + public static final class FlinkStateNameSpaceSerializerSnapshot + implements TypeSerializerSnapshot { + + @Nullable private Coder windowCoder; + + public FlinkStateNameSpaceSerializerSnapshot() {} + + FlinkStateNameSpaceSerializerSnapshot(FlinkStateNamespaceKeySerializer ser) { + this.windowCoder = ser.getCoder(); + } + + @Override + public int getCurrentVersion() { + return 0; + } + + @Override + public void writeSnapshot(DataOutputView out) throws IOException { + new JavaSerializer>().serialize(windowCoder, out); + } + + @Override + public void readSnapshot(int readVersion, DataInputView in, ClassLoader userCodeClassLoader) + throws IOException { + this.windowCoder = new JavaSerializer>().deserialize(in); + } + + @Override + public TypeSerializer restoreSerializer() { + return new FlinkStateNamespaceKeySerializer(windowCoder); + } + + @Override + public TypeSerializerSchemaCompatibility resolveSchemaCompatibility( + TypeSerializerSnapshot oldSerializerSnapshot) { + return TypeSerializerSchemaCompatibility.compatibleAsIs(); + } + } + } + + private static class FlinkValueState implements ValueState { + + private final StateNamespace namespace; + private final String stateId; + private final ValueStateDescriptor flinkStateDescriptor; + private final KeyedStateBackend flinkStateBackend; + private final FlinkStateNamespaceKeySerializer namespaceSerializer; + + FlinkValueState( + KeyedStateBackend flinkStateBackend, + String stateId, + StateNamespace namespace, + Coder coder, + FlinkStateNamespaceKeySerializer namespaceSerializer, + boolean fasterCopy) { + + this.namespace = namespace; + this.stateId = stateId; + this.flinkStateBackend = flinkStateBackend; + this.namespaceSerializer = namespaceSerializer; + + flinkStateDescriptor = + new ValueStateDescriptor<>(stateId, new CoderTypeSerializer<>(coder, fasterCopy)); + } + + @Override + public void write(T input) { + try { + flinkStateBackend + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) + .update(input); + } catch (Exception e) { + throw new RuntimeException("Error updating state.", e); + } + } + + @Override + public ValueState readLater() { + return this; + } + + @Override + public T read() { + try { + return flinkStateBackend + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) + .value(); + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public void clear() { + try { + flinkStateBackend + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) + .clear(); + } catch (Exception e) { + throw new RuntimeException("Error clearing state.", e); + } + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkValueState that = (FlinkValueState) o; + + return namespace.equals(that.namespace) && stateId.equals(that.stateId); + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + stateId.hashCode(); + return result; + } + } + + private static class FlinkOrderedListState implements OrderedListState { + private final StateNamespace namespace; + private final ListStateDescriptor> flinkStateDescriptor; + private final KeyedStateBackend flinkStateBackend; + private final FlinkStateNamespaceKeySerializer namespaceSerializer; + + FlinkOrderedListState( + KeyedStateBackend flinkStateBackend, + String stateId, + StateNamespace namespace, + Coder coder, + FlinkStateNamespaceKeySerializer namespaceSerializer, + boolean fasterCopy) { + this.namespace = namespace; + this.flinkStateBackend = flinkStateBackend; + this.flinkStateDescriptor = + new ListStateDescriptor<>( + stateId, new CoderTypeSerializer<>(TimestampedValueCoder.of(coder), fasterCopy)); + this.namespaceSerializer = namespaceSerializer; + } + + @Override + public Iterable> readRange(Instant minTimestamp, Instant limitTimestamp) { + return readAsMap().subMap(minTimestamp, limitTimestamp).values(); + } + + @Override + public void clearRange(Instant minTimestamp, Instant limitTimestamp) { + SortedMap> sortedMap = readAsMap(); + sortedMap.subMap(minTimestamp, limitTimestamp).clear(); + try { + ListState> partitionedState = + flinkStateBackend.getPartitionedState( + namespace, namespaceSerializer, flinkStateDescriptor); + partitionedState.update(Lists.newArrayList(sortedMap.values())); + } catch (Exception e) { + throw new RuntimeException("Error adding to bag state.", e); + } + } + + @Override + public OrderedListState readRangeLater(Instant minTimestamp, Instant limitTimestamp) { + return this; + } + + @Override + public void add(TimestampedValue value) { + try { + ListState> partitionedState = + flinkStateBackend.getPartitionedState( + namespace, namespaceSerializer, flinkStateDescriptor); + partitionedState.add(value); + } catch (Exception e) { + throw new RuntimeException("Error adding to bag state.", e); + } + } + + @Override + public ReadableState isEmpty() { + return new ReadableState() { + @Override + public Boolean read() { + try { + Iterable> result = + flinkStateBackend + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) + .get(); + return result == null; + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public ReadableState readLater() { + return this; + } + }; + } + + @Override + @Nullable + public Iterable> read() { + return readAsMap().values(); + } + + private SortedMap> readAsMap() { + Iterable> listValues; + try { + ListState> partitionedState = + flinkStateBackend.getPartitionedState( + namespace, namespaceSerializer, flinkStateDescriptor); + listValues = MoreObjects.firstNonNull(partitionedState.get(), Collections.emptyList()); + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + + SortedMap> sortedMap = Maps.newTreeMap(); + for (TimestampedValue value : listValues) { + sortedMap.put(value.getTimestamp(), value); + } + return sortedMap; + } + + @Override + public GroupingState, Iterable>> readLater() { + return this; + } + + @Override + public void clear() { + try { + flinkStateBackend + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) + .clear(); + } catch (Exception e) { + throw new RuntimeException("Error clearing state.", e); + } + } + } + + private static class FlinkBagState implements BagState { + + private final StateNamespace namespace; + private final String stateId; + private final ListStateDescriptor flinkStateDescriptor; + private final KeyedStateBackend flinkStateBackend; + private final boolean storesVoidValues; + private final FlinkStateNamespaceKeySerializer namespaceSerializer; + + FlinkBagState( + KeyedStateBackend flinkStateBackend, + String stateId, + StateNamespace namespace, + Coder coder, + FlinkStateNamespaceKeySerializer namespaceSerializer, + boolean fasterCopy) { + + this.namespace = namespace; + this.stateId = stateId; + this.flinkStateBackend = flinkStateBackend; + this.storesVoidValues = coder instanceof VoidCoder; + this.flinkStateDescriptor = + new ListStateDescriptor<>(stateId, new CoderTypeSerializer<>(coder, fasterCopy)); + this.namespaceSerializer = namespaceSerializer; + } + + @Override + public void add(T input) { + try { + ListState partitionedState = + flinkStateBackend.getPartitionedState( + namespace, namespaceSerializer, flinkStateDescriptor); + if (storesVoidValues) { + Preconditions.checkState(input == null, "Expected to a null value but was: %s", input); + // Flink does not allow storing null values + // If we have null values, we use the structural null value + input = (T) VoidCoder.of().structuralValue((Void) input); + } + partitionedState.add(input); + } catch (Exception e) { + throw new RuntimeException("Error adding to bag state.", e); + } + } + + @Override + public BagState readLater() { + return this; + } + + @Override + @Nonnull + public Iterable read() { + try { + ListState partitionedState = + flinkStateBackend.getPartitionedState( + namespace, namespaceSerializer, flinkStateDescriptor); + Iterable result = partitionedState.get(); + if (storesVoidValues) { + return () -> { + final Iterator underlying = result.iterator(); + return new Iterator() { + @Override + public boolean hasNext() { + return underlying.hasNext(); + } + + @Override + public T next() { + // Simply move the iterator forward but ignore the value. + // The value can be the structural null value or NULL itself, + // if this has been restored from serialized state. + underlying.next(); + return null; + } + }; + }; + } + return result != null ? ImmutableList.copyOf(result) : Collections.emptyList(); + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public ReadableState isEmpty() { + return new ReadableState() { + @Override + public Boolean read() { + try { + Iterable result = + flinkStateBackend + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) + .get(); + return result == null; + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public ReadableState readLater() { + return this; + } + }; + } + + @Override + public void clear() { + try { + flinkStateBackend + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) + .clear(); + } catch (Exception e) { + throw new RuntimeException("Error clearing state.", e); + } + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkBagState that = (FlinkBagState) o; + + return namespace.equals(that.namespace) && stateId.equals(that.stateId); + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + stateId.hashCode(); + return result; + } + } + + private static class FlinkCombiningState + implements CombiningState { + + private final StateNamespace namespace; + private final String stateId; + private final Combine.CombineFn combineFn; + private final ValueStateDescriptor flinkStateDescriptor; + private final KeyedStateBackend flinkStateBackend; + private final FlinkStateNamespaceKeySerializer namespaceSerializer; + + FlinkCombiningState( + KeyedStateBackend flinkStateBackend, + String stateId, + Combine.CombineFn combineFn, + StateNamespace namespace, + Coder accumCoder, + FlinkStateNamespaceKeySerializer namespaceSerializer, + boolean fasterCopy) { + + this.namespace = namespace; + this.stateId = stateId; + this.combineFn = combineFn; + this.flinkStateBackend = flinkStateBackend; + this.namespaceSerializer = namespaceSerializer; + + flinkStateDescriptor = + new ValueStateDescriptor<>(stateId, new CoderTypeSerializer<>(accumCoder, fasterCopy)); + } + + @Override + public CombiningState readLater() { + return this; + } + + @Override + public void add(InputT value) { + try { + org.apache.flink.api.common.state.ValueState state = + flinkStateBackend.getPartitionedState( + namespace, namespaceSerializer, flinkStateDescriptor); + + AccumT current = state.value(); + if (current == null) { + current = combineFn.createAccumulator(); + } + current = combineFn.addInput(current, value); + state.update(current); + } catch (Exception e) { + throw new RuntimeException("Error adding to state.", e); + } + } + + @Override + public void addAccum(AccumT accum) { + try { + org.apache.flink.api.common.state.ValueState state = + flinkStateBackend.getPartitionedState( + namespace, namespaceSerializer, flinkStateDescriptor); + + AccumT current = state.value(); + if (current == null) { + state.update(accum); + } else { + current = combineFn.mergeAccumulators(Lists.newArrayList(current, accum)); + state.update(current); + } + } catch (Exception e) { + throw new RuntimeException("Error adding to state.", e); + } + } + + @Override + public AccumT getAccum() { + try { + AccumT accum = + flinkStateBackend + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) + .value(); + return accum != null ? accum : combineFn.createAccumulator(); + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public AccumT mergeAccumulators(Iterable accumulators) { + return combineFn.mergeAccumulators(accumulators); + } + + @Override + public OutputT read() { + try { + org.apache.flink.api.common.state.ValueState state = + flinkStateBackend.getPartitionedState( + namespace, namespaceSerializer, flinkStateDescriptor); + + AccumT accum = state.value(); + if (accum != null) { + return combineFn.extractOutput(accum); + } else { + return combineFn.extractOutput(combineFn.createAccumulator()); + } + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public ReadableState isEmpty() { + return new ReadableState() { + @Override + public Boolean read() { + try { + return flinkStateBackend + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) + .value() + == null; + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public ReadableState readLater() { + return this; + } + }; + } + + @Override + public void clear() { + try { + flinkStateBackend + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) + .clear(); + } catch (Exception e) { + throw new RuntimeException("Error clearing state.", e); + } + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkCombiningState that = (FlinkCombiningState) o; + + return namespace.equals(that.namespace) && stateId.equals(that.stateId); + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + stateId.hashCode(); + return result; + } + } + + private static class FlinkCombiningStateWithContext + implements CombiningState { + + private final StateNamespace namespace; + private final String stateId; + private final CombineWithContext.CombineFnWithContext combineFn; + private final ValueStateDescriptor flinkStateDescriptor; + private final KeyedStateBackend flinkStateBackend; + private final CombineWithContext.Context context; + private final FlinkStateNamespaceKeySerializer namespaceSerializer; + + FlinkCombiningStateWithContext( + KeyedStateBackend flinkStateBackend, + String stateId, + CombineWithContext.CombineFnWithContext combineFn, + StateNamespace namespace, + Coder accumCoder, + FlinkStateNamespaceKeySerializer namespaceSerializer, + CombineWithContext.Context context, + boolean fasterCopy) { + + this.namespace = namespace; + this.stateId = stateId; + this.combineFn = combineFn; + this.flinkStateBackend = flinkStateBackend; + this.context = context; + this.namespaceSerializer = namespaceSerializer; + + flinkStateDescriptor = + new ValueStateDescriptor<>(stateId, new CoderTypeSerializer<>(accumCoder, fasterCopy)); + } + + @Override + public CombiningState readLater() { + return this; + } + + @Override + public void add(InputT value) { + try { + org.apache.flink.api.common.state.ValueState state = + flinkStateBackend.getPartitionedState( + namespace, namespaceSerializer, flinkStateDescriptor); + + AccumT current = state.value(); + if (current == null) { + current = combineFn.createAccumulator(context); + } + current = combineFn.addInput(current, value, context); + state.update(current); + } catch (Exception e) { + throw new RuntimeException("Error adding to state.", e); + } + } + + @Override + public void addAccum(AccumT accum) { + try { + org.apache.flink.api.common.state.ValueState state = + flinkStateBackend.getPartitionedState( + namespace, namespaceSerializer, flinkStateDescriptor); + + AccumT current = state.value(); + if (current == null) { + state.update(accum); + } else { + current = combineFn.mergeAccumulators(Lists.newArrayList(current, accum), context); + state.update(current); + } + } catch (Exception e) { + throw new RuntimeException("Error adding to state.", e); + } + } + + @Override + public AccumT getAccum() { + try { + AccumT accum = + flinkStateBackend + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) + .value(); + return accum != null ? accum : combineFn.createAccumulator(context); + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public AccumT mergeAccumulators(Iterable accumulators) { + return combineFn.mergeAccumulators(accumulators, context); + } + + @Override + public OutputT read() { + try { + org.apache.flink.api.common.state.ValueState state = + flinkStateBackend.getPartitionedState( + namespace, namespaceSerializer, flinkStateDescriptor); + + AccumT accum = state.value(); + if (accum != null) { + return combineFn.extractOutput(accum, context); + } else { + return combineFn.extractOutput(combineFn.createAccumulator(context), context); + } + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public ReadableState isEmpty() { + return new ReadableState() { + @Override + public Boolean read() { + try { + return flinkStateBackend + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) + .value() + == null; + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public ReadableState readLater() { + return this; + } + }; + } + + @Override + public void clear() { + try { + flinkStateBackend + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) + .clear(); + } catch (Exception e) { + throw new RuntimeException("Error clearing state.", e); + } + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkCombiningStateWithContext that = + (FlinkCombiningStateWithContext) o; + + return namespace.equals(that.namespace) && stateId.equals(that.stateId); + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + stateId.hashCode(); + return result; + } + } + + private class FlinkWatermarkHoldState implements WatermarkHoldState { + + private final TimestampCombiner timestampCombiner; + private final String namespaceString; + private org.apache.flink.api.common.state.MapState watermarkHoldsState; + + public FlinkWatermarkHoldState( + KeyedStateBackend flinkStateBackend, + MapStateDescriptor watermarkHoldStateDescriptor, + String stateId, + StateNamespace namespace, + TimestampCombiner timestampCombiner) { + this.timestampCombiner = timestampCombiner; + // Combines StateNamespace and stateId to generate a unique namespace for + // watermarkHoldsState. We do not want to use Flink's namespacing to be + // able to recover watermark holds efficiently during recovery. + this.namespaceString = namespace.stringKey() + stateId; + try { + this.watermarkHoldsState = + flinkStateBackend.getPartitionedState( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + watermarkHoldStateDescriptor); + } catch (Exception e) { + throw new RuntimeException("Could not access state for watermark partition view"); + } + } + + @Override + public TimestampCombiner getTimestampCombiner() { + return timestampCombiner; + } + + @Override + public WatermarkHoldState readLater() { + return this; + } + + @Override + public ReadableState isEmpty() { + return new ReadableState() { + @Override + public Boolean read() { + try { + return watermarkHoldsState.get(namespaceString) == null; + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public ReadableState readLater() { + return this; + } + }; + } + + @Override + public void add(Instant value) { + try { + Instant current = watermarkHoldsState.get(namespaceString); + if (current == null) { + addWatermarkHoldUsage(value); + watermarkHoldsState.put(namespaceString, value); + } else { + Instant combined = timestampCombiner.combine(current, value); + if (combined.getMillis() != current.getMillis()) { + removeWatermarkHoldUsage(current); + addWatermarkHoldUsage(combined); + watermarkHoldsState.put(namespaceString, combined); + } + } + } catch (Exception e) { + throw new RuntimeException("Error updating state.", e); + } + } + + @Override + public Instant read() { + try { + return watermarkHoldsState.get(namespaceString); + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public void clear() { + Instant current = read(); + if (current != null) { + removeWatermarkHoldUsage(current); + } + try { + watermarkHoldsState.remove(namespaceString); + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkWatermarkHoldState that = (FlinkWatermarkHoldState) o; + + if (!timestampCombiner.equals(that.timestampCombiner)) { + return false; + } + return namespaceString.equals(that.namespaceString); + } + + @Override + public int hashCode() { + int result = namespaceString.hashCode(); + result = 31 * result + timestampCombiner.hashCode(); + return result; + } + } + + private static class FlinkMapState implements MapState { + + private final StateNamespace namespace; + private final String stateId; + private final MapStateDescriptor flinkStateDescriptor; + private final KeyedStateBackend flinkStateBackend; + private final FlinkStateNamespaceKeySerializer namespaceSerializer; + + FlinkMapState( + KeyedStateBackend flinkStateBackend, + String stateId, + StateNamespace namespace, + Coder mapKeyCoder, + Coder mapValueCoder, + FlinkStateNamespaceKeySerializer namespaceSerializer, + boolean fasterCopy) { + this.namespace = namespace; + this.stateId = stateId; + this.flinkStateBackend = flinkStateBackend; + this.flinkStateDescriptor = + new MapStateDescriptor<>( + stateId, + new CoderTypeSerializer<>(mapKeyCoder, fasterCopy), + new CoderTypeSerializer<>(mapValueCoder, fasterCopy)); + this.namespaceSerializer = namespaceSerializer; + } + + @Override + public ReadableState get(final KeyT input) { + return getOrDefault(input, null); + } + + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState getOrDefault( + KeyT key, @Nullable ValueT defaultValue) { + return new ReadableState() { + @Override + public @Nullable ValueT read() { + try { + ValueT value = + flinkStateBackend + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) + .get(key); + return (value != null) ? value : defaultValue; + } catch (Exception e) { + throw new RuntimeException("Error get from state.", e); + } + } + + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState readLater() { + return this; + } + }; + } + + @Override + public void put(KeyT key, ValueT value) { + try { + flinkStateBackend + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) + .put(key, value); + } catch (Exception e) { + throw new RuntimeException("Error put kv to state.", e); + } + } + + @Override + public ReadableState computeIfAbsent( + final KeyT key, Function mappingFunction) { + try { + ValueT current = + flinkStateBackend + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) + .get(key); + + if (current == null) { + flinkStateBackend + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) + .put(key, mappingFunction.apply(key)); + } + return ReadableStates.immediate(current); + } catch (Exception e) { + throw new RuntimeException("Error put kv to state.", e); + } + } + + @Override + public void remove(KeyT key) { + try { + flinkStateBackend + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) + .remove(key); + } catch (Exception e) { + throw new RuntimeException("Error remove map state key.", e); + } + } + + @Override + public ReadableState> keys() { + return new ReadableState>() { + @Override + public Iterable read() { + try { + Iterable result = + flinkStateBackend + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) + .keys(); + return result != null ? ImmutableList.copyOf(result) : Collections.emptyList(); + } catch (Exception e) { + throw new RuntimeException("Error get map state keys.", e); + } + } + + @Override + public ReadableState> readLater() { + return this; + } + }; + } + + @Override + public ReadableState> values() { + return new ReadableState>() { + @Override + public Iterable read() { + try { + Iterable result = + flinkStateBackend + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) + .values(); + return result != null ? ImmutableList.copyOf(result) : Collections.emptyList(); + } catch (Exception e) { + throw new RuntimeException("Error get map state values.", e); + } + } + + @Override + public ReadableState> readLater() { + return this; + } + }; + } + + @Override + public ReadableState>> entries() { + return new ReadableState>>() { + @Override + public Iterable> read() { + try { + Iterable> result = + flinkStateBackend + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) + .entries(); + return result != null ? ImmutableList.copyOf(result) : Collections.emptyList(); + } catch (Exception e) { + throw new RuntimeException("Error get map state entries.", e); + } + } + + @Override + public ReadableState>> readLater() { + return this; + } + }; + } + + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState< + @UnknownKeyFor @NonNull @Initialized Boolean> + isEmpty() { + ReadableState> keys = this.keys(); + return new ReadableState() { + @Override + public @Nullable Boolean read() { + return Iterables.isEmpty(keys.read()); + } + + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState readLater() { + keys.readLater(); + return this; + } + }; + } + + @Override + public void clear() { + try { + flinkStateBackend + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) + .clear(); + } catch (Exception e) { + throw new RuntimeException("Error clearing state.", e); + } + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkMapState that = (FlinkMapState) o; + + return namespace.equals(that.namespace) && stateId.equals(that.stateId); + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + stateId.hashCode(); + return result; + } + } + + private static class FlinkSetState implements SetState { + + private final StateNamespace namespace; + private final String stateId; + private final MapStateDescriptor flinkStateDescriptor; + private final KeyedStateBackend flinkStateBackend; + private final FlinkStateNamespaceKeySerializer namespaceSerializer; + + FlinkSetState( + KeyedStateBackend flinkStateBackend, + String stateId, + StateNamespace namespace, + Coder coder, + FlinkStateNamespaceKeySerializer namespaceSerializer, + boolean fasterCopy) { + this.namespace = namespace; + this.stateId = stateId; + this.flinkStateBackend = flinkStateBackend; + this.flinkStateDescriptor = + new MapStateDescriptor<>( + stateId, new CoderTypeSerializer<>(coder, fasterCopy), BooleanSerializer.INSTANCE); + this.namespaceSerializer = namespaceSerializer; + } + + @Override + public ReadableState contains(final T t) { + try { + Boolean result = + flinkStateBackend + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) + .get(t); + return ReadableStates.immediate(result != null && result); + } catch (Exception e) { + throw new RuntimeException("Error contains value from state.", e); + } + } + + @Override + public ReadableState addIfAbsent(final T t) { + try { + org.apache.flink.api.common.state.MapState state = + flinkStateBackend.getPartitionedState( + namespace, namespaceSerializer, flinkStateDescriptor); + boolean alreadyContained = state.contains(t); + if (!alreadyContained) { + state.put(t, true); + } + return ReadableStates.immediate(!alreadyContained); + } catch (Exception e) { + throw new RuntimeException("Error addIfAbsent value to state.", e); + } + } + + @Override + public void remove(T t) { + try { + flinkStateBackend + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) + .remove(t); + } catch (Exception e) { + throw new RuntimeException("Error remove value to state.", e); + } + } + + @Override + public SetState readLater() { + return this; + } + + @Override + public void add(T value) { + try { + flinkStateBackend + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) + .put(value, true); + } catch (Exception e) { + throw new RuntimeException("Error add value to state.", e); + } + } + + @Override + public ReadableState isEmpty() { + return new ReadableState() { + @Override + public Boolean read() { + try { + Iterable result = + flinkStateBackend + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) + .keys(); + return result == null || Iterables.isEmpty(result); + } catch (Exception e) { + throw new RuntimeException("Error isEmpty from state.", e); + } + } + + @Override + public ReadableState readLater() { + return this; + } + }; + } + + @Override + public Iterable read() { + try { + Iterable result = + flinkStateBackend + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) + .keys(); + return result != null ? ImmutableList.copyOf(result) : Collections.emptyList(); + } catch (Exception e) { + throw new RuntimeException("Error read from state.", e); + } + } + + @Override + public void clear() { + try { + flinkStateBackend + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) + .clear(); + } catch (Exception e) { + throw new RuntimeException("Error clearing state.", e); + } + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkSetState that = (FlinkSetState) o; + + return namespace.equals(that.namespace) && stateId.equals(that.stateId); + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + stateId.hashCode(); + return result; + } + } + + public void addWatermarkHoldUsage(Instant watermarkHold) { + watermarkHolds.add(watermarkHold.getMillis()); + } + + public void removeWatermarkHoldUsage(Instant watermarkHold) { + watermarkHolds.remove(watermarkHold.getMillis()); + } + + /** Restores a view of the watermark holds of all keys of this partition. */ + private void restoreWatermarkHoldsView() throws Exception { + org.apache.flink.api.common.state.MapState mapState = + flinkStateBackend.getPartitionedState( + VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, watermarkHoldStateDescriptor); + try (Stream keys = + flinkStateBackend.getKeys(watermarkHoldStateDescriptor.getName(), VoidNamespace.INSTANCE)) { + Iterator iterator = keys.iterator(); + while (iterator.hasNext()) { + flinkStateBackend.setCurrentKey(iterator.next()); + mapState.values().forEach(this::addWatermarkHoldUsage); + } + } + } + + /** Eagerly create user state to work around https://jira.apache.org/jira/browse/FLINK-12653. */ + public static class EarlyBinder implements StateBinder { + + private final KeyedStateBackend keyedStateBackend; + private final Boolean fasterCopy; + private final FlinkStateNamespaceKeySerializer namespaceSerializer; + + public EarlyBinder( + KeyedStateBackend keyedStateBackend, + SerializablePipelineOptions pipelineOptions, + Coder windowCoder) { + this.keyedStateBackend = keyedStateBackend; + this.fasterCopy = pipelineOptions.get().as(FlinkPipelineOptions.class).getFasterCopy(); + this.namespaceSerializer = new FlinkStateNamespaceKeySerializer(windowCoder); + } + + @Override + public ValueState bindValue(String id, StateSpec> spec, Coder coder) { + try { + keyedStateBackend.getOrCreateKeyedState( + namespaceSerializer, + new ValueStateDescriptor<>(id, new CoderTypeSerializer<>(coder, fasterCopy))); + } catch (Exception e) { + throw new RuntimeException(e); + } + + return null; + } + + @Override + public BagState bindBag(String id, StateSpec> spec, Coder elemCoder) { + try { + keyedStateBackend.getOrCreateKeyedState( + namespaceSerializer, + new ListStateDescriptor<>(id, new CoderTypeSerializer<>(elemCoder, fasterCopy))); + } catch (Exception e) { + throw new RuntimeException(e); + } + + return null; + } + + @Override + public SetState bindSet(String id, StateSpec> spec, Coder elemCoder) { + try { + keyedStateBackend.getOrCreateKeyedState( + namespaceSerializer, + new MapStateDescriptor<>( + id, new CoderTypeSerializer<>(elemCoder, fasterCopy), BooleanSerializer.INSTANCE)); + } catch (Exception e) { + throw new RuntimeException(e); + } + return null; + } + + @Override + public org.apache.beam.sdk.state.MapState bindMap( + String id, + StateSpec> spec, + Coder mapKeyCoder, + Coder mapValueCoder) { + try { + keyedStateBackend.getOrCreateKeyedState( + namespaceSerializer, + new MapStateDescriptor<>( + id, + new CoderTypeSerializer<>(mapKeyCoder, fasterCopy), + new CoderTypeSerializer<>(mapValueCoder, fasterCopy))); + } catch (Exception e) { + throw new RuntimeException(e); + } + return null; + } + + @Override + public OrderedListState bindOrderedList( + String id, StateSpec> spec, Coder elemCoder) { + try { + keyedStateBackend.getOrCreateKeyedState( + namespaceSerializer, + new ListStateDescriptor<>( + id, new CoderTypeSerializer<>(TimestampedValueCoder.of(elemCoder), fasterCopy))); + } catch (Exception e) { + throw new RuntimeException(e); + } + + return null; + } + + @Override + public MultimapState bindMultimap( + String id, + StateSpec> spec, + Coder keyCoder, + Coder valueCoder) { + throw new UnsupportedOperationException( + String.format("%s is not supported", MultimapState.class.getSimpleName())); + } + + @Override + public CombiningState bindCombining( + String id, + StateSpec> spec, + Coder accumCoder, + Combine.CombineFn combineFn) { + try { + keyedStateBackend.getOrCreateKeyedState( + namespaceSerializer, + new ValueStateDescriptor<>(id, new CoderTypeSerializer<>(accumCoder, fasterCopy))); + } catch (Exception e) { + throw new RuntimeException(e); + } + return null; + } + + @Override + public + CombiningState bindCombiningWithContext( + String id, + StateSpec> spec, + Coder accumCoder, + CombineWithContext.CombineFnWithContext combineFn) { + try { + keyedStateBackend.getOrCreateKeyedState( + namespaceSerializer, + new ValueStateDescriptor<>(id, new CoderTypeSerializer<>(accumCoder, fasterCopy))); + } catch (Exception e) { + throw new RuntimeException(e); + } + return null; + } + + @Override + public WatermarkHoldState bindWatermark( + String id, StateSpec spec, TimestampCombiner timestampCombiner) { + try { + keyedStateBackend.getOrCreateKeyedState( + VoidNamespaceSerializer.INSTANCE, + new MapStateDescriptor<>( + "watermark-holds", + StringSerializer.INSTANCE, + new CoderTypeSerializer<>(InstantCoder.of(), fasterCopy))); + } catch (Exception e) { + throw new RuntimeException(e); + } + return null; + } + } +} diff --git a/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/EncodedValueComparatorTest.java b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/EncodedValueComparatorTest.java new file mode 100644 index 000000000000..2aad3903f848 --- /dev/null +++ b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/EncodedValueComparatorTest.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import org.apache.beam.runners.flink.translation.types.EncodedValueComparator; +import org.apache.beam.runners.flink.translation.types.EncodedValueTypeInformation; +import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.util.CoderUtils; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.serialization.SerializerConfigImpl; +import org.apache.flink.api.common.typeutils.ComparatorTestBase; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.junit.Assert; + +/** Test for {@link EncodedValueComparator}. */ +public class EncodedValueComparatorTest extends ComparatorTestBase { + + @Override + protected TypeComparator createComparator(boolean ascending) { + return new EncodedValueTypeInformation().createComparator(ascending, new ExecutionConfig()); + } + + @Override + protected TypeSerializer createSerializer() { + return new EncodedValueTypeInformation().createSerializer(new SerializerConfigImpl()); + } + + @Override + protected void deepEquals(String message, byte[] should, byte[] is) { + Assert.assertArrayEquals(message, should, is); + } + + @Override + protected byte[][] getSortedTestData() { + StringUtf8Coder coder = StringUtf8Coder.of(); + + try { + return new byte[][] { + CoderUtils.encodeToByteArray(coder, ""), + CoderUtils.encodeToByteArray(coder, "Lorem Ipsum Dolor Omit Longer"), + CoderUtils.encodeToByteArray(coder, "aaaa"), + CoderUtils.encodeToByteArray(coder, "abcd"), + CoderUtils.encodeToByteArray(coder, "abce"), + CoderUtils.encodeToByteArray(coder, "abdd"), + CoderUtils.encodeToByteArray(coder, "accd"), + CoderUtils.encodeToByteArray(coder, "bbcd") + }; + } catch (CoderException e) { + throw new RuntimeException("Could not encode values.", e); + } + } +} diff --git a/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/FlinkExecutionEnvironmentsTest.java b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/FlinkExecutionEnvironmentsTest.java new file mode 100644 index 000000000000..83b8719811e0 --- /dev/null +++ b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/FlinkExecutionEnvironmentsTest.java @@ -0,0 +1,582 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.core.Is.is; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.net.InetSocketAddress; +import java.nio.file.Files; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.RestOptions; +import org.apache.flink.configuration.StateBackendOptions; +import org.apache.flink.streaming.api.environment.LocalStreamEnvironment; +import org.apache.flink.streaming.api.environment.RemoteStreamEnvironment; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.powermock.reflect.Whitebox; + +/** Tests for {@link FlinkExecutionEnvironments}. */ +@RunWith(Parameterized.class) +public class FlinkExecutionEnvironmentsTest { + + @Rule public TemporaryFolder temporaryFolder = new TemporaryFolder(); + @Rule public ExpectedException expectedException = ExpectedException.none(); + + @Parameterized.Parameter public boolean useDataStreamForBatch; + + @Parameterized.Parameters(name = "UseDataStreamForBatch = {0}") + public static Collection useDataStreamForBatchJobValues() { + return Arrays.asList(new Object[][] {{false}, {true}}); + } + + private FlinkPipelineOptions getDefaultPipelineOptions() { + FlinkPipelineOptions options = FlinkPipelineOptions.defaults(); + options.setUseDataStreamForBatch(useDataStreamForBatch); + return options; + } + + @Test + public void shouldSetParallelismBatch() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(TestFlinkRunner.class); + options.setParallelism(42); + + StreamExecutionEnvironment bev = + FlinkExecutionEnvironments.createBatchExecutionEnvironment(options); + + assertThat(options.getParallelism(), is(42)); + assertThat(bev.getParallelism(), is(42)); + } + + @Test + public void shouldSetParallelismStreaming() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(TestFlinkRunner.class); + options.setParallelism(42); + + StreamExecutionEnvironment sev = + FlinkExecutionEnvironments.createStreamExecutionEnvironment(options); + + assertThat(options.getParallelism(), is(42)); + assertThat(sev.getParallelism(), is(42)); + } + + @Test + public void shouldSetMaxParallelismStreaming() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(TestFlinkRunner.class); + options.setMaxParallelism(42); + + StreamExecutionEnvironment sev = + FlinkExecutionEnvironments.createStreamExecutionEnvironment(options); + + assertThat(options.getMaxParallelism(), is(42)); + assertThat(sev.getMaxParallelism(), is(42)); + } + + @Test + public void shouldInferParallelismFromEnvironmentBatch() throws IOException { + String flinkConfDir = extractFlinkConfig(); + + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(TestFlinkRunner.class); + options.setFlinkMaster("host:80"); + + StreamExecutionEnvironment bev = + FlinkExecutionEnvironments.createBatchExecutionEnvironment( + options, Collections.emptyList(), flinkConfDir); + + assertThat(options.getParallelism(), is(23)); + assertThat(bev.getParallelism(), is(23)); + } + + @Test + public void shouldInferParallelismFromEnvironmentStreaming() throws IOException { + String confDir = extractFlinkConfig(); + + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(TestFlinkRunner.class); + options.setFlinkMaster("host:80"); + + StreamExecutionEnvironment sev = + FlinkExecutionEnvironments.createStreamExecutionEnvironment( + options, Collections.emptyList(), confDir); + + assertThat(options.getParallelism(), is(23)); + assertThat(sev.getParallelism(), is(23)); + } + + @Test + public void shouldFallbackToDefaultParallelismBatch() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(TestFlinkRunner.class); + options.setFlinkMaster("host:80"); + + StreamExecutionEnvironment bev = + FlinkExecutionEnvironments.createBatchExecutionEnvironment(options); + + assertThat(options.getParallelism(), is(1)); + assertThat(bev.getParallelism(), is(1)); + } + + @Test + public void shouldFallbackToDefaultParallelismStreaming() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(TestFlinkRunner.class); + options.setFlinkMaster("host:80"); + + StreamExecutionEnvironment sev = + FlinkExecutionEnvironments.createStreamExecutionEnvironment(options); + + assertThat(options.getParallelism(), is(1)); + assertThat(sev.getParallelism(), is(1)); + } + + @Test + public void useDefaultParallelismFromContextBatch() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(TestFlinkRunner.class); + + StreamExecutionEnvironment bev = + FlinkExecutionEnvironments.createBatchExecutionEnvironment(options); + + assertThat(bev, instanceOf(LocalStreamEnvironment.class)); + assertThat(options.getParallelism(), is(LocalStreamEnvironment.getDefaultLocalParallelism())); + assertThat(bev.getParallelism(), is(LocalStreamEnvironment.getDefaultLocalParallelism())); + } + + @Test + public void useDefaultParallelismFromContextStreaming() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(TestFlinkRunner.class); + + StreamExecutionEnvironment sev = + FlinkExecutionEnvironments.createStreamExecutionEnvironment(options); + + assertThat(sev, instanceOf(LocalStreamEnvironment.class)); + assertThat(options.getParallelism(), is(LocalStreamEnvironment.getDefaultLocalParallelism())); + assertThat(sev.getParallelism(), is(LocalStreamEnvironment.getDefaultLocalParallelism())); + } + + @Test + public void shouldParsePortForRemoteEnvironmentBatch() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(FlinkRunner.class); + options.setFlinkMaster("host:1234"); + + StreamExecutionEnvironment bev = + FlinkExecutionEnvironments.createBatchExecutionEnvironment(options); + + assertThat(bev, instanceOf(RemoteStreamEnvironment.class)); + checkHostAndPort(bev, "host", 1234); + } + + @Test + public void shouldParsePortForRemoteEnvironmentStreaming() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(FlinkRunner.class); + options.setFlinkMaster("host:1234"); + + StreamExecutionEnvironment sev = + FlinkExecutionEnvironments.createStreamExecutionEnvironment(options); + + assertThat(sev, instanceOf(RemoteStreamEnvironment.class)); + checkHostAndPort(sev, "host", 1234); + } + + @Test + public void shouldAllowPortOmissionForRemoteEnvironmentBatch() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(FlinkRunner.class); + options.setFlinkMaster("host"); + + StreamExecutionEnvironment bev = + FlinkExecutionEnvironments.createBatchExecutionEnvironment(options); + + assertThat(bev, instanceOf(RemoteStreamEnvironment.class)); + checkHostAndPort(bev, "host", RestOptions.PORT.defaultValue()); + } + + @Test + public void shouldAllowPortOmissionForRemoteEnvironmentStreaming() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(FlinkRunner.class); + options.setFlinkMaster("host"); + + StreamExecutionEnvironment sev = + FlinkExecutionEnvironments.createStreamExecutionEnvironment(options); + + assertThat(sev, instanceOf(RemoteStreamEnvironment.class)); + checkHostAndPort(sev, "host", RestOptions.PORT.defaultValue()); + } + + @Test + public void shouldTreatAutoAndEmptyHostTheSameBatch() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(FlinkRunner.class); + + StreamExecutionEnvironment sev = + FlinkExecutionEnvironments.createBatchExecutionEnvironment(options); + + options.setFlinkMaster("[auto]"); + + StreamExecutionEnvironment sev2 = + FlinkExecutionEnvironments.createBatchExecutionEnvironment(options); + + assertEquals(sev.getClass(), sev2.getClass()); + } + + @Test + public void shouldTreatAutoAndEmptyHostTheSameStreaming() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(FlinkRunner.class); + + StreamExecutionEnvironment sev = + FlinkExecutionEnvironments.createStreamExecutionEnvironment(options); + + options.setFlinkMaster("[auto]"); + + StreamExecutionEnvironment sev2 = + FlinkExecutionEnvironments.createStreamExecutionEnvironment(options); + + assertEquals(sev.getClass(), sev2.getClass()); + } + + @Test + public void shouldDetectMalformedPortBatch() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(FlinkRunner.class); + options.setFlinkMaster("host:p0rt"); + + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Unparseable port number"); + + FlinkExecutionEnvironments.createBatchExecutionEnvironment(options); + } + + @Test + public void shouldDetectMalformedPortStreaming() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(FlinkRunner.class); + options.setFlinkMaster("host:p0rt"); + + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Unparseable port number"); + + FlinkExecutionEnvironments.createStreamExecutionEnvironment(options); + } + + @Test + public void shouldSupportIPv4Batch() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(FlinkRunner.class); + + options.setFlinkMaster("192.168.1.1:1234"); + StreamExecutionEnvironment bev = + FlinkExecutionEnvironments.createBatchExecutionEnvironment(options); + checkHostAndPort(bev, "192.168.1.1", 1234); + + options.setFlinkMaster("192.168.1.1"); + bev = FlinkExecutionEnvironments.createBatchExecutionEnvironment(options); + checkHostAndPort(bev, "192.168.1.1", RestOptions.PORT.defaultValue()); + } + + @Test + public void shouldSupportIPv4Streaming() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(FlinkRunner.class); + + options.setFlinkMaster("192.168.1.1:1234"); + StreamExecutionEnvironment bev = + FlinkExecutionEnvironments.createBatchExecutionEnvironment(options); + checkHostAndPort(bev, "192.168.1.1", 1234); + + options.setFlinkMaster("192.168.1.1"); + bev = FlinkExecutionEnvironments.createBatchExecutionEnvironment(options); + checkHostAndPort(bev, "192.168.1.1", RestOptions.PORT.defaultValue()); + } + + @Test + public void shouldSupportIPv6Batch() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(FlinkRunner.class); + + options.setFlinkMaster("[FE80:CD00:0000:0CDE:1257:0000:211E:729C]:1234"); + StreamExecutionEnvironment bev = + FlinkExecutionEnvironments.createBatchExecutionEnvironment(options); + checkHostAndPort(bev, "FE80:CD00:0000:0CDE:1257:0000:211E:729C", 1234); + + options.setFlinkMaster("FE80:CD00:0000:0CDE:1257:0000:211E:729C"); + bev = FlinkExecutionEnvironments.createBatchExecutionEnvironment(options); + checkHostAndPort( + bev, "FE80:CD00:0000:0CDE:1257:0000:211E:729C", RestOptions.PORT.defaultValue()); + } + + @Test + public void shouldSupportIPv6Streaming() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(FlinkRunner.class); + + options.setFlinkMaster("[FE80:CD00:0000:0CDE:1257:0000:211E:729C]:1234"); + StreamExecutionEnvironment sev = + FlinkExecutionEnvironments.createStreamExecutionEnvironment(options); + checkHostAndPort(sev, "FE80:CD00:0000:0CDE:1257:0000:211E:729C", 1234); + + options.setFlinkMaster("FE80:CD00:0000:0CDE:1257:0000:211E:729C"); + sev = FlinkExecutionEnvironments.createStreamExecutionEnvironment(options); + checkHostAndPort( + sev, "FE80:CD00:0000:0CDE:1257:0000:211E:729C", RestOptions.PORT.defaultValue()); + } + + @Test + public void shouldRemoveHttpProtocolFromHostBatch() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(FlinkRunner.class); + + for (String flinkMaster : + new String[] { + "http://host:1234", " http://host:1234", "https://host:1234", " https://host:1234" + }) { + options.setFlinkMaster(flinkMaster); + StreamExecutionEnvironment sev = + FlinkExecutionEnvironments.createBatchExecutionEnvironment(options); + checkHostAndPort(sev, "host", 1234); + } + } + + @Test + public void shouldRemoveHttpProtocolFromHostStreaming() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(FlinkRunner.class); + + for (String flinkMaster : + new String[] { + "http://host:1234", " http://host:1234", "https://host:1234", " https://host:1234" + }) { + options.setFlinkMaster(flinkMaster); + StreamExecutionEnvironment sev = + FlinkExecutionEnvironments.createStreamExecutionEnvironment(options); + checkHostAndPort(sev, "host", 1234); + } + } + + private String extractFlinkConfig() throws IOException { + InputStream inputStream = getClass().getResourceAsStream("/flink-test-config.yaml"); + File root = temporaryFolder.getRoot(); + Files.copy(inputStream, new File(root, "config.yaml").toPath()); + return root.getAbsolutePath(); + } + + @Test + public void shouldAutoSetIdleSourcesFlagWithoutCheckpointing() { + // Checkpointing disabled, shut down sources immediately + FlinkPipelineOptions options = getDefaultPipelineOptions(); + FlinkExecutionEnvironments.createStreamExecutionEnvironment(options); + assertThat(options.getShutdownSourcesAfterIdleMs(), is(0L)); + } + + @Test + public void shouldAutoSetIdleSourcesFlagWithCheckpointing() { + // Checkpointing is enabled, never shut down sources + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setCheckpointingInterval(1000L); + FlinkExecutionEnvironments.createStreamExecutionEnvironment(options); + assertThat(options.getShutdownSourcesAfterIdleMs(), is(Long.MAX_VALUE)); + } + + @Test + public void shouldAcceptExplicitlySetIdleSourcesFlagWithoutCheckpointing() { + // Checkpointing disabled, accept flag + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setShutdownSourcesAfterIdleMs(42L); + FlinkExecutionEnvironments.createStreamExecutionEnvironment(options); + assertThat(options.getShutdownSourcesAfterIdleMs(), is(42L)); + } + + @Test + public void shouldAcceptExplicitlySetIdleSourcesFlagWithCheckpointing() { + // Checkpointing enable, still accept flag + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setCheckpointingInterval(1000L); + options.setShutdownSourcesAfterIdleMs(42L); + FlinkExecutionEnvironments.createStreamExecutionEnvironment(options); + assertThat(options.getShutdownSourcesAfterIdleMs(), is(42L)); + } + + @Test + public void shouldSetSavepointRestoreForRemoteStreaming() { + String path = "fakePath"; + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(TestFlinkRunner.class); + options.setFlinkMaster("host:80"); + options.setSavepointPath(path); + + StreamExecutionEnvironment sev = + FlinkExecutionEnvironments.createStreamExecutionEnvironment(options); + // subject to change with https://issues.apache.org/jira/browse/FLINK-11048 + assertThat(sev, instanceOf(RemoteStreamEnvironment.class)); + assertThat(getSavepointPath(sev), is(path)); + } + + @Test + public void shouldFailOnUnknownStateBackend() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setStreaming(true); + options.setStateBackend("unknown"); + options.setStateBackendStoragePath("/path"); + + assertThrows( + "State backend was set to 'unknown' but no storage path was provided.", + IllegalArgumentException.class, + () -> FlinkExecutionEnvironments.createStreamExecutionEnvironment(options)); + } + + @Test + public void shouldFailOnNoStoragePathProvided() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setStreaming(true); + options.setStateBackend("unknown"); + + assertThrows( + "State backend was set to 'unknown' but no storage path was provided.", + IllegalArgumentException.class, + () -> FlinkExecutionEnvironments.createStreamExecutionEnvironment(options)); + } + + @Test + public void shouldCreateFileSystemStateBackend() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setStreaming(true); + options.setStateBackend("fileSystem"); + options.setStateBackendStoragePath(temporaryFolder.getRoot().toURI().toString()); + + StreamExecutionEnvironment sev = + FlinkExecutionEnvironments.createStreamExecutionEnvironment(options); + + assertEquals("hashmap", sev.getConfiguration().get(StateBackendOptions.STATE_BACKEND)); + } + + @Test + public void shouldCreateRocksDbStateBackend() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setStreaming(true); + options.setStateBackend("rocksDB"); + options.setStateBackendStoragePath(temporaryFolder.getRoot().toURI().toString()); + + StreamExecutionEnvironment sev = + FlinkExecutionEnvironments.createStreamExecutionEnvironment(options); + + assertEquals("rocksdb", sev.getConfiguration().get(StateBackendOptions.STATE_BACKEND)); + } + + /** Test interface. */ + public interface TestOptions extends PipelineOptions { + String getKey1(); + + void setKey1(String value); + + Boolean getKey2(); + + void setKey2(Boolean value); + + String getKey3(); + + void setKey3(String value); + } + + @Test + public void shouldSetWebUIOptions() { + PipelineOptionsFactory.register(TestOptions.class); + PipelineOptionsFactory.register(FlinkPipelineOptions.class); + + FlinkPipelineOptions options = + PipelineOptionsFactory.fromArgs( + "--key1=value1", + "--key2", + "--key3=", + "--parallelism=10", + "--checkpointTimeoutMillis=500") + .as(FlinkPipelineOptions.class); + + StreamExecutionEnvironment sev = + FlinkExecutionEnvironments.createStreamExecutionEnvironment(options); + + Map actualMap = sev.getConfig().getGlobalJobParameters().toMap(); + + Map expectedMap = new HashMap<>(); + expectedMap.put("key1", "value1"); + expectedMap.put("key2", "true"); + expectedMap.put("key3", ""); + expectedMap.put("checkpointTimeoutMillis", "500"); + expectedMap.put("parallelism", "10"); + + Map filteredMap = + expectedMap.entrySet().stream() + .filter( + kv -> + actualMap.containsKey(kv.getKey()) + && kv.getValue().equals(actualMap.get(kv.getKey()))) + .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())); + + assertTrue(expectedMap.size() == filteredMap.size()); + } + + private void checkHostAndPort(Object env, String expectedHost, int expectedPort) { + String host = + ((Configuration) Whitebox.getInternalState(env, "configuration")).get(RestOptions.ADDRESS); + int port = + ((Configuration) Whitebox.getInternalState(env, "configuration")).get(RestOptions.PORT); + assertThat( + new InetSocketAddress(host, port), is(new InetSocketAddress(expectedHost, expectedPort))); + } + + private String getSavepointPath(Object env) { + // pre Flink 1.20 config + String path = + ((Configuration) Whitebox.getInternalState(env, "configuration")) + .getString("execution.savepoint.path", null); + if (path == null) { + // Flink 1.20+ + path = + ((Configuration) Whitebox.getInternalState(env, "configuration")) + .getString("execution.state-recovery.path", null); + } + return path; + } +} diff --git a/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironmentTest.java b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironmentTest.java new file mode 100644 index 000000000000..64ea685e8950 --- /dev/null +++ b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironmentTest.java @@ -0,0 +1,421 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import static org.apache.beam.sdk.testing.RegexMatcher.matches; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.CoreMatchers.startsWith; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.core.Every.everyItem; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.fail; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.io.PrintStream; +import java.io.Serializable; +import java.lang.reflect.Method; +import java.net.MalformedURLException; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.io.GenerateSequence; +import org.apache.beam.sdk.io.TextIO; +import org.apache.beam.sdk.runners.PTransformOverride; +import org.apache.beam.sdk.runners.PTransformOverrideFactory; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.util.construction.PTransformMatchers; +import org.apache.beam.sdk.util.construction.PTransformTranslation; +import org.apache.beam.sdk.util.construction.resources.PipelineResources; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.streaming.api.environment.RemoteStreamEnvironment; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.hamcrest.BaseMatcher; +import org.hamcrest.Description; +import org.hamcrest.Matchers; +import org.joda.time.Duration; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; +import org.powermock.reflect.Whitebox; + +/** Tests for {@link FlinkPipelineExecutionEnvironment}. */ +@RunWith(JUnit4.class) +@SuppressWarnings({ + "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) +}) +public class FlinkPipelineExecutionEnvironmentTest implements Serializable { + + @Rule public transient TemporaryFolder tmpFolder = new TemporaryFolder(); + + private FlinkPipelineOptions getDefaultPipelineOptions() { + return FlinkPipelineOptions.defaults(); + } + + @Test + public void shouldRecognizeAndTranslateStreamingPipeline() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(TestFlinkRunner.class); + options.setFlinkMaster("[auto]"); + + FlinkPipelineExecutionEnvironment flinkEnv = new FlinkPipelineExecutionEnvironment(options); + Pipeline pipeline = Pipeline.create(); + + pipeline + .apply(GenerateSequence.from(0).withRate(1, Duration.standardSeconds(1))) + .apply( + ParDo.of( + new DoFn() { + + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + c.output(Long.toString(c.element())); + } + })) + .apply(Window.into(FixedWindows.of(Duration.standardHours(1)))) + .apply(TextIO.write().withNumShards(1).withWindowedWrites().to("/dummy/path")); + + flinkEnv.translate(pipeline); + + // no exception should be thrown + } + + @Test + public void shouldPrepareFilesToStageWhenFlinkMasterIsSetExplicitly() throws IOException { + FlinkPipelineOptions options = testPreparingResourcesToStage("localhost:8081", true, false); + + assertThat(options.getFilesToStage().size(), is(2)); + assertThat(options.getFilesToStage().get(0), matches(".*\\.jar")); + } + + @Test + public void shouldFailWhenFileDoesNotExistAndFlinkMasterIsSetExplicitly() { + assertThrows( + "To-be-staged file does not exist: ", + IllegalStateException.class, + () -> testPreparingResourcesToStage("localhost:8081", true, true)); + } + + @Test + public void shouldNotPrepareFilesToStageWhenFlinkMasterIsSetToAuto() throws IOException { + FlinkPipelineOptions options = testPreparingResourcesToStage("[auto]"); + + assertThat(options.getFilesToStage().size(), is(2)); + assertThat(options.getFilesToStage(), everyItem(not(matches(".*\\.jar")))); + } + + @Test + public void shouldNotPrepareFilesToStageWhenFlinkMasterIsSetToLocal() throws IOException { + FlinkPipelineOptions options = testPreparingResourcesToStage("[local]"); + + assertThat(options.getFilesToStage().size(), is(2)); + assertThat(options.getFilesToStage(), everyItem(not(matches(".*\\.jar")))); + } + + @Test + public void shouldUseDefaultTempLocationIfNoneSet() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(TestFlinkRunner.class); + options.setFlinkMaster("clusterAddress"); + + FlinkPipelineExecutionEnvironment flinkEnv = new FlinkPipelineExecutionEnvironment(options); + + Pipeline pipeline = Pipeline.create(options); + flinkEnv.translate(pipeline); + + String defaultTmpDir = System.getProperty("java.io.tmpdir"); + + assertThat(options.getFilesToStage(), hasItem(startsWith(defaultTmpDir))); + } + + @Test + public void shouldUsePreparedFilesOnRemoteEnvironment() throws Exception { + shouldUsePreparedFilesOnRemoteStreamEnvironment(true); + shouldUsePreparedFilesOnRemoteStreamEnvironment(false); + } + + public void shouldUsePreparedFilesOnRemoteStreamEnvironment(boolean streamingMode) + throws Exception { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(TestFlinkRunner.class); + options.setFlinkMaster("clusterAddress"); + options.setStreaming(streamingMode); + + FlinkPipelineExecutionEnvironment flinkEnv = new FlinkPipelineExecutionEnvironment(options); + + Pipeline pipeline = Pipeline.create(options); + flinkEnv.translate(pipeline); + + List jarFiles; + + StreamExecutionEnvironment streamExecutionEnvironment = + flinkEnv.getStreamExecutionEnvironment(); + assertThat(streamExecutionEnvironment, instanceOf(RemoteStreamEnvironment.class)); + jarFiles = getJars(streamExecutionEnvironment); + List urlConvertedStagedFiles = convertFilesToURLs(options.getFilesToStage()); + + assertThat(jarFiles, is(urlConvertedStagedFiles)); + } + + @Test + public void shouldUseTransformOverrides() { + boolean[] testParameters = {true, false}; + for (boolean streaming : testParameters) { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setStreaming(streaming); + options.setRunner(FlinkRunner.class); + FlinkPipelineExecutionEnvironment flinkEnv = new FlinkPipelineExecutionEnvironment(options); + Pipeline p = Mockito.spy(Pipeline.create(options)); + + flinkEnv.translate(p); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ImmutableList.class); + Mockito.verify(p).replaceAll(captor.capture()); + ImmutableList overridesList = captor.getValue(); + + assertThat(overridesList.isEmpty(), is(false)); + assertThat( + overridesList.size(), is(FlinkTransformOverrides.getDefaultOverrides(options).size())); + } + } + + @Test + public void shouldProvideParallelismToTransformOverrides() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setStreaming(true); + options.setRunner(FlinkRunner.class); + FlinkPipelineExecutionEnvironment flinkEnv = new FlinkPipelineExecutionEnvironment(options); + Pipeline p = Pipeline.create(options); + // Create a transform applicable for PTransformMatchers.writeWithRunnerDeterminedSharding() + // which requires parallelism + p.apply(Create.of("test")).apply(TextIO.write().to("/tmp")); + p = Mockito.spy(p); + + // If this succeeds we're ok + flinkEnv.translate(p); + + // Verify we were using desired replacement transform + ArgumentCaptor captor = ArgumentCaptor.forClass(ImmutableList.class); + Mockito.verify(p).replaceAll(captor.capture()); + ImmutableList overridesList = captor.getValue(); + assertThat( + overridesList, + hasItem( + new BaseMatcher() { + @Override + public void describeTo(Description description) {} + + @Override + public boolean matches(Object actual) { + if (actual instanceof PTransformOverride) { + PTransformOverrideFactory overrideFactory = + ((PTransformOverride) actual).getOverrideFactory(); + if (overrideFactory + instanceof FlinkStreamingPipelineTranslator.StreamingShardedWriteFactory) { + FlinkStreamingPipelineTranslator.StreamingShardedWriteFactory factory = + (FlinkStreamingPipelineTranslator.StreamingShardedWriteFactory) + overrideFactory; + return factory.options.getParallelism() > 0; + } + } + return false; + } + })); + } + + @Test + public void shouldUseStreamingTransformOverridesWithUnboundedSources() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + // no explicit streaming mode set + options.setRunner(FlinkRunner.class); + FlinkPipelineExecutionEnvironment flinkEnv = new FlinkPipelineExecutionEnvironment(options); + Pipeline p = Mockito.spy(Pipeline.create(options)); + + // Add unbounded source which will set the streaming mode to true + p.apply(GenerateSequence.from(0)); + + flinkEnv.translate(p); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ImmutableList.class); + Mockito.verify(p).replaceAll(captor.capture()); + ImmutableList overridesList = captor.getValue(); + + assertThat( + overridesList, + hasItem( + PTransformOverride.of( + PTransformMatchers.urnEqualTo(PTransformTranslation.CREATE_VIEW_TRANSFORM_URN), + CreateStreamingFlinkView.Factory.INSTANCE))); + } + + @Test + public void testTranslationModeOverrideWithUnboundedSources() { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(FlinkRunner.class); + options.setStreaming(false); + + FlinkPipelineExecutionEnvironment flinkEnv = new FlinkPipelineExecutionEnvironment(options); + Pipeline pipeline = Pipeline.create(options); + pipeline.apply(GenerateSequence.from(0)); + flinkEnv.translate(pipeline); + + assertThat(options.isStreaming(), Matchers.is(true)); + } + + @Test + public void testTranslationModeNoOverrideWithoutUnboundedSources() { + boolean[] testArgs = new boolean[] {true, false}; + for (boolean streaming : testArgs) { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(FlinkRunner.class); + options.setStreaming(streaming); + + FlinkPipelineExecutionEnvironment flinkEnv = new FlinkPipelineExecutionEnvironment(options); + Pipeline pipeline = Pipeline.create(options); + pipeline.apply(GenerateSequence.from(0).to(10)); + flinkEnv.translate(pipeline); + + assertThat(options.isStreaming(), Matchers.is(streaming)); + } + } + + @Test + public void shouldLogWarningWhenCheckpointingIsDisabled() { + Pipeline pipeline = Pipeline.create(); + pipeline.getOptions().setRunner(TestFlinkRunner.class); + + pipeline + // Add an UnboundedSource to check for the warning if checkpointing is disabled + .apply(GenerateSequence.from(0)) + .apply( + ParDo.of( + new DoFn() { + @ProcessElement + public void processElement(ProcessContext ctx) { + throw new RuntimeException("Failing here is ok."); + } + })); + + final PrintStream oldErr = System.err; + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + PrintStream replacementStdErr = new PrintStream(byteArrayOutputStream); + try { + System.setErr(replacementStdErr); + // Run pipeline and fail during execution + pipeline.run(); + fail("Should have failed"); + } catch (Exception e) { + // We want to fail here + } finally { + System.setErr(oldErr); + } + replacementStdErr.flush(); + assertThat( + new String(byteArrayOutputStream.toByteArray(), StandardCharsets.UTF_8), + containsString( + "UnboundedSources present which rely on checkpointing, but checkpointing is disabled.")); + } + + private FlinkPipelineOptions testPreparingResourcesToStage(String flinkMaster) + throws IOException { + return testPreparingResourcesToStage(flinkMaster, false, true); + } + + private FlinkPipelineOptions testPreparingResourcesToStage( + String flinkMaster, boolean includeIndividualFile, boolean includeNonExisting) + throws IOException { + Pipeline pipeline = Pipeline.create(); + String tempLocation = tmpFolder.newFolder().getAbsolutePath(); + + List filesToStage = new ArrayList<>(); + + File stagingDir = tmpFolder.newFolder(); + File stageFile = new File(stagingDir, "stage"); + stageFile.createNewFile(); + filesToStage.add(stagingDir.getAbsolutePath()); + + if (includeIndividualFile) { + String temporaryLocation = tmpFolder.newFolder().getAbsolutePath(); + List filesToZip = new ArrayList<>(); + filesToZip.add(stagingDir.getAbsolutePath()); + File individualStagingFile = + new File(PipelineResources.prepareFilesForStaging(filesToZip, temporaryLocation).get(0)); + filesToStage.add(individualStagingFile.getAbsolutePath()); + } + + if (includeNonExisting) { + filesToStage.add("/path/to/not/existing/dir"); + } + + FlinkPipelineOptions options = setPipelineOptions(flinkMaster, tempLocation, filesToStage); + FlinkPipelineExecutionEnvironment flinkEnv = new FlinkPipelineExecutionEnvironment(options); + flinkEnv.translate(pipeline); + return options; + } + + private FlinkPipelineOptions setPipelineOptions( + String flinkMaster, String tempLocation, List filesToStage) { + FlinkPipelineOptions options = getDefaultPipelineOptions(); + options.setRunner(TestFlinkRunner.class); + options.setFlinkMaster(flinkMaster); + options.setTempLocation(tempLocation); + options.setFilesToStage(filesToStage); + return options; + } + + private static List convertFilesToURLs(List filePaths) { + return filePaths.stream() + .map( + file -> { + try { + return new File(file).getAbsoluteFile().toURI().toURL(); + } catch (MalformedURLException e) { + throw new RuntimeException("Failed to convert to URL", e); + } + }) + .collect(Collectors.toList()); + } + + private List getJars(Object env) throws Exception { + Configuration config = Whitebox.getInternalState(env, "configuration"); + Class accesorClass = Class.forName("org.apache.flink.client.cli.ExecutionConfigAccessor"); + Method fromConfigurationMethod = + accesorClass.getDeclaredMethod("fromConfiguration", Configuration.class); + Object accesor = fromConfigurationMethod.invoke(null, config); + + Method getJarsMethod = accesorClass.getDeclaredMethod("getJars"); + return (List) getJarsMethod.invoke(accesor); + } +} diff --git a/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/FlinkPipelineOptionsTest.java b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/FlinkPipelineOptionsTest.java new file mode 100644 index 000000000000..f1e35fafe83b --- /dev/null +++ b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/FlinkPipelineOptionsTest.java @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.Is.is; +import static org.hamcrest.core.IsNull.nullValue; + +import java.util.Collections; +import java.util.HashMap; +import org.apache.beam.repackaged.core.org.apache.commons.lang3.SerializationUtils; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperator; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.options.Default; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFnSchemaInformation; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.flink.api.common.serialization.SerializerConfigImpl; +import org.apache.flink.api.common.typeinfo.TypeHint; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.streaming.api.CheckpointingMode; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.joda.time.Instant; +import org.junit.Assert; +import org.junit.Test; + +/** + * Tests for serialization and deserialization of {@link PipelineOptions} in {@link DoFnOperator}. + */ +public class FlinkPipelineOptionsTest { + + /** Pipeline options. */ + public interface MyOptions extends FlinkPipelineOptions { + @Description("Bla bla bla") + @Default.String("Hello") + String getTestOption(); + + void setTestOption(String value); + } + + private static MyOptions options = + PipelineOptionsFactory.fromArgs("--testOption=nothing").as(MyOptions.class); + + /** These defaults should only be changed with a very good reason. */ + @Test + public void testDefaults() { + FlinkPipelineOptions options = FlinkPipelineOptions.defaults(); + assertThat(options.getParallelism(), is(-1)); + assertThat(options.getMaxParallelism(), is(-1)); + assertThat(options.getFlinkMaster(), is("[auto]")); + assertThat(options.getFilesToStage(), is(nullValue())); + assertThat(options.getLatencyTrackingInterval(), is(0L)); + assertThat(options.getShutdownSourcesAfterIdleMs(), is(-1L)); + assertThat(options.getObjectReuse(), is(false)); + assertThat(options.getCheckpointingMode(), is(CheckpointingMode.EXACTLY_ONCE.name())); + assertThat(options.getMinPauseBetweenCheckpoints(), is(-1L)); + assertThat(options.getCheckpointingInterval(), is(-1L)); + assertThat(options.getCheckpointTimeoutMillis(), is(-1L)); + assertThat(options.getNumConcurrentCheckpoints(), is(1)); + assertThat(options.getTolerableCheckpointFailureNumber(), is(0)); + assertThat(options.getFinishBundleBeforeCheckpointing(), is(false)); + assertThat(options.getNumberOfExecutionRetries(), is(-1)); + assertThat(options.getExecutionRetryDelay(), is(-1L)); + assertThat(options.getRetainExternalizedCheckpointsOnCancellation(), is(false)); + assertThat(options.getStateBackendFactory(), is(nullValue())); + assertThat(options.getStateBackend(), is(nullValue())); + assertThat(options.getStateBackendStoragePath(), is(nullValue())); + assertThat(options.getExecutionModeForBatch(), is(FlinkPipelineOptions.PIPELINED)); + assertThat(options.getUseDataStreamForBatch(), is(false)); + assertThat(options.getSavepointPath(), is(nullValue())); + assertThat(options.getAllowNonRestoredState(), is(false)); + assertThat(options.getDisableMetrics(), is(false)); + assertThat(options.getFasterCopy(), is(false)); + + assertThat(options.isStreaming(), is(false)); + assertThat(options.getMaxBundleSize(), is(5000L)); + assertThat(options.getMaxBundleTimeMills(), is(10000L)); + + // In streaming mode bundle size and bundle time are shorter + FlinkPipelineOptions optionsStreaming = FlinkPipelineOptions.defaults(); + optionsStreaming.setStreaming(true); + assertThat(optionsStreaming.getMaxBundleSize(), is(1000L)); + assertThat(optionsStreaming.getMaxBundleTimeMills(), is(1000L)); + } + + @Test(expected = Exception.class) + public void parDoBaseClassPipelineOptionsNullTest() { + TupleTag mainTag = new TupleTag<>("main-output"); + Coder> coder = WindowedValues.getValueOnlyCoder(StringUtf8Coder.of()); + new DoFnOperator<>( + new TestDoFn(), + "stepName", + coder, + Collections.emptyMap(), + mainTag, + Collections.emptyList(), + new DoFnOperator.MultiOutputOutputManagerFactory<>( + mainTag, coder, new SerializablePipelineOptions(FlinkPipelineOptions.defaults())), + WindowingStrategy.globalDefault(), + new HashMap<>(), + Collections.emptyList(), + null, + null, /* key coder */ + null /* key selector */, + DoFnSchemaInformation.create(), + Collections.emptyMap()); + } + + /** Tests that PipelineOptions are present after serialization. */ + @Test + public void parDoBaseClassPipelineOptionsSerializationTest() throws Exception { + + TupleTag mainTag = new TupleTag<>("main-output"); + + Coder> coder = WindowedValues.getValueOnlyCoder(StringUtf8Coder.of()); + DoFnOperator doFnOperator = + new DoFnOperator<>( + new TestDoFn(), + "stepName", + coder, + Collections.emptyMap(), + mainTag, + Collections.emptyList(), + new DoFnOperator.MultiOutputOutputManagerFactory<>( + mainTag, coder, new SerializablePipelineOptions(FlinkPipelineOptions.defaults())), + WindowingStrategy.globalDefault(), + new HashMap<>(), + Collections.emptyList(), + options, + null, /* key coder */ + null /* key selector */, + DoFnSchemaInformation.create(), + Collections.emptyMap()); + + final byte[] serialized = SerializationUtils.serialize(doFnOperator); + + @SuppressWarnings("unchecked") + DoFnOperator deserialized = SerializationUtils.deserialize(serialized); + + TypeInformation> typeInformation = + TypeInformation.of(new TypeHint>() {}); + + OneInputStreamOperatorTestHarness, WindowedValue> testHarness = + new OneInputStreamOperatorTestHarness<>( + deserialized, typeInformation.createSerializer(new SerializerConfigImpl())); + testHarness.open(); + + // execute once to access options + testHarness.processElement( + new StreamRecord<>( + WindowedValues.of( + new Object(), Instant.now(), GlobalWindow.INSTANCE, PaneInfo.NO_FIRING))); + + testHarness.close(); + } + + private static class TestDoFn extends DoFn { + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + Assert.assertNotNull(c.getPipelineOptions()); + Assert.assertEquals( + options.getTestOption(), c.getPipelineOptions().as(MyOptions.class).getTestOption()); + } + } +} diff --git a/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/FlinkRequiresStableInputTest.java b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/FlinkRequiresStableInputTest.java new file mode 100644 index 000000000000..b382cfeb6d22 --- /dev/null +++ b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/FlinkRequiresStableInputTest.java @@ -0,0 +1,288 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import static org.apache.beam.sdk.testing.FileChecksumMatcher.fileContentsHaveChecksum; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; + +import java.util.Collections; +import java.util.Date; +import java.util.Optional; +import java.util.concurrent.Executors; +import org.apache.beam.model.jobmanagement.v1.JobApi; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.runners.jobsubmission.JobInvocation; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.sdk.RequiresStableInputIT; +import org.apache.beam.sdk.io.FileSystems; +import org.apache.beam.sdk.io.fs.ResolveOptions; +import org.apache.beam.sdk.io.fs.ResourceId; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PortablePipelineOptions; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.TimerSpec; +import org.apache.beam.sdk.state.TimerSpecs; +import org.apache.beam.sdk.testing.CrashingRunner; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.WithKeys; +import org.apache.beam.sdk.util.FilePatternMatchingShardedFile; +import org.apache.beam.sdk.util.construction.Environments; +import org.apache.beam.sdk.util.construction.PipelineTranslation; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ListeningExecutorService; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.MoreExecutors; +import org.joda.time.Instant; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +/** Tests {@link DoFn.RequiresStableInput} with Flink. */ +public class FlinkRequiresStableInputTest { + + @ClassRule public static TemporaryFolder tempFolder = new TemporaryFolder(); + + private static final String VALUE = "value"; + // SHA-1 hash of string "value" + private static final String VALUE_CHECKSUM = "f32b67c7e26342af42efabc674d441dca0a281c5"; + + private static ListeningExecutorService flinkJobExecutor; + private static final int PARALLELISM = 1; + private static final long CHECKPOINT_INTERVAL = 2000L; + private static final long FINISH_SOURCE_INTERVAL = 3 * CHECKPOINT_INTERVAL; + + @BeforeClass + public static void setup() { + // Restrict this to only one thread to avoid multiple Flink clusters up at the same time + // which is not suitable for memory-constraint environments, i.e. Jenkins. + flinkJobExecutor = MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(1)); + } + + /** + * Test for the support of {@link DoFn.RequiresStableInput} in both {@link ParDo.SingleOutput} and + * {@link ParDo.MultiOutput}. + * + *

In each test, a singleton string value is paired with a random key. In the following + * transform, the value is written to a file, whose path is specified by the random key, and then + * the transform fails. When the pipeline retries, the latter transform should receive the same + * input from the former transform, because its {@link DoFn} is annotated with {@link + * DoFn.RequiresStableInput}, and it will not fail due to presence of the file. Therefore, only + * one file for each transform is expected. + * + *

A Savepoint is taken until the desired state in the operators has been reached. We then + * restore the savepoint to check if we produce impotent results. + */ + @Test(timeout = 30_000) + public void testParDoRequiresStableInput() throws Exception { + runTest(false); + } + + // Currently failing with duplicated "value" emitted (3 times) + @Ignore("https://github.com/apache/beam/issues/21333") + @Test(timeout = 30_000) + public void testParDoRequiresStableInputPortable() throws Exception { + runTest(true); + } + + @Test(timeout = 30_000) + public void testParDoRequiresStableInputStateful() throws Exception { + testParDoRequiresStableInputStateful(false); + } + + @Test(timeout = 30_000) + public void testParDoRequiresStableInputStatefulPortable() throws Exception { + testParDoRequiresStableInputStateful(true); + } + + private void testParDoRequiresStableInputStateful(boolean portable) throws Exception { + FlinkPipelineOptions opts = getFlinkOptions(portable); + opts.as(FlinkPipelineOptions.class).setShutdownSourcesAfterIdleMs(FINISH_SOURCE_INTERVAL); + opts.as(FlinkPipelineOptions.class).setNumberOfExecutionRetries(0); + Pipeline pipeline = Pipeline.create(opts); + PCollection result = + pipeline + .apply(Create.of(1, 2, 3, 4)) + .apply(WithKeys.of((Void) null)) + .apply(ParDo.of(new StableDoFn())); + PAssert.that(result).containsInAnyOrder(1, 2, 3, 4); + executePipeline(pipeline, portable); + } + + private void runTest(boolean portable) throws Exception { + FlinkPipelineOptions options = getFlinkOptions(portable); + + ResourceId outputDir = + FileSystems.matchNewResource(tempFolder.getRoot().getAbsolutePath(), true) + .resolve( + String.format("requires-stable-input-%tF-% sideEffect = + ign -> { + throw new IllegalStateException("Failing job to test @RequiresStableInput"); + }; + PCollection impulse = p.apply("CreatePCollectionOfOneValue", Create.of(VALUE)); + impulse + .apply( + "Single-PairWithRandomKey", + MapElements.via(new RequiresStableInputIT.PairWithRandomKeyFn())) + .apply( + "Single-MakeSideEffectAndThenFail", + ParDo.of( + new RequiresStableInputIT.MakeSideEffectAndThenFailFn( + singleOutputPrefix, sideEffect))); + impulse + .apply( + "Multi-PairWithRandomKey", + MapElements.via(new RequiresStableInputIT.PairWithRandomKeyFn())) + .apply( + "Multi-MakeSideEffectAndThenFail", + ParDo.of( + new RequiresStableInputIT.MakeSideEffectAndThenFailFn( + multiOutputPrefix, sideEffect)) + .withOutputTags(new TupleTag<>(), TupleTagList.empty())); + + return p; + } + + private FlinkPipelineOptions getFlinkOptions(boolean portable) { + FlinkPipelineOptions options = FlinkPipelineOptions.defaults(); + options.setParallelism(PARALLELISM); + options.setCheckpointingInterval(CHECKPOINT_INTERVAL); + options.setShutdownSourcesAfterIdleMs(FINISH_SOURCE_INTERVAL); + options.setFinishBundleBeforeCheckpointing(true); + options.setMaxBundleTimeMills(100L); + options.setStreaming(true); + if (portable) { + options.setRunner(CrashingRunner.class); + options + .as(PortablePipelineOptions.class) + .setDefaultEnvironmentType(Environments.ENVIRONMENT_EMBEDDED); + } else { + options.setRunner(FlinkRunner.class); + } + return options; + } + + private static class StableDoFn extends DoFn, Integer> { + + @StateId("state") + final StateSpec> stateSpec = StateSpecs.bag(); + + @TimerId("flush") + final TimerSpec flushSpec = TimerSpecs.timer(TimeDomain.EVENT_TIME); + + @ProcessElement + @RequiresStableInput + public void process( + @Element KV input, + @StateId("state") BagState buffer, + @TimerId("flush") Timer flush, + OutputReceiver output) { + + // Timers do not to work with stateful stable dofn, + // see https://github.com/apache/beam/issues/24662 + // Once this is resolved, flush the buffer on timer + // flush.set(GlobalWindow.INSTANCE.maxTimestamp()); + // buffer.add(input.getValue()); + output.output(input.getValue()); + } + + @OnTimer("flush") + public void flush( + @Timestamp Instant ts, + @StateId("state") BagState buffer, + OutputReceiver output) { + + Optional.ofNullable(buffer.read()) + .ifPresent(b -> b.forEach(e -> output.outputWithTimestamp(e, ts))); + buffer.clear(); + } + } +} diff --git a/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/FlinkRunnerTest.java b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/FlinkRunnerTest.java new file mode 100644 index 000000000000..78a94b47244d --- /dev/null +++ b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/FlinkRunnerTest.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import static org.hamcrest.CoreMatchers.allOf; +import static org.junit.Assert.assertThrows; + +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.sdk.io.GenerateSequence; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.flink.client.program.PackagedProgram; +import org.apache.flink.client.program.PackagedProgramUtils; +import org.apache.flink.client.program.ProgramInvocationException; +import org.apache.flink.configuration.Configuration; +import org.hamcrest.MatcherAssert; +import org.hamcrest.core.StringContains; +import org.junit.Test; + +/** Test for {@link FlinkRunner}. */ +public class FlinkRunnerTest { + + @Test + public void testEnsureStdoutStdErrIsRestored() throws Exception { + PackagedProgram packagedProgram = + PackagedProgram.newBuilder().setEntryPointClassName(getClass().getName()).build(); + int parallelism = Runtime.getRuntime().availableProcessors(); + // OptimizerPlanEnvironment Removed in Flink 2 + // OptimizerPlanEnvironment env = + // new OptimizerPlanEnvironment(new Configuration(), getClass().getClassLoader(), + // parallelism); + Exception e = + assertThrows( + ProgramInvocationException.class, + () -> { + // Flink will throw an error because no job graph will be generated by the main method + PackagedProgramUtils.getPipelineFromProgram( + packagedProgram, new Configuration(), parallelism, true); + }); + // Test that Flink wasn't able to intercept the stdout/stderr and we printed to the regular + // output instead + MatcherAssert.assertThat( + e.getMessage(), + allOf( + StringContains.containsString("System.out: "), + StringContains.containsString("System.err: "))); + } + + /** Main method for {@code testEnsureStdoutStdErrIsRestored()}. */ + public static void main(String[] args) { + FlinkPipelineOptions options = FlinkPipelineOptions.defaults(); + options.setRunner(NotExecutingFlinkRunner.class); + Pipeline p = Pipeline.create(options); + p.apply(GenerateSequence.from(0)); + + // This will call Workarounds.restoreOriginalStdOutAndStdErr() through the constructor of + // FlinkRunner + p.run(); + } + + private static class NotExecutingFlinkRunner extends FlinkRunner { + + protected NotExecutingFlinkRunner(FlinkPipelineOptions options) { + // Stdout/Stderr is restored here + super(options); + } + + @SuppressWarnings("unused") + public static NotExecutingFlinkRunner fromOptions(PipelineOptions options) { + return new NotExecutingFlinkRunner(options.as(FlinkPipelineOptions.class)); + } + + @Override + public PipelineResult run(Pipeline pipeline) { + // Do not execute to test the stdout printing + return null; + } + } +} diff --git a/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/FlinkSavepointTest.java b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/FlinkSavepointTest.java new file mode 100644 index 000000000000..bcca529a64b9 --- /dev/null +++ b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/FlinkSavepointTest.java @@ -0,0 +1,432 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import static org.hamcrest.MatcherAssert.assertThat; + +import java.io.Serializable; +import java.net.URI; +import java.util.Collections; +import java.util.Objects; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.runners.jobsubmission.JobInvocation; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.io.GenerateSequence; +import org.apache.beam.sdk.options.PortablePipelineOptions; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.TimerSpec; +import org.apache.beam.sdk.state.TimerSpecs; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.Impulse; +import org.apache.beam.sdk.transforms.InferableFunction; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.util.construction.Environments; +import org.apache.beam.sdk.util.construction.PipelineTranslation; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ListeningExecutorService; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.MoreExecutors; +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.JobStatus; +import org.apache.flink.configuration.CheckpointingOptions; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.RestOptions; +import org.apache.flink.configuration.StateBackendOptions; +import org.apache.flink.core.execution.SavepointFormatType; +import org.apache.flink.runtime.client.JobStatusMessage; +import org.apache.flink.runtime.jobgraph.JobGraph; +import org.apache.flink.runtime.jobgraph.SavepointRestoreSettings; +import org.apache.flink.runtime.minicluster.MiniCluster; +import org.apache.flink.runtime.minicluster.MiniClusterConfiguration; +import org.hamcrest.Matchers; +import org.hamcrest.core.IsIterableContaining; +import org.joda.time.Instant; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.rules.Timeout; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Tests that Flink's Savepoints work with the Flink Runner. This includes taking a savepoint of a + * running pipeline, shutting down the pipeline, and restarting the pipeline from the savepoint with + * a different parallelism. + */ +@SuppressWarnings({ + "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) + // TODO(https://github.com/apache/beam/issues/21230): Remove when new version of + // errorprone is released (2.11.0) + "unused" +}) +public class FlinkSavepointTest implements Serializable { + + private static final Logger LOG = LoggerFactory.getLogger(FlinkSavepointTest.class); + + /** Flink cluster that runs over the lifespan of the tests. */ + private static transient MiniCluster flinkCluster; + + /** Static for synchronization between the pipeline state and the test. */ + private static volatile CountDownLatch oneShotLatch; + + /** Reusable executor for portable jobs. */ + private static ListeningExecutorService flinkJobExecutor; + + /** Temporary folder for savepoints. */ + @ClassRule public static transient TemporaryFolder tempFolder = new TemporaryFolder(); + + /** Each test has a timeout of 60 seconds (for safety). */ + @Rule public Timeout timeout = new Timeout(2, TimeUnit.MINUTES); + + @BeforeClass + public static void beforeClass() throws Exception { + flinkJobExecutor = MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(1)); + + Configuration config = new Configuration(); + // Avoid port collision in parallel tests + config.set(RestOptions.PORT, 0); + config.set(StateBackendOptions.STATE_BACKEND, "hashmap"); + + String savepointPath = "file://" + tempFolder.getRoot().getAbsolutePath(); + LOG.info("Savepoints will be written to {}", savepointPath); + // It is necessary to configure the checkpoint directory for the state backend, + // even though we only create savepoints in this test. + config.set(CheckpointingOptions.CHECKPOINTS_DIRECTORY, savepointPath); + // Checkpoints will go into a subdirectory of this directory + config.set(CheckpointingOptions.SAVEPOINT_DIRECTORY, savepointPath); + + MiniClusterConfiguration clusterConfig = + new MiniClusterConfiguration.Builder() + .setConfiguration(config) + .setNumTaskManagers(2) + .setNumSlotsPerTaskManager(2) + .build(); + + flinkCluster = new MiniCluster(clusterConfig); + flinkCluster.start(); + } + + @AfterClass + public static void afterClass() throws Exception { + flinkCluster.close(); + flinkCluster = null; + + flinkJobExecutor.shutdown(); + flinkJobExecutor.awaitTermination(10, TimeUnit.SECONDS); + if (!flinkJobExecutor.isShutdown()) { + LOG.warn("Could not shutdown Flink job executor"); + } + flinkJobExecutor = null; + } + + @After + public void afterTest() throws Exception { + for (JobStatusMessage jobStatusMessage : flinkCluster.listJobs().get()) { + if (jobStatusMessage.getJobState().name().equals("RUNNING")) { + flinkCluster.cancelJob(jobStatusMessage.getJobId()).get(); + } + } + ensureNoJobRunning(); + } + + @Test + public void testSavepointRestoreLegacy() throws Exception { + runSavepointAndRestore(false); + } + + @Test + public void testSavepointRestorePortable() throws Exception { + runSavepointAndRestore(true); + } + + private void runSavepointAndRestore(boolean isPortablePipeline) throws Exception { + FlinkPipelineOptions options = FlinkPipelineOptions.defaults(); + options.setStreaming(true); + // Initial parallelism + options.setParallelism(2); + options.setRunner(FlinkRunner.class); + // Avoid any task from shutting down which would prevent savepointing + options.setShutdownSourcesAfterIdleMs(Long.MAX_VALUE); + + oneShotLatch = new CountDownLatch(1); + options.setJobName("initial-" + UUID.randomUUID()); + Pipeline pipeline = Pipeline.create(options); + createStreamingJob(pipeline, false, isPortablePipeline); + + final JobID jobID; + if (isPortablePipeline) { + jobID = executePortable(pipeline); + } else { + jobID = executeLegacy(pipeline); + } + oneShotLatch.await(); + String savepointDir = takeSavepoint(jobID); + flinkCluster.cancelJob(jobID).get(); + ensureNoJobRunning(); + + oneShotLatch = new CountDownLatch(1); + // Increase parallelism + options.setParallelism(4); + options.setJobName("restored-" + UUID.randomUUID()); + pipeline = Pipeline.create(options); + createStreamingJob(pipeline, true, isPortablePipeline); + + if (isPortablePipeline) { + restoreFromSavepointPortable(pipeline, savepointDir); + } else { + restoreFromSavepointLegacy(pipeline, savepointDir); + } + oneShotLatch.await(); + } + + private JobID executeLegacy(Pipeline pipeline) throws Exception { + JobGraph jobGraph = getJobGraph(pipeline); + flinkCluster.submitJob(jobGraph).get(); + return waitForJobToBeReady(pipeline.getOptions().getJobName()); + } + + private JobID executePortable(Pipeline pipeline) throws Exception { + pipeline + .getOptions() + .as(PortablePipelineOptions.class) + .setDefaultEnvironmentType(Environments.ENVIRONMENT_EMBEDDED); + pipeline.getOptions().as(FlinkPipelineOptions.class).setFlinkMaster(getFlinkMaster()); + + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(pipeline); + + FlinkPipelineOptions pipelineOptions = pipeline.getOptions().as(FlinkPipelineOptions.class); + JobInvocation jobInvocation = + FlinkJobInvoker.create(null) + .createJobInvocation( + "id", + "none", + flinkJobExecutor, + pipelineProto, + pipelineOptions, + new FlinkPipelineRunner(pipelineOptions, null, Collections.emptyList())); + + jobInvocation.start(); + + return waitForJobToBeReady(pipeline.getOptions().getJobName()); + } + + private String getFlinkMaster() throws Exception { + URI uri = flinkCluster.getRestAddress().get(); + return uri.getHost() + ":" + uri.getPort(); + } + + private void ensureNoJobRunning() throws Exception { + while (!flinkCluster.listJobs().get().stream() + .allMatch(job -> job.getJobState().isTerminalState())) { + Thread.sleep(50); + } + } + + private JobID waitForJobToBeReady(String jobName) + throws InterruptedException, ExecutionException { + while (true) { + Optional jobId = + flinkCluster.listJobs().get().stream() + .filter((status) -> status.getJobName().equals(jobName)) + .findAny(); + if (jobId.isPresent()) { + JobStatusMessage status = jobId.get(); + if (status.getJobState().equals(JobStatus.RUNNING)) { + return status.getJobId(); + } + LOG.info("Job '{}' is in state {}, waiting...", jobName, status.getJobState()); + } else { + LOG.info("Job '{}' does not yet exist, waiting...", jobName); + } + Thread.sleep(100); + } + } + + private String takeSavepoint(JobID jobID) throws Exception { + Exception exception = null; + // try multiple times because the job might not be ready yet + for (int i = 0; i < 10; i++) { + try { + CompletableFuture savepointFuture = + flinkCluster.triggerSavepoint(jobID, null, false, SavepointFormatType.DEFAULT); + return savepointFuture.get(); + } catch (Exception e) { + exception = e; + LOG.debug("Exception while triggerSavepoint, trying again", e); + Thread.sleep(100); + } + } + throw exception; + } + + private void restoreFromSavepointLegacy(Pipeline pipeline, String savepointDir) + throws ExecutionException, InterruptedException { + JobGraph jobGraph = getJobGraph(pipeline); + SavepointRestoreSettings savepointSettings = SavepointRestoreSettings.forPath(savepointDir); + jobGraph.setSavepointRestoreSettings(savepointSettings); + flinkCluster.submitJob(jobGraph).get(); + } + + private void restoreFromSavepointPortable(Pipeline pipeline, String savepointDir) + throws Exception { + FlinkPipelineOptions flinkOptions = pipeline.getOptions().as(FlinkPipelineOptions.class); + flinkOptions.setSavepointPath(savepointDir); + executePortable(pipeline); + } + + private JobGraph getJobGraph(Pipeline pipeline) { + FlinkRunner flinkRunner = FlinkRunner.fromOptions(pipeline.getOptions()); + return flinkRunner.getJobGraph(pipeline); + } + + private static PCollection createStreamingJob( + Pipeline pipeline, boolean restored, boolean isPortablePipeline) { + final PCollection> key; + if (isPortablePipeline) { + key = + pipeline + .apply("ImpulseStage", Impulse.create()) + .apply( + "KvMapperStage", + MapElements.via( + new InferableFunction>() { + @Override + public KV apply(byte[] input) { + // This only writes data to one of the two initial partitions. + // We want to test this due to + // https://jira.apache.org/jira/browse/BEAM-7144 + return KV.of("key", null); + } + })) + .apply( + "TimerStage", + ParDo.of( + new DoFn, KV>() { + + @StateId("nextInteger") + private final StateSpec> valueStateSpec = + StateSpecs.value(); + + @TimerId("timer") + private final TimerSpec timer = TimerSpecs.timer(TimeDomain.EVENT_TIME); + + @ProcessElement + public void processElement( + ProcessContext context, @TimerId("timer") Timer timer) { + timer.set(new Instant(0)); + } + + @OnTimer("timer") + public void onTimer( + OnTimerContext context, + @StateId("nextInteger") ValueState nextInteger, + @TimerId("timer") Timer timer) { + Long current = nextInteger.read(); + current = current != null ? current : 0L; + context.output(KV.of("key", current)); + LOG.debug("triggering timer {}", current); + nextInteger.write(current + 1); + // Trigger timer again and continue to hold back the watermark + timer.withOutputTimestamp(new Instant(0)).set(context.fireTimestamp()); + } + })); + } else { + key = + pipeline + .apply("IdGeneratorStage", GenerateSequence.from(0)) + .apply( + "KvMapperStage", + ParDo.of( + new DoFn>() { + @ProcessElement + public void processElement(ProcessContext context) { + context.output(KV.of("key", context.element())); + } + })); + } + if (restored) { + return key.apply( + "VerificationStage", + ParDo.of( + new DoFn, String>() { + + @StateId("valueState") + private final StateSpec> valueStateSpec = StateSpecs.value(); + + @StateId("bagState") + private final StateSpec> bagStateSpec = StateSpecs.bag(); + + @ProcessElement + public void processElement( + ProcessContext context, + @StateId("valueState") ValueState intValueState, + @StateId("bagState") BagState intBagState) { + assertThat(intValueState.read(), Matchers.is(42)); + assertThat(intBagState.read(), IsIterableContaining.hasItems(40, 1, 1)); + oneShotLatch.countDown(); + } + })); + } else { + return key.apply( + "VerificationStage", + ParDo.of( + new DoFn, String>() { + + @StateId("valueState") + private final StateSpec> valueStateSpec = StateSpecs.value(); + + @StateId("bagState") + private final StateSpec> bagStateSpec = StateSpecs.bag(); + + @ProcessElement + public void processElement( + ProcessContext context, + @StateId("valueState") ValueState intValueState, + @StateId("bagState") BagState intBagState) { + long value = Objects.requireNonNull(context.element().getValue()); + LOG.debug("value: {} timestamp: {}", value, context.timestamp().getMillis()); + if (value == 0L) { + intValueState.write(42); + intBagState.add(40); + intBagState.add(1); + intBagState.add(1); + } else if (value >= 1) { + oneShotLatch.countDown(); + } + } + })); + } + } +} diff --git a/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/FlinkSubmissionTest.java b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/FlinkSubmissionTest.java new file mode 100644 index 000000000000..66079f855a77 --- /dev/null +++ b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/FlinkSubmissionTest.java @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import java.io.File; +import java.lang.reflect.Field; +import java.lang.reflect.Modifier; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.security.Permission; +import java.util.Collection; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.io.GenerateSequence; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.util.construction.resources.PipelineResources; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.flink.client.cli.CliFrontend; +import org.apache.flink.configuration.ConfigConstants; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.RestOptions; +import org.apache.flink.runtime.client.JobStatusMessage; +import org.apache.flink.runtime.minicluster.MiniClusterConfiguration; +import org.apache.flink.runtime.minicluster.RpcServiceSharing; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.rules.Timeout; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** End-to-end submission test of Beam jobs on a Flink cluster. */ +@SuppressWarnings({ + "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) +}) +public class FlinkSubmissionTest { + + private static final Logger LOG = LoggerFactory.getLogger(FlinkSubmissionTest.class); + + @ClassRule public static final TemporaryFolder TEMP_FOLDER = new TemporaryFolder(); + private static final Map ENV = System.getenv(); + private static final SecurityManager SECURITY_MANAGER = System.getSecurityManager(); + + /** Flink cluster that runs over the lifespan of the tests. */ + private static transient RemoteMiniCluster flinkCluster; + + /** Each test has a timeout of 60 seconds (for safety). */ + @Rule public Timeout timeout = new Timeout(60, TimeUnit.SECONDS); + + /** Counter which keeps track of the number of jobs submitted. */ + private static int expectedNumberOfJobs; + + @BeforeClass + public static void beforeClass() throws Exception { + Configuration config = new Configuration(); + // Avoid port collision in parallel tests on the same machine + config.set(RestOptions.PORT, 0); + + MiniClusterConfiguration clusterConfig = + new MiniClusterConfiguration.Builder() + .setConfiguration(config) + .setNumTaskManagers(1) + .setNumSlotsPerTaskManager(1) + // Create a shared actor system for all cluster services + .setRpcServiceSharing(RpcServiceSharing.SHARED) + .build(); + + flinkCluster = new RemoteMiniClusterImpl(clusterConfig); + flinkCluster.start(); + prepareEnvironment(); + } + + @AfterClass + public static void afterClass() throws Exception { + restoreEnvironment(); + flinkCluster.close(); + flinkCluster = null; + } + + @Test + public void testSubmissionBatch() throws Exception { + runSubmission(false, false); + } + + @Test + public void testSubmissionStreaming() throws Exception { + runSubmission(false, true); + } + + @Test + public void testDetachedSubmissionBatch() throws Exception { + runSubmission(true, false); + } + + @Test + public void testDetachedSubmissionStreaming() throws Exception { + runSubmission(true, true); + } + + private void runSubmission(boolean isDetached, boolean isStreaming) throws Exception { + PipelineOptions options = PipelineOptionsFactory.create(); + options.as(FlinkPipelineOptions.class).setStreaming(isStreaming); + options.setTempLocation(TEMP_FOLDER.getRoot().getPath()); + String jarPath = + Iterables.getFirst( + PipelineResources.detectClassPathResourcesToStage(getClass().getClassLoader(), options), + null); + + try { + throwExceptionOnSystemExit(); + ImmutableList.Builder argsBuilder = ImmutableList.builder(); + argsBuilder.add("run").add("-c").add(getClass().getName()); + if (isDetached) { + argsBuilder.add("-d"); + } + argsBuilder.add(jarPath); + argsBuilder.add("--runner=flink"); + + if (isStreaming) { + argsBuilder.add("--streaming"); + } + + FlinkSubmissionTest.expectedNumberOfJobs++; + ImmutableList args = argsBuilder.build(); + // Run end-to-end test + CliFrontend.main(args.toArray(new String[0])); + } catch (SystemExitException e) { + // The CliFrontend exited and we can move on to check if the job has finished + } finally { + restoreDefaultSystemExitBehavior(); + } + + waitUntilJobIsCompleted(); + } + + private void waitUntilJobIsCompleted() throws Exception { + while (true) { + Collection allJobsStates = flinkCluster.listJobs().get(); + if (allJobsStates.size() == expectedNumberOfJobs + && allJobsStates.stream() + .allMatch(jobStatus -> jobStatus.getJobState().isTerminalState())) { + LOG.info( + "All job finished with statuses: {}", + allJobsStates.stream().map(j -> j.getJobState().name()).collect(Collectors.toList())); + return; + } + Thread.sleep(50); + } + } + + /** The Flink program which is executed by the CliFrontend. */ + public static void main(String[] args) { + FlinkPipelineOptions options = + PipelineOptionsFactory.fromArgs(args).withValidation().as(FlinkPipelineOptions.class); + options.setRunner(FlinkRunner.class); + options.setParallelism(1); + Pipeline p = Pipeline.create(options); + p.apply(GenerateSequence.from(0).to(1)); + p.run(); + } + + private static void prepareEnvironment() throws Exception { + // Write a Flink config + File file = TEMP_FOLDER.newFile("config.yaml"); + String config = + String.format( + "rest:\n port: '%d'\njobmanager:\n rpc:\n address: %s\n port: '%d'", + flinkCluster.getRestPort(), "localhost", flinkCluster.getClusterPort()); + + Files.write(file.toPath(), config.getBytes(StandardCharsets.UTF_8)); + + // Create a new environment with the location of the Flink config for CliFrontend + ImmutableMap newEnv = + ImmutableMap.builder() + .putAll(ENV.entrySet()) + .put(ConfigConstants.ENV_FLINK_CONF_DIR, file.getParent()) + .build(); + + modifyEnv(newEnv); + } + + private static void restoreEnvironment() throws Exception { + modifyEnv(ENV); + } + + /** + * We modify the JVM's environment variables here. This is necessary for the end-to-end test + * because Flink's CliFrontend requires a Flink configuration file for which the location can only + * be set using the {@code ConfigConstants.ENV_FLINK_CONF_DIR} environment variable. + */ + private static void modifyEnv(Map env) throws Exception { + Class processEnv = Class.forName("java.lang.ProcessEnvironment"); + Field envField = processEnv.getDeclaredField("theUnmodifiableEnvironment"); + + Field modifiersField = Field.class.getDeclaredField("modifiers"); + modifiersField.setAccessible(true); + modifiersField.setInt(envField, envField.getModifiers() & ~Modifier.FINAL); + + envField.setAccessible(true); + envField.set(null, env); + envField.setAccessible(false); + + modifiersField.setInt(envField, envField.getModifiers() & Modifier.FINAL); + modifiersField.setAccessible(false); + } + + /** Prevents the CliFrontend from calling System.exit. */ + private static void throwExceptionOnSystemExit() { + System.setSecurityManager( + new SecurityManager() { + @Override + public void checkPermission(Permission permission) { + if (permission.getName().startsWith("exitVM")) { + throw new SystemExitException(); + } + if (SECURITY_MANAGER != null) { + SECURITY_MANAGER.checkPermission(permission); + } + } + }); + } + + private static void restoreDefaultSystemExitBehavior() { + System.setSecurityManager(SECURITY_MANAGER); + } + + private static class SystemExitException extends SecurityException {} +} diff --git a/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/ReadSourceTest.java b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/ReadSourceTest.java new file mode 100644 index 000000000000..b314718d4f75 --- /dev/null +++ b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/ReadSourceTest.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import java.io.File; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.Instant; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.io.GenerateSequence; +import org.apache.beam.sdk.io.TextIO; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Joiner; +import org.apache.flink.api.common.JobExecutionResult; +import org.apache.flink.api.common.JobID; +import org.apache.flink.test.util.JavaProgramTestBase; +import org.apache.flink.test.util.TestBaseUtils; + +/** Reads from a bounded source in batch execution. */ +public class ReadSourceTest extends JavaProgramTestBase { + + protected String resultPath; + + public ReadSourceTest() {} + + private static final String[] EXPECTED_RESULT = + new String[] {"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"}; + + @Override + protected void preSubmit() throws Exception { + resultPath = getTempDirPath("result"); + + // need to create the dir, otherwise Beam sinks don't + // work for these tests + + if (!new File(new URI(resultPath)).mkdirs()) { + throw new RuntimeException("Could not create output dir."); + } + } + + @Override + protected void postSubmit() throws Exception { + TestBaseUtils.compareResultsByLinesInMemory(Joiner.on('\n').join(EXPECTED_RESULT), resultPath); + } + + @Override + protected JobExecutionResult testProgram() throws Exception { + return runProgram(resultPath); + } + + private static JobExecutionResult runProgram(String resultPath) throws Exception { + + Pipeline p = FlinkTestPipeline.createForBatch(); + + PCollection result = + p.apply(GenerateSequence.from(0).to(10)) + .apply( + ParDo.of( + new DoFn() { + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + c.output(c.element().toString()); + } + })); + + result.apply(TextIO.write().to(new URI(resultPath).getPath() + "/part")); + Instant now = Instant.now(); + p.run(); + return new JobExecutionResult( + new JobID(p.getOptions().getJobName().getBytes(StandardCharsets.UTF_8)), + Duration.between(now, Instant.now()).toMillis(), + null); + } +} diff --git a/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/adapter/BeamFlinkDataStreamAdapterTest.java b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/adapter/BeamFlinkDataStreamAdapterTest.java new file mode 100644 index 000000000000..3883aa5d10d4 --- /dev/null +++ b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/adapter/BeamFlinkDataStreamAdapterTest.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.adapter; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; + +import java.util.Map; +import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.Flatten; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.Sum; +import org.apache.beam.sdk.transforms.WithTimestamps; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionList; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.flink.api.common.eventtime.SerializableTimestampAssigner; +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.windowing.assigners.TumblingEventTimeWindows; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; + +public class BeamFlinkDataStreamAdapterTest { + + private static PTransform, PCollection> withPrefix( + String prefix) { + return ParDo.of( + new DoFn() { + @ProcessElement + public void processElement(@Element String word, OutputReceiver out) { + out.output(prefix + word); + } + }); + } + + @Test + public void testApplySimpleTransform() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.createLocalEnvironment(); + + DataStream input = env.fromCollection(ImmutableList.of("a", "b", "c")); + DataStream result = + new BeamFlinkDataStreamAdapter().applyBeamPTransform(input, withPrefix("x")); + + assertThat( + ImmutableList.copyOf(result.executeAndCollect()), containsInAnyOrder("xa", "xb", "xc")); + } + + @Test + public void testApplyCompositeTransform() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.createLocalEnvironment(); + + DataStream input = env.fromCollection(ImmutableList.of("a", "b", "c")); + DataStream result = + new BeamFlinkDataStreamAdapter() + .applyBeamPTransform( + input, + new PTransform, PCollection>() { + @Override + public PCollection expand(PCollection input) { + return input.apply(withPrefix("x")).apply(withPrefix("y")); + } + }); + + assertThat( + ImmutableList.copyOf(result.executeAndCollect()), containsInAnyOrder("yxa", "yxb", "yxc")); + } + + @Test + public void testApplyMultiInputTransform() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.createLocalEnvironment(); + + DataStream input1 = env.fromCollection(ImmutableList.of("a", "b", "c")); + DataStream input2 = env.fromCollection(ImmutableList.of("d", "e", "f")); + DataStream result = + new BeamFlinkDataStreamAdapter() + .applyBeamPTransform( + ImmutableMap.of("x", input1, "y", input2), + new PTransform>() { + @Override + public PCollection expand(PCollectionTuple input) { + return PCollectionList.of(input.get("x").apply(withPrefix("x"))) + .and(input.get("y").apply(withPrefix("y"))) + .apply(Flatten.pCollections()); + } + }); + + assertThat( + ImmutableList.copyOf(result.executeAndCollect()), + containsInAnyOrder("xa", "xb", "xc", "yd", "ye", "yf")); + } + + @Test + public void testApplyMultiOutputTransform() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.createLocalEnvironment(); + + DataStream input = env.fromCollection(ImmutableList.of("a", "b", "c")); + Map> result = + new BeamFlinkDataStreamAdapter() + .applyMultiOutputBeamPTransform( + input, + new PTransform, PCollectionTuple>() { + @Override + public PCollectionTuple expand(PCollection input) { + return PCollectionTuple.of("x", input.apply(withPrefix("x"))) + .and("y", input.apply(withPrefix("y"))); + } + }); + + assertThat( + ImmutableList.copyOf(result.get("x").executeAndCollect()), + containsInAnyOrder("xa", "xb", "xc")); + assertThat( + ImmutableList.copyOf(result.get("y").executeAndCollect()), + containsInAnyOrder("ya", "yb", "yc")); + } + + @Test + public void testApplyGroupingTransform() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.createLocalEnvironment(); + + DataStream input = env.fromCollection(ImmutableList.of("a", "a", "b")); + DataStream> result = + new BeamFlinkDataStreamAdapter() + .applyBeamPTransform( + input, + new PTransform, PCollection>>() { + @Override + public PCollection> expand(PCollection input) { + return input + .apply(Window.into(FixedWindows.of(Duration.millis(10)))) + .apply(Count.perElement()); + } + }); + + assertThat( + ImmutableList.copyOf(result.executeAndCollect()), + containsInAnyOrder(KV.of("a", 2L), KV.of("b", 1L))); + } + + @Test + public void testApplyPreservesInputTimestamps() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.createLocalEnvironment(); + + DataStream input = + env.fromCollection(ImmutableList.of(1L, 2L, 12L)) + .assignTimestampsAndWatermarks( + WatermarkStrategy.forBoundedOutOfOrderness(java.time.Duration.ofMillis(100)) + .withTimestampAssigner( + (SerializableTimestampAssigner) + (element, recordTimestamp) -> element)); + DataStream result = + new BeamFlinkDataStreamAdapter() + .applyBeamPTransform( + input, + new PTransform, PCollection>() { + @Override + public PCollection expand(PCollection input) { + return input + .apply(Window.into(FixedWindows.of(Duration.millis(10)))) + .apply(Sum.longsGlobally().withoutDefaults()); + } + }); + + assertThat(ImmutableList.copyOf(result.executeAndCollect()), containsInAnyOrder(3L, 12L)); + } + + @Test + public void testApplyPreservesOutputTimestamps() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.createLocalEnvironment(); + + DataStream input = env.fromCollection(ImmutableList.of(1L, 2L, 12L)); + DataStream withTimestamps = + new BeamFlinkDataStreamAdapter() + .applyBeamPTransform( + input, + new PTransform, PCollection>() { + @Override + public PCollection expand(PCollection input) { + return input.apply(WithTimestamps.of(x -> Instant.ofEpochMilli(x))); + } + }); + + assertThat( + ImmutableList.copyOf( + withTimestamps + .windowAll(TumblingEventTimeWindows.of(java.time.Duration.ofMillis(10))) + .reduce((ReduceFunction) (a, b) -> a + b) + .executeAndCollect()), + containsInAnyOrder(3L, 12L)); + } +} diff --git a/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/streaming/BoundedSourceRestoreTest.java b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/streaming/BoundedSourceRestoreTest.java new file mode 100644 index 000000000000..897e2e3467b8 --- /dev/null +++ b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/streaming/BoundedSourceRestoreTest.java @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.streaming; + +import static org.junit.Assert.assertTrue; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.Set; +import org.apache.beam.runners.flink.translation.wrappers.streaming.io.TestCountingSource; +import org.apache.beam.runners.flink.translation.wrappers.streaming.io.UnboundedSourceWrapper; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.io.CountingSource; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.util.construction.UnboundedReadFromBoundedSource.BoundedToUnboundedSourceAdapter; +import org.apache.beam.sdk.util.construction.UnboundedReadFromBoundedSource.BoundedToUnboundedSourceAdapter.Checkpoint; +import org.apache.beam.sdk.values.ValueWithRecordId; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.runtime.event.WatermarkEvent; +import org.apache.flink.streaming.api.operators.StreamSource; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness; +import org.apache.flink.util.OutputTag; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +/** Test for bounded source restore in streaming mode. */ +@RunWith(Parameterized.class) +public class BoundedSourceRestoreTest { + + private final int numTasks; + private final int numSplits; + + public BoundedSourceRestoreTest(int numTasks, int numSplits) { + this.numTasks = numTasks; + this.numSplits = numSplits; + } + + @Parameterized.Parameters + public static Collection data() { + /* Parameters for initializing the tests: {numTasks, numSplits} */ + return Arrays.asList( + new Object[][] { + {1, 1}, {1, 2}, {1, 4}, + }); + } + + @Test + public void testRestore() throws Exception { + final int numElements = 102; + final int firstBatchSize = 23; + final int secondBatchSize = numElements - firstBatchSize; + final Set emittedElements = new HashSet<>(); + final Object checkpointLock = new Object(); + PipelineOptions options = PipelineOptionsFactory.create(); + + // bounded source wrapped as unbounded source + BoundedSource source = CountingSource.upTo(numElements); + BoundedToUnboundedSourceAdapter unboundedSource = + new BoundedToUnboundedSourceAdapter<>(source); + UnboundedSourceWrapper> flinkWrapper = + new UnboundedSourceWrapper<>("stepName", options, unboundedSource, numSplits); + + StreamSource< + WindowedValue>, UnboundedSourceWrapper>> + sourceOperator = new StreamSource<>(flinkWrapper); + + AbstractStreamOperatorTestHarness>> testHarness = + new AbstractStreamOperatorTestHarness<>( + sourceOperator, + numTasks /* max parallelism */, + numTasks /* parallelism */, + 0 /* subtask index */); + + // the first half of elements is read + boolean readFirstBatchOfElements = false; + try { + testHarness.open(); + StreamSources.run( + sourceOperator, checkpointLock, new PartialCollector<>(emittedElements, firstBatchSize)); + } catch (SuccessException e) { + // success + readFirstBatchOfElements = true; + } + assertTrue("Did not successfully read first batch of elements.", readFirstBatchOfElements); + + // draw a snapshot + OperatorSubtaskState snapshot = testHarness.snapshot(0, 0); + + // finalize checkpoint + final ArrayList finalizeList = new ArrayList<>(); + TestCountingSource.setFinalizeTracker(finalizeList); + testHarness.notifyOfCompletedCheckpoint(0); + + // create a completely new source but restore from the snapshot + BoundedSource restoredSource = CountingSource.upTo(numElements); + BoundedToUnboundedSourceAdapter restoredUnboundedSource = + new BoundedToUnboundedSourceAdapter<>(restoredSource); + UnboundedSourceWrapper> restoredFlinkWrapper = + new UnboundedSourceWrapper<>("stepName", options, restoredUnboundedSource, numSplits); + StreamSource< + WindowedValue>, UnboundedSourceWrapper>> + restoredSourceOperator = new StreamSource<>(restoredFlinkWrapper); + + // set parallelism to 1 to ensure that our testing operator gets all checkpointed state + AbstractStreamOperatorTestHarness>> restoredTestHarness = + new AbstractStreamOperatorTestHarness<>( + restoredSourceOperator, + numTasks /* max parallelism */, + 1 /* parallelism */, + 0 /* subtask index */); + + // restore snapshot + restoredTestHarness.initializeState(snapshot); + + // run again and verify that we see the other elements + boolean readSecondBatchOfElements = false; + try { + restoredTestHarness.open(); + StreamSources.run( + restoredSourceOperator, + checkpointLock, + new PartialCollector<>(emittedElements, secondBatchSize)); + } catch (SuccessException e) { + // success + readSecondBatchOfElements = true; + } + assertTrue("Did not successfully read second batch of elements.", readSecondBatchOfElements); + + // verify that we saw all NUM_ELEMENTS elements + assertTrue(emittedElements.size() == numElements); + } + + /** A special {@link RuntimeException} that we throw to signal that the test was successful. */ + private static class SuccessException extends RuntimeException {} + + /** A collector which consumes only specified number of elements. */ + private static class PartialCollector + implements StreamSources.OutputWrapper>>> { + + private final Set emittedElements; + private final int elementsToConsumeLimit; + + private int count = 0; + + private PartialCollector(Set emittedElements, int elementsToConsumeLimit) { + this.emittedElements = emittedElements; + this.elementsToConsumeLimit = elementsToConsumeLimit; + } + + @Override + public void emitWatermark(Watermark watermark) {} + + @Override + public void emitWatermark(WatermarkEvent event) {} + + @Override + public void collect(OutputTag outputTag, StreamRecord streamRecord) { + collect((StreamRecord) streamRecord); + } + + @Override + public void emitLatencyMarker(LatencyMarker latencyMarker) {} + + @Override + public void collect(StreamRecord>> record) { + emittedElements.add(record.getValue().getValue().getValue()); + count++; + if (count >= elementsToConsumeLimit) { + throw new SuccessException(); + } + } + + @Override + public void close() {} + } +} diff --git a/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java new file mode 100644 index 000000000000..6d74c51d7d9b --- /dev/null +++ b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.streaming; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.Is.is; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.UUID; +import org.apache.beam.runners.core.StateInternals; +import org.apache.beam.runners.core.StateInternalsTest; +import org.apache.beam.runners.core.StateNamespaces; +import org.apache.beam.runners.core.StateTag; +import org.apache.beam.runners.core.StateTags; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.flink.FlinkPipelineOptions; +import org.apache.beam.runners.flink.adapter.FlinkKey; +import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkStateInternals; +import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.state.WatermarkHoldState; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; +import org.apache.beam.sdk.util.CoderUtils; +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.serialization.SerializerConfigImpl; +import org.apache.flink.api.java.typeutils.ValueTypeInfo; +import org.apache.flink.core.fs.CloseableRegistry; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.operators.testutils.DummyEnvironment; +import org.apache.flink.runtime.query.KvStateRegistry; +import org.apache.flink.runtime.state.AbstractKeyedStateBackend; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyedStateBackend; +import org.apache.flink.runtime.state.ttl.TtlTimeProvider; +import org.hamcrest.Matchers; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link FlinkStateInternals}. This is based on {@link StateInternalsTest}. */ +@RunWith(JUnit4.class) +@SuppressWarnings({ + "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) +}) +public class FlinkStateInternalsTest extends StateInternalsTest { + + @Override + protected StateInternals createStateInternals() { + try { + KeyedStateBackend keyedStateBackend = createStateBackend(); + return new FlinkStateInternals<>( + keyedStateBackend, + StringUtf8Coder.of(), + IntervalWindow.getCoder(), + new SerializablePipelineOptions(FlinkPipelineOptions.defaults())); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Test + public void testWatermarkHoldsPersistence() throws Exception { + KeyedStateBackend keyedStateBackend = createStateBackend(); + FlinkStateInternals stateInternals = + new FlinkStateInternals<>( + keyedStateBackend, + StringUtf8Coder.of(), + IntervalWindow.getCoder(), + new SerializablePipelineOptions(FlinkPipelineOptions.defaults())); + + StateTag stateTag = + StateTags.watermarkStateInternal("hold", TimestampCombiner.EARLIEST); + WatermarkHoldState globalWindow = stateInternals.state(StateNamespaces.global(), stateTag); + WatermarkHoldState fixedWindow = + stateInternals.state( + StateNamespaces.window( + IntervalWindow.getCoder(), new IntervalWindow(new Instant(0), new Instant(10))), + stateTag); + + Instant noHold = new Instant(Long.MAX_VALUE); + assertThat(stateInternals.minWatermarkHoldMs(), is(noHold.getMillis())); + + Instant high = new Instant(10); + globalWindow.add(high); + assertThat(stateInternals.minWatermarkHoldMs(), is(high.getMillis())); + + Instant middle = new Instant(5); + fixedWindow.add(middle); + assertThat(stateInternals.minWatermarkHoldMs(), is(middle.getMillis())); + + Instant low = new Instant(1); + globalWindow.add(low); + assertThat(stateInternals.minWatermarkHoldMs(), is(low.getMillis())); + + // Try to overwrite with later hold (should not succeed) + globalWindow.add(high); + assertThat(stateInternals.minWatermarkHoldMs(), is(low.getMillis())); + fixedWindow.add(high); + assertThat(stateInternals.minWatermarkHoldMs(), is(low.getMillis())); + + // Watermark hold should be computed across all keys + FlinkKey firstKey = keyedStateBackend.getCurrentKey(); + changeKey(keyedStateBackend); + FlinkKey secondKey = keyedStateBackend.getCurrentKey(); + assertThat(firstKey, is(Matchers.not(secondKey))); + assertThat(stateInternals.minWatermarkHoldMs(), is(low.getMillis())); + // ..but be tracked per key / window + assertThat(globalWindow.read(), is(Matchers.nullValue())); + assertThat(fixedWindow.read(), is(Matchers.nullValue())); + globalWindow.add(middle); + fixedWindow.add(high); + assertThat(globalWindow.read(), is(middle)); + assertThat(fixedWindow.read(), is(high)); + // Old key should give previous results + keyedStateBackend.setCurrentKey(firstKey); + assertThat(globalWindow.read(), is(low)); + assertThat(fixedWindow.read(), is(middle)); + + // Discard watermark view and recover it + stateInternals = + new FlinkStateInternals<>( + keyedStateBackend, + StringUtf8Coder.of(), + IntervalWindow.getCoder(), + new SerializablePipelineOptions(FlinkPipelineOptions.defaults())); + globalWindow = stateInternals.state(StateNamespaces.global(), stateTag); + fixedWindow = + stateInternals.state( + StateNamespaces.window( + IntervalWindow.getCoder(), new IntervalWindow(new Instant(0), new Instant(10))), + stateTag); + + // Watermark hold across all keys should be unchanged + assertThat(stateInternals.minWatermarkHoldMs(), is(low.getMillis())); + + // Check the holds for the second key and clear them + keyedStateBackend.setCurrentKey(secondKey); + assertThat(globalWindow.read(), is(middle)); + assertThat(fixedWindow.read(), is(high)); + globalWindow.clear(); + fixedWindow.clear(); + + // Check the holds for the first key and clear them + keyedStateBackend.setCurrentKey(firstKey); + assertThat(globalWindow.read(), is(low)); + assertThat(fixedWindow.read(), is(middle)); + + fixedWindow.clear(); + assertThat(stateInternals.minWatermarkHoldMs(), is(low.getMillis())); + + globalWindow.clear(); + assertThat(stateInternals.minWatermarkHoldMs(), is(noHold.getMillis())); + } + + @Test + public void testGlobalWindowWatermarkHoldClear() throws Exception { + KeyedStateBackend keyedStateBackend = createStateBackend(); + FlinkStateInternals stateInternals = + new FlinkStateInternals<>( + keyedStateBackend, + StringUtf8Coder.of(), + IntervalWindow.getCoder(), + new SerializablePipelineOptions(FlinkPipelineOptions.defaults())); + StateTag stateTag = + StateTags.watermarkStateInternal("hold", TimestampCombiner.EARLIEST); + Instant now = Instant.now(); + WatermarkHoldState state = stateInternals.state(StateNamespaces.global(), stateTag); + state.add(now); + stateInternals.clearGlobalState(); + assertThat(state.read(), is((Instant) null)); + } + + public static KeyedStateBackend createStateBackend() throws Exception { + AbstractKeyedStateBackend keyedStateBackend = + MemoryStateBackendWrapper.createKeyedStateBackend( + new DummyEnvironment("test", 1, 0), + new JobID(), + "test_op", + new ValueTypeInfo<>(FlinkKey.class).createSerializer(new SerializerConfigImpl()), + 2, + new KeyGroupRange(0, 1), + new KvStateRegistry().createTaskRegistry(new JobID(), new JobVertexID()), + TtlTimeProvider.DEFAULT, + null, + Collections.emptyList(), + new CloseableRegistry()); + + changeKey(keyedStateBackend); + + return keyedStateBackend; + } + + private static void changeKey(KeyedStateBackend keyedStateBackend) + throws CoderException { + keyedStateBackend.setCurrentKey( + FlinkKey.of( + ByteBuffer.wrap( + CoderUtils.encodeToByteArray(StringUtf8Coder.of(), UUID.randomUUID().toString())))); + } +} diff --git a/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java new file mode 100644 index 000000000000..d371e7f994e3 --- /dev/null +++ b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.streaming; + +import java.io.IOException; +import java.util.Collection; +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.fs.CloseableRegistry; +import org.apache.flink.metrics.MetricGroup; +import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.runtime.query.TaskKvStateRegistry; +import org.apache.flink.runtime.state.AbstractKeyedStateBackend; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyedStateBackendParametersImpl; +import org.apache.flink.runtime.state.KeyedStateHandle; +import org.apache.flink.runtime.state.OperatorStateBackend; +import org.apache.flink.runtime.state.OperatorStateBackendParametersImpl; +import org.apache.flink.runtime.state.OperatorStateHandle; +import org.apache.flink.runtime.state.hashmap.HashMapStateBackend; +import org.apache.flink.runtime.state.ttl.TtlTimeProvider; + +class MemoryStateBackendWrapper { + static AbstractKeyedStateBackend createKeyedStateBackend( + Environment env, + JobID jobID, + String operatorIdentifier, + TypeSerializer keySerializer, + int numberOfKeyGroups, + KeyGroupRange keyGroupRange, + TaskKvStateRegistry kvStateRegistry, + TtlTimeProvider ttlTimeProvider, + MetricGroup metricGroup, + Collection stateHandles, + CloseableRegistry cancelStreamRegistry) + throws IOException { + + HashMapStateBackend backend = new HashMapStateBackend(); + return backend.createKeyedStateBackend( + new KeyedStateBackendParametersImpl<>( + env, + jobID, + operatorIdentifier, + keySerializer, + numberOfKeyGroups, + keyGroupRange, + kvStateRegistry, + ttlTimeProvider, + metricGroup, + stateHandles, + cancelStreamRegistry)); + } + + static OperatorStateBackend createOperatorStateBackend( + Environment env, + String operatorIdentifier, + Collection stateHandles, + CloseableRegistry cancelStreamRegistry) + throws Exception { + HashMapStateBackend backend = new HashMapStateBackend(); + return backend.createOperatorStateBackend( + new OperatorStateBackendParametersImpl( + env, operatorIdentifier, stateHandles, cancelStreamRegistry)); + } +} diff --git a/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java new file mode 100644 index 000000000000..a39af17766fc --- /dev/null +++ b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.streaming; + +import org.apache.flink.runtime.operators.testutils.MockEnvironmentBuilder; +import org.apache.flink.streaming.api.functions.source.legacy.SourceFunction; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.Output; +import org.apache.flink.streaming.api.operators.StreamSource; +import org.apache.flink.streaming.runtime.streamrecord.RecordAttributes; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.OperatorChain; +import org.apache.flink.streaming.runtime.tasks.RegularOperatorChain; +import org.apache.flink.streaming.runtime.tasks.StreamTask; +import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus; + +/** {@link StreamSource} utilities, that bridge incompatibilities between Flink releases. */ +public class StreamSources { + + public static > void run( + StreamSource streamSource, + Object lockingObject, + Output> collector) + throws Exception { + streamSource.run(lockingObject, collector, createOperatorChain(streamSource)); + } + + private static OperatorChain createOperatorChain(AbstractStreamOperator operator) { + return new RegularOperatorChain<>( + operator.getContainingTask(), + StreamTask.createRecordWriterDelegate( + operator.getOperatorConfig(), new MockEnvironmentBuilder().build())); + } + + /** The emitWatermarkStatus method was added in Flink 1.14, so we need to wrap Output. */ + public interface OutputWrapper extends Output { + @Override + default void emitWatermarkStatus(WatermarkStatus watermarkStatus) {} + + /** In Flink 1.19 the {@code emitRecordAttributes} method was added. */ + @Override + default void emitRecordAttributes(RecordAttributes recordAttributes) { + throw new UnsupportedOperationException("emitRecordAttributes not implemented"); + } + } +} diff --git a/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunctionTest.java b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunctionTest.java new file mode 100644 index 000000000000..611434a13930 --- /dev/null +++ b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunctionTest.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import java.util.Collections; +import java.util.Map; +import org.apache.beam.runners.flink.metrics.FlinkMetricContainer; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFnSchemaInformation; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.flink.api.common.functions.DefaultOpenContext; +import org.apache.flink.api.common.functions.RuntimeContext; +import org.junit.Test; +import org.mockito.Mockito; +import org.powermock.reflect.Whitebox; + +/** Tests for {@link FlinkDoFnFunction}. */ +@SuppressWarnings({ + "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) +}) +public class FlinkDoFnFunctionTest { + + @Test + public void testAccumulatorRegistrationOnOperatorClose() throws Exception { + FlinkDoFnFunction doFnFunction = + new TestDoFnFunction( + "step", + WindowingStrategy.globalDefault(), + Collections.emptyMap(), + PipelineOptionsFactory.create(), + Collections.emptyMap(), + new TupleTag<>(), + null, + Collections.emptyMap(), + DoFnSchemaInformation.create(), + Collections.emptyMap()); + + doFnFunction.open(new DefaultOpenContext()); + + String metricContainerFieldName = "metricContainer"; + FlinkMetricContainer monitoredContainer = + Mockito.spy( + (FlinkMetricContainer) + Whitebox.getInternalState(doFnFunction, metricContainerFieldName)); + Whitebox.setInternalState(doFnFunction, metricContainerFieldName, monitoredContainer); + + doFnFunction.close(); + Mockito.verify(monitoredContainer).registerMetricsForPipelineResult(); + } + + private static class TestDoFnFunction extends FlinkDoFnFunction { + + public TestDoFnFunction( + String stepName, + WindowingStrategy windowingStrategy, + Map sideInputs, + PipelineOptions options, + Map outputMap, + TupleTag mainOutputTag, + Coder inputCoder, + Map outputCoderMap, + DoFnSchemaInformation doFnSchemaInformation, + Map sideInputMapping) { + super( + new IdentityFn(), + stepName, + windowingStrategy, + sideInputs, + options, + outputMap, + mainOutputTag, + inputCoder, + outputCoderMap, + doFnSchemaInformation, + sideInputMapping); + } + + @Override + public RuntimeContext getRuntimeContext() { + return Mockito.mock(RuntimeContext.class); + } + + private static class IdentityFn extends DoFn { + @ProcessElement + public void processElement(ProcessContext c) { + c.output(c.element()); + } + } + } +} diff --git a/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunctionTest.java b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunctionTest.java new file mode 100644 index 000000000000..73ea7f96260c --- /dev/null +++ b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunctionTest.java @@ -0,0 +1,347 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import static org.apache.beam.sdk.util.construction.PTransformTranslation.PAR_DO_TRANSFORM_URN; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.model.pipeline.v1.RunnerApi.Components; +import org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload; +import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection; +import org.apache.beam.runners.flink.metrics.FlinkMetricContainer; +import org.apache.beam.runners.fnexecution.control.BundleCheckpointHandler; +import org.apache.beam.runners.fnexecution.control.BundleFinalizationHandler; +import org.apache.beam.runners.fnexecution.control.BundleProgressHandler; +import org.apache.beam.runners.fnexecution.control.ExecutableStageContext; +import org.apache.beam.runners.fnexecution.control.InstructionRequestHandler; +import org.apache.beam.runners.fnexecution.control.OutputReceiverFactory; +import org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors; +import org.apache.beam.runners.fnexecution.control.RemoteBundle; +import org.apache.beam.runners.fnexecution.control.StageBundleFactory; +import org.apache.beam.runners.fnexecution.control.TimerReceiverFactory; +import org.apache.beam.runners.fnexecution.provisioning.JobInfo; +import org.apache.beam.runners.fnexecution.state.StateRequestHandler; +import org.apache.beam.sdk.fn.data.FnDataReceiver; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.join.RawUnionValue; +import org.apache.beam.sdk.util.construction.Timer; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.Struct; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.flink.api.common.cache.DistributedCache; +import org.apache.flink.api.common.functions.DefaultOpenContext; +import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.util.Collector; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.powermock.reflect.Whitebox; + +/** Tests for {@link FlinkExecutableStageFunction}. */ +@RunWith(Parameterized.class) +@SuppressWarnings({ + "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) +}) +public class FlinkExecutableStageFunctionTest { + + @Parameterized.Parameters + public static Object[] data() { + return new Object[] {true, false}; + } + + @Parameterized.Parameter public boolean isStateful; + + @Rule public ExpectedException thrown = ExpectedException.none(); + + @Mock private RuntimeContext runtimeContext; + @Mock private DistributedCache distributedCache; + @Mock private Collector collector; + @Mock private ExecutableStageContext stageContext; + @Mock private StageBundleFactory stageBundleFactory; + @Mock private StateRequestHandler stateRequestHandler; + @Mock private ProcessBundleDescriptors.ExecutableProcessBundleDescriptor processBundleDescriptor; + + // NOTE: ExecutableStage.fromPayload expects exactly one input, so we provide one here. These unit + // tests in general ignore the executable stage itself and mock around it. + private final ExecutableStagePayload stagePayload = + ExecutableStagePayload.newBuilder() + .setInput("input") + .setComponents( + Components.newBuilder() + .putTransforms( + "transform", + RunnerApi.PTransform.newBuilder() + .putInputs("bla", "input") + .setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PAR_DO_TRANSFORM_URN)) + .build()) + .putPcollections("input", PCollection.getDefaultInstance()) + .build()) + .addUserStates( + ExecutableStagePayload.UserStateId.newBuilder().setTransformId("transform").build()) + .build(); + private final JobInfo jobInfo = + JobInfo.create("job-id", "job-name", "retrieval-token", Struct.getDefaultInstance()); + + @Before + public void setUpMocks() throws Exception { + MockitoAnnotations.initMocks(this); + when(runtimeContext.getDistributedCache()).thenReturn(distributedCache); + when(stageContext.getStageBundleFactory(any())).thenReturn(stageBundleFactory); + RemoteBundle remoteBundle = Mockito.mock(RemoteBundle.class); + when(stageBundleFactory.getBundle( + any(), + any(StateRequestHandler.class), + any(BundleProgressHandler.class), + any(BundleFinalizationHandler.class), + any(BundleCheckpointHandler.class))) + .thenReturn(remoteBundle); + when(stageBundleFactory.getBundle( + any(), + any(TimerReceiverFactory.class), + any(StateRequestHandler.class), + any(BundleProgressHandler.class))) + .thenReturn(remoteBundle); + ImmutableMap input = + ImmutableMap.builder().put("input", Mockito.mock(FnDataReceiver.class)).build(); + when(remoteBundle.getInputReceivers()).thenReturn(input); + when(processBundleDescriptor.getTimerSpecs()).thenReturn(Collections.emptyMap()); + } + + @Test + public void sdkErrorsSurfaceOnClose() throws Exception { + FlinkExecutableStageFunction function = getFunction(Collections.emptyMap()); + function.open(new DefaultOpenContext()); + + @SuppressWarnings("unchecked") + RemoteBundle bundle = Mockito.mock(RemoteBundle.class); + when(stageBundleFactory.getBundle( + any(), + any(StateRequestHandler.class), + any(BundleProgressHandler.class), + any(BundleFinalizationHandler.class), + any(BundleCheckpointHandler.class))) + .thenReturn(bundle); + + @SuppressWarnings("unchecked") + FnDataReceiver> receiver = Mockito.mock(FnDataReceiver.class); + when(bundle.getInputReceivers()).thenReturn(ImmutableMap.of("input", receiver)); + + Exception expected = new Exception(); + doThrow(expected).when(bundle).close(); + thrown.expect(is(expected)); + function.mapPartition(Collections.emptyList(), collector); + } + + @Test + public void expectedInputsAreSent() throws Exception { + FlinkExecutableStageFunction function = getFunction(Collections.emptyMap()); + function.open(new DefaultOpenContext()); + + @SuppressWarnings("unchecked") + RemoteBundle bundle = Mockito.mock(RemoteBundle.class); + when(stageBundleFactory.getBundle( + any(), + any(StateRequestHandler.class), + any(BundleProgressHandler.class), + any(BundleFinalizationHandler.class), + any(BundleCheckpointHandler.class))) + .thenReturn(bundle); + + @SuppressWarnings("unchecked") + FnDataReceiver> receiver = Mockito.mock(FnDataReceiver.class); + when(bundle.getInputReceivers()).thenReturn(ImmutableMap.of("input", receiver)); + + WindowedValue one = WindowedValues.valueInGlobalWindow(1); + WindowedValue two = WindowedValues.valueInGlobalWindow(2); + WindowedValue three = WindowedValues.valueInGlobalWindow(3); + function.mapPartition(Arrays.asList(one, two, three), collector); + + verify(receiver).accept(one); + verify(receiver).accept(two); + verify(receiver).accept(three); + verifyNoMoreInteractions(receiver); + } + + @Test + public void outputsAreTaggedCorrectly() throws Exception { + WindowedValue three = WindowedValues.valueInGlobalWindow(3); + WindowedValue four = WindowedValues.valueInGlobalWindow(4); + WindowedValue five = WindowedValues.valueInGlobalWindow(5); + Map outputTagMap = + ImmutableMap.of( + "one", 1, + "two", 2, + "three", 3); + + // We use a real StageBundleFactory here in order to exercise the output receiver factory. + StageBundleFactory stageBundleFactory = + new StageBundleFactory() { + + private boolean once; + + @Override + public RemoteBundle getBundle( + OutputReceiverFactory receiverFactory, + TimerReceiverFactory timerReceiverFactory, + StateRequestHandler stateRequestHandler, + BundleProgressHandler progressHandler, + BundleFinalizationHandler finalizationHandler, + BundleCheckpointHandler checkpointHandler) { + return new RemoteBundle() { + @Override + public String getId() { + return "bundle-id"; + } + + @Override + public Map getInputReceivers() { + return ImmutableMap.of( + "input", + input -> { + /* Ignore input*/ + }); + } + + @Override + public Map, FnDataReceiver> getTimerReceivers() { + return Collections.emptyMap(); + } + + @Override + public void requestProgress() { + throw new UnsupportedOperationException(); + } + + @Override + public void split(double fractionOfRemainder) { + throw new UnsupportedOperationException(); + } + + @Override + public void close() throws Exception { + if (once) { + return; + } + // Emit all values to the runner when the bundle is closed. + receiverFactory.create("one").accept(three); + receiverFactory.create("two").accept(four); + receiverFactory.create("three").accept(five); + once = true; + } + }; + } + + @Override + public ProcessBundleDescriptors.ExecutableProcessBundleDescriptor + getProcessBundleDescriptor() { + return processBundleDescriptor; + } + + @Override + public InstructionRequestHandler getInstructionRequestHandler() { + return null; + } + + @Override + public void close() throws Exception {} + }; + // Wire the stage bundle factory into our context. + when(stageContext.getStageBundleFactory(any())).thenReturn(stageBundleFactory); + + FlinkExecutableStageFunction function = getFunction(outputTagMap); + function.open(new DefaultOpenContext()); + + if (isStateful) { + function.reduce(Collections.emptyList(), collector); + } else { + function.mapPartition(Collections.emptyList(), collector); + } + // Ensure that the tagged values sent to the collector have the correct union tags as specified + // in the output map. + verify(collector).collect(new RawUnionValue(1, three)); + verify(collector).collect(new RawUnionValue(2, four)); + verify(collector).collect(new RawUnionValue(3, five)); + verifyNoMoreInteractions(collector); + } + + @Test + public void testStageBundleClosed() throws Exception { + FlinkExecutableStageFunction function = getFunction(Collections.emptyMap()); + function.open(new DefaultOpenContext()); + function.close(); + verify(stageBundleFactory).getProcessBundleDescriptor(); + verify(stageBundleFactory).close(); + verifyNoMoreInteractions(stageBundleFactory); + } + + @Test + public void testAccumulatorRegistrationOnOperatorClose() throws Exception { + FlinkExecutableStageFunction function = getFunction(Collections.emptyMap()); + function.open(new DefaultOpenContext()); + + String metricContainerFieldName = "metricContainer"; + FlinkMetricContainer monitoredContainer = + Mockito.spy( + (FlinkMetricContainer) Whitebox.getInternalState(function, metricContainerFieldName)); + Whitebox.setInternalState(function, metricContainerFieldName, monitoredContainer); + + function.close(); + Mockito.verify(monitoredContainer).registerMetricsForPipelineResult(); + } + + /** + * Creates a {@link FlinkExecutableStageFunction}. Sets the runtime context to {@link + * #runtimeContext}. The context factory is mocked to return {@link #stageContext} every time. The + * behavior of the stage context itself is unchanged. + */ + private FlinkExecutableStageFunction getFunction(Map outputMap) { + FlinkExecutableStageContextFactory contextFactory = + Mockito.mock(FlinkExecutableStageContextFactory.class); + when(contextFactory.get(any())).thenReturn(stageContext); + FlinkExecutableStageFunction function = + new FlinkExecutableStageFunction<>( + "step", + PipelineOptionsFactory.create(), + stagePayload, + jobInfo, + outputMap, + contextFactory, + null, + null); + function.setRuntimeContext(runtimeContext); + Whitebox.setInternalState(function, "stateRequestHandler", stateRequestHandler); + return function; + } +} diff --git a/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunctionTest.java b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunctionTest.java new file mode 100644 index 000000000000..f76a2e39eb4b --- /dev/null +++ b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunctionTest.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import java.util.Collections; +import java.util.Map; +import org.apache.beam.runners.flink.metrics.FlinkMetricContainer; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFnSchemaInformation; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.flink.api.common.functions.DefaultOpenContext; +import org.apache.flink.api.common.functions.RuntimeContext; +import org.junit.Test; +import org.mockito.Mockito; +import org.powermock.reflect.Whitebox; + +/** Tests for {@link FlinkStatefulDoFnFunction}. */ +@SuppressWarnings({ + "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) +}) +public class FlinkStatefulDoFnFunctionTest { + + @Test + public void testAccumulatorRegistrationOnOperatorClose() throws Exception { + FlinkStatefulDoFnFunction doFnFunction = + new TestDoFnFunction( + "step", + WindowingStrategy.globalDefault(), + Collections.emptyMap(), + PipelineOptionsFactory.create(), + Collections.emptyMap(), + new TupleTag<>(), + null, + Collections.emptyMap(), + DoFnSchemaInformation.create(), + Collections.emptyMap()); + + doFnFunction.open(new DefaultOpenContext()); + + String metricContainerFieldName = "metricContainer"; + FlinkMetricContainer monitoredContainer = + Mockito.spy( + (FlinkMetricContainer) + Whitebox.getInternalState(doFnFunction, metricContainerFieldName)); + Whitebox.setInternalState(doFnFunction, metricContainerFieldName, monitoredContainer); + + doFnFunction.close(); + Mockito.verify(monitoredContainer).registerMetricsForPipelineResult(); + } + + private static class TestDoFnFunction extends FlinkStatefulDoFnFunction { + + public TestDoFnFunction( + String stepName, + WindowingStrategy windowingStrategy, + Map sideInputs, + PipelineOptions options, + Map outputMap, + TupleTag mainOutputTag, + Coder inputCoder, + Map outputCoderMap, + DoFnSchemaInformation doFnSchemaInformation, + Map sideInputMapping) { + super( + new IdentityFn(), + stepName, + windowingStrategy, + sideInputs, + options, + outputMap, + mainOutputTag, + inputCoder, + outputCoderMap, + doFnSchemaInformation, + sideInputMapping); + } + + @Override + public RuntimeContext getRuntimeContext() { + return Mockito.mock(RuntimeContext.class); + } + + private static class IdentityFn extends DoFn { + @ProcessElement + public void processElement(ProcessContext c) { + c.output(c.element()); + } + } + } +} diff --git a/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/translation/functions/ImpulseSourceFunctionTest.java b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/translation/functions/ImpulseSourceFunctionTest.java new file mode 100644 index 000000000000..a425b8798aac --- /dev/null +++ b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/translation/functions/ImpulseSourceFunctionTest.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.Is.is; +import static org.hamcrest.core.IsInstanceOf.instanceOf; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.OperatorStateStore; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.streaming.api.functions.source.legacy.SourceFunction; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestName; +import org.mockito.ArgumentMatcher; +import org.mockito.ArgumentMatchers; +import org.mockito.Mockito; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Tests for {@link ImpulseSourceFunction}. */ +@SuppressWarnings({ + "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) +}) +public class ImpulseSourceFunctionTest { + + private static final Logger LOG = LoggerFactory.getLogger(ImpulseSourceFunctionTest.class); + + @Rule public TestName testName = new TestName(); + + private final SourceFunction.SourceContext> sourceContext; + private final ImpulseElementMatcher elementMatcher = new ImpulseElementMatcher(); + + public ImpulseSourceFunctionTest() { + this.sourceContext = Mockito.mock(SourceFunction.SourceContext.class); + when(sourceContext.getCheckpointLock()).thenReturn(new Object()); + } + + @Test + public void testInstanceOfSourceFunction() { + // should be a non-parallel source function + assertThat(new ImpulseSourceFunction(0), instanceOf(SourceFunction.class)); + } + + @Test(timeout = 10_000) + public void testImpulseInitial() throws Exception { + ImpulseSourceFunction source = new ImpulseSourceFunction(0); + // No state available from previous runs + ListState mockListState = getMockListState(Collections.emptyList()); + source.initializeState(getInitializationContext(mockListState)); + + // 1) Should finish + source.run(sourceContext); + // 2) Should use checkpoint lock + verify(sourceContext).getCheckpointLock(); + // 3) Should emit impulse element and the final watermark + verify(sourceContext).collect(argThat(elementMatcher)); + verify(sourceContext).emitWatermark(Watermark.MAX_WATERMARK); + verifyNoMoreInteractions(sourceContext); + // 4) Should modify checkpoint state + verify(mockListState).get(); + verify(mockListState).add(true); + verifyNoMoreInteractions(mockListState); + } + + @Test(timeout = 10_000) + public void testImpulseRestored() throws Exception { + ImpulseSourceFunction source = new ImpulseSourceFunction(0); + // Previous state available + ListState mockListState = getMockListState(Collections.singletonList(true)); + source.initializeState(getInitializationContext(mockListState)); + + // 1) Should finish + source.run(sourceContext); + // 2) Should keep checkpoint state + verify(mockListState).get(); + verifyNoMoreInteractions(mockListState); + // 3) Should always emit the final watermark + verify(sourceContext).emitWatermark(Watermark.MAX_WATERMARK); + // 4) Should _not_ emit impulse element + verifyNoMoreInteractions(sourceContext); + } + + @Test(timeout = 10_000) + public void testKeepAlive() throws Exception { + ImpulseSourceFunction source = new ImpulseSourceFunction(Long.MAX_VALUE); + + // No previous state available (=impulse should be emitted) + ListState mockListState = getMockListState(Collections.emptyList()); + source.initializeState(getInitializationContext(mockListState)); + + Thread sourceThread = + new Thread( + () -> { + try { + source.run(sourceContext); + // should not finish + } catch (Exception e) { + LOG.error("Exception while executing ImpulseSourceFunction", e); + } + }); + try { + sourceThread.start(); + source.cancel(); + // should finish + sourceThread.join(); + } finally { + sourceThread.interrupt(); + sourceThread.join(); + } + verify(sourceContext).collect(argThat(elementMatcher)); + verify(sourceContext).emitWatermark(Watermark.MAX_WATERMARK); + verify(mockListState).add(true); + verify(mockListState).get(); + verifyNoMoreInteractions(mockListState); + } + + @Test(timeout = 10_000) + public void testKeepAliveDuringInterrupt() throws Exception { + ImpulseSourceFunction source = new ImpulseSourceFunction(Long.MAX_VALUE); + + // No previous state available (=impulse should not be emitted) + ListState mockListState = getMockListState(Collections.singletonList(true)); + source.initializeState(getInitializationContext(mockListState)); + + Thread sourceThread = + new Thread( + () -> { + try { + source.run(sourceContext); + // should not finish + } catch (Exception e) { + LOG.error("Exception while executing ImpulseSourceFunction", e); + } + }); + + sourceThread.start(); + sourceThread.interrupt(); + Thread.sleep(200); + assertThat(sourceThread.isAlive(), is(true)); + + // should quit + source.cancel(); + sourceThread.interrupt(); + sourceThread.join(); + + // Should always emit the final watermark + verify(sourceContext).emitWatermark(Watermark.MAX_WATERMARK); + // no element should have been emitted because the impulse was emitted before restore + verifyNoMoreInteractions(sourceContext); + } + + private static FunctionInitializationContext getInitializationContext(ListState listState) + throws Exception { + FunctionInitializationContext mock = Mockito.mock(FunctionInitializationContext.class); + OperatorStateStore mockOperatorState = getMockOperatorState(listState); + when(mock.getOperatorStateStore()).thenReturn(mockOperatorState); + return mock; + } + + private static OperatorStateStore getMockOperatorState(ListState listState) + throws Exception { + OperatorStateStore mock = Mockito.mock(OperatorStateStore.class); + when(mock.getListState(ArgumentMatchers.any(ListStateDescriptor.class))).thenReturn(listState); + return mock; + } + + private static ListState getMockListState(List initialState) throws Exception { + ListState mock = Mockito.mock(ListState.class); + when(mock.get()).thenReturn(initialState); + return mock; + } + + private static class ImpulseElementMatcher implements ArgumentMatcher> { + + @Override + public boolean matches(WindowedValue o) { + return o instanceof WindowedValue + && Arrays.equals((byte[]) ((WindowedValue) o).getValue(), new byte[] {}); + } + } +} diff --git a/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapperTest.java b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapperTest.java new file mode 100644 index 000000000000..48939b0cbbf1 --- /dev/null +++ b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapperTest.java @@ -0,0 +1,1027 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.io; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.core.Is.is; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.when; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CountDownLatch; +import java.util.stream.LongStream; +import org.apache.beam.runners.flink.FlinkPipelineOptions; +import org.apache.beam.runners.flink.metrics.FlinkMetricContainer; +import org.apache.beam.runners.flink.streaming.StreamSources; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.io.CountingSource; +import org.apache.beam.sdk.io.UnboundedSource; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.construction.UnboundedReadFromBoundedSource; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.ValueWithRecordId; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Joiner; +import org.apache.flink.api.common.TaskInfo; +import org.apache.flink.api.common.functions.DefaultOpenContext; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.runtime.event.WatermarkEvent; +import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups; +import org.apache.flink.streaming.api.functions.source.legacy.SourceFunction; +import org.apache.flink.streaming.api.operators.StreamSource; +import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService; +import org.apache.flink.streaming.runtime.tasks.TestProcessingTimeService; +import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness; +import org.apache.flink.util.InstantiationUtil; +import org.apache.flink.util.OutputTag; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.experimental.runners.Enclosed; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.mockito.Mockito; +import org.powermock.reflect.Whitebox; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Tests for {@link UnboundedSourceWrapper}. */ +@RunWith(Enclosed.class) +@SuppressWarnings({ + "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) +}) +public class UnboundedSourceWrapperTest { + + private static final Logger LOG = LoggerFactory.getLogger(UnboundedSourceWrapperTest.class); + + /** Parameterized tests. */ + @RunWith(Parameterized.class) + public static class ParameterizedUnboundedSourceWrapperTest { + private final int numTasks; + private final int numSplits; + + public ParameterizedUnboundedSourceWrapperTest(int numTasks, int numSplits) { + this.numTasks = numTasks; + this.numSplits = numSplits; + } + + @Parameterized.Parameters(name = "numTasks = {0}; numSplits={1}") + public static Collection data() { + /* + * Parameters for initializing the tests: + * {numTasks, numSplits} + * The test currently assumes powers of two for some assertions. + */ + return Arrays.asList( + new Object[][] { + {1, 1}, {1, 2}, {1, 4}, + {2, 1}, {2, 2}, {2, 4}, + {4, 1}, {4, 2}, {4, 4} + }); + } + + /** + * Creates a {@link UnboundedSourceWrapper} that has one or multiple readers per source. If + * numSplits > numTasks the source has one source will manage multiple readers. + */ + @Test(timeout = 30_000) + public void testValueEmission() throws Exception { + final int numElementsPerShard = 20; + FlinkPipelineOptions options = FlinkPipelineOptions.defaults(); + + final long[] numElementsReceived = {0L}; + final int[] numWatermarksReceived = {0}; + + // this source will emit exactly NUM_ELEMENTS for each parallel reader, + // afterwards it will stall. We check whether we also receive NUM_ELEMENTS + // elements later. + TestCountingSource source = + new TestCountingSource(numElementsPerShard).withFixedNumSplits(numSplits); + + for (int subtaskIndex = 0; subtaskIndex < numTasks; subtaskIndex++) { + UnboundedSourceWrapper, TestCountingSource.CounterMark> flinkWrapper = + new UnboundedSourceWrapper<>("stepName", options, source, numTasks); + + // the source wrapper will only request as many splits as there are tasks and the source + // will create at most numSplits splits + assertEquals(numSplits, flinkWrapper.getSplitSources().size()); + + StreamSource< + WindowedValue>>, + UnboundedSourceWrapper, TestCountingSource.CounterMark>> + sourceOperator = new StreamSource<>(flinkWrapper); + + AbstractStreamOperatorTestHarness>>> + testHarness = + new AbstractStreamOperatorTestHarness<>( + sourceOperator, + numTasks /* max parallelism */, + numTasks /* parallelism */, + subtaskIndex /* subtask index */); + + // The testing timer service is synchronous, so we must configure a watermark interval + // > 0, otherwise we can get loop infinitely due to a timer always becoming ready after + // it has been set. + testHarness.getExecutionConfig().setAutoWatermarkInterval(10L); + testHarness.setProcessingTime(System.currentTimeMillis()); + // event time is default for Flink 2 and no need to configure + // testHarness.setTimeCharacteristic(TimeCharacteristic.EventTime); + + Thread processingTimeUpdateThread = startProcessingTimeUpdateThread(testHarness); + + try { + testHarness.open(); + StreamSources.run( + sourceOperator, + testHarness.getCheckpointLock(), + new StreamSources.OutputWrapper< + StreamRecord>>>>() { + private boolean hasSeenMaxWatermark = false; + + @Override + public void emitWatermark(Watermark watermark) { + // we get this when there is no more data + // it can happen that we get the max watermark several times, so guard against + // this + if (!hasSeenMaxWatermark + && watermark.getTimestamp() + >= BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis()) { + numWatermarksReceived[0]++; + hasSeenMaxWatermark = true; + } + } + + @Override + public void emitWatermark(WatermarkEvent event) {} + + @Override + public void collect(OutputTag outputTag, StreamRecord streamRecord) { + collect((StreamRecord) streamRecord); + } + + @Override + public void emitLatencyMarker(LatencyMarker latencyMarker) {} + + @Override + public void collect( + StreamRecord>>> + windowedValueStreamRecord) { + numElementsReceived[0]++; + } + + @Override + public void close() {} + }); + } finally { + processingTimeUpdateThread.interrupt(); + processingTimeUpdateThread.join(); + } + } + // verify that we get the expected count across all subtasks + assertEquals(numElementsPerShard * numSplits, numElementsReceived[0]); + // and that we get as many final watermarks as there are subtasks + assertEquals(numTasks, numWatermarksReceived[0]); + } + + /** + * Creates a {@link UnboundedSourceWrapper} that has one or multiple readers per source. If + * numSplits > numTasks the source will manage multiple readers. + * + *

This test verifies that watermarks are correctly forwarded. + */ + @Test(timeout = 30_000) + public void testWatermarkEmission() throws Exception { + final int numElements = 500; + PipelineOptions options = PipelineOptionsFactory.create(); + + // this source will emit exactly NUM_ELEMENTS across all parallel readers, + // afterwards it will stall. We check whether we also receive NUM_ELEMENTS + // elements later. + TestCountingSource source = new TestCountingSource(numElements); + UnboundedSourceWrapper, TestCountingSource.CounterMark> flinkWrapper = + new UnboundedSourceWrapper<>("stepName", options, source, numSplits); + + assertEquals(numSplits, flinkWrapper.getSplitSources().size()); + + final StreamSource< + WindowedValue>>, + UnboundedSourceWrapper, TestCountingSource.CounterMark>> + sourceOperator = new StreamSource<>(flinkWrapper); + + final AbstractStreamOperatorTestHarness< + WindowedValue>>> + testHarness = + new AbstractStreamOperatorTestHarness<>( + sourceOperator, + numTasks /* max parallelism */, + numTasks /* parallelism */, + 0 /* subtask index */); + testHarness.getExecutionConfig().setLatencyTrackingInterval(0); + testHarness.getExecutionConfig().setAutoWatermarkInterval(1); + + testHarness.setProcessingTime(Long.MIN_VALUE); + // testHarness.setTimeCharacteristicsetTimeCharacteristic(TimeCharacteristic.EventTime); + + final ConcurrentLinkedQueue caughtExceptions = new ConcurrentLinkedQueue<>(); + + // We test emission of two watermarks here, one intermediate, one final + final CountDownLatch seenWatermarks = new CountDownLatch(2); + final int minElementsPerReader = numElements / numSplits; + final CountDownLatch minElementsCountdown = new CountDownLatch(minElementsPerReader); + + // first halt the source to test auto watermark emission + source.haltEmission(); + testHarness.open(); + + Thread sourceThread = + new Thread( + () -> { + try { + StreamSources.run( + sourceOperator, + testHarness.getCheckpointLock(), + new StreamSources.OutputWrapper< + StreamRecord>>>>() { + + @Override + public void emitWatermark(Watermark watermark) { + seenWatermarks.countDown(); + } + + @Override + public void emitWatermark(WatermarkEvent event) {} + + @Override + public void collect( + OutputTag outputTag, StreamRecord streamRecord) {} + + @Override + public void emitLatencyMarker(LatencyMarker latencyMarker) {} + + @Override + public void collect( + StreamRecord>>> + windowedValueStreamRecord) { + minElementsCountdown.countDown(); + } + + @Override + public void close() {} + }); + } catch (Exception e) { + LOG.info("Caught exception:", e); + caughtExceptions.add(e); + } + }); + + sourceThread.start(); + + while (flinkWrapper.getLocalReaders().stream() + .anyMatch(reader -> reader.getWatermark().getMillis() == 0)) { + // readers haven't been initialized + Thread.sleep(50); + } + + // Need to advance this so that the watermark timers in the source wrapper fire + // Synchronize is necessary because this can interfere with updating the PriorityQueue + // of the ProcessingTimeService which is also accessed through UnboundedSourceWrapper. + synchronized (testHarness.getCheckpointLock()) { + testHarness.setProcessingTime(0); + } + + // now read the elements + source.continueEmission(); + // ..and await elements + minElementsCountdown.await(); + + // Need to advance this so that the watermark timers in the source wrapper fire + // Synchronize is necessary because this can interfere with updating the PriorityQueue + // of the ProcessingTimeService which is also accessed through UnboundedSourceWrapper. + synchronized (testHarness.getCheckpointLock()) { + testHarness.setProcessingTime(Long.MAX_VALUE); + } + + seenWatermarks.await(); + + if (!caughtExceptions.isEmpty()) { + fail("Caught exception(s): " + Joiner.on(",").join(caughtExceptions)); + } + + sourceOperator.cancel(); + sourceThread.join(); + } + + /** + * Verify that snapshot/restore work as expected. We bring up a source and cancel after seeing a + * certain number of elements. Then we snapshot that source, bring up a completely new source + * that we restore from the snapshot and verify that we see all expected elements in the end. + */ + @Test + public void testRestore() throws Exception { + final int numElements = 20; + final Object checkpointLock = new Object(); + PipelineOptions options = PipelineOptionsFactory.create(); + + // this source will emit exactly NUM_ELEMENTS across all parallel readers, + // afterwards it will stall. We check whether we also receive NUM_ELEMENTS + // elements later. + TestCountingSource source = new TestCountingSource(numElements); + UnboundedSourceWrapper, TestCountingSource.CounterMark> flinkWrapper = + new UnboundedSourceWrapper<>("stepName", options, source, numSplits); + + assertEquals(numSplits, flinkWrapper.getSplitSources().size()); + + StreamSource< + WindowedValue>>, + UnboundedSourceWrapper, TestCountingSource.CounterMark>> + sourceOperator = new StreamSource<>(flinkWrapper); + + AbstractStreamOperatorTestHarness>>> + testHarness = + new AbstractStreamOperatorTestHarness<>( + sourceOperator, + numTasks /* max parallelism */, + numTasks /* parallelism */, + 0 /* subtask index */); + + // testHarness.setTimeCharacteristic(TimeCharacteristic.EventTime); + + final Set> emittedElements = new HashSet<>(); + + boolean readFirstBatchOfElements = false; + + try { + testHarness.open(); + StreamSources.run( + sourceOperator, + checkpointLock, + new StreamSources.OutputWrapper< + StreamRecord>>>>() { + private int count = 0; + + @Override + public void emitWatermark(Watermark watermark) {} + + @Override + public void emitWatermark(WatermarkEvent event) {} + + @Override + public void collect(OutputTag outputTag, StreamRecord streamRecord) { + collect((StreamRecord) streamRecord); + } + + @Override + public void emitLatencyMarker(LatencyMarker latencyMarker) {} + + @Override + public void collect( + StreamRecord>>> + windowedValueStreamRecord) { + + emittedElements.add(windowedValueStreamRecord.getValue().getValue().getValue()); + count++; + if (count >= numElements / 2) { + throw new SuccessException(); + } + } + + @Override + public void close() {} + }); + } catch (SuccessException e) { + // success + readFirstBatchOfElements = true; + } + + assertTrue("Did not successfully read first batch of elements.", readFirstBatchOfElements); + + // simulate pipeline stop/drain scenario, where sources are closed first. + sourceOperator.cancel(); + + // draw a snapshot + OperatorSubtaskState snapshot = testHarness.snapshot(0, 0); + + // test that finalizeCheckpoint on CheckpointMark is called + final ArrayList finalizeList = new ArrayList<>(); + TestCountingSource.setFinalizeTracker(finalizeList); + testHarness.notifyOfCompletedCheckpoint(0); + assertEquals(flinkWrapper.getLocalSplitSources().size(), finalizeList.size()); + + // stop the pipeline + testHarness.close(); + + // create a completely new source but restore from the snapshot + TestCountingSource restoredSource = new TestCountingSource(numElements); + UnboundedSourceWrapper, TestCountingSource.CounterMark> + restoredFlinkWrapper = + new UnboundedSourceWrapper<>("stepName", options, restoredSource, numSplits); + + assertEquals(numSplits, restoredFlinkWrapper.getSplitSources().size()); + + StreamSource< + WindowedValue>>, + UnboundedSourceWrapper, TestCountingSource.CounterMark>> + restoredSourceOperator = new StreamSource<>(restoredFlinkWrapper); + + // set parallelism to 1 to ensure that our testing operator gets all checkpointed state + AbstractStreamOperatorTestHarness>>> + restoredTestHarness = + new AbstractStreamOperatorTestHarness<>( + restoredSourceOperator, + numTasks /* max parallelism */, + 1 /* parallelism */, + 0 /* subtask index */); + + // restoredTestHarness.setTimeCharacteristic(TimeCharacteristic.EventTime); + + // restore snapshot + restoredTestHarness.initializeState(snapshot); + + boolean readSecondBatchOfElements = false; + + // run again and verify that we see the other elements + try { + restoredTestHarness.open(); + StreamSources.run( + restoredSourceOperator, + checkpointLock, + new StreamSources.OutputWrapper< + StreamRecord>>>>() { + private int count = 0; + + @Override + public void emitWatermark(Watermark watermark) {} + + @Override + public void emitWatermark(WatermarkEvent event) {} + + @Override + public void collect(OutputTag outputTag, StreamRecord streamRecord) { + collect((StreamRecord) streamRecord); + } + + @Override + public void emitLatencyMarker(LatencyMarker latencyMarker) {} + + @Override + public void collect( + StreamRecord>>> + windowedValueStreamRecord) { + emittedElements.add(windowedValueStreamRecord.getValue().getValue().getValue()); + count++; + if (count >= numElements / 2) { + throw new SuccessException(); + } + } + + @Override + public void close() {} + }); + } catch (SuccessException e) { + // success + readSecondBatchOfElements = true; + } + + assertEquals( + Math.max(1, numSplits / numTasks), restoredFlinkWrapper.getLocalSplitSources().size()); + + assertTrue("Did not successfully read second batch of elements.", readSecondBatchOfElements); + + // verify that we saw all NUM_ELEMENTS elements + assertTrue(emittedElements.size() == numElements); + } + + @Test + public void testNullCheckpoint() throws Exception { + final int numElements = 20; + PipelineOptions options = PipelineOptionsFactory.create(); + + TestCountingSource source = + new TestCountingSource(numElements) { + @Override + public Coder getCheckpointMarkCoder() { + return null; + } + }; + + UnboundedSourceWrapper, TestCountingSource.CounterMark> flinkWrapper = + new UnboundedSourceWrapper<>("stepName", options, source, numSplits); + + StreamSource< + WindowedValue>>, + UnboundedSourceWrapper, TestCountingSource.CounterMark>> + sourceOperator = new StreamSource<>(flinkWrapper); + + AbstractStreamOperatorTestHarness>>> + testHarness = + new AbstractStreamOperatorTestHarness<>( + sourceOperator, + numTasks /* max parallelism */, + numTasks /* parallelism */, + 0 /* subtask index */); + + // testHarness.setTimeCharacteristic(TimeCharacteristic.EventTime); + + testHarness.open(); + + OperatorSubtaskState snapshot = testHarness.snapshot(0, 0); + + UnboundedSourceWrapper, TestCountingSource.CounterMark> + restoredFlinkWrapper = + new UnboundedSourceWrapper<>( + "stepName", options, new TestCountingSource(numElements), numSplits); + + StreamSource< + WindowedValue>>, + UnboundedSourceWrapper, TestCountingSource.CounterMark>> + restoredSourceOperator = new StreamSource<>(restoredFlinkWrapper); + + // set parallelism to 1 to ensure that our testing operator gets all checkpointed state + AbstractStreamOperatorTestHarness>>> + restoredTestHarness = + new AbstractStreamOperatorTestHarness<>( + restoredSourceOperator, + numTasks /* max parallelism */, + 1 /* parallelism */, + 0 /* subtask index */); + + restoredTestHarness.setup(); + restoredTestHarness.initializeState(snapshot); + restoredTestHarness.open(); + + // when the source checkpointed a null we don't re-initialize the splits, that is we + // will have no splits. + assertEquals(0, restoredFlinkWrapper.getLocalSplitSources().size()); + } + + /** A special {@link RuntimeException} that we throw to signal that the test was successful. */ + private static class SuccessException extends RuntimeException {} + } + + /** Not parameterized tests. */ + @RunWith(JUnit4.class) + public static class BasicTest { + + /** Check serialization a {@link UnboundedSourceWrapper}. */ + @Test + public void testSerialization() throws Exception { + final int parallelism = 1; + final int numElements = 20; + PipelineOptions options = PipelineOptionsFactory.create(); + + TestCountingSource source = new TestCountingSource(numElements); + UnboundedSourceWrapper, TestCountingSource.CounterMark> flinkWrapper = + new UnboundedSourceWrapper<>("stepName", options, source, parallelism); + + InstantiationUtil.serializeObject(flinkWrapper); + } + + @Test(timeout = 10_000) + public void testSourceWithNoReaderDoesNotShutdown() throws Exception { + testSourceDoesNotShutdown(false); + } + + @Test(timeout = 10_000) + public void testSourceWithReadersDoesNotShutdown() throws Exception { + testSourceDoesNotShutdown(true); + } + + private static void testSourceDoesNotShutdown(boolean shouldHaveReaders) throws Exception { + final int parallelism = 2; + FlinkPipelineOptions options = FlinkPipelineOptions.defaults(); + // Make sure we do not shut down + options.setShutdownSourcesAfterIdleMs(Long.MAX_VALUE); + + TestCountingSource source = new TestCountingSource(20).withoutSplitting(); + + UnboundedSourceWrapper, TestCountingSource.CounterMark> sourceWrapper = + new UnboundedSourceWrapper<>("noReader", options, source, parallelism); + + StreamingRuntimeContext mock = Mockito.mock(StreamingRuntimeContext.class); + TaskInfo mockTaskInfo = Mockito.mock(TaskInfo.class); + if (shouldHaveReaders) { + // Since the source can't be split, the first subtask index will read everything + Mockito.when(mockTaskInfo.getIndexOfThisSubtask()).thenReturn(0); + } else { + // Set up the RuntimeContext such that this instance won't receive any readers + Mockito.when(mockTaskInfo.getIndexOfThisSubtask()).thenReturn(parallelism - 1); + } + + Mockito.when(mockTaskInfo.getNumberOfParallelSubtasks()).thenReturn(parallelism); + Mockito.when(mock.getTaskInfo()).thenReturn(mockTaskInfo); + ProcessingTimeService timerService = Mockito.mock(ProcessingTimeService.class); + Mockito.when(timerService.getCurrentProcessingTime()).thenReturn(Long.MAX_VALUE); + Mockito.when(mock.getProcessingTimeService()).thenReturn(timerService); + Mockito.when(mock.getJobConfiguration()).thenReturn(new Configuration()); + Mockito.when(mock.getMetricGroup()) + .thenReturn(UnregisteredMetricGroups.createUnregisteredOperatorMetricGroup()); + sourceWrapper.setRuntimeContext(mock); + sourceWrapper.open(new DefaultOpenContext()); + + SourceFunction.SourceContext sourceContext = Mockito.mock(SourceFunction.SourceContext.class); + Object checkpointLock = new Object(); + Mockito.when(sourceContext.getCheckpointLock()).thenReturn(checkpointLock); + // Initialize source context early to avoid concurrency issues with its initialization in the + // run + // method and the onProcessingTime call on the wrapper. + sourceWrapper.setSourceContext(sourceContext); + + sourceWrapper.open(new DefaultOpenContext()); + assertThat(sourceWrapper.getLocalReaders().isEmpty(), is(!shouldHaveReaders)); + + Thread thread = + new Thread( + () -> { + try { + sourceWrapper.run(sourceContext); + } catch (Exception e) { + LOG.error("Error while running UnboundedSourceWrapper", e); + } + }); + + try { + thread.start(); + // Wait to see if the wrapper shuts down immediately in case it doesn't have readers + if (!shouldHaveReaders) { + // The expected state is for finalizeSource to sleep instead of exiting + while (true) { + StackTraceElement[] callStack = thread.getStackTrace(); + if (callStack.length >= 2 + && "sleep".equals(callStack[0].getMethodName()) + && "finalizeSource".equals(callStack[1].getMethodName())) { + break; + } + Thread.sleep(10); + } + } + // Source should still be running even if there are no readers + assertThat(sourceWrapper.isRunning(), is(true)); + synchronized (checkpointLock) { + // Trigger emission of the watermark by updating processing time. + // The actual processing time value does not matter. + sourceWrapper.onProcessingTime(42); + } + // Source should still be running even when watermark is at max + assertThat(sourceWrapper.isRunning(), is(true)); + assertThat(thread.isAlive(), is(true)); + sourceWrapper.cancel(); + } finally { + thread.interrupt(); + // try to join but also don't mask exceptions with test timeout + thread.join(1000); + } + assertThat(thread.isAlive(), is(false)); + } + + @Test + public void testSequentialReadingFromBoundedSource() throws Exception { + UnboundedReadFromBoundedSource.BoundedToUnboundedSourceAdapter source = + new UnboundedReadFromBoundedSource.BoundedToUnboundedSourceAdapter<>( + CountingSource.upTo(1000)); + + FlinkPipelineOptions options = FlinkPipelineOptions.defaults(); + + UnboundedSourceWrapper< + Long, UnboundedReadFromBoundedSource.BoundedToUnboundedSourceAdapter.Checkpoint> + sourceWrapper = new UnboundedSourceWrapper<>("sequentialRead", options, source, 4); + + StreamingRuntimeContext runtimeContextMock = Mockito.mock(StreamingRuntimeContext.class); + TaskInfo mockTaskInfo = Mockito.mock(TaskInfo.class); + Mockito.when(mockTaskInfo.getIndexOfThisSubtask()).thenReturn(0); + when(mockTaskInfo.getNumberOfParallelSubtasks()).thenReturn(2); + Mockito.when(runtimeContextMock.getTaskInfo()).thenReturn(mockTaskInfo); + + TestProcessingTimeService processingTimeService = new TestProcessingTimeService(); + processingTimeService.setCurrentTime(0); + when(runtimeContextMock.getProcessingTimeService()).thenReturn(processingTimeService); + when(runtimeContextMock.getJobConfiguration()).thenReturn(new Configuration()); + when(runtimeContextMock.getMetricGroup()) + .thenReturn(UnregisteredMetricGroups.createUnregisteredOperatorMetricGroup()); + + sourceWrapper.setRuntimeContext(runtimeContextMock); + + sourceWrapper.open(new DefaultOpenContext()); + assertThat(sourceWrapper.getLocalReaders().size(), is(2)); + + List integers = new ArrayList<>(); + sourceWrapper.run( + new SourceFunction.SourceContext>>() { + private final Object checkpointLock = new Object(); + + @Override + public void collect(WindowedValue> element) { + integers.add(element.getValue().getValue()); + } + + @Override + public void collectWithTimestamp( + WindowedValue> element, long timestamp) { + throw new IllegalStateException("Should not collect with timestamp"); + } + + @Override + public void emitWatermark(Watermark mark) {} + + @Override + public void markAsTemporarilyIdle() {} + + @Override + public Object getCheckpointLock() { + return checkpointLock; + } + + @Override + public void close() {} + }); + + // The source is effectively split into two parts: The initial splitting is performed with a + // parallelism of 4, but there are 2 parallel subtasks. This instances taskes 2 out of 4 + // partitions. + assertThat(integers.size(), is(500)); + assertThat( + integers, + contains( + LongStream.concat(LongStream.range(0, 250), LongStream.range(500, 750)) + .boxed() + .toArray())); + } + + @Test + public void testAccumulatorRegistrationOnOperatorClose() throws Exception { + FlinkPipelineOptions options = FlinkPipelineOptions.defaults(); + + TestCountingSource source = new TestCountingSource(20).withoutSplitting(); + + UnboundedSourceWrapper, TestCountingSource.CounterMark> sourceWrapper = + new UnboundedSourceWrapper<>("noReader", options, source, 2); + + StreamingRuntimeContext mock = Mockito.mock(StreamingRuntimeContext.class); + TaskInfo mockTaskInfo = Mockito.mock(TaskInfo.class); + Mockito.when(mockTaskInfo.getNumberOfParallelSubtasks()).thenReturn(1); + Mockito.when(mockTaskInfo.getIndexOfThisSubtask()).thenReturn(0); + Mockito.when(mock.getTaskInfo()).thenReturn(mockTaskInfo); + sourceWrapper.setRuntimeContext(mock); + + sourceWrapper.open(new DefaultOpenContext()); + + String metricContainerFieldName = "metricContainer"; + FlinkMetricContainer monitoredContainer = + Mockito.spy( + (FlinkMetricContainer) + Whitebox.getInternalState(sourceWrapper, metricContainerFieldName)); + Whitebox.setInternalState(sourceWrapper, metricContainerFieldName, monitoredContainer); + + sourceWrapper.close(); + Mockito.verify(monitoredContainer).registerMetricsForPipelineResult(); + } + } + + @RunWith(JUnit4.class) + public static class IntegrationTests { + + /** Tests that idle readers are polled for more data after having returned no data. */ + @Test(timeout = 30_000) + public void testPollingOfIdleReaders() throws Exception { + IdlingUnboundedSource source = + new IdlingUnboundedSource<>( + Arrays.asList("first", "second", "third"), StringUtf8Coder.of()); + + FlinkPipelineOptions options = FlinkPipelineOptions.defaults(); + options.setShutdownSourcesAfterIdleMs(0L); + options.setParallelism(4); + + UnboundedSourceWrapper wrappedSource = + new UnboundedSourceWrapper<>("sequentialRead", options, source, 4); + + StreamSource< + WindowedValue>, + UnboundedSourceWrapper> + sourceOperator = new StreamSource<>(wrappedSource); + AbstractStreamOperatorTestHarness>> testHarness = + new AbstractStreamOperatorTestHarness<>(sourceOperator, 4, 4, 0); + // testHarness.setTimeCharacteristic(TimeCharacteristic.EventTime); + testHarness.getExecutionConfig().setAutoWatermarkInterval(10L); + + testHarness.open(); + ArrayList output = new ArrayList<>(); + + Thread processingTimeUpdateThread = startProcessingTimeUpdateThread(testHarness); + + StreamSources.run( + sourceOperator, + testHarness.getCheckpointLock(), + new StreamSources.OutputWrapper< + StreamRecord>>>() { + @Override + public void emitWatermark(Watermark mark) {} + + @Override + public void emitWatermark(WatermarkEvent watermark) {} + + @Override + public void emitLatencyMarker(LatencyMarker latencyMarker) {} + + @Override + public void collect(OutputTag outputTag, StreamRecord record) { + throw new IllegalStateException(); + } + + @Override + public void collect(StreamRecord>> record) { + output.add(record.getValue().getValue().getValue()); + } + + @Override + public void close() {} + }); + + // Two idles in between elements + one after end of input. + assertThat(source.getNumIdles(), is(3)); + assertThat(output, contains("first", "second", "third")); + + processingTimeUpdateThread.interrupt(); + processingTimeUpdateThread.join(); + } + } + + private static Thread startProcessingTimeUpdateThread( + AbstractStreamOperatorTestHarness testHarness) { + // start a thread that advances processing time, so that we eventually get the final + // watermark which is only updated via a processing-time trigger + Thread processingTimeUpdateThread = + new Thread() { + @Override + public void run() { + while (true) { + try { + // Need to advance this so that the watermark timers in the source wrapper fire + // Synchronize is necessary because this can interfere with updating the + // PriorityQueue of the ProcessingTimeService which is accessed when setting + // timers in UnboundedSourceWrapper. + synchronized (testHarness.getCheckpointLock()) { + testHarness.setProcessingTime(System.currentTimeMillis()); + } + Thread.sleep(10); + } catch (InterruptedException e) { + // this is ok + break; + } catch (Exception e) { + LOG.error("Unexpected error advancing processing time", e); + break; + } + } + } + }; + processingTimeUpdateThread.start(); + return processingTimeUpdateThread; + } + + /** + * Source that advances on every second call to {@link UnboundedReader#advance()}. + * + * @param Type of elements. + */ + private static class IdlingUnboundedSource + extends UnboundedSource { + + private final ConcurrentHashMap numIdles = new ConcurrentHashMap<>(); + + private final String uuid = UUID.randomUUID().toString(); + + private final List data; + private final Coder outputCoder; + + public IdlingUnboundedSource(List data, Coder outputCoder) { + this.data = data; + this.outputCoder = outputCoder; + } + + @Override + public List> split( + int desiredNumSplits, PipelineOptions options) { + return Collections.singletonList(this); + } + + @Override + public UnboundedReader createReader( + PipelineOptions options, @Nullable CheckpointMark checkpointMark) { + return new UnboundedReader() { + + private int currentIdx = -1; + private boolean lastAdvanced = false; + + @Override + public boolean start() { + return advance(); + } + + @Override + public boolean advance() { + if (lastAdvanced) { + // Idle for this call. + numIdles.merge(uuid, 1, Integer::sum); + lastAdvanced = false; + return false; + } + if (currentIdx < data.size() - 1) { + currentIdx++; + lastAdvanced = true; + return true; + } + return false; + } + + @Override + public Instant getWatermark() { + if (currentIdx >= data.size() - 1) { + return BoundedWindow.TIMESTAMP_MAX_VALUE; + } + return new Instant(currentIdx); + } + + @Override + public CheckpointMark getCheckpointMark() { + return CheckpointMark.NOOP_CHECKPOINT_MARK; + } + + @Override + public UnboundedSource getCurrentSource() { + return IdlingUnboundedSource.this; + } + + @Override + public T getCurrent() throws NoSuchElementException { + if (currentIdx >= 0 && currentIdx < data.size()) { + return data.get(currentIdx); + } + throw new NoSuchElementException(); + } + + @Override + public Instant getCurrentTimestamp() throws NoSuchElementException { + if (currentIdx >= 0 && currentIdx < data.size()) { + return new Instant(currentIdx); + } + throw new NoSuchElementException(); + } + + @Override + public void close() { + // No-op. + } + }; + } + + @Override + public Coder getCheckpointMarkCoder() { + return null; + } + + @Override + public Coder getOutputCoder() { + return outputCoder; + } + + int getNumIdles() { + return numIdles.getOrDefault(uuid, 0); + } + } +} diff --git a/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/stableinput/BufferingDoFnRunnerTest.java b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/stableinput/BufferingDoFnRunnerTest.java new file mode 100644 index 000000000000..e0e6d5c00286 --- /dev/null +++ b/runners/flink/2.0/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/stableinput/BufferingDoFnRunnerTest.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.stableinput; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; + +import java.util.Collections; +import java.util.List; +import org.apache.beam.runners.core.DoFnRunner; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.flink.FlinkPipelineOptions; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.runtime.state.OperatorStateBackend; +import org.junit.Assert; +import org.junit.Test; +import org.mockito.Mockito; + +/** + * Tests for {@link BufferingDoFnRunner}. + * + *

For more tests see: + * + *

- {@link org.apache.beam.runners.flink.FlinkRequiresStableInputTest} + * + *

-{@link org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperatorTest} + * + *

- {@link BufferedElementsTest} + */ +@SuppressWarnings({ + "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) +}) +public class BufferingDoFnRunnerTest { + + @Test + public void testRestoreWithoutConcurrentCheckpoints() throws Exception { + BufferingDoFnRunner bufferingDoFnRunner = createBufferingDoFnRunner(1, Collections.emptyList()); + assertThat(bufferingDoFnRunner.currentStateIndex, is(0)); + assertThat(bufferingDoFnRunner.numCheckpointBuffers, is(2)); + } + + @Test + public void testRestoreWithoutConcurrentCheckpointsWithPendingCheckpoint() throws Exception { + BufferingDoFnRunner bufferingDoFnRunner; + + bufferingDoFnRunner = + createBufferingDoFnRunner( + 1, Collections.singletonList(new BufferingDoFnRunner.CheckpointIdentifier(0, 1000))); + assertThat(bufferingDoFnRunner.currentStateIndex, is(1)); + assertThat(bufferingDoFnRunner.numCheckpointBuffers, is(2)); + + bufferingDoFnRunner = + createBufferingDoFnRunner( + 1, Collections.singletonList(new BufferingDoFnRunner.CheckpointIdentifier(1, 1000))); + assertThat(bufferingDoFnRunner.currentStateIndex, is(0)); + assertThat(bufferingDoFnRunner.numCheckpointBuffers, is(2)); + } + + @Test + public void + testRestoreWithoutConcurrentCheckpointsWithPendingCheckpointFromConcurrentCheckpointing() + throws Exception { + BufferingDoFnRunner bufferingDoFnRunner = + createBufferingDoFnRunner( + 1, Collections.singletonList(new BufferingDoFnRunner.CheckpointIdentifier(5, 42))); + assertThat(bufferingDoFnRunner.currentStateIndex, is(0)); + assertThat(bufferingDoFnRunner.numCheckpointBuffers, is(6)); + } + + @Test + public void testRestoreWithConcurrentCheckpoints() throws Exception { + BufferingDoFnRunner bufferingDoFnRunner = createBufferingDoFnRunner(2, Collections.emptyList()); + assertThat(bufferingDoFnRunner.currentStateIndex, is(0)); + assertThat(bufferingDoFnRunner.numCheckpointBuffers, is(3)); + } + + @Test + public void testRestoreWithConcurrentCheckpointsFromPendingCheckpoint() throws Exception { + BufferingDoFnRunner bufferingDoFnRunner; + + bufferingDoFnRunner = + createBufferingDoFnRunner( + 2, Collections.singletonList(new BufferingDoFnRunner.CheckpointIdentifier(0, 1000))); + assertThat(bufferingDoFnRunner.currentStateIndex, is(1)); + assertThat(bufferingDoFnRunner.numCheckpointBuffers, is(3)); + + bufferingDoFnRunner = + createBufferingDoFnRunner( + 2, Collections.singletonList(new BufferingDoFnRunner.CheckpointIdentifier(2, 1000))); + assertThat(bufferingDoFnRunner.currentStateIndex, is(0)); + assertThat(bufferingDoFnRunner.numCheckpointBuffers, is(3)); + } + + @Test + public void testRestoreWithConcurrentCheckpointsFromPendingCheckpoints() throws Exception { + BufferingDoFnRunner bufferingDoFnRunner; + + bufferingDoFnRunner = + createBufferingDoFnRunner( + 3, + ImmutableList.of( + new BufferingDoFnRunner.CheckpointIdentifier(0, 42), + new BufferingDoFnRunner.CheckpointIdentifier(1, 43))); + assertThat(bufferingDoFnRunner.currentStateIndex, is(2)); + assertThat(bufferingDoFnRunner.numCheckpointBuffers, is(4)); + + bufferingDoFnRunner = + createBufferingDoFnRunner( + 3, + ImmutableList.of( + new BufferingDoFnRunner.CheckpointIdentifier(2, 42), + new BufferingDoFnRunner.CheckpointIdentifier(3, 43))); + assertThat(bufferingDoFnRunner.currentStateIndex, is(0)); + assertThat(bufferingDoFnRunner.numCheckpointBuffers, is(4)); + } + + @Test + public void testRejectConcurrentCheckpointingBoundaries() { + Assert.assertThrows( + IllegalArgumentException.class, + () -> { + createBufferingDoFnRunner(0, Collections.emptyList()); + }); + Assert.assertThrows( + IllegalArgumentException.class, + () -> { + createBufferingDoFnRunner(Short.MAX_VALUE, Collections.emptyList()); + }); + } + + private static BufferingDoFnRunner createBufferingDoFnRunner( + int concurrentCheckpoints, + List notYetAcknowledgeCheckpoints) + throws Exception { + DoFnRunner doFnRunner = Mockito.mock(DoFnRunner.class); + OperatorStateBackend operatorStateBackend = Mockito.mock(OperatorStateBackend.class); + + // Setup not yet acknowledged checkpoint union list state + ListState unionListState = Mockito.mock(ListState.class); + Mockito.when(operatorStateBackend.getUnionListState(Mockito.any())) + .thenReturn(unionListState); + Mockito.when(unionListState.get()).thenReturn(notYetAcknowledgeCheckpoints); + + // Setup buffer list state + Mockito.when(operatorStateBackend.getListState(Mockito.any())) + .thenReturn(Mockito.mock(ListState.class)); + + return BufferingDoFnRunner.create( + doFnRunner, + "stable-input", + StringUtf8Coder.of(), + WindowedValues.getFullCoder(VarIntCoder.of(), GlobalWindow.Coder.INSTANCE), + operatorStateBackend, + null, + concurrentCheckpoints, + new SerializablePipelineOptions(FlinkPipelineOptions.defaults())); + } +} diff --git a/runners/flink/2.0/src/test/resources/flink-test-config.yaml b/runners/flink/2.0/src/test/resources/flink-test-config.yaml new file mode 100644 index 000000000000..d34134695dd6 --- /dev/null +++ b/runners/flink/2.0/src/test/resources/flink-test-config.yaml @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +taskmanager: + memory: + network: + max: 2gb + fraction: '0.2' + managed: + size: 1gb +parallelism: + default: '23' diff --git a/runners/flink/flink_runner.gradle b/runners/flink/flink_runner.gradle index 52f9631f455f..b6b0430509fb 100644 --- a/runners/flink/flink_runner.gradle +++ b/runners/flink/flink_runner.gradle @@ -27,8 +27,9 @@ import groovy.json.JsonOutput def base_path = ".." -def overrides(versions, type, base_path) { - versions.collect { "${base_path}/${it}/src/${type}/java" } + ["./src/${type}/java"] +def overrides(versions, type, base_path, group='java') { + // order is important + ["${base_path}/src/${type}/${group}"] + versions.collect { "${base_path}/${it}/src/${type}/${group}" } + ["./src/${type}/${group}"] } def all_versions = flink_versions.split(",") @@ -38,8 +39,8 @@ def previous_versions = all_versions.findAll { it < flink_major } // Version specific code overrides. def main_source_overrides = overrides(previous_versions, "main", base_path) def test_source_overrides = overrides(previous_versions, "test", base_path) -def main_resources_overrides = [] -def test_resources_overrides = [] +def main_resources_overrides = overrides(previous_versions, "main", base_path, "resources") +def test_resources_overrides = overrides(previous_versions, "test", base_path, "resources") def archivesBaseName = "beam-runners-flink-${flink_major}" @@ -49,7 +50,8 @@ applyJavaNature( automaticModuleName: 'org.apache.beam.runners.flink', archivesBaseName: archivesBaseName, // flink runner jars are in same package name. Publish javadoc once. - exportJavadoc: project.ext.flink_version.startsWith(all_versions.first()) + exportJavadoc: project.ext.flink_version.startsWith(all_versions.first()), + requireJavaVersion: project.ext.flink_major.startsWith('2') ? JavaVersion.VERSION_11 : null ) description = "Apache Beam :: Runners :: Flink $flink_version" @@ -68,10 +70,16 @@ evaluationDependsOn(":examples:java") */ def sourceOverridesBase = project.layout.buildDirectory.dir('source-overrides/src').get() -def copySourceOverrides = tasks.register('copySourceOverrides', Copy) { - it.from main_source_overrides - it.into "${sourceOverridesBase}/main/java" - it.duplicatesStrategy DuplicatesStrategy.INCLUDE +def copySourceOverrides = tasks.register('copySourceOverrides', Copy) { copyTask -> + copyTask.from main_source_overrides + copyTask.into "${sourceOverridesBase}/main/java" + copyTask.duplicatesStrategy DuplicatesStrategy.INCLUDE + + if (project.ext.has('excluded_files') && project.ext.excluded_files.containsKey('main')) { + project.ext.excluded_files.main.each { file -> + copyTask.exclude "**/${file}" + } + } } def copyResourcesOverrides = tasks.register('copyResourcesOverrides', Copy) { @@ -80,10 +88,16 @@ def copyResourcesOverrides = tasks.register('copyResourcesOverrides', Copy) { it.duplicatesStrategy DuplicatesStrategy.INCLUDE } -def copyTestSourceOverrides = tasks.register('copyTestSourceOverrides', Copy) { - it.from test_source_overrides - it.into "${sourceOverridesBase}/test/java" - it.duplicatesStrategy DuplicatesStrategy.INCLUDE +def copyTestSourceOverrides = tasks.register('copyTestSourceOverrides', Copy) { copyTask -> + copyTask.from test_source_overrides + copyTask.into "${sourceOverridesBase}/test/java" + copyTask.duplicatesStrategy DuplicatesStrategy.INCLUDE + + if (project.ext.has('excluded_files') && project.ext.excluded_files.containsKey('test')) { + project.ext.excluded_files.test.each { file -> + copyTask.exclude "**/${file}" + } + } } def copyTestResourcesOverrides = tasks.register('copyTestResourcesOverrides', Copy) { @@ -119,18 +133,18 @@ def sourceBase = "${project.projectDir}/../src" sourceSets { main { java { - srcDirs = ["${sourceBase}/main/java", "${sourceOverridesBase}/main/java"] + srcDirs = ["${sourceOverridesBase}/main/java"] } resources { - srcDirs = ["${sourceBase}/main/resources", "${sourceOverridesBase}/main/resources"] + srcDirs = ["${sourceOverridesBase}/main/resources"] } } test { java { - srcDirs = ["${sourceBase}/test/java", "${sourceOverridesBase}/test/java"] + srcDirs = ["${sourceOverridesBase}/test/java"] } resources { - srcDirs = ["${sourceBase}/test/resources", "${sourceOverridesBase}/test/resources"] + srcDirs = ["${sourceOverridesBase}/test/resources"] } } } @@ -196,7 +210,10 @@ dependencies { implementation "org.apache.flink:flink-core:$flink_version" implementation "org.apache.flink:flink-metrics-core:$flink_version" - implementation "org.apache.flink:flink-java:$flink_version" + if (project.ext.flink_major.startsWith('1')) { + // FLINK-36336: dataset API removed in Flink 2 + implementation "org.apache.flink:flink-java:$flink_version" + } implementation "org.apache.flink:flink-runtime:$flink_version" implementation "org.apache.flink:flink-metrics-core:$flink_version" diff --git a/runners/flink/job-server-container/flink_job_server_container.gradle b/runners/flink/job-server-container/flink_job_server_container.gradle index 3f30a1aac1fb..65f962428142 100644 --- a/runners/flink/job-server-container/flink_job_server_container.gradle +++ b/runners/flink/job-server-container/flink_job_server_container.gradle @@ -53,15 +53,19 @@ task copyDockerfileDependencies(type: Copy) { } def pushContainers = project.rootProject.hasProperty(["isRelease"]) || project.rootProject.hasProperty("push-containers") +def containerName = project.parent.name.startsWith("2") ? "flink_job_server" : "flink${project.parent.name}_job_server" +def containerTag = project.rootProject.hasProperty(["docker-tag"]) ? project.rootProject["docker-tag"] : project.sdk_version +if (project.parent.name.startsWith("2")) { + containerTag += "-flink" + project.parent.name +} docker { name containerImageName( - name: project.docker_image_default_repo_prefix + "flink${project.parent.name}_job_server", + name: project.docker_image_default_repo_prefix + containerName, root: project.rootProject.hasProperty(["docker-repository-root"]) ? project.rootProject["docker-repository-root"] : project.docker_image_default_repo_root, - tag: project.rootProject.hasProperty(["docker-tag"]) ? - project.rootProject["docker-tag"] : project.sdk_version) + tag: containerTag) // tags used by dockerTag task tags containerImageTags() files "./build/" diff --git a/runners/flink/job-server/flink_job_server.gradle b/runners/flink/job-server/flink_job_server.gradle index d8a818ff84c4..8788f9c7c2e2 100644 --- a/runners/flink/job-server/flink_job_server.gradle +++ b/runners/flink/job-server/flink_job_server.gradle @@ -29,6 +29,9 @@ apply plugin: 'application' // we need to set mainClassName before applying shadow plugin mainClassName = "org.apache.beam.runners.flink.FlinkJobServerDriver" +// Resolve the Flink project name (and version) the job-server is based on +def flinkRunnerProject = "${project.path.replace(":job-server", "")}" + applyJavaNature( automaticModuleName: 'org.apache.beam.runners.flink.jobserver', archivesBaseName: project.hasProperty('archives_base_name') ? archives_base_name : archivesBaseName, @@ -37,11 +40,9 @@ applyJavaNature( shadowClosure: { append "reference.conf" }, + requireJavaVersion: project(flinkRunnerProject).ext.flink_major.startsWith('2') ? JavaVersion.VERSION_11 : null ) -// Resolve the Flink project name (and version) the job-server is based on -def flinkRunnerProject = "${project.path.replace(":job-server", "")}" - description = project(flinkRunnerProject).description + " :: Job Server" /* diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPipelineTranslator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPipelineTranslator.java index b415c9b10559..626f5fe81110 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPipelineTranslator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPipelineTranslator.java @@ -119,7 +119,7 @@ public void visitPrimitiveTransform(TransformHierarchy.Node node) { } /** A translator of a {@link PTransform}. */ - public interface BatchTransformTranslator { + public interface BatchTransformTranslator> { default boolean canTranslate(TransformT transform, FlinkBatchTranslationContext context) { return true; diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineOptions.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineOptions.java index 901207a91f00..7e493ea1a98f 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineOptions.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineOptions.java @@ -303,11 +303,13 @@ public Long create(PipelineOptions options) { void setAutoWatermarkInterval(Long interval); + /** ExecutionMode is only effective for DataSet API and has been removed in Flink 2.0. */ + @Deprecated() @Description( "Flink mode for data exchange of batch pipelines. " + "Reference {@link org.apache.flink.api.common.ExecutionMode}. " + "Set this to BATCH_FORCED if pipelines get blocked, see " - + "https://issues.apache.org/jira/browse/FLINK-10672") + + "https://issues.apache.org/jira/browse/FLINK-10672.") @Default.String(PIPELINED) String getExecutionModeForBatch(); diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineRunner.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineRunner.java index c9559a392704..11175129d7ef 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineRunner.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineRunner.java @@ -84,7 +84,9 @@ public PortablePipelineResult run(final Pipeline pipeline, JobInfo jobInfo) thro SdkHarnessOptions.getConfiguredLoggerFromOptions(pipelineOptions.as(SdkHarnessOptions.class)); FlinkPortablePipelineTranslator translator; - if (!pipelineOptions.isStreaming() && !hasUnboundedPCollections(pipeline)) { + if (!pipelineOptions.getUseDataStreamForBatch() + && !pipelineOptions.isStreaming() + && !hasUnboundedPCollections(pipeline)) { // TODO: Do we need to inspect for unbounded sources before fusing? translator = FlinkBatchPortablePipelineTranslator.createTranslator(); } else { diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java index 79a90c554027..b3b40d2874a7 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java @@ -103,13 +103,13 @@ import org.apache.flink.api.common.functions.RichFlatMapFunction; import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.operators.ProcessingTimeService.ProcessingTimeCallback; +import org.apache.flink.api.common.state.CheckpointListener; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.ResultTypeQueryable; import org.apache.flink.api.java.typeutils.ValueTypeInfo; import org.apache.flink.configuration.Configuration; -import org.apache.flink.runtime.state.CheckpointListener; import org.apache.flink.runtime.state.FunctionInitializationContext; import org.apache.flink.runtime.state.FunctionSnapshotContext; import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/FlinkMetricContainerBase.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/FlinkMetricContainerBase.java index a9a6db47c814..e54d5575a479 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/FlinkMetricContainerBase.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/FlinkMetricContainerBase.java @@ -50,7 +50,9 @@ abstract class FlinkMetricContainerBase { private static final String METRIC_KEY_SEPARATOR = - GlobalConfiguration.loadConfiguration().getString(MetricOptions.SCOPE_DELIMITER); + GlobalConfiguration.loadConfiguration() + .getOptional(MetricOptions.SCOPE_DELIMITER) + .orElseGet(MetricOptions.SCOPE_DELIMITER::defaultValue); protected final MetricsContainerStepMap metricsContainers; private final Map flinkCounterCache; diff --git a/runners/flink/1.17/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java similarity index 100% rename from runners/flink/1.17/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java rename to runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java index 7811f1f85a67..1d50fd72d465 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java @@ -45,11 +45,11 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.operators.ProcessingTimeService.ProcessingTimeCallback; +import org.apache.flink.api.common.state.CheckpointListener; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.state.OperatorStateStore; import org.apache.flink.configuration.Configuration; -import org.apache.flink.runtime.state.CheckpointListener; import org.apache.flink.runtime.state.DefaultOperatorStateBackend; import org.apache.flink.runtime.state.FunctionInitializationContext; import org.apache.flink.runtime.state.FunctionSnapshotContext; diff --git a/runners/flink/1.17/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java similarity index 100% rename from runners/flink/1.17/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java rename to runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java diff --git a/runners/flink/1.17/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java similarity index 100% rename from runners/flink/1.17/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java rename to runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java diff --git a/settings.gradle.kts b/settings.gradle.kts index bea48565bfc4..de3fd7146501 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -264,6 +264,8 @@ include(":sdks:java:javadoc") include(":sdks:java:maven-archetypes:examples") include(":sdks:java:maven-archetypes:gcp-bom-examples") include(":sdks:java:maven-archetypes:starter") +include("sdks:java:ml:inference:remote") +include("sdks:java:ml:inference:openai") include(":sdks:java:testing:nexmark") include(":sdks:java:testing:expansion-service") include(":sdks:java:testing:jpms-tests") @@ -281,44 +283,37 @@ include(":sdks:python") include(":sdks:python:apache_beam:testing:load_tests") include(":sdks:python:apache_beam:testing:benchmarks:nexmark") include(":sdks:python:container") -include(":sdks:python:container:py39") include(":sdks:python:container:py310") include(":sdks:python:container:py311") include(":sdks:python:container:py312") include(":sdks:python:container:py313") include(":sdks:python:container:distroless") -include(":sdks:python:container:distroless:py39") include(":sdks:python:container:distroless:py310") include(":sdks:python:container:distroless:py311") include(":sdks:python:container:distroless:py312") include(":sdks:python:container:distroless:py313") include(":sdks:python:container:ml") -include(":sdks:python:container:ml:py39") include(":sdks:python:container:ml:py310") include(":sdks:python:container:ml:py311") include(":sdks:python:container:ml:py312") include(":sdks:python:container:ml:py313") include(":sdks:python:expansion-service-container") include(":sdks:python:test-suites:dataflow") -include(":sdks:python:test-suites:dataflow:py39") include(":sdks:python:test-suites:dataflow:py310") include(":sdks:python:test-suites:dataflow:py311") include(":sdks:python:test-suites:dataflow:py312") include(":sdks:python:test-suites:dataflow:py313") include(":sdks:python:test-suites:direct") -include(":sdks:python:test-suites:direct:py39") include(":sdks:python:test-suites:direct:py310") include(":sdks:python:test-suites:direct:py311") include(":sdks:python:test-suites:direct:py312") include(":sdks:python:test-suites:direct:py313") include(":sdks:python:test-suites:direct:xlang") -include(":sdks:python:test-suites:portable:py39") include(":sdks:python:test-suites:portable:py310") include(":sdks:python:test-suites:portable:py311") include(":sdks:python:test-suites:portable:py312") include(":sdks:python:test-suites:portable:py313") include(":sdks:python:test-suites:tox:pycommon") -include(":sdks:python:test-suites:tox:py39") include(":sdks:python:test-suites:tox:py310") include(":sdks:python:test-suites:tox:py311") include(":sdks:python:test-suites:tox:py312") @@ -376,6 +371,3 @@ include("sdks:java:extensions:sql:iceberg") findProject(":sdks:java:extensions:sql:iceberg")?.name = "iceberg" include("examples:java:iceberg") findProject(":examples:java:iceberg")?.name = "iceberg" - -include("sdks:java:ml:inference:remote") -include("sdks:java:ml:inference:openai") diff --git a/website/www/site/layouts/shortcodes/flink_java_pipeline_options.html b/website/www/site/layouts/shortcodes/flink_java_pipeline_options.html index a3526d7d0d28..34d6c5243776 100644 --- a/website/www/site/layouts/shortcodes/flink_java_pipeline_options.html +++ b/website/www/site/layouts/shortcodes/flink_java_pipeline_options.html @@ -64,7 +64,7 @@ executionModeForBatch - Flink mode for data exchange of batch pipelines. Reference {@link org.apache.flink.api.common.ExecutionMode}. Set this to BATCH_FORCED if pipelines get blocked, see https://issues.apache.org/jira/browse/FLINK-10672 + Flink mode for data exchange of batch pipelines. Reference {@link org.apache.flink.api.common.ExecutionMode}. Set this to BATCH_FORCED if pipelines get blocked, see https://issues.apache.org/jira/browse/FLINK-10672. Default: PIPELINED @@ -77,11 +77,6 @@ Enables or disables externalized checkpoints. Works in conjunction with CheckpointingInterval Default: false - - failOnCheckpointingErrors - Sets the expected behaviour for tasks in case that they encounter an error in their checkpointing procedure. If this is set to true, the task will fail on checkpointing error. If this is set to false, the task will only decline the checkpoint and continue running. - Default: true - fasterCopy Remove unneeded deep copy between operators. See https://issues.apache.org/jira/browse/BEAM-11146 @@ -172,11 +167,6 @@ The degree of parallelism to be used when distributing operations onto workers. If the parallelism is not set, the configured Flink default is used, or 1 if none can be found. Default: -1 - - reIterableGroupByKeyResult - Flag indicating whether result of GBK needs to be re-iterable. Re-iterable result implies that all values for a single key must fit in memory as we currently do not support spilling to disk. - Default: false - reportCheckpointDuration If not null, reports the checkpoint duration of each ParDo stage in the provided metric namespace. @@ -199,7 +189,7 @@ stateBackend - State backend to store Beam's state. Use 'rocksdb' or 'filesystem'. + State backend to store Beam's state. Use 'rocksdb' or 'hashmap' (same as 'filesystem'). @@ -212,6 +202,11 @@ State backend path to persist state backend data. Used to initialize state backend. + + tolerableCheckpointFailureNumber + Sets the expected behaviour for tasks in case that they encounter an error in their checkpointing procedure. To tolerate a specific number of failures, set it to a positive number. + Default: 0 + unalignedCheckpointEnabled If set, Unaligned checkpoints contain in-flight data (i.e., data stored in buffers) as part of the checkpoint state, allowing checkpoint barriers to overtake these buffers. Thus, the checkpoint duration becomes independent of the current throughput as checkpoint barriers are effectively not embedded into the stream of data anymore diff --git a/website/www/site/layouts/shortcodes/flink_python_pipeline_options.html b/website/www/site/layouts/shortcodes/flink_python_pipeline_options.html index 183dacfd5a09..e3fe24216a54 100644 --- a/website/www/site/layouts/shortcodes/flink_python_pipeline_options.html +++ b/website/www/site/layouts/shortcodes/flink_python_pipeline_options.html @@ -64,7 +64,7 @@ execution_mode_for_batch - Flink mode for data exchange of batch pipelines. Reference {@link org.apache.flink.api.common.ExecutionMode}. Set this to BATCH_FORCED if pipelines get blocked, see https://issues.apache.org/jira/browse/FLINK-10672 + Flink mode for data exchange of batch pipelines. Reference {@link org.apache.flink.api.common.ExecutionMode}. Set this to BATCH_FORCED if pipelines get blocked, see https://issues.apache.org/jira/browse/FLINK-10672. Default: PIPELINED @@ -77,11 +77,6 @@ Enables or disables externalized checkpoints. Works in conjunction with CheckpointingInterval Default: false - - fail_on_checkpointing_errors - Sets the expected behaviour for tasks in case that they encounter an error in their checkpointing procedure. If this is set to true, the task will fail on checkpointing error. If this is set to false, the task will only decline the checkpoint and continue running. - Default: true - faster_copy Remove unneeded deep copy between operators. See https://issues.apache.org/jira/browse/BEAM-11146 @@ -172,11 +167,6 @@ The degree of parallelism to be used when distributing operations onto workers. If the parallelism is not set, the configured Flink default is used, or 1 if none can be found. Default: -1 - - re_iterable_group_by_key_result - Flag indicating whether result of GBK needs to be re-iterable. Re-iterable result implies that all values for a single key must fit in memory as we currently do not support spilling to disk. - Default: false - report_checkpoint_duration If not null, reports the checkpoint duration of each ParDo stage in the provided metric namespace. @@ -199,7 +189,7 @@ state_backend - State backend to store Beam's state. Use 'rocksdb' or 'filesystem'. + State backend to store Beam's state. Use 'rocksdb' or 'hashmap' (same as 'filesystem'). @@ -212,6 +202,11 @@ State backend path to persist state backend data. Used to initialize state backend. + + tolerable_checkpoint_failure_number + Sets the expected behaviour for tasks in case that they encounter an error in their checkpointing procedure. To tolerate a specific number of failures, set it to a positive number. + Default: 0 + unaligned_checkpoint_enabled If set, Unaligned checkpoints contain in-flight data (i.e., data stored in buffers) as part of the checkpoint state, allowing checkpoint barriers to overtake these buffers. Thus, the checkpoint duration becomes independent of the current throughput as checkpoint barriers are effectively not embedded into the stream of data anymore