From 541fd95a55890ad51b783d3ffee54c811d5e1e6f Mon Sep 17 00:00:00 2001 From: Elia LIU Date: Fri, 6 Feb 2026 14:15:35 +1100 Subject: [PATCH 1/2] feat(ml): add stateless bundle-local size-aware batching and benchmark --- .../benchmarks/sort_and_batch_benchmark.py | 650 ++++++++++++++++++ sdks/python/apache_beam/transforms/util.py | 285 ++++++++ .../apache_beam/transforms/util_test.py | 217 ++++++ 3 files changed, 1152 insertions(+) create mode 100644 sdks/python/apache_beam/testing/benchmarks/sort_and_batch_benchmark.py diff --git a/sdks/python/apache_beam/testing/benchmarks/sort_and_batch_benchmark.py b/sdks/python/apache_beam/testing/benchmarks/sort_and_batch_benchmark.py new file mode 100644 index 000000000000..e1caa73e0a5d --- /dev/null +++ b/sdks/python/apache_beam/testing/benchmarks/sort_and_batch_benchmark.py @@ -0,0 +1,650 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Benchmark: BatchElements vs SortAndBatchElements (weight-based splitting). + +Compares two batching strategies for variable-length inference workloads: + +- Baseline (BatchElements): fixed-count chunking, ignores element sizes. +- Stateless (SortAndBatchElements): within each bundle, sorts elements + by size, then splits batches using max_batch_weight so that each batch + has a bounded total weight. The improvement comes from *changing batch + boundaries* (weight-based splitting), NOT from sorting alone -- sorting + within fixed boundaries yields 0% gain (verified by strict-control). + +Padding ratio:: + + padding_ratio = sum(max_len_in_batch * batch_size) / sum(actual_lengths) + Lower is better. 1.0 = no padding waste. + +Methodology: + +- N=20 independent trials per condition (3 warmup trials excluded). +- Same input corpus (seed=42) for A/B comparison. +- Percentile method: linear interpolation between adjacent ranks + (equivalent to numpy.percentile with method='linear'). + For N=20 trials: P50 interpolates ranks 10-11 (0-indexed 9-10), + P95 interpolates ranks 19-20 (0-indexed 18-19), + P99 interpolates near rank 20 (0-indexed 18.81). +- Reports median [IQR] and P95 for each metric. +- Inference model: latency = batch_size * (max_seq_len / 50)^1.5 ms + (simulates transformer-like scaling). + +Run:: + + python3 -m apache_beam.testing.benchmarks.sort_and_batch_benchmark +""" + +import math +import random +import statistics +import time +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple + +# --------------------------------------------------------------------------- +# Data generators +# --------------------------------------------------------------------------- + + +def generate_highly_skewed_data( + num_elements: int, + min_length: int = 1, + max_length: int = 500, + seed: int = 42) -> List[str]: + """Pareto(alpha=1.2) -- most short, few very long.""" + random.seed(seed) + data = [] + for _ in range(num_elements): + length = int(random.paretovariate(1.2) * min_length) + length = min(max(length, min_length), max_length) + data.append('x' * length) + return data + + +def generate_lognormal_data( + num_elements: int, + mean_length: int = 50, + std_factor: float = 0.8, + min_length: int = 1, + max_length: int = 500, + seed: int = 42) -> List[str]: + """Log-normal -- moderate skew, typical NLP.""" + random.seed(seed) + mu = math.log(mean_length) + sigma = std_factor + data = [] + for _ in range(num_elements): + length = int(random.lognormvariate(mu, sigma)) + length = min(max(length, min_length), max_length) + data.append('x' * length) + return data + + +def generate_bimodal_data( + num_elements: int, + mode1_mean: int = 20, + mode2_mean: int = 200, + mode1_ratio: float = 0.7, + min_length: int = 1, + max_length: int = 500, + seed: int = 42) -> List[str]: + """Bimodal -- two distinct length groups.""" + random.seed(seed) + data = [] + for _ in range(num_elements): + if random.random() < mode1_ratio: + length = int(random.gauss(mode1_mean, mode1_mean * 0.3)) + else: + length = int(random.gauss(mode2_mean, mode2_mean * 0.3)) + length = min(max(length, min_length), max_length) + data.append('x' * length) + return data + + +def generate_low_variance_data( + num_elements: int, + mean_length: int = 100, + cv: float = 0.1, + min_length: int = 1, + max_length: int = 500, + seed: int = 42) -> List[str]: + """Low-variance control (CV=10%).""" + random.seed(seed) + std = mean_length * cv + data = [] + for _ in range(num_elements): + length = int(random.gauss(mean_length, std)) + length = min(max(length, min_length), max_length) + data.append('x' * length) + return data + + +# --------------------------------------------------------------------------- +# Batching algorithms +# --------------------------------------------------------------------------- + + +def simulate_batch_elements(data: List[str], + max_batch_size: int) -> List[List[str]]: + """Baseline: simple count-based chunking (BatchElements behaviour).""" + batches = [] + current_batch = [] + for element in data: + current_batch.append(element) + if len(current_batch) >= max_batch_size: + batches.append(current_batch) + current_batch = [] + if current_batch: + batches.append(current_batch) + return batches + + +def simulate_sort_and_batch_elements( + data: List[str], + max_batch_size: int, + max_batch_weight: int, + element_size_fn: Optional[Callable[[Any], int]] = None, + bundle_size: Optional[int] = None) -> List[List[str]]: + """Core mechanism: sort by size + weight-based batch splitting.""" + if element_size_fn is None: + element_size_fn = len + + # Split into bundles if specified (realistic Beam behavior) + if bundle_size is not None and bundle_size > 0: + bundles = [ + data[i:i + bundle_size] for i in range(0, len(data), bundle_size) + ] + else: + bundles = [data] + + all_batches = [] + + for bundle in bundles: + # Sort by element size (ascending) + sorted_bundle = sorted(bundle, key=element_size_fn) + + current_batch = [] + current_weight = 0 + + for element in sorted_bundle: + element_weight = element_size_fn(element) + + # Check if adding this element would exceed limits + would_exceed_count = len(current_batch) >= max_batch_size + would_exceed_weight = ( + current_weight + element_weight > max_batch_weight and current_batch) + + if would_exceed_count or would_exceed_weight: + all_batches.append(current_batch) + current_batch = [] + current_weight = 0 + + current_batch.append(element) + current_weight += element_weight + + if current_batch: + all_batches.append(current_batch) + + return all_batches + + +# --------------------------------------------------------------------------- +# Simulated inference +# --------------------------------------------------------------------------- + + +def simulate_inference_latency( + batch: List[str], base_latency_ms: float = 1.0) -> float: + """Simulate transformer inference: O(batch_size * seq_len^1.5).""" + if not batch: + return 0.0 + batch_size = len(batch) + max_len = max(len(s) for s in batch) + return base_latency_ms * batch_size * (max_len / 50)**1.5 + + +# --------------------------------------------------------------------------- +# Stats helpers +# --------------------------------------------------------------------------- + + +def percentile(data: Sequence[float], p: float) -> float: + """Percentile via linear interpolation between adjacent ranks. + + Equivalent to numpy.percentile(data, p, method='linear'). + For N=20: P50 interpolates ranks 10-11, P95 ranks 19-20, + P99 near rank 20 (fractional index 18.81). + """ + if not data: + return 0.0 + s = sorted(data) + k = (len(s) - 1) * p / 100 + f = int(k) + c = min(f + 1, len(s) - 1) + return s[f] + (k - f) * (s[c] - s[f]) + + +def compute_padding_stats(batches: List[List[str]]) -> Dict[str, Any]: + """Padding-efficiency statistics for a list of batches.""" + total_actual = 0 + total_padded = 0 + batch_sizes = [] + max_lengths = [] + + for batch in batches: + if not batch: + continue + lengths = [len(s) for s in batch] + mx = max(lengths) + total_actual += sum(lengths) + total_padded += mx * len(batch) + batch_sizes.append(len(batch)) + max_lengths.append(mx) + + efficiency = total_actual / total_padded if total_padded else 0.0 + padding_ratio = total_padded / total_actual if total_actual else float('inf') + + return { + 'efficiency': efficiency, + 'padding_ratio': padding_ratio, + 'num_batches': len(batches), + 'avg_batch_size': statistics.mean(batch_sizes) if batch_sizes else 0, + 'total_actual_length': total_actual, + 'total_padded_length': total_padded, + 'padding_overhead': total_padded - total_actual, + 'batch_size_p50': percentile(batch_sizes, 50) if batch_sizes else 0, + 'batch_size_p95': percentile(batch_sizes, 95) if batch_sizes else 0, + 'batch_size_max': max(batch_sizes) if batch_sizes else 0, + 'max_len_p50': percentile(max_lengths, 50) if max_lengths else 0, + 'max_len_p95': percentile(max_lengths, 95) if max_lengths else 0, + } + + +# --------------------------------------------------------------------------- +# Invariant validation +# --------------------------------------------------------------------------- + + +def validate_invariants( + data: List[str], + baseline_batches: List[List[str]], + stateless_batches: List[List[str]], + config: Dict[str, Any]) -> Dict[str, Any]: + """Validate element/token counts and batch-size equality.""" + n = len(data) + b_n = sum(len(b) for b in baseline_batches) + s_n = sum(len(b) for b in stateless_batches) + tok = sum(len(s) for s in data) + b_tok = sum(sum(len(s) for s in b) for b in baseline_batches) + s_tok = sum(sum(len(s) for s in b) for b in stateless_batches) + + return { + 'input_elements': n, + 'baseline_elements': b_n, + 'stateless_elements': s_n, + 'elements_match': n == b_n == s_n, + 'input_tokens': tok, + 'baseline_tokens': b_tok, + 'stateless_tokens': s_tok, + 'tokens_match': tok == b_tok == s_tok, + 'baseline_num_batches': len(baseline_batches), + 'stateless_num_batches': len(stateless_batches), + } + + +# --------------------------------------------------------------------------- +# Performance benchmark (N=20 trials) +# --------------------------------------------------------------------------- + + +def run_performance_benchmark( + data: List[str], + max_batch_size: int, + max_batch_weight: int, + bundle_size: int = 500, + num_trials: int = 20, + warmup_trials: int = 3) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Run N=20 trials for baseline and stateless.""" + total_tokens = sum(len(s) for s in data) + + baseline_trials = [] + stateless_trials = [] + + for trial_idx in range(warmup_trials + num_trials): + is_warmup = trial_idx < warmup_trials + + # --- Baseline --- + start = time.perf_counter() + b_batches = simulate_batch_elements(data, max_batch_size) + batch_ms = (time.perf_counter() - start) * 1000 + b_inf = [simulate_inference_latency(b) for b in b_batches] + b_e2e = batch_ms + sum(b_inf) + if not is_warmup: + baseline_trials.append({ + 'overhead_ms': batch_ms, + 'inference_ms': sum(b_inf), + 'e2e_ms': b_e2e, + 'batch_latencies': b_inf, + 'num_batches': len(b_batches), + }) + + # --- Stateless (SortAndBatchElements) --- + start = time.perf_counter() + s_batches = simulate_sort_and_batch_elements( + data, max_batch_size, max_batch_weight, bundle_size=bundle_size) + sort_ms = (time.perf_counter() - start) * 1000 + s_inf = [simulate_inference_latency(b) for b in s_batches] + s_e2e = sort_ms + sum(s_inf) + if not is_warmup: + stateless_trials.append({ + 'overhead_ms': sort_ms, + 'inference_ms': sum(s_inf), + 'e2e_ms': s_e2e, + 'batch_latencies': s_inf, + 'num_batches': len(s_batches), + }) + + def _stats(trials): + e2e = [t['e2e_ms'] for t in trials] + tput = [total_tokens / (t['e2e_ms'] / 1000) for t in trials] + overhead = [t['overhead_ms'] for t in trials] + all_lat = [l for t in trials for l in t['batch_latencies']] + return { + 'e2e_median': percentile(e2e, 50), + 'e2e_p25': percentile(e2e, 25), + 'e2e_p75': percentile(e2e, 75), + 'e2e_p95': percentile(e2e, 95), + 'tput_median': percentile(tput, 50), + 'tput_p25': percentile(tput, 25), + 'tput_p75': percentile(tput, 75), + 'tput_p95': percentile(tput, 95), + 'overhead_median': percentile(overhead, 50), + 'overhead_p25': percentile(overhead, 25), + 'overhead_p75': percentile(overhead, 75), + 'overhead_p95': percentile(overhead, 95), + 'batch_lat_p50': percentile(all_lat, 50), + 'batch_lat_p95': percentile(all_lat, 95), + 'batch_lat_p99': percentile(all_lat, 99), + 'inf_p95': percentile(all_lat, 95), + 'num_trials': len(trials), + 'num_batches': trials[0]['num_batches'] if trials else 0, + } + + return _stats(baseline_trials), _stats(stateless_trials) + + +# --------------------------------------------------------------------------- +# Single benchmark run +# --------------------------------------------------------------------------- + + +def run_benchmark( + num_elements: int = 10000, + min_length: int = 1, + max_length: int = 500, + max_batch_size: int = 32, + max_batch_weight: int = 2000, + bundle_size: int = 500, + distribution: str = 'pareto', + seed: int = 42) -> Dict[str, Any]: + """Run baseline vs stateless comparison.""" + generators = { + 'pareto': lambda: generate_highly_skewed_data( + num_elements, min_length, max_length, seed), + 'lognormal': lambda: generate_lognormal_data( + num_elements, 50, 0.8, min_length, max_length, seed), + 'bimodal': lambda: generate_bimodal_data( + num_elements, 20, 200, 0.7, min_length, max_length, seed), + 'low_variance': lambda: generate_low_variance_data( + num_elements, 100, 0.1, min_length, max_length, seed), + } + if distribution not in generators: + raise ValueError(f"Unknown distribution: {distribution}") + + data = generators[distribution]() + lengths = [len(s) for s in data] + + baseline_batches = simulate_batch_elements(data, max_batch_size) + stateless_batches = simulate_sort_and_batch_elements( + data, max_batch_size, max_batch_weight, bundle_size=bundle_size) + + baseline_pad = compute_padding_stats(baseline_batches) + stateless_pad = compute_padding_stats(stateless_batches) + + baseline_perf, stateless_perf = run_performance_benchmark( + data, max_batch_size, max_batch_weight, bundle_size) + baseline_pad.update(baseline_perf) + stateless_pad.update(stateless_perf) + + validation = validate_invariants( + data, + baseline_batches, + stateless_batches, { + 'max_batch_size': max_batch_size, + 'max_batch_weight': max_batch_weight + }) + + return { + 'config': { + 'num_elements': num_elements, + 'max_batch_size': max_batch_size, + 'max_batch_weight': max_batch_weight, + 'bundle_size': bundle_size, + 'distribution': distribution, + }, + 'data_stats': { + 'min': min(lengths), + 'max': max(lengths), + 'mean': statistics.mean(lengths), + 'median': statistics.median(lengths), + 'std': statistics.stdev(lengths), + }, + 'baseline': baseline_pad, + 'stateless': stateless_pad, + 'validation': validation, + } + + +# --------------------------------------------------------------------------- +# Printing +# --------------------------------------------------------------------------- + + +def _fmt_iqr(median, p25, p75, unit=''): + return f"{median:.1f} [{p25:.1f}-{p75:.1f}]{unit}" + + +def print_results(results: Dict[str, Any]) -> None: + cfg = results['config'] + ds = results['data_stats'] + bl = results['baseline'] + st = results['stateless'] + val = results['validation'] + + print("=" * 80) + print( + f"Distribution: {cfg['distribution']} | " + f"N={cfg['num_elements']} | " + f"max_batch_size={cfg['max_batch_size']} | " + f"max_batch_weight={cfg['max_batch_weight']}") + print( + f"Input lengths: min={ds['min']} max={ds['max']} " + f"mean={ds['mean']:.1f} median={ds['median']:.0f} std={ds['std']:.1f}") + print("-" * 80) + + def _arm(label, s): + print(f"\n {label}:") + print(f" Num batches: {s['num_batches']}") + print(f" Padding ratio: {s['padding_ratio']:.2f}x") + print(" ") + print(" Throughput (Ktok/s):") + med = s['tput_median'] / 1000 + p25 = s['tput_p25'] / 1000 + p75 = s['tput_p75'] / 1000 + print(f" Median [IQR]: {med:.1f}" + f" [{p25:.1f}-{p75:.1f}]") + print(f" P95: {s['tput_p95']/1000:.1f}") + print(" ") + print(" E2E latency (ms):") + print( + f" Median [IQR]: {s['e2e_median']:.1f}" + f" [{s['e2e_p25']:.1f}-{s['e2e_p75']:.1f}]") + print(f" P95: {s['e2e_p95']:.1f}") + print(" ") + print(" Overhead (ms):") + print( + f" Median [IQR]:" + f" {s['overhead_median']:.2f}" + f" [{s['overhead_p25']:.2f}" + f"-{s['overhead_p75']:.2f}]") + print(f" P95: {s['overhead_p95']:.2f}") + print(" ") + print(" Batch latency (ms):") + print(f" P50: {s['batch_lat_p50']:.1f}") + print(f" P95: {s['batch_lat_p95']:.1f}") + print(f" P99: {s['batch_lat_p99']:.1f}") + + _arm("Baseline (BatchElements)", bl) + _arm("Stateless (SortAndBatchElements w/ weight-based splitting)", st) + + # Delta — explicit arrows so direction is unambiguous + # ↓ = value decreased (good for latency/padding) + # ↑ = value increased (good for throughput) + def _delta_lower(base, new): + """For metrics where lower is better (latency, padding).""" + if base == 0: + return 'N/A' + pct = (base - new) / base * 100 + arrow = '\u2193' if pct > 0 else '\u2191' + return f"{arrow}{abs(pct):.1f}%" + + def _delta_higher(base, new): + """For metrics where higher is better (throughput).""" + if base == 0: + return 'N/A' + pct = (new - base) / base * 100 + arrow = '\u2191' if pct > 0 else '\u2193' + return f"{arrow}{abs(pct):.1f}%" + + print(f"\n {'_' * 76}") + print(" DELTA (Baseline -> Stateless):") + + def _line(label, bv, sv, delta_fn, fmt='.1f', unit=''): + d = delta_fn(bv, sv) + print(f" {label}: {bv:{fmt}}{unit}" + f" -> {sv:{fmt}}{unit} ({d})") + + bl_tmed = bl['tput_median'] / 1000 + st_tmed = st['tput_median'] / 1000 + bl_tp95 = bl['tput_p95'] / 1000 + st_tp95 = st['tput_p95'] / 1000 + + _line( + 'Padding ratio ', + bl['padding_ratio'], + st['padding_ratio'], + _delta_lower, + fmt='.2f', + unit='x') + _line('Throughput median', bl_tmed, st_tmed, _delta_higher, unit=' Ktok/s') + _line('Throughput p95 ', bl_tp95, st_tp95, _delta_higher, unit=' Ktok/s') + _line( + 'E2E latency med ', + bl['e2e_median'], + st['e2e_median'], + _delta_lower, + unit=' ms') + _line( + 'E2E latency p95 ', + bl['e2e_p95'], + st['e2e_p95'], + _delta_lower, + unit=' ms') + _line( + 'Batch lat p95 ', + bl['batch_lat_p95'], + st['batch_lat_p95'], + _delta_lower, + unit=' ms') + _line( + 'Batch lat p99 ', + bl['batch_lat_p99'], + st['batch_lat_p99'], + _delta_lower, + unit=' ms') + + # Invariants + e_ok = "Y" if val['elements_match'] else "X" + t_ok = "Y" if val['tokens_match'] else "X" + b_nb = val['baseline_num_batches'] + s_nb = val['stateless_num_batches'] + print( + f"\n Invariants: elements {e_ok} tokens {t_ok}" + f" (baseline {b_nb} -> stateless {s_nb}" + f" batches)") + print("=" * 80) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + print("=" * 80) + print("BASELINE (count-based) vs STATELESS (weight-based boundary splitting)") + print("=" * 80) + print() + print("Experiment design:") + print(" A = Baseline : BatchElements with max_batch_size=32 (count-based)") + print(" B = Stateless : SortAndBatchElements with max_batch_weight=2000") + print( + " (sort by size within bundle -> weight-based split)") + print() + print("Why Stateless wins:") + print(" Weight-based splitting changes batch BOUNDARIES so each batch has") + print( + " similar-length elements -> less padding. Sorting alone within fixed") + print(" boundaries yields 0% gain (verified by strict-control experiment).") + print() + print("Methodology:") + print(" - N=20 trials, 3 warmup excluded") + print(" - Percentiles: linear interpolation (= numpy default)") + print(" - Same seed=42 for both arms") + print(" - Inference model: latency = batch_size * (max_seq_len/50)^1.5 ms") + print() + + dist = 'pareto' + print(f"\nRunning: {dist}...") + r = run_benchmark( + num_elements=10000, + max_batch_size=32, + max_batch_weight=2000, + bundle_size=500, + distribution=dist, + seed=42) + print_results(r) + + +if __name__ == '__main__': + main() diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index dd14bd8f57bd..ab832dce0207 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -105,6 +105,7 @@ 'RemoveDuplicates', 'Reshuffle', 'Secret', + 'SortAndBatchElements', 'ToString', 'Take', 'Tee', @@ -1322,6 +1323,290 @@ def expand(self, pcoll): self._batch_size_estimator, self._element_size_fn)) +def _default_element_size_fn(element: Any) -> int: + """Default element size function that tries len(), falls back to 1. + + This function attempts to compute the size of an element using len(). + If the element does not support len() (e.g., integers), it falls back to 1. + + Args: + element: The element to compute the size of. + + Returns: + The size of the element, or 1 if len() is not supported. + """ + try: + return len(element) + except TypeError: + return 1 + + +class _SortAndBatchElementsDoFn(DoFn): + """DoFn that buffers, sorts by element size, and batches elements. + + This DoFn is used internally by ``SortAndBatchElements`` for + PCollections with the default (global) window. It accumulates all + elements in the current bundle, sorts them by size in ascending order, + and emits optimally-sized batches on ``finish_bundle``. + + Args: + min_batch_size: The minimum number of elements per batch. Must be >= 1. + max_batch_size: The maximum number of elements per batch. + Must be >= ``min_batch_size``. + max_batch_weight: The maximum total weight of elements in a batch, + where weight is computed by ``element_size_fn``. Must be >= 1. + element_size_fn: A callable mapping an element to its integer + size/weight. + """ + def __init__( + self, + min_batch_size: int, + max_batch_size: int, + max_batch_weight: int, + element_size_fn: Callable[[Any], int]): + if min_batch_size < 1: + raise ValueError(f'min_batch_size must be >= 1, got {min_batch_size}') + if max_batch_size < min_batch_size: + raise ValueError( + f'max_batch_size ({max_batch_size}) must be >= ' + f'min_batch_size ({min_batch_size})') + if max_batch_weight < 1: + raise ValueError(f'max_batch_weight must be >= 1, got {max_batch_weight}') + self._min_batch_size = min_batch_size + self._max_batch_size = max_batch_size + self._max_batch_weight = max_batch_weight + self._element_size_fn = element_size_fn + self._buffer = [] + + def start_bundle(self): + self._buffer = [] + + def process(self, element): + self._buffer.append(element) + + def finish_bundle(self): + if not self._buffer: + return + + # Sort elements by size (ascending) for optimal batching + # Elements of similar sizes will be grouped together + sorted_elements = sorted(self._buffer, key=self._element_size_fn) + + batch = [] + batch_weight = 0 + + for element in sorted_elements: + element_size = self._element_size_fn(element) + + # Check if adding this element would exceed limits + would_exceed_count = len(batch) >= self._max_batch_size + would_exceed_weight = ( + batch_weight + element_size >= self._max_batch_weight and batch) + + if would_exceed_count or would_exceed_weight: + # Emit current batch + yield window.GlobalWindows.windowed_value_at_end_of_window(batch) + batch = [] + batch_weight = 0 + + batch.append(element) + batch_weight += element_size + + # Emit remaining elements + if batch: + yield window.GlobalWindows.windowed_value_at_end_of_window(batch) + + self._buffer = None + + +class _WindowAwareSortAndBatchElementsDoFn(DoFn): + """DoFn that buffers, sorts by element size, and batches elements per window. + + This DoFn is used internally by ``SortAndBatchElements`` for + PCollections with non-default (e.g. fixed, sliding, or session) windows. + Elements are buffered per window and each window is flushed independently. + To prevent unbounded memory growth, when the number of live windows + exceeds ``_MAX_LIVE_WINDOWS`` the largest window buffer is flushed early. + + Args: + min_batch_size: The minimum number of elements per batch. Must be >= 1. + max_batch_size: The maximum number of elements per batch. + Must be >= ``min_batch_size``. + max_batch_weight: The maximum total weight of elements in a batch, + where weight is computed by ``element_size_fn``. Must be >= 1. + element_size_fn: A callable mapping an element to its integer + size/weight. + """ + + _MAX_LIVE_WINDOWS = 10 + + def __init__( + self, + min_batch_size: int, + max_batch_size: int, + max_batch_weight: int, + element_size_fn: Callable[[Any], int]): + if min_batch_size < 1: + raise ValueError(f'min_batch_size must be >= 1, got {min_batch_size}') + if max_batch_size < min_batch_size: + raise ValueError( + f'max_batch_size ({max_batch_size}) must be >= ' + f'min_batch_size ({min_batch_size})') + if max_batch_weight < 1: + raise ValueError(f'max_batch_weight must be >= 1, got {max_batch_weight}') + self._min_batch_size = min_batch_size + self._max_batch_size = max_batch_size + self._max_batch_weight = max_batch_weight + self._element_size_fn = element_size_fn + self._buffers = collections.defaultdict(list) + + def start_bundle(self): + self._buffers = collections.defaultdict(list) + + def process(self, element, window=DoFn.WindowParam): + self._buffers[window].append(element) + + # If we have too many live windows, flush the largest one + if len(self._buffers) > self._MAX_LIVE_WINDOWS: + largest_window = max( + self._buffers.keys(), key=lambda w: len(self._buffers[w])) + yield from self._flush_window(largest_window) + + def _flush_window(self, win): + """Flush all elements for a given window.""" + buffer = self._buffers.pop(win, []) + if not buffer: + return + + # Sort elements by size (ascending) + sorted_elements = sorted(buffer, key=self._element_size_fn) + + batch = [] + batch_weight = 0 + + for element in sorted_elements: + element_size = self._element_size_fn(element) + + would_exceed_count = len(batch) >= self._max_batch_size + would_exceed_weight = ( + batch_weight + element_size >= self._max_batch_weight and batch) + + if would_exceed_count or would_exceed_weight: + yield windowed_value.WindowedValue(batch, win.max_timestamp(), (win, )) + batch = [] + batch_weight = 0 + + batch.append(element) + batch_weight += element_size + + if batch: + yield windowed_value.WindowedValue(batch, win.max_timestamp(), (win, )) + + def finish_bundle(self): + for win in list(self._buffers.keys()): + yield from self._flush_window(win) + self._buffers = None + + +@typehints.with_input_types(T) +@typehints.with_output_types(list[T]) +class SortAndBatchElements(PTransform): + """A Transform that sorts elements by size before batching. + + This transform is designed to optimize batch processing by grouping elements + of similar sizes together. This is particularly useful for ML inference + workloads where input sequences of varying lengths need to be padded to the + maximum length in the batch - by sorting elements by size before batching, + padding overhead is minimized. + + The transform consumes a PCollection of element type T and produces a + PCollection of element type list[T], where elements within each batch are + sorted by their size (as determined by element_size_fn). + + Elements are batched per-window and batches emitted in the window + corresponding to its contents. Each batch is emitted with a timestamp at + the end of their window. + + Unlike BatchElements which emits batches as soon as size limits are reached, + SortAndBatchElements buffers all elements in a bundle, sorts them by size, + and then creates optimally-sized batches. This trade-off of increased memory + usage for better batch homogeneity can significantly reduce padding overhead. + + Args: + min_batch_size: The minimum number of elements in a batch. Must be >= 1. + max_batch_size: The maximum number of elements in a batch. + Must be >= min_batch_size. + max_batch_weight: The maximum total weight of elements in a batch, + where weight is computed by element_size_fn. Must be >= 1. + element_size_fn: (optional) A function mapping an element to its + size/weight. + If not provided, defaults to trying len(element) and falling back to 1 + if the element doesn't support len(). This default allows sorting to + work for common types like strings, lists, and arrays. + + Example usage:: + + # Batch strings by total character count + strings = ['a', 'bb', 'ccc', 'dddd', 'eeeee'] + batched = strings | SortAndBatchElements( + min_batch_size=1, + max_batch_size=3, + max_batch_weight=10) + # Possible output: [['a', 'bb', 'ccc'], ['dddd', 'eeeee']] + # Elements are sorted by length and batched optimally + + # Batch with custom size function + data = [{'text': 'short'}, {'text': 'medium text'}, + {'text': 'long text here'}] + batched = data | SortAndBatchElements( + min_batch_size=1, + max_batch_size=10, + max_batch_weight=100, + element_size_fn=lambda x: len(x['text'])) + """ + def __init__( + self, + min_batch_size: int, + max_batch_size: int, + max_batch_weight: int, + element_size_fn: Optional[Callable[[Any], int]] = None): + if min_batch_size < 1: + raise ValueError(f'min_batch_size must be >= 1, got {min_batch_size}') + if max_batch_size < min_batch_size: + raise ValueError( + f'max_batch_size ({max_batch_size}) must be >= ' + f'min_batch_size ({min_batch_size})') + if max_batch_weight < 1: + raise ValueError(f'max_batch_weight must be >= 1, got {max_batch_weight}') + if element_size_fn is not None and not callable(element_size_fn): + raise TypeError('element_size_fn must be callable') + + self._min_batch_size = min_batch_size + self._max_batch_size = max_batch_size + self._max_batch_weight = max_batch_weight + + # Smart default: try len(), fallback to 1 when len() is unsupported + self._element_size_fn: Callable[[Any], int] = ( + element_size_fn + if element_size_fn is not None else _default_element_size_fn) + + def expand(self, pcoll): + if pcoll.windowing.is_default(): + return pcoll | ParDo( + _SortAndBatchElementsDoFn( + self._min_batch_size, + self._max_batch_size, + self._max_batch_weight, + self._element_size_fn)) + else: + return pcoll | ParDo( + _WindowAwareSortAndBatchElementsDoFn( + self._min_batch_size, + self._max_batch_size, + self._max_batch_weight, + self._element_size_fn)) + + class _IdentityWindowFn(NonMergingWindowFn): """Windowing function that preserves existing windows. diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index 448ba8a7ad9d..b4471f0415ed 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -1026,6 +1026,223 @@ def test_stateful_grows_to_max_batch(self): assert_that(res, equal_to([1, 1, 2, 4, 8, 16, 32, 50, 50])) +class SortAndBatchElementsTest(unittest.TestCase): + """Tests for SortAndBatchElements transform.""" + def test_elements_are_sorted_by_size(self): + """Test that elements are sorted by size within batches.""" + with TestPipeline() as p: + # Create elements with varying sizes + data = ['aaaaa', 'bb', 'cccc', 'a', 'ddd'] + res = ( + p + | beam.Create(data, reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=5, max_batch_weight=100)) + + def check_sorted(batch): + lengths = [len(s) for s in batch] + assert lengths == sorted(lengths), ( + f'Batch not sorted by size: {lengths}') + return batch + + _ = res | beam.Map(check_sorted) + + def test_batch_respects_max_batch_size(self): + """Test that batches do not exceed max_batch_size.""" + with TestPipeline() as p: + res = ( + p + | beam.Create(['a'] * 10, reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=3, max_batch_weight=100) + | beam.Map(len)) + assert_that(res, equal_to([3, 3, 3, 1])) + + def test_batch_respects_max_batch_weight(self): + """Test that batches do not exceed max_batch_weight.""" + with TestPipeline() as p: + # Each element has size 5, max_batch_weight is 12 + # So we can fit at most 2 elements per batch + data = ['aaaaa', 'bbbbb', 'ccccc', 'ddddd'] + res = ( + p + | beam.Create(data, reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=10, max_batch_weight=12) + | beam.Map(len)) + assert_that(res, equal_to([2, 2])) + + def test_default_element_size_fn_with_strings(self): + """Test default element_size_fn works with strings.""" + with TestPipeline() as p: + data = ['a', 'bbb', 'cc'] + res = ( + p + | beam.Create(data, reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=3, max_batch_weight=100) + | beam.FlatMap(lambda batch: [len(s) for s in batch])) + # Elements should be sorted by length: 'a'(1), 'cc'(2), 'bbb'(3) + assert_that(res, equal_to([1, 2, 3])) + + def test_default_element_size_fn_with_integers(self): + """Test default element_size_fn falls back to 1 for integers.""" + with TestPipeline() as p: + data = [10, 20, 30, 40, 50] + res = ( + p + | beam.Create(data, reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=3, max_batch_weight=100) + | beam.Map(len)) + # With size=1 for all, should batch by max_batch_size + assert_that(res, equal_to([3, 2])) + + def test_custom_element_size_fn(self): + """Test using a custom element_size_fn.""" + with TestPipeline() as p: + data = [{'text': 'a'}, {'text': 'bbb'}, {'text': 'cc'}] + res = ( + p + | beam.Create(data, reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, + max_batch_size=3, + max_batch_weight=100, + element_size_fn=lambda x: len(x['text'])) + | beam.FlatMap(lambda batch: [len(e['text']) for e in batch])) + # Should be sorted by text length + assert_that(res, equal_to([1, 2, 3])) + + def test_empty_input(self): + """Test with empty input produces no output.""" + with TestPipeline() as p: + res = ( + p + | beam.Create([], reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=10, max_batch_weight=100) + | beam.Map(len)) + assert_that(res, equal_to([])) + + def test_single_element(self): + """Test with a single element.""" + with TestPipeline() as p: + res = ( + p + | beam.Create(['hello'], reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=10, max_batch_weight=100)) + assert_that(res, equal_to([['hello']])) + + def test_windowed_batches(self): + """Test that windowed elements are batched per window.""" + with TestPipeline('FnApiRunner') as p: + res = ( + p + | beam.Create(range(1, 8), reshuffle=False) + | beam.Map(lambda t: window.TimestampedValue('a' * t, t)) + | beam.WindowInto(window.FixedWindows(3)) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=10, max_batch_weight=100) + | beam.Map(lambda batch: ''.join(batch))) + # FixedWindows(3) with default offset 0 produces: + # Window [0, 3): elements at t=1,2 with sizes 1,2 + # Window [3, 6): elements at t=3,4,5 with sizes 3,4,5 + # Window [6, 9): elements at t=6,7 with sizes 6,7 + assert_that( + res, + equal_to([ + 'a' * (1 + 2), # Window [0, 3) + 'a' * (3 + 4 + 5), # Window [3, 6) + 'a' * (6 + 7), # Window [6, 9) + ])) + + def test_validation_min_batch_size(self): + """Test that min_batch_size validation raises ValueError.""" + with self.assertRaises(ValueError) as cm: + util.SortAndBatchElements( + min_batch_size=0, max_batch_size=10, max_batch_weight=100) + self.assertIn('min_batch_size must be >= 1', str(cm.exception)) + + def test_validation_max_batch_size(self): + """Test that max_batch_size < min_batch_size raises ValueError.""" + with self.assertRaises(ValueError) as cm: + util.SortAndBatchElements( + min_batch_size=10, max_batch_size=5, max_batch_weight=100) + self.assertIn('max_batch_size', str(cm.exception)) + self.assertIn('min_batch_size', str(cm.exception)) + + def test_validation_max_batch_weight(self): + """Test that max_batch_weight validation raises ValueError.""" + with self.assertRaises(ValueError) as cm: + util.SortAndBatchElements( + min_batch_size=1, max_batch_size=10, max_batch_weight=0) + self.assertIn('max_batch_weight must be >= 1', str(cm.exception)) + + def test_validation_element_size_fn_callable(self): + """Test that a non-callable element_size_fn raises TypeError.""" + with self.assertRaises(TypeError) as cm: + util.SortAndBatchElements( + min_batch_size=1, + max_batch_size=10, + max_batch_weight=100, + element_size_fn=123) + self.assertIn('element_size_fn must be callable', str(cm.exception)) + + def test_batch_timestamps(self): + """Test that batches have correct timestamps.""" + with TestPipeline('FnApiRunner') as p: + res = ( + p + | beam.Create(['a', 'bb', 'ccc'], reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=10, max_batch_weight=100) + | + beam.Map(lambda batch, ts=beam.DoFn.TimestampParam: (len(batch), ts))) + assert_that(res, equal_to([(3, GlobalWindow().max_timestamp())])) + + def test_padding_efficiency_improvement(self): + """Test that sorting improves padding efficiency.""" + # This test verifies the core value proposition of SortAndBatchElements + data = ['a', 'aaaaa', 'aa', 'aaaa', 'aaa'] + + # Compute what BatchElements would produce (preserves input order) + batch_elements_batches = [] + with TestPipeline() as p: + _ = ( + p + | 'Create1' >> beam.Create(data, reshuffle=False) + | util.BatchElements(min_batch_size=5, max_batch_size=5) + | beam.Map(lambda b: batch_elements_batches.append(list(b)))) + + # Compute what SortAndBatchElements produces + sort_batch_batches = [] + with TestPipeline() as p: + _ = ( + p + | 'Create2' >> beam.Create(data, reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=5, max_batch_weight=100) + | beam.Map(lambda b: sort_batch_batches.append(list(b)))) + + # Calculate padding overhead for each approach + # Padding overhead: + # sum(max_len_in_batch * batch_size) - sum(actual_lengths) + def compute_overhead(batches): + overhead = 0 + for batch in batches: + lengths = [len(s) for s in batch] + overhead += max(lengths) * len(batch) - sum(lengths) + return overhead + + batch_overhead = compute_overhead(batch_elements_batches) + sort_overhead = compute_overhead(sort_batch_batches) + + # SortAndBatchElements should have less or equal overhead + self.assertLessEqual(sort_overhead, batch_overhead) + + class IdentityWindowTest(unittest.TestCase): def test_window_preserved(self): expected_timestamp = timestamp.Timestamp(5) From fc66805f2f8282c4240777bddf06f6c9e07016ba Mon Sep 17 00:00:00 2001 From: Elia LIU Date: Sun, 8 Feb 2026 15:28:57 +1100 Subject: [PATCH 2/2] fix(ml): improve test coverage for SortAndBatchElements - Exclude *_benchmark.py from codecov (standalone scripts, not production code) - Remove redundant validation from internal DoFn classes (already validated by PTransform) - Add direct in-process unit tests for DoFn internals to capture coverage (FnApiRunner runs DoFns in separate process, invisible to coverage tools) Co-Authored-By: Claude Opus 4.6 --- .github/codecov.yml | 1 + sdks/python/apache_beam/transforms/util.py | 16 --- .../apache_beam/transforms/util_test.py | 133 ++++++++++++++++++ 3 files changed, 134 insertions(+), 16 deletions(-) diff --git a/.github/codecov.yml b/.github/codecov.yml index 0936f392ccef..5d0eaccf22da 100644 --- a/.github/codecov.yml +++ b/.github/codecov.yml @@ -73,6 +73,7 @@ ignore: - "**/*_microbenchmark.py" - "sdks/go/pkg/beam/register/register.go" - "sdks/python/apache_beam/testing/benchmarks/nexmark/**" + - "**/*_benchmark.py" - "sdks/python/apache_beam/examples/**" # See https://docs.codecov.com/docs/flags for options. diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index ab832dce0207..29d1ed087d4f 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -1364,14 +1364,6 @@ def __init__( max_batch_size: int, max_batch_weight: int, element_size_fn: Callable[[Any], int]): - if min_batch_size < 1: - raise ValueError(f'min_batch_size must be >= 1, got {min_batch_size}') - if max_batch_size < min_batch_size: - raise ValueError( - f'max_batch_size ({max_batch_size}) must be >= ' - f'min_batch_size ({min_batch_size})') - if max_batch_weight < 1: - raise ValueError(f'max_batch_weight must be >= 1, got {max_batch_weight}') self._min_batch_size = min_batch_size self._max_batch_size = max_batch_size self._max_batch_weight = max_batch_weight @@ -1446,14 +1438,6 @@ def __init__( max_batch_size: int, max_batch_weight: int, element_size_fn: Callable[[Any], int]): - if min_batch_size < 1: - raise ValueError(f'min_batch_size must be >= 1, got {min_batch_size}') - if max_batch_size < min_batch_size: - raise ValueError( - f'max_batch_size ({max_batch_size}) must be >= ' - f'min_batch_size ({min_batch_size})') - if max_batch_weight < 1: - raise ValueError(f'max_batch_weight must be >= 1, got {max_batch_weight}') self._min_batch_size = min_batch_size self._max_batch_size = max_batch_size self._max_batch_weight = max_batch_weight diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index b4471f0415ed..a0c7f3e43c5f 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -1243,6 +1243,139 @@ def compute_overhead(batches): self.assertLessEqual(sort_overhead, batch_overhead) +class SortAndBatchElementsDoFnDirectTest(unittest.TestCase): + """Direct unit tests for DoFn internals to ensure coverage. + + Beam's FnApiRunner executes DoFns in a separate SDK harness process, + so coverage tools in the main process cannot capture DoFn code paths. + These tests exercise the DoFn methods directly in-process. + """ + + def test_default_element_size_fn_len(self): + from apache_beam.transforms.util import _default_element_size_fn + self.assertEqual(_default_element_size_fn('abc'), 3) + self.assertEqual(_default_element_size_fn([1, 2]), 2) + + def test_default_element_size_fn_fallback(self): + from apache_beam.transforms.util import _default_element_size_fn + self.assertEqual(_default_element_size_fn(42), 1) + self.assertEqual(_default_element_size_fn(3.14), 1) + + def test_global_dofn_sort_and_batch(self): + """Test _SortAndBatchElementsDoFn directly.""" + from apache_beam.transforms.util import _SortAndBatchElementsDoFn + dofn = _SortAndBatchElementsDoFn( + min_batch_size=1, max_batch_size=3, max_batch_weight=100, + element_size_fn=len) + dofn.start_bundle() + for elem in ['ccccc', 'bb', 'dddd', 'a', 'eee']: + dofn.process(elem) + batches = [wv.value for wv in dofn.finish_bundle()] + # All elements emitted + self.assertEqual(sum(len(b) for b in batches), 5) + # Each batch respects max_batch_size=3 + for batch in batches: + self.assertLessEqual(len(batch), 3) + # Elements within each batch are sorted by size + for batch in batches: + lengths = [len(s) for s in batch] + self.assertEqual(lengths, sorted(lengths)) + + def test_global_dofn_empty_bundle(self): + """Test finish_bundle with no elements returns nothing.""" + from apache_beam.transforms.util import _SortAndBatchElementsDoFn + dofn = _SortAndBatchElementsDoFn( + min_batch_size=1, max_batch_size=10, max_batch_weight=100, + element_size_fn=len) + dofn.start_bundle() + result = list(dofn.finish_bundle() or []) + self.assertEqual(result, []) + + def test_global_dofn_weight_splitting(self): + """Test weight-based splitting in the global DoFn.""" + from apache_beam.transforms.util import _SortAndBatchElementsDoFn + # Each element has size 5, max_batch_weight=12 -> 2 per batch + dofn = _SortAndBatchElementsDoFn( + min_batch_size=1, max_batch_size=100, max_batch_weight=12, + element_size_fn=len) + dofn.start_bundle() + for elem in ['aaaaa', 'bbbbb', 'ccccc', 'ddddd']: + dofn.process(elem) + batches = [wv.value for wv in dofn.finish_bundle()] + self.assertEqual(len(batches), 2) + for batch in batches: + self.assertEqual(len(batch), 2) + + def test_windowed_dofn_flush_and_finish(self): + """Test _WindowAwareSortAndBatchElementsDoFn directly.""" + from apache_beam.transforms.util import ( + _WindowAwareSortAndBatchElementsDoFn) + dofn = _WindowAwareSortAndBatchElementsDoFn( + min_batch_size=1, max_batch_size=10, max_batch_weight=100, + element_size_fn=len) + dofn.start_bundle() + win1 = IntervalWindow(0, 3) + win2 = IntervalWindow(3, 6) + # Manually add to buffers (bypass process() to avoid DoFn.WindowParam) + dofn._buffers[win1].extend(['aa', 'b', 'ccc']) + dofn._buffers[win2].extend(['dddd', 'ee']) + batches = list(dofn.finish_bundle()) + # All elements across both windows emitted + total_elements = sum(len(wv.value) for wv in batches) + self.assertEqual(total_elements, 5) + # Each batch has the correct window + for wv in batches: + self.assertIn(wv.windows[0], (win1, win2)) + + def test_windowed_dofn_overflow_flush(self): + """Test that exceeding _MAX_LIVE_WINDOWS triggers early flush.""" + from apache_beam.transforms.util import ( + _WindowAwareSortAndBatchElementsDoFn) + dofn = _WindowAwareSortAndBatchElementsDoFn( + min_batch_size=1, max_batch_size=10, max_batch_weight=100, + element_size_fn=len) + dofn.start_bundle() + # Fill up to _MAX_LIVE_WINDOWS + for i in range(dofn._MAX_LIVE_WINDOWS): + win = IntervalWindow(i * 10, (i + 1) * 10) + dofn._buffers[win].append('x' * (i + 1)) + self.assertEqual(len(dofn._buffers), dofn._MAX_LIVE_WINDOWS) + # Adding one more window should trigger overflow flush + overflow_win = IntervalWindow(100, 110) + results = list(dofn.process('overflow', overflow_win)) + # One window was flushed, so buffer count stays at _MAX_LIVE_WINDOWS + self.assertLessEqual(len(dofn._buffers), dofn._MAX_LIVE_WINDOWS) + # The flushed window produced output + self.assertGreater(len(results), 0) + + def test_windowed_dofn_flush_empty_window(self): + """Test _flush_window with a non-existent window returns nothing.""" + from apache_beam.transforms.util import ( + _WindowAwareSortAndBatchElementsDoFn) + dofn = _WindowAwareSortAndBatchElementsDoFn( + min_batch_size=1, max_batch_size=10, max_batch_weight=100, + element_size_fn=len) + dofn.start_bundle() + result = list(dofn._flush_window(IntervalWindow(0, 10))) + self.assertEqual(result, []) + + def test_windowed_dofn_weight_splitting(self): + """Test weight-based splitting in the windowed DoFn.""" + from apache_beam.transforms.util import ( + _WindowAwareSortAndBatchElementsDoFn) + dofn = _WindowAwareSortAndBatchElementsDoFn( + min_batch_size=1, max_batch_size=100, max_batch_weight=12, + element_size_fn=len) + dofn.start_bundle() + win = IntervalWindow(0, 10) + dofn._buffers[win].extend(['aaaaa', 'bbbbb', 'ccccc', 'ddddd']) + batches = list(dofn._flush_window(win)) + self.assertEqual(len(batches), 2) + for wv in batches: + self.assertEqual(len(wv.value), 2) + self.assertEqual(wv.windows[0], win) + + class IdentityWindowTest(unittest.TestCase): def test_window_preserved(self): expected_timestamp = timestamp.Timestamp(5)