Skip to content

Commit 9987114

Browse files
committed
Merge branch 'apache:master' into users/elia/stateless-smart-bucketing
2 parents fc66805 + b7577e7 commit 9987114

14 files changed

Lines changed: 471 additions & 223 deletions

.github/workflows/beam_PostCommit_Java_ValidatesRunner_Flink.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,15 @@ env:
5050

5151
jobs:
5252
beam_PostCommit_Java_ValidatesRunner_Flink:
53-
name: ${{ matrix.job_name }} (${{ matrix.job_phrase }})
53+
name: ${{ matrix.job_name }} (${{ matrix.flink_version }})
5454
runs-on: [self-hosted, ubuntu-20.04, main]
5555
timeout-minutes: 100
5656
strategy:
5757
matrix:
5858
job_name: [beam_PostCommit_Java_ValidatesRunner_Flink]
5959
job_phrase: [Run Flink ValidatesRunner]
60+
# every major version
61+
flink_version: ['1.20', '2.0']
6062
if: |
6163
github.event_name == 'workflow_dispatch' ||
6264
github.event_name == 'pull_request_target' ||
@@ -69,7 +71,7 @@ jobs:
6971
with:
7072
comment_phrase: ${{ matrix.job_phrase }}
7173
github_token: ${{ secrets.GITHUB_TOKEN }}
72-
github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }})
74+
github_job: ${{ matrix.job_name }} (${{ matrix.flink_version }})
7375
- name: Setup environment
7476
uses: ./.github/actions/setup-environment-action
7577
with:
@@ -78,7 +80,7 @@ jobs:
7880
- name: run validatesRunner script
7981
uses: ./.github/actions/gradle-command-self-hosted-action
8082
with:
81-
gradle-command: :runners:flink:1.20:validatesRunner
83+
gradle-command: :runners:flink:${{ matrix.flink_version }}:validatesRunner
8284
- name: Archive JUnit Test Results
8385
uses: actions/upload-artifact@v4
8486
if: ${{ !success() }}

.github/workflows/run_rc_validation_go_wordcount.yml

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,18 @@ name: Validate Go SDK Release Candidate
33
on:
44
workflow_dispatch:
55
inputs:
6-
rc_tag:
7-
description: 'Beam RC Tag (e.g., v2.59.0-RC1)'
6+
RELEASE_VER:
7+
description: 'Beam Release Version (e.g., 2.69.0)'
88
required: true
9-
type: string
10-
container_tag:
11-
description: 'Beam Go SDK Container Tag (e.g., 2.59.0rc1)'
9+
default: '2.69.0'
10+
RC_NUM:
11+
description: 'Release Candidate number (e.g., 1)'
1212
required: true
13-
type: string
13+
default: '1'
1414

1515
# This allows a subsequently queued workflow run to interrupt previous runs
1616
concurrency:
17-
group: '${{ github.workflow }} @ ${{ github.event.inputs.rc_tag }}' # Group by RC tag
17+
group: '${{ github.workflow }}'
1818
cancel-in-progress: true
1919

2020
# Setting explicit permissions (copied from Java Mobile Gaming workflow)
@@ -40,6 +40,8 @@ env:
4040
GCS_TEMP_LOCATION: gs://rc-validation-migration-tests/temp/
4141
GCS_STAGING_LOCATION: gs://rc-validation-migration-tests/staging/
4242
GCS_INPUT_PATH: gs://apache-beam-samples/shakespeare/kinglear.txt
43+
CONTAINER_TAG: "${{github.event.inputs.RELEASE_VER}}rc${{github.event.inputs.RC_NUM}}"
44+
RC_TAG: "v${{github.event.inputs.RELEASE_VER}}-RC${{github.event.inputs.RC_NUM}}"
4345

