From 3ced4aec25f609f634acf5416f66ebf6c4fecb55 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Mon, 7 Apr 2025 15:01:48 -0700 Subject: [PATCH 1/2] test_create_produce_requests --- test/test_sender.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/test/test_sender.py b/test/test_sender.py index eedc43d25..7f7f0c18b 100644 --- a/test/test_sender.py +++ b/test/test_sender.py @@ -1,9 +1,12 @@ # pylint: skip-file from __future__ import absolute_import -import pytest +import collections import io +import pytest +from kafka.vendor import six + from kafka.client_async import KafkaClient from kafka.protocol.broker_api_versions import BROKER_API_VERSIONS from kafka.producer.kafka import KafkaProducer @@ -20,8 +23,8 @@ def accumulator(): @pytest.fixture -def sender(client, accumulator, metrics, mocker): - return Sender(client, client.cluster, accumulator, metrics=metrics) +def sender(client, accumulator): + return Sender(client, client.cluster, accumulator) @pytest.mark.parametrize(("api_version", "produce_version"), [ @@ -30,7 +33,7 @@ def sender(client, accumulator, metrics, mocker): ((0, 9), 1), ((0, 8, 0), 0) ]) -def test_produce_request(sender, mocker, api_version, produce_version): +def test_produce_request(sender, api_version, produce_version): sender._client._api_versions = BROKER_API_VERSIONS[api_version] tp = TopicPartition('foo', 0) magic = KafkaProducer.max_usable_produce_magic(api_version) @@ -40,3 +43,24 @@ def test_produce_request(sender, mocker, api_version, produce_version): records.close() produce_request = sender._produce_request(0, 0, 0, [batch]) assert isinstance(produce_request, ProduceRequest[produce_version]) + + +@pytest.mark.parametrize(("api_version", "produce_version"), [ + ((2, 1), 7), +]) +def test_create_produce_requests(sender, api_version, produce_version): + sender._client._api_versions = BROKER_API_VERSIONS[api_version] + tp = TopicPartition('foo', 0) + magic = KafkaProducer.max_usable_produce_magic(api_version) + batches_by_node = collections.defaultdict(list) + for node in range(3): + for _ in range(5): + records = MemoryRecordsBuilder( + magic=1, compression_type=0, batch_size=100000) + batches_by_node[node].append(ProducerBatch(tp, records)) + records.close() + + produce_requests_by_node = sender._create_produce_requests(batches_by_node) + assert len(produce_requests_by_node) == 3 + for node in range(3): + assert isinstance(produce_requests_by_node[node], ProduceRequest[produce_version]) From 4bc69d3ac497bdb3f2e9034925910df2b3b9ab80 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Mon, 7 Apr 2025 18:14:05 -0700 Subject: [PATCH 2/2] More Sender tests --- test/test_sender.py | 170 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 160 insertions(+), 10 deletions(-) diff --git a/test/test_sender.py b/test/test_sender.py index 7f7f0c18b..a1a775b59 100644 --- a/test/test_sender.py +++ b/test/test_sender.py @@ -3,16 +3,21 @@ import collections import io +import time import pytest +from unittest.mock import call + from kafka.vendor import six from kafka.client_async import KafkaClient +import kafka.errors as Errors from kafka.protocol.broker_api_versions import BROKER_API_VERSIONS from kafka.producer.kafka import KafkaProducer from kafka.protocol.produce import ProduceRequest from kafka.producer.record_accumulator import RecordAccumulator, ProducerBatch from kafka.producer.sender import Sender +from kafka.producer.transaction_state import TransactionState from kafka.record.memory_records import MemoryRecordsBuilder from kafka.structs import TopicPartition @@ -27,6 +32,16 @@ def sender(client, accumulator): return Sender(client, client.cluster, accumulator) +def producer_batch(topic='foo', partition=0, magic=2): + tp = TopicPartition(topic, partition) + records = MemoryRecordsBuilder( + magic=magic, compression_type=0, batch_size=100000) + batch = ProducerBatch(tp, records) + batch.try_append(0, None, b'msg', []) + batch.records.close() + return batch + + @pytest.mark.parametrize(("api_version", "produce_version"), [ ((2, 1), 7), ((0, 10, 0), 2), @@ -35,12 +50,8 @@ def sender(client, accumulator): ]) def test_produce_request(sender, api_version, produce_version): sender._client._api_versions = BROKER_API_VERSIONS[api_version] - tp = TopicPartition('foo', 0) magic = KafkaProducer.max_usable_produce_magic(api_version) - records = MemoryRecordsBuilder( - magic=1, compression_type=0, batch_size=100000) - batch = ProducerBatch(tp, records) - records.close() + batch = producer_batch(magic=magic) produce_request = sender._produce_request(0, 0, 0, [batch]) assert isinstance(produce_request, ProduceRequest[produce_version]) @@ -55,12 +66,151 @@ def test_create_produce_requests(sender, api_version, produce_version): batches_by_node = collections.defaultdict(list) for node in range(3): for _ in range(5): - records = MemoryRecordsBuilder( - magic=1, compression_type=0, batch_size=100000) - batches_by_node[node].append(ProducerBatch(tp, records)) - records.close() - + batches_by_node[node].append(producer_batch(magic=magic)) produce_requests_by_node = sender._create_produce_requests(batches_by_node) assert len(produce_requests_by_node) == 3 for node in range(3): assert isinstance(produce_requests_by_node[node], ProduceRequest[produce_version]) + + +def test_complete_batch_success(sender): + batch = producer_batch() + assert not batch.produce_future.is_done + + # No error, base_offset 0 + sender._complete_batch(batch, None, 0, timestamp_ms=123, log_start_offset=456) + assert batch.is_done + assert batch.produce_future.is_done + assert batch.produce_future.succeeded() + assert batch.produce_future.value == (0, 123, 456) + + +def test_complete_batch_transaction(sender): + sender._transaction_state = TransactionState() + batch = producer_batch() + assert sender._transaction_state.sequence_number(batch.topic_partition) == 0 + assert sender._transaction_state.producer_id_and_epoch.producer_id == batch.producer_id + + # No error, base_offset 0 + sender._complete_batch(batch, None, 0) + assert batch.is_done + assert sender._transaction_state.sequence_number(batch.topic_partition) == batch.record_count + + +@pytest.mark.parametrize(("error", "refresh_metadata"), [ + (Errors.KafkaConnectionError, True), + (Errors.CorruptRecordError, False), + (Errors.UnknownTopicOrPartitionError, True), + (Errors.NotLeaderForPartitionError, True), + (Errors.MessageSizeTooLargeError, False), + (Errors.InvalidTopicError, False), + (Errors.RecordListTooLargeError, False), + (Errors.NotEnoughReplicasError, False), + (Errors.NotEnoughReplicasAfterAppendError, False), + (Errors.InvalidRequiredAcksError, False), + (Errors.TopicAuthorizationFailedError, False), + (Errors.UnsupportedForMessageFormatError, False), + (Errors.InvalidProducerEpochError, False), + (Errors.ClusterAuthorizationFailedError, False), + (Errors.TransactionalIdAuthorizationFailedError, False), +]) +def test_complete_batch_error(sender, error, refresh_metadata): + sender._client.cluster._last_successful_refresh_ms = (time.time() - 10) * 1000 + sender._client.cluster._need_update = False + assert sender._client.cluster.ttl() > 0 + batch = producer_batch() + sender._complete_batch(batch, error, -1) + if refresh_metadata: + assert sender._client.cluster.ttl() == 0 + else: + assert sender._client.cluster.ttl() > 0 + assert batch.is_done + assert batch.produce_future.failed() + assert isinstance(batch.produce_future.exception, error) + + +@pytest.mark.parametrize(("error", "retry"), [ + (Errors.KafkaConnectionError, True), + (Errors.CorruptRecordError, False), + (Errors.UnknownTopicOrPartitionError, True), + (Errors.NotLeaderForPartitionError, True), + (Errors.MessageSizeTooLargeError, False), + (Errors.InvalidTopicError, False), + (Errors.RecordListTooLargeError, False), + (Errors.NotEnoughReplicasError, True), + (Errors.NotEnoughReplicasAfterAppendError, True), + (Errors.InvalidRequiredAcksError, False), + (Errors.TopicAuthorizationFailedError, False), + (Errors.UnsupportedForMessageFormatError, False), + (Errors.InvalidProducerEpochError, False), + (Errors.ClusterAuthorizationFailedError, False), + (Errors.TransactionalIdAuthorizationFailedError, False), +]) +def test_complete_batch_retry(sender, accumulator, mocker, error, retry): + sender.config['retries'] = 1 + mocker.spy(sender, '_fail_batch') + mocker.patch.object(accumulator, 'reenqueue') + batch = producer_batch() + sender._complete_batch(batch, error, -1) + if retry: + assert not batch.is_done + accumulator.reenqueue.assert_called_with(batch) + batch.attempts += 1 # normally handled by accumulator.reenqueue, but it's mocked + sender._complete_batch(batch, error, -1) + assert batch.is_done + assert isinstance(batch.produce_future.exception, error) + else: + assert batch.is_done + assert isinstance(batch.produce_future.exception, error) + + +def test_complete_batch_producer_id_changed_no_retry(sender, accumulator, mocker): + sender._transaction_state = TransactionState() + sender.config['retries'] = 1 + mocker.spy(sender, '_fail_batch') + mocker.patch.object(accumulator, 'reenqueue') + error = Errors.NotLeaderForPartitionError + batch = producer_batch() + sender._complete_batch(batch, error, -1) + assert not batch.is_done + accumulator.reenqueue.assert_called_with(batch) + batch.records._producer_id = 123 # simulate different producer_id + assert batch.producer_id != sender._transaction_state.producer_id_and_epoch.producer_id + sender._complete_batch(batch, error, -1) + assert batch.is_done + assert isinstance(batch.produce_future.exception, error) + + +def test_fail_batch(sender, accumulator, mocker): + sender._transaction_state = TransactionState() + mocker.patch.object(TransactionState, 'reset_producer_id') + batch = producer_batch() + mocker.patch.object(batch, 'done') + assert sender._transaction_state.producer_id_and_epoch.producer_id == batch.producer_id + error = Exception('error') + sender._fail_batch(batch, base_offset=0, timestamp_ms=None, exception=error, log_start_offset=None) + sender._transaction_state.reset_producer_id.assert_called_once() + batch.done.assert_called_with(base_offset=0, timestamp_ms=None, exception=error, log_start_offset=None) + + +def test_handle_produce_response(): + pass + + +def test_failed_produce(sender, mocker): + mocker.patch.object(sender, '_complete_batch') + mock_batches = ['foo', 'bar', 'fizzbuzz'] + sender._failed_produce(mock_batches, 0, 'error') + sender._complete_batch.assert_has_calls([ + call('foo', 'error', -1), + call('bar', 'error', -1), + call('fizzbuzz', 'error', -1), + ]) + + +def test_maybe_wait_for_producer_id(): + pass + + +def test_run_once(): + pass