diff --git a/kafka/partitioner/default.py b/kafka/partitioner/default.py index a33b850cc..69613fe63 100644 --- a/kafka/partitioner/default.py +++ b/kafka/partitioner/default.py @@ -8,24 +8,18 @@ class DefaultPartitioner: If key is None, selects partition randomly from available, or from all partitions if none are currently available """ - @classmethod - def __call__(cls, key, all_partitions, available): - """ - Get the partition corresponding to key - :param key: partitioning key - :param all_partitions: list of all partitions sorted by partition ID - :param available: list of available partitions in no particular order - :return: one of the values from all_partitions or available - """ - if key is None: - if available: - return random.choice(available) - return random.choice(all_partitions) - - idx = murmur2(key) - idx &= 0x7fffffff - idx %= len(all_partitions) - return all_partitions[idx] + def partition(self, topic, serialized_key, cluster): + if topic not in cluster.topics(): + raise ValueError("Topic %s not found in ClusterMetadata" % (topic,)) + all_partitions = sorted(cluster.partitions_for_topic(topic)) + available = list(cluster.available_partitions_for_topic(topic)) + if serialized_key is not None: + idx = murmur2(serialized_key) + idx &= 0x7fffffff + idx %= len(all_partitions) + return all_partitions[idx] + pool = available if available else all_partitions + return random.choice(pool) # https://github.com/apache/kafka/blob/0.8.2/clients/src/main/java/org/apache/kafka/common/utils/Utils.java#L244 diff --git a/kafka/partitioner/sticky.py b/kafka/partitioner/sticky.py index cdda61101..a58b22a79 100644 --- a/kafka/partitioner/sticky.py +++ b/kafka/partitioner/sticky.py @@ -16,106 +16,91 @@ """ import random +import threading -from kafka.partitioner.default import murmur2 +from kafka.partitioner.default import DefaultPartitioner -class StickyPartitioner: +class StickyPartitioner(DefaultPartitioner): """Partitioner that sticks null-key records to one partition per topic until ``on_new_batch`` rotates it. - Thread-safety: the underlying ``_sticky`` dict is mutated only by - individually-atomic Python ops (get / setitem / contains). Two - concurrent partitioners may pick different sticky partitions; the - last write wins and both choices are valid, so no lock is needed. + Thread-safety: ``_sticky`` mutations are protected by ``_lock`` so + concurrent ``send()`` callers can't observe a torn read-modify-write. """ def __init__(self): self._sticky = {} # topic -> partition_id - # Java's accumulator distinguishes "first batch created on a - # partition" (no rotation) from "existing batch filled, new one - # being created" (rotate). Our accumulator collapses both into - # ``new_batch_created=True``, so the partitioner absorbs the - # *first* on_new_batch event per sticky and only rotates on the - # subsequent one. Without this, we'd rotate on every record - # whose partition has no existing batch, defeating stickiness. - self._sticky_seen_batch = set() # topics whose current sticky has had >=1 batch event + self._lock = threading.Lock() - def partition(self, topic, key, all_partitions, available): + def partition(self, topic, key, cluster): """Choose a partition for the next record. Arguments: topic (str): topic to partition on. key (bytes or None): partitioning key. - all_partitions (list[int]): every partition ID for the topic, - sorted ascending. - available (list[int]): partitions whose leader is currently - known (may be empty when metadata is stale). + cluster (ClusterMetadata): metadata for cluster; provides + all and available partitions for topic. + + Raises: + ValueError: if topic is not in ClusterMetadata Returns: int: chosen partition ID. """ + if topic not in cluster.topics(): + raise ValueError("Topic %s not found in ClusterMetadata" % (topic,)) if key is not None: - idx = murmur2(key) - idx &= 0x7fffffff - idx %= len(all_partitions) - return all_partitions[idx] + return super().partition(topic, key, cluster) # Null key: reuse the sticky partition if still valid. - partition = self._sticky.get(topic) - if partition is not None: - if available: - if partition in available: + with self._lock: + partition = self._sticky.get(topic) + if partition is not None: + all_partitions = sorted(cluster.partitions_for_topic(topic)) + available = list(cluster.available_partitions_for_topic(topic)) + if available: + if partition in available: + return partition + elif partition in all_partitions: return partition - elif partition in all_partitions: - return partition - # Stale (leader unavailable, topic shrunk); fall through to re-pick. - return self._pick_sticky(topic, all_partitions, available) + # Stale (leader unavailable, topic shrunk); fall through to re-pick. + return self._pick_sticky_locked(topic, cluster) - def on_new_batch(self, topic, all_partitions, prev_partition): - """Hook called by ``KafkaProducer`` when the accumulator just - opened a new batch for ``topic`` on ``prev_partition``. + def on_new_batch(self, topic, cluster, prev_partition): + """Hook called by ``KafkaProducer`` on the abort-for-new-batch + retry path: rotate the sticky for ``topic`` so the next + null-key record lands on a different partition. - The *first* event per sticky is absorbed silently: it - corresponds to the first batch ever being created on the - partition we just picked, which is expected - we want - subsequent records to keep landing there. The *second* event - means the previous batch filled up and a new one was opened; - that's the signal to rotate to a different partition so the - next records build up a fresh dense batch elsewhere. + Stale events (where another thread already rotated us off + ``prev_partition``) are no-ops. """ - if self._sticky.get(topic) != prev_partition: - # Someone else (or a key-routed send) already moved us off - # this partition; don't override their choice. - return - if topic not in self._sticky_seen_batch: - self._sticky_seen_batch.add(topic) - return - # Existing batch filled; rotate. - self._sticky_seen_batch.discard(topic) - self._pick_sticky(topic, all_partitions, None, - avoid=prev_partition) + with self._lock: + if self._sticky.get(topic) != prev_partition: + # Another caller already rotated us; don't override. + return + self._pick_sticky_locked(topic, cluster, avoid=prev_partition) - def _pick_sticky(self, topic, all_partitions, available, avoid=None): - pool = available if available else all_partitions - candidates = [p for p in pool if p != avoid] if avoid is not None else pool - if not candidates: - # Single-partition topic, or only the avoid-partition is - # available - no rotation possible. - candidates = pool - partition = random.choice(candidates) + def _pick_sticky_locked(self, topic, cluster, avoid=None): + """Pick a new sticky partition for ``topic``. Must be called with + ``self._lock`` held. Returns None when the topic is no longer in + cluster metadata (caller is expected to no-op in that case).""" + all_partitions = cluster.partitions_for_topic(topic) + if not all_partitions: + return None + all_partitions = sorted(all_partitions) + available = list(cluster.available_partitions_for_topic(topic) or ()) + if available: + if len(available) == 1: + partition = available[0] + else: + # >= 2 available: pick uniformly, avoiding ``avoid`` if set. + candidates = [p for p in available if p != avoid] if avoid is not None else available + if not candidates: + candidates = available + partition = random.choice(candidates) + else: + # No partitions are currently available - pick from the full + # set without enforcing ``!= avoid`` + partition = random.choice(all_partitions) self._sticky[topic] = partition - # Reset the seen-batch flag; the new sticky has had no batches yet. - self._sticky_seen_batch.discard(topic) return partition - - # Compatibility shim: legacy code paths that treat partitioners as - # bare callables (key, all_partitions, available) still work, though - # they lose the per-topic stickiness. - def __call__(self, key, all_partitions, available): - if key is not None: - idx = murmur2(key) - idx &= 0x7fffffff - idx %= len(all_partitions) - return all_partitions[idx] - pool = available if available else all_partitions - return random.choice(pool) diff --git a/kafka/producer/kafka.py b/kafka/producer/kafka.py index e74d061bf..8db5d7ebe 100644 --- a/kafka/producer/kafka.py +++ b/kafka/producer/kafka.py @@ -234,9 +234,9 @@ class KafkaProducer: would have the effect of reducing the number of requests sent but would add up to 5ms of latency to records sent in the absence of load. Default: 0. - partitioner (callable): Callable used to determine which partition + partitioner (Partitioner): Used to determine which partition each message is assigned to. Called (after key serialization): - partitioner(key_bytes, all_partitions, available_partitions). + partitioner.partition(topic, key_bytes, cluster_metadata). The default partitioner implementation hashes each non-None key using the same murmur2 algorithm as the java client so that messages with the same key are assigned to the same partition. @@ -391,7 +391,7 @@ class KafkaProducer: 'retries': float('inf'), 'batch_size': 16384, 'linger_ms': 0, - 'partitioner': StickyPartitioner(), + 'partitioner': DefaultPartitioner(), 'connections_max_idle_ms': 9 * 60 * 1000, 'max_block_ms': 60000, 'max_request_size': 1048576, @@ -857,6 +857,8 @@ def send(self, topic, value=None, key=None, headers=None, partition=None, timest sum(len(h_key.encode("utf-8")) + len(h_value) for h_key, h_value in headers) if headers else -1, ).failure(e) + # Track if the user passed an explicit partition b/c sticky logic does not apply + explicit_partition = partition is not None partition = self._partition(topic, partition, key, value, key_bytes, value_bytes) assert partition is not None, f'Partitioner did not assign a partition for topic {topic}!' @@ -874,24 +876,34 @@ def send(self, topic, value=None, key=None, headers=None, partition=None, timest if self._transaction_manager and self._transaction_manager.is_transactional(): self._transaction_manager.maybe_add_partition_to_transaction(tp) - result = self._accumulator.append(tp, timestamp_ms, key_bytes, value_bytes, headers) - future, batch_is_full, new_batch_created = result + # KIP-480: when sticky-aware partitioning is in play (no explicit + # partition, no key), try once with abort_on_new_batch=True. If the + # accumulator would have to allocate a fresh batch for this partition, + # rotate the sticky partition first and re-pick. The record that + # *triggers* the new batch then lands on the rotated partition, not + # the next one. + sticky_eligible = not explicit_partition and key_bytes is None + result = self._accumulator.append(tp, timestamp_ms, key_bytes, value_bytes, headers, + abort_on_new_batch=sticky_eligible) + future, batch_is_full, new_batch_created, abort_for_new_batch = result + if abort_for_new_batch: + prev_partition = partition + on_new_batch = getattr(self.config['partitioner'], 'on_new_batch', None) + if on_new_batch is not None: + on_new_batch(topic, self._metadata, prev_partition) + # Re-pick - sticky cache may now point at a different partition. + partition = self._partition(topic, None, key, value, key_bytes, value_bytes) + tp = TopicPartition(topic, partition) + if self._transaction_manager and self._transaction_manager.is_transactional(): + self._transaction_manager.maybe_add_partition_to_transaction(tp) + result = self._accumulator.append(tp, timestamp_ms, key_bytes, value_bytes, headers, + abort_on_new_batch=False) + future, batch_is_full, new_batch_created, _ = result + if batch_is_full or new_batch_created: log.debug("%s: Waking up the sender since %s is either full or" " getting a new batch", str(self), tp) self._sender.wakeup() - # KIP-480: notify a sticky-aware partitioner that this null-key - # record opened a new batch on `partition`, so the next null-key - # send for `topic` rotates to a different partition. Keyed - # records hash deterministically and don't participate in sticky - # rotation, so skip the hook for them. - if new_batch_created and key_bytes is None: - partitioner = self.config['partitioner'] - on_new_batch = getattr(partitioner, 'on_new_batch', None) - if on_new_batch is not None: - all_partitions = self._metadata.partitions_for_topic(topic) - if all_partitions is not None: - on_new_batch(topic, sorted(all_partitions), partition) return future def flush(self, timeout=None): @@ -984,28 +996,17 @@ def _serialize(self, f, topic, data): def _partition(self, topic, partition, key, value, serialized_key, serialized_value): - all_partitions = self._metadata.partitions_for_topic(topic) - available = self._metadata.available_partitions_for_topic(topic) - if all_partitions is None or available is None: + if topic not in self._metadata.topics(): return None if partition is not None: assert partition >= 0 - assert partition in all_partitions, 'Unrecognized partition' + all_partitions = self._metadata.partitions_for_topic(topic) + assert all_partitions is not None and partition in all_partitions, ( + 'Unrecognized partition %s for topic %s' % (partition, topic)) return partition - # Prefer the topic-aware partition() method (KIP-480 sticky - # partitioner needs the topic for its per-topic stickiness). - # Fall back to the legacy callable interface so user-supplied - # custom partitioners written against pre-KIP-480 kafka-python - # continue to work unchanged. - partitioner = self.config['partitioner'] - if hasattr(partitioner, 'partition'): - return partitioner.partition(topic, serialized_key, - sorted(all_partitions), - list(available)) - return partitioner(serialized_key, - sorted(all_partitions), - list(available)) + return self.config['partitioner'].partition( + topic, serialized_key, self._metadata) def metrics(self, raw=False): """Get metrics on producer performance. @@ -1032,4 +1033,4 @@ def metrics(self, raw=False): return metrics def __str__(self): - return "" % (self.config['client_id'], self.config['transactional_id']) + return "" % (self.config.get('client_id', None), self.config.get('transactional_id', None)) diff --git a/kafka/producer/record_accumulator.py b/kafka/producer/record_accumulator.py index ea8f707d3..20bb2f5b4 100644 --- a/kafka/producer/record_accumulator.py +++ b/kafka/producer/record_accumulator.py @@ -110,7 +110,8 @@ def _tp_lock(self, tp): self._tp_locks[tp] = threading.Lock() return self._tp_locks[tp] - def append(self, tp, timestamp_ms, key, value, headers, now=None): + def append(self, tp, timestamp_ms, key, value, headers, now=None, + abort_on_new_batch=False): """Add a record to the accumulator, return the append result. The append result will contain the future metadata, and flag for @@ -123,9 +124,14 @@ def append(self, tp, timestamp_ms, key, value, headers, now=None): key (bytes): The key for the record value (bytes): The value for the record headers (List[Tuple[str, bytes]]): The header fields for the record + abort_on_new_batch (bool): KIP-480. When True, return early with + ``abort_for_new_batch=True`` instead of allocating a new + batch when no in-progress batch has room. Caller is expected + to consult the partitioner's ``on_new_batch`` hook, re-pick + the partition, and retry with ``abort_on_new_batch=False``. Returns: - tuple: (future, batch_is_full, new_batch_created) + tuple: (future, batch_is_full, new_batch_created, abort_for_new_batch) """ assert isinstance(tp, TopicPartition), 'not TopicPartition' assert not self._closed, 'RecordAccumulator is closed' @@ -142,7 +148,12 @@ def append(self, tp, timestamp_ms, key, value, headers, now=None): future = last.try_append(timestamp_ms, key, value, headers, now=now) if future is not None: batch_is_full = len(dq) > 1 or last.records.is_full() - return future, batch_is_full, False + return future, batch_is_full, False, False + + if abort_on_new_batch: + # KIP-480: don't allocate a new batch yet. Caller will + # rotate the sticky partition and retry. + return None, False, False, True with self._tp_lock(tp): # Need to check if producer is closed again after grabbing the @@ -156,7 +167,7 @@ def append(self, tp, timestamp_ms, key, value, headers, now=None): # Somebody else found us a batch, return the one we # waited for! Hopefully this doesn't happen often... batch_is_full = len(dq) > 1 or last.records.is_full() - return future, batch_is_full, False + return future, batch_is_full, False, False if self._transaction_manager and self.config['message_version'] < 2: raise Errors.UnsupportedVersionError("Attempting to use idempotence with a broker which" @@ -176,7 +187,7 @@ def append(self, tp, timestamp_ms, key, value, headers, now=None): dq.append(batch) self._incomplete.add(batch) batch_is_full = len(dq) > 1 or batch.records.is_full() - return future, batch_is_full, True + return future, batch_is_full, True, False finally: self._appends_in_progress.decrement() diff --git a/test/producer/test_partitioner.py b/test/producer/test_partitioner.py index a4b2b0cbf..2068562ad 100644 --- a/test/producer/test_partitioner.py +++ b/test/producer/test_partitioner.py @@ -1,22 +1,49 @@ +from unittest.mock import MagicMock + import pytest from kafka.partitioner import DefaultPartitioner, StickyPartitioner, murmur2 +def _cluster(topic, all_partitions, available=None): + """Build a mock ClusterMetadata that returns the given partition + sets for ``topic`` (matches the interface DefaultPartitioner / + StickyPartitioner actually call: ``topics()``, + ``partitions_for_topic()``, ``available_partitions_for_topic()``).""" + if available is None: + available = all_partitions + cluster = MagicMock() + cluster.topics.return_value = {topic} + cluster.partitions_for_topic.return_value = set(all_partitions) + cluster.available_partitions_for_topic.return_value = set(available) + return cluster + + def test_default_partitioner(): partitioner = DefaultPartitioner() - all_partitions = available = list(range(100)) + all_partitions = list(range(100)) + cluster = _cluster('t', all_partitions) # partitioner should return the same partition for the same key - p1 = partitioner(b'foo', all_partitions, available) - p2 = partitioner(b'foo', all_partitions, available) + p1 = partitioner.partition('t', b'foo', cluster) + p2 = partitioner.partition('t', b'foo', cluster) assert p1 == p2 assert p1 in all_partitions # when key is None, choose one of available partitions - assert partitioner(None, all_partitions, [123]) == 123 + cluster_single = _cluster('t', all_partitions, available=[123]) + cluster_single.partitions_for_topic.return_value = set(all_partitions) | {123} + assert partitioner.partition('t', None, cluster_single) == 123 - # with fallback to all_partitions - assert partitioner(None, all_partitions, []) in all_partitions + # with fallback to all_partitions when available is empty + cluster_no_avail = _cluster('t', all_partitions, available=[]) + assert partitioner.partition('t', None, cluster_no_avail) in all_partitions + + +def test_default_partitioner_unknown_topic(): + partitioner = DefaultPartitioner() + cluster = _cluster('t', list(range(10))) + with pytest.raises(ValueError, match='not found'): + partitioner.partition('other', b'k', cluster) @pytest.mark.parametrize("bytes_payload,partition_number", [ @@ -25,9 +52,10 @@ def test_default_partitioner(): ]) def test_murmur2_java_compatibility(bytes_payload, partition_number): partitioner = DefaultPartitioner() - all_partitions = available = list(range(1000)) + all_partitions = list(range(1000)) + cluster = _cluster('t', all_partitions) # compare with output from Kafka's org.apache.kafka.clients.producer.Partitioner - assert partitioner(bytes_payload, all_partitions, available) == partition_number + assert partitioner.partition('t', bytes_payload, cluster) == partition_number def test_murmur2_not_ascii(): @@ -43,89 +71,175 @@ class TestStickyPartitioner: def test_keyed_records_hash_like_default(self): sticky = StickyPartitioner() default = DefaultPartitioner() - all_partitions = available = list(range(100)) - assert (sticky.partition('t', b'foo', all_partitions, available) - == default(b'foo', all_partitions, available)) - assert (sticky.partition('t', b'bar', all_partitions, available) - == default(b'bar', all_partitions, available)) - - def test_null_key_sticks_until_second_on_new_batch(self): - """The *first* on_new_batch event is absorbed (it's the first - batch being opened on the newly-picked sticky - exactly what we - want). Rotation only happens on the *second* event, which - signals that the previous batch filled up. Without this, the - partitioner would rotate on every record whose partition had - no existing batch.""" + all_partitions = list(range(100)) + cluster = _cluster('t', all_partitions) + assert (sticky.partition('t', b'foo', cluster) + == default.partition('t', b'foo', cluster)) + assert (sticky.partition('t', b'bar', cluster) + == default.partition('t', b'bar', cluster)) + + def test_unknown_topic_raises(self): + sticky = StickyPartitioner() + cluster = _cluster('t', [0, 1]) + with pytest.raises(ValueError, match='not found'): + sticky.partition('other', None, cluster) + + def test_null_key_sticks_until_on_new_batch(self): + """A stream of null-key records pins to the chosen partition; + ``on_new_batch`` rotates immediately (no "absorb first event" + hack - KafkaProducer.send invokes the hook only on the + abort-for-new-batch retry path, matching Java).""" sticky = StickyPartitioner() - all_partitions = available = list(range(10)) - p1 = sticky.partition('t', None, all_partitions, available) + all_partitions = list(range(10)) + cluster = _cluster('t', all_partitions) + p1 = sticky.partition('t', None, cluster) for _ in range(50): - assert sticky.partition('t', None, all_partitions, available) == p1 - - # First on_new_batch: opens the first batch on p1 - no rotation. - sticky.on_new_batch('t', all_partitions, p1) - assert sticky.partition('t', None, all_partitions, available) == p1 + assert sticky.partition('t', None, cluster) == p1 - # Second on_new_batch: previous batch filled - rotate. - sticky.on_new_batch('t', all_partitions, p1) - p2 = sticky.partition('t', None, all_partitions, available) - assert p2 != p1, 'second on_new_batch should rotate' + # on_new_batch rotates immediately. + sticky.on_new_batch('t', cluster, p1) + p2 = sticky.partition('t', None, cluster) + assert p2 != p1, 'on_new_batch should rotate' for _ in range(50): - assert sticky.partition('t', None, all_partitions, available) == p2 + assert sticky.partition('t', None, cluster) == p2 + + def test_on_new_batch_rotation_avoids_prev_partition(self): + """Regression test for the prior bug where on_new_batch passed + available=None to _pick_sticky_locked and the rotation could + randomly land back on the same partition. Now that on_new_batch + passes the cluster through, the available-aware avoid logic + kicks in and rotation is guaranteed.""" + sticky = StickyPartitioner() + all_partitions = list(range(10)) + cluster = _cluster('t', all_partitions) + p1 = sticky.partition('t', None, cluster) + # Many rotations should never pick p1 again when many partitions + # are available and p1 is the avoid target. + for _ in range(200): + sticky._sticky['t'] = p1 # restore state between iterations + sticky.on_new_batch('t', cluster, p1) + assert sticky._sticky['t'] != p1, ( + 'rotation must avoid prev when other available partitions exist') def test_per_topic_state_independent(self): """Stickiness is per-topic; rotating one topic doesn't affect another.""" sticky = StickyPartitioner() - all_partitions = available = list(range(10)) - p_a = sticky.partition('a', None, all_partitions, available) - p_b = sticky.partition('b', None, all_partitions, available) - # Two on_new_batch events on 'a' to actually rotate it. - sticky.on_new_batch('a', all_partitions, p_a) - sticky.on_new_batch('a', all_partitions, p_a) + all_partitions = list(range(10)) + cluster_a = _cluster('a', all_partitions) + cluster_b = _cluster('b', all_partitions) + # Combined cluster mock answers for both topics. + cluster = MagicMock() + cluster.topics.return_value = {'a', 'b'} + cluster.partitions_for_topic.side_effect = ( + lambda t: set(all_partitions)) + cluster.available_partitions_for_topic.side_effect = ( + lambda t: set(all_partitions)) + p_a = sticky.partition('a', None, cluster) + p_b = sticky.partition('b', None, cluster) + # Rotate 'a' once. + sticky.on_new_batch('a', cluster, p_a) # 'b' is untouched. - assert sticky.partition('b', None, all_partitions, available) == p_b + assert sticky.partition('b', None, cluster) == p_b def test_unavailable_sticky_partition_repicks(self): """If the stuck partition's leader becomes unavailable, the next partition() call repicks from the available set.""" sticky = StickyPartitioner() all_partitions = list(range(10)) - p1 = sticky.partition('t', None, all_partitions, all_partitions) + cluster = _cluster('t', all_partitions, available=all_partitions) + p1 = sticky.partition('t', None, cluster) # Now only partitions != p1 are available. - available = [p for p in all_partitions if p != p1] - p2 = sticky.partition('t', None, all_partitions, available) + cluster.available_partitions_for_topic.return_value = ( + set(all_partitions) - {p1}) + p2 = sticky.partition('t', None, cluster) assert p2 != p1 - assert p2 in available + assert p2 in (set(all_partitions) - {p1}) def test_single_partition_topic_cannot_rotate(self): """on_new_batch on a single-partition topic just keeps the same partition - there's nothing else to rotate to.""" sticky = StickyPartitioner() - all_partitions = available = [0] - assert sticky.partition('t', None, all_partitions, available) == 0 - sticky.on_new_batch('t', all_partitions, 0) - assert sticky.partition('t', None, all_partitions, available) == 0 + cluster = _cluster('t', [0]) + assert sticky.partition('t', None, cluster) == 0 + sticky.on_new_batch('t', cluster, 0) + assert sticky.partition('t', None, cluster) == 0 def test_on_new_batch_ignores_stale_prev_partition(self): """If a key-routed send or another caller already rotated the sticky between when we picked and when on_new_batch fires, the hook is a no-op (don't override their choice).""" sticky = StickyPartitioner() - all_partitions = available = list(range(10)) - p1 = sticky.partition('t', None, all_partitions, available) + all_partitions = list(range(10)) + cluster = _cluster('t', all_partitions) + p1 = sticky.partition('t', None, cluster) # Simulate someone else rotating away. sticky._sticky['t'] = (p1 + 1) % len(all_partitions) current = sticky._sticky['t'] - sticky.on_new_batch('t', all_partitions, prev_partition=p1) + sticky.on_new_batch('t', cluster, prev_partition=p1) assert sticky._sticky['t'] == current, 'should not overwrite live sticky' - def test_legacy_callable_interface_still_works(self): - """A user-supplied custom partitioner written against the old - callable signature must keep working. We exercise that shim - via StickyPartitioner itself (it implements both forms).""" + def test_no_available_picks_without_avoid(self): + """When ``available_partitions_for_topic`` returns empty, Java + picks random % partitions.size() without enforcing ``!= avoid`` + — ensure we mirror that and don't filter the avoided partition + out of an already-degraded fallback set.""" + sticky = StickyPartitioner() + cluster = _cluster('t', [0, 1, 2], available=[]) + sticky._sticky['t'] = 1 # pre-seed sticky + observed = set() + for _ in range(200): + partition = sticky._pick_sticky_locked('t', cluster, avoid=1) + observed.add(partition) + assert observed == {0, 1, 2}, ( + 'avoid should not filter when available is empty; got %s' % observed) + + def test_single_available_partition_repeats_even_if_avoided(self): + """Single-element ``available`` must always return that one, + even if it equals ``avoid`` (Java's nextPartition does the same).""" + sticky = StickyPartitioner() + cluster = _cluster('t', [0, 1, 2], available=[1]) + for _ in range(50): + partition = sticky._pick_sticky_locked('t', cluster, avoid=1) + assert partition == 1 + + def test_pick_sticky_unknown_topic_returns_none(self): + """_pick_sticky_locked is the defensive layer for on_new_batch's + no-pre-validation call; a topic that disappeared from cluster + metadata returns None (caller no-ops).""" + sticky = StickyPartitioner() + cluster = MagicMock() + cluster.partitions_for_topic.return_value = None + assert sticky._pick_sticky_locked('gone', cluster) is None + + def test_thread_safety_under_contention(self): + """Concurrent ``partition`` + ``on_new_batch`` calls from many + threads should never raise (no torn read-modify-write) and the + final ``_sticky[topic]`` value must be one of the valid + partitions.""" + import threading sticky = StickyPartitioner() - all_partitions = available = list(range(100)) - # __call__ (no topic) ignores stickiness; just verify it returns - # a valid partition for both keyed and null-key inputs. - assert sticky(b'foo', all_partitions, available) in all_partitions - assert sticky(None, all_partitions, available) in all_partitions + topic = 't' + partitions = list(range(20)) + cluster = _cluster(topic, partitions) + errors = [] + stop = threading.Event() + + def hammer(): + try: + while not stop.is_set(): + p = sticky.partition(topic, None, cluster) + sticky.on_new_batch(topic, cluster, p) + except Exception as exc: + errors.append(exc) + + threads = [threading.Thread(target=hammer) for _ in range(8)] + for t in threads: + t.start() + # Let them race for a moment. + import time as _time + _time.sleep(0.2) + stop.set() + for t in threads: + t.join(timeout=2) + assert not errors, 'concurrent ops raised: %r' % errors + assert sticky._sticky[topic] in partitions diff --git a/test/producer/test_producer.py b/test/producer/test_producer.py index 46e1a5bc2..90e4945c3 100644 --- a/test/producer/test_producer.py +++ b/test/producer/test_producer.py @@ -27,11 +27,13 @@ def test_kafka_producer_context_manager_closes_on_exit(): assert threading.active_count() == threads -def test_partition_uses_topic_aware_api_when_available(): - """_partition routes through partitioner.partition(topic, ...) when - the configured partitioner exposes it (KIP-480 sticky path).""" +def test_partition_calls_partitioner_partition_with_cluster(): + """_partition routes through partitioner.partition(topic, key, cluster) + — the new signature passes the ClusterMetadata directly so the + partitioner can call available_partitions_for_topic / topics() itself.""" producer = KafkaProducer.__new__(KafkaProducer) producer._metadata = MagicMock() + producer._metadata.topics.return_value = {'t'} producer._metadata.partitions_for_topic.return_value = {0, 1, 2} producer._metadata.available_partitions_for_topic.return_value = {0, 1, 2} @@ -41,27 +43,118 @@ def test_partition_uses_topic_aware_api_when_available(): result = producer._partition('t', None, None, None, b'key-bytes', b'val') assert result == 1 - partitioner.partition.assert_called_once_with('t', b'key-bytes', [0, 1, 2], [0, 1, 2]) + partitioner.partition.assert_called_once_with('t', b'key-bytes', producer._metadata) -def test_partition_falls_back_to_legacy_callable(): - """Custom partitioners written against the legacy callable signature - (no .partition method) keep working unchanged.""" +def test_partition_explicit_partition_skips_partitioner(): + """Explicit partition= argument bypasses the partitioner entirely. + The partition must still be in the topic's known set.""" producer = KafkaProducer.__new__(KafkaProducer) producer._metadata = MagicMock() + producer._metadata.topics.return_value = {'t'} producer._metadata.partitions_for_topic.return_value = {0, 1, 2} - producer._metadata.available_partitions_for_topic.return_value = {0, 1, 2} + partitioner = MagicMock() + producer.config = {'partitioner': partitioner} + + assert producer._partition('t', 1, None, None, b'k', b'v') == 1 + partitioner.partition.assert_not_called() + - # A plain function - no .partition attribute - must still work. - calls = [] - def legacy_partitioner(key, all_partitions, available): - calls.append((key, all_partitions, available)) - return 2 - producer.config = {'partitioner': legacy_partitioner} +def test_partition_explicit_partition_rejects_unknown_partition(): + producer = KafkaProducer.__new__(KafkaProducer) + producer._metadata = MagicMock() + producer._metadata.topics.return_value = {'t'} + producer._metadata.partitions_for_topic.return_value = {0, 1, 2} + producer.config = {'partitioner': MagicMock()} + with pytest.raises(AssertionError): + producer._partition('t', 99, None, None, b'k', b'v') - result = producer._partition('t', None, None, None, b'k', b'v') - assert result == 2 - assert calls == [(b'k', [0, 1, 2], [0, 1, 2])] + +def _producer_for_send_test(partitioner): + """Build a real KafkaProducer but replace the accumulator + sender + with mocks so ``send()`` doesn't try to actually push data.""" + producer = KafkaProducer(api_version=(2, 1), partitioner=partitioner) + producer._accumulator = MagicMock() + producer._sender = MagicMock() + producer._metadata = MagicMock() + producer._metadata.topics.return_value = {'t'} + producer._metadata.partitions_for_topic.return_value = set(range(20)) + producer._metadata.available_partitions_for_topic.return_value = set(range(20)) + return producer + + +def _success_result(): + from kafka.producer.future import FutureRecordMetadata, FutureProduceResult + from kafka.structs import TopicPartition + return (FutureRecordMetadata( + FutureProduceResult(TopicPartition('t', 0)), + 0, 0, 0, 0, 0, 0), + False, True, False) + + +def test_send_null_key_triggers_on_new_batch_via_abort_retry(): + """KIP-480 Java-faithful flow: a null-key send whose accumulator has + no in-progress batch must invoke ``partitioner.on_new_batch`` (rotate + the sticky) and re-pick the partition before the actual append, + matching KafkaProducer.doSend's abort-for-new-batch retry path.""" + partitioner = MagicMock(spec=['partition', 'on_new_batch']) + partitioner.partition.side_effect = [3, 7] # initial pick, post-rotate + producer = _producer_for_send_test(partitioner) + abort = (None, False, False, True) + producer._accumulator.append.side_effect = [abort, _success_result()] + + try: + producer.send('t', value=b'msg') + # Initial pick + post-rotate re-pick. + assert partitioner.partition.call_count == 2 + # on_new_batch fired exactly once, with the cluster metadata and + # the *initial* sticky. + partitioner.on_new_batch.assert_called_once_with( + 't', producer._metadata, 3) + # Two appends: first aborted, second landed the record on partition 7. + assert producer._accumulator.append.call_count == 2 + second_call = producer._accumulator.append.call_args_list[1] + tp_arg = second_call.args[0] + assert tp_arg.partition == 7 + finally: + producer.close(timeout=1) + + +def test_send_keyed_skips_on_new_batch(): + """Keyed records bypass the sticky abort-retry path — on_new_batch + must not fire.""" + partitioner = MagicMock(spec=['partition', 'on_new_batch']) + partitioner.partition.return_value = 0 + producer = _producer_for_send_test(partitioner) + producer._accumulator.append.return_value = _success_result() + + try: + producer.send('t', key=b'k', value=b'v') + partitioner.on_new_batch.assert_not_called() + # Keyed records pass abort_on_new_batch=False directly — one append. + assert producer._accumulator.append.call_count == 1 + kwargs = producer._accumulator.append.call_args.kwargs + assert kwargs.get('abort_on_new_batch') is False + finally: + producer.close(timeout=1) + + +def test_send_with_explicit_partition_skips_on_new_batch(): + """Explicit partition overrides the partitioner entirely — no + rotation hook should fire.""" + partitioner = MagicMock(spec=['partition', 'on_new_batch']) + producer = _producer_for_send_test(partitioner) + producer._accumulator.append.return_value = _success_result() + + try: + producer.send('t', value=b'v', partition=1) + partitioner.partition.assert_not_called() + partitioner.on_new_batch.assert_not_called() + # Explicit partition also goes straight to abort_on_new_batch=False. + kwargs = producer._accumulator.append.call_args.kwargs + assert kwargs.get('abort_on_new_batch') is False + finally: + producer.close(timeout=1) def test_idempotent_producer_reset_producer_id(cluster): diff --git a/test/producer/test_record_accumulator.py b/test/producer/test_record_accumulator.py index aaea6da58..bcfd5fad1 100644 --- a/test/producer/test_record_accumulator.py +++ b/test/producer/test_record_accumulator.py @@ -163,3 +163,36 @@ def test_expired_batches(cluster, tp): accum.muted.add(tp) expired_batches = accum.expired_batches(now=now) assert len(expired_batches) == 1, "The batch should not be expired when the partition is muted" + + +def test_abort_on_new_batch_returns_sentinel(tp): + """KIP-480 plumbing: with abort_on_new_batch=True and no in-progress + batch, append must return (None, False, False, True) instead of + allocating a fresh batch.""" + accum = RecordAccumulator() + future, batch_full, new_batch, abort = accum.append( + tp, 0, b'k', b'v', [], now=0, abort_on_new_batch=True) + assert future is None + assert batch_full is False + assert new_batch is False + assert abort is True + # And no batch was actually allocated. + assert tp not in accum._batches or not accum._batches[tp] + + +def test_abort_on_new_batch_appends_to_existing(tp): + """If a batch already exists with room, abort_on_new_batch=True still + succeeds — only the new-batch-allocation path is gated.""" + accum = RecordAccumulator() + # First append (with abort_on_new_batch=False) creates the batch. + future1, _, new1, abort1 = accum.append( + tp, 0, b'k1', b'v1', [], now=0, abort_on_new_batch=False) + assert future1 is not None + assert new1 is True + assert abort1 is False + # Second append with abort_on_new_batch=True lands in the existing batch. + future2, _, new2, abort2 = accum.append( + tp, 0, b'k2', b'v2', [], now=0, abort_on_new_batch=True) + assert future2 is not None + assert new2 is False + assert abort2 is False