4446
jobs:
4547
validate-rc-package:
@@ -61,7 +63,7 @@ jobs:
6163
wget -O $TEMP_DIR/wordcount.go https://raw.githubusercontent.com/apache/beam/refs/heads/master/sdks/go/examples/wordcount/wordcount.go
6264
cd $TEMP_DIR
6365
go mod init rc-test
64-
go get github.com/apache/beam/sdks/v2/go/pkg/beam@${{ github.event.inputs.rc_tag }}
66+
go get github.com/apache/beam/sdks/v2/go/pkg/beam@${{ env.RC_TAG }}
6567
go mod tidy
6668
echo "work_dir=$TEMP_DIR" >> $GITHUB_OUTPUT # Output relative path
6769
@@ -97,7 +99,7 @@ jobs:
9799
working-directory: ./${{ steps.setup_go.outputs.work_dir }}
98100
env:
99101
# Define output path based on constant prefix and RC tag for uniqueness
100-
GCS_OUTPUT_PATH: ${{ env.GCS_OUTPUT_PREFIX }}/${{ github.event.inputs.rc_tag }}/dataflow/output
102+
GCS_OUTPUT_PATH: ${{ env.GCS_OUTPUT_PREFIX }}/${{ env.RC_TAG }}/dataflow/output
101103
run: |
102104
echo "Using output path: $GCS_OUTPUT_PATH"
103105
go run wordcount.go \
@@ -109,13 +111,13 @@ jobs:
109111
--temp_location=${{ env.GCS_TEMP_LOCATION }} \
110112
--staging_location=${{ env.GCS_STAGING_LOCATION }} \
111113
--environment_type=DOCKER \
112-
--environment_config=apache/beam_go_sdk:${{ github.event.inputs.container_tag }}
114+
--environment_config=apache/beam_go_sdk:${{ env.CONTAINER_TAG }}
113115
114116
- name: Check Dataflow Output in GCS
115117
working-directory: ./${{ steps.setup_go.outputs.work_dir }} # Added working directory for consistency, though not strictly needed for gsutil
116118
env:
117119
# Re-define the output path pattern for checking
118-
GCS_OUTPUT_PATH_PATTERN: ${{ env.GCS_OUTPUT_PREFIX }}/${{ github.event.inputs.rc_tag }}/dataflow/output*
120+
GCS_OUTPUT_PATH_PATTERN: ${{ env.GCS_OUTPUT_PREFIX }}/${{ env.RC_TAG }}/dataflow/output*
119121
run: |
120122
echo "Checking for Dataflow output files in GCS at: $GCS_OUTPUT_PATH_PATTERN"
121123
# Use gsutil stat. The -q flag suppresses errors for non-existent files,

CHANGES.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@
6161

6262
* New highly anticipated feature X added to Python SDK ([#X](https://github.com/apache/beam/issues/X)).
6363
* New highly anticipated feature Y added to Java SDK ([#Y](https://github.com/apache/beam/issues/Y)).
64+
* Flink 2.0 support for Java classic Flink runner ([#36947](https://github.com/apache/beam/issues/36947)).
65+
Also added intial, experimental support for Portable Flink runner since this Beam version.
66+
6467

6568
## I/Os
6669

@@ -70,7 +73,6 @@
7073
## New Features / Improvements
7174

7275
* (Python) Added exception chaining to preserve error context in CloudSQLEnrichmentHandler, processes utilities, and core transforms ([#37422](https://github.com/apache/beam/issues/37422)).
73-
* (Python) Added `take(n)` convenience for PCollection: `beam.take(n)` and `pcoll.take(n)` to get the first N elements deterministically without Top.Of + FlatMap ([#X](https://github.com/apache/beam/issues/37429)).
7476
* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
7577

7678
## Breaking Changes

sdks/go.mod

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ require (
5454
github.com/tetratelabs/wazero v1.11.0
5555
github.com/xitongsys/parquet-go v1.6.2
5656
github.com/xitongsys/parquet-go-source v0.0.0-20241021075129-b732d2ac9c9b
57-
go.mongodb.org/mongo-driver v1.17.7
57+
go.mongodb.org/mongo-driver v1.17.8
5858
golang.org/x/net v0.49.0
59-
golang.org/x/oauth2 v0.34.0
59+
golang.org/x/oauth2 v0.35.0
6060
golang.org/x/sync v0.19.0
6161
golang.org/x/sys v0.40.0
6262
golang.org/x/text v0.33.0

sdks/go.sum

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1476,8 +1476,8 @@ github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxt
14761476
go.einride.tech/aip v0.73.0 h1:bPo4oqBo2ZQeBKo4ZzLb1kxYXTY1ysJhpvQyfuGzvps=
14771477
go.einride.tech/aip v0.73.0/go.mod h1:Mj7rFbmXEgw0dq1dqJ7JGMvYCZZVxmGOR3S4ZcV5LvQ=
14781478
go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
1479-
go.mongodb.org/mongo-driver v1.17.7 h1:a9w+U3Vt67eYzcfq3k/OAv284/uUUkL0uP75VE5rCOU=
1480-
go.mongodb.org/mongo-driver v1.17.7/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ=
1479+
go.mongodb.org/mongo-driver v1.17.8 h1:BDP3+U3Y8K0vTrpqDJIRaXNhb/bKyoVeg6tIJsW5EhM=
1480+
go.mongodb.org/mongo-driver v1.17.8/go.mod h1:LlOhpH5NUEfhxcAwG0UEkMqwYcc4JU18gtCdGudk/tQ=
14811481
go.opencensus.io v0.15.0/go.mod h1:UffZAU+4sDEINUGP/B7UfBBkq4fqLu9zXAX7ke6CHW0=
14821482
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
14831483
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
@@ -1727,8 +1727,8 @@ golang.org/x/oauth2 v0.4.0/go.mod h1:RznEsdpjGAINPTOF0UH/t+xJ75L18YO3Ho6Pyn+uRec
17271727
golang.org/x/oauth2 v0.5.0/go.mod h1:9/XBHVqLaWO3/BRHs5jbpYCnOZVjj5V0ndyaAM7KB4I=
17281728
golang.org/x/oauth2 v0.6.0/go.mod h1:ycmewcwgD4Rpr3eZJLSB4Kyyljb3qDh40vJ8STE5HKw=
17291729
golang.org/x/oauth2 v0.7.0/go.mod h1:hPLQkd9LyjfXTiRohC/41GhcFqxisoUQ99sCUOHO9x4=
1730-
golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw=
1731-
golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
1730+
golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ=
1731+
golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
17321732
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
17331733
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
17341734
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=

sdks/python/apache_beam/ml/inference/base.py

Lines changed: 90 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,11 @@
6868
try:
6969
# pylint: disable=wrong-import-order, wrong-import-position
7070
import resource
71+
72+
from apache_beam.ml.inference.model_manager import ModelManager
7173
except ImportError:
7274
resource = None # type: ignore[assignment]
75+
ModelManager = None # type: ignore[assignment]
7376

7477
_NANOSECOND_TO_MILLISECOND = 1_000_000
7578
_NANOSECOND_TO_MICROSECOND = 1_000
@@ -533,11 +536,12 @@ def request(
533536
raise NotImplementedError(type(self))
534537

535538

536-
class _ModelManager:
539+
class _ModelHandlerManager:
537540
"""
538-
A class for efficiently managing copies of multiple models. Will load a
539-
single copy of each model into a multi_process_shared object and then
540-
return a lookup key for that object.
541+
A class for efficiently managing copies of multiple model handlers.
542+
Will load a single copy of each model from the model handler into a
543+
multi_process_shared object and then return a lookup key for that
544+
object. Used for KeyedModelHandler only.
541545
"""
542546
def __init__(self, mh_map: dict[str, ModelHandler]):
543547
"""
@@ -602,8 +606,9 @@ def load(self, key: str) -> _ModelLoadStats:
602606

603607
def increment_max_models(self, increment: int):
604608
"""
605-
Increments the number of models that this instance of a _ModelManager is
606-
able to hold. If it is never called, no limit is imposed.
609+
Increments the number of models that this instance of a
610+
_ModelHandlerManager is able to hold. If it is never called,
611+
no limit is imposed.
607612
Args:
608613
increment: the amount by which we are incrementing the number of models.
609614
"""
@@ -656,7 +661,7 @@ def __init__(
656661
class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
657662
ModelHandler[tuple[KeyT, ExampleT],
658663
tuple[KeyT, PredictionT],
659-
Union[ModelT, _ModelManager]]):
664+
Union[ModelT, _ModelHandlerManager]]):
660665
def __init__(
661666
self,
662667
unkeyed: Union[ModelHandler[ExampleT, PredictionT, ModelT],
@@ -809,15 +814,15 @@ def __init__(
809814
'to exactly one model handler.')
810815
self._key_to_id_map[key] = keys[0]
811816

812-
def load_model(self) -> Union[ModelT, _ModelManager]:
817+
def load_model(self) -> Union[ModelT, _ModelHandlerManager]:
813818
if self._single_model:
814819
return self._unkeyed.load_model()
815-
return _ModelManager(self._id_to_mh_map)
820+
return _ModelHandlerManager(self._id_to_mh_map)
816821

817822
def run_inference(
818823
self,
819824
batch: Sequence[tuple[KeyT, ExampleT]],
820-
model: Union[ModelT, _ModelManager],
825+
model: Union[ModelT, _ModelHandlerManager],
821826
inference_args: Optional[dict[str, Any]] = None
822827
) -> Iterable[tuple[KeyT, PredictionT]]:
823828
if self._single_model:
@@ -919,7 +924,7 @@ def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
919924

920925
def update_model_paths(
921926
self,
922-
model: Union[ModelT, _ModelManager],
927+
model: Union[ModelT, _ModelHandlerManager],
923928
model_paths: list[KeyModelPathMapping[KeyT]] = None):
924929
# When there are many models, the keyed model handler is responsible for
925930
# reorganizing the model handlers into cohorts and telling the model
@@ -1338,6 +1343,8 @@ def __init__(
13381343
model_metadata_pcoll: beam.PCollection[ModelMetadata] = None,
13391344
watch_model_pattern: Optional[str] = None,
13401345
model_identifier: Optional[str] = None,
1346+
use_model_manager: bool = False,
1347+
model_manager_args: Optional[dict[str, Any]] = None,
13411348
**kwargs):
13421349
"""
13431350
A transform that takes a PCollection of examples (or features) for use
@@ -1378,6 +1385,8 @@ def __init__(
13781385
self._exception_handling_timeout = None
13791386
self._timeout = None
13801387
self._watch_model_pattern = watch_model_pattern
1388+
self._use_model_manager = use_model_manager
1389+
self._model_manager_args = model_manager_args
13811390
self._kwargs = kwargs
13821391
# Generate a random tag to use for shared.py and multi_process_shared.py to
13831392
# allow us to effectively disambiguate in multi-model settings. Only use
@@ -1490,7 +1499,9 @@ def expand(
14901499
self._clock,
14911500
self._metrics_namespace,
14921501
load_model_at_runtime,
1493-
self._model_tag),
1502+
self._model_tag,
1503+
self._use_model_manager,
1504+
self._model_manager_args),
14941505
self._inference_args,
14951506
beam.pvalue.AsSingleton(
14961507
self._model_metadata_pcoll,
@@ -1803,31 +1814,75 @@ def load_model_status(
18031814
return shared.Shared().acquire(lambda: _ModelStatus(False), tag=tag)
18041815

18051816

1817+
class _ProxyLoader:
1818+
"""
1819+
A helper callable to wrap the loader for MultiProcessShared.
1820+
"""
1821+
def __init__(self, loader_func, model_tag):
1822+
self.loader_func = loader_func
1823+
self.model_tag = model_tag
1824+
1825+
def __call__(self):
1826+
# Generate a unique tag for the model being loaded so that
1827+
# we will have unique instances of the model in multi_process_shared
1828+
# space instead of reusing the same instance over. The instance will
1829+
# be initialized and left running as a separate process, which then
1830+
# can be grabbed again using the unique tag if needed during inference.
1831+
unique_tag = self.model_tag + '_' + uuid.uuid4().hex
1832+
# Ensure that each model loaded in a different process for parallelism
1833+
multi_process_shared.MultiProcessShared(
1834+
self.loader_func, tag=unique_tag, always_proxy=True,
1835+
spawn_process=True).acquire()
1836+
# Only return the tag to avoid pickling issues with the model itself.
1837+
return unique_tag
1838+
1839+
18061840
class _SharedModelWrapper():
18071841
"""A router class to map incoming calls to the correct model.
18081842
18091843
This allows us to round robin calls to models sitting in different
18101844
processes so that we can more efficiently use resources (e.g. GPUs).
18111845
"""
1812-
def __init__(self, models: list[Any], model_tag: str):
1846+
def __init__(
1847+
self,
1848+
models: Union[list[Any], ModelManager],
1849+
model_tag: str,
1850+
loader_func: Optional[Callable[[], Any]] = None):
18131851
self.models = models
1814-
if len(models) > 1:
1852+
self.use_model_manager = not isinstance(models, list)
1853+
self.model_tag = model_tag
1854+
self.loader_func = loader_func
1855+
if not self.use_model_manager and len(models) > 1:
18151856
self.model_router = multi_process_shared.MultiProcessShared(
18161857
lambda: _ModelRoutingStrategy(),
18171858
tag=f'{model_tag}_counter',
18181859
always_proxy=True).acquire()
18191860

18201861
def next_model(self):
1862+
if self.use_model_manager:
1863+
loader_wrapper = _ProxyLoader(self.loader_func, self.model_tag)
1864+
return self.models.acquire_model(self.model_tag, loader_wrapper)
1865+
18211866
if len(self.models) == 1:
18221867
# Short circuit if there's no routing strategy needed in order to
18231868
# avoid the cross-process call
18241869
return self.models[0]
18251870

18261871
return self.models[self.model_router.next_model_index(len(self.models))]
18271872

1873+
def release_model(self, model_tag: str, model: Any):
1874+
if self.use_model_manager:
1875+
self.models.release_model(model_tag, model)
1876+
18281877
def all_models(self):
1878+
if self.use_model_manager:
1879+
return self.models.all_models()[self.model_tag]
18291880
return self.models
18301881

1882+
def force_reset(self):
1883+
if self.use_model_manager:
1884+
self.models.force_reset()
1885+
18311886

18321887
class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]):
18331888
def __init__(
@@ -1836,7 +1891,9 @@ def __init__(
18361891
clock,
18371892
metrics_namespace,
18381893
load_model_at_runtime: bool = False,
1839-
model_tag: str = "RunInference"):
1894+
model_tag: str = "RunInference",
1895+
use_model_manager: bool = False,
1896+
model_manager_args: Optional[dict[str, Any]] = None):
18401897
"""A DoFn implementation generic to frameworks.
18411898
18421899
Args:
@@ -1860,6 +1917,8 @@ def __init__(
18601917
# _cur_tag is the tag of the actually loaded model
18611918
self._model_tag = model_tag
18621919
self._cur_tag = model_tag
1920+
self.use_model_manager = use_model_manager
1921+
self._model_manager_args = model_manager_args or {}
18631922

18641923
def _load_model(
18651924
self,
@@ -1894,7 +1953,15 @@ def load():
18941953
model_tag = side_input_model_path
18951954
# Ensure the tag we're loading is valid, if not replace it with a valid tag
18961955
self._cur_tag = self._model_metadata.get_valid_tag(model_tag)
1897-
if self._model_handler.share_model_across_processes():
1956+
if self.use_model_manager:
1957+
logging.info("Using Model Manager to manage models automatically.")
1958+
model_manager = multi_process_shared.MultiProcessShared(
1959+
lambda: ModelManager(**self._model_manager_args),
1960+
tag='model_manager',
1961+
always_proxy=True).acquire()
1962+
model_wrapper = _SharedModelWrapper(
1963+
model_manager, self._cur_tag, self._model_handler.load_model)
1964+
elif self._model_handler.share_model_across_processes():
18981965
models = []
18991966
for copy_tag in _get_tags_for_copies(self._cur_tag,
19001967
self._model_handler.model_copies()):
@@ -1949,8 +2016,15 @@ def _run_inference(self, batch, inference_args):
19492016
start_time = _to_microseconds(self._clock.time_ns())
19502017
try:
19512018
model = self._model.next_model()
2019+
if isinstance(model, str):
2020+
# ModelManager with MultiProcessShared returns the model tag
2021+
unique_tag = model
2022+
model = multi_process_shared.MultiProcessShared(
2023+
lambda: None, tag=model, always_proxy=True).acquire()
19522024
result_generator = self._model_handler.run_inference(
19532025
batch, model, inference_args)
2026+
if self.use_model_manager:
2027+
self._model.release_model(self._model_tag, unique_tag)
19542028
except BaseException as e:
19552029
if self._metrics_collector:
19562030
self._metrics_collector.failed_batches_counter.inc()

0 commit comments

Comments
 (0)