11# pylint: skip-file
22from __future__ import absolute_import
33
4- import pytest
4+ import collections
55import io
6+ import time
7+
8+ import pytest
9+ from unittest .mock import call
10+
11+ from kafka .vendor import six
612
713from kafka .client_async import KafkaClient
14+ import kafka .errors as Errors
815from kafka .protocol .broker_api_versions import BROKER_API_VERSIONS
916from kafka .producer .kafka import KafkaProducer
1017from kafka .protocol .produce import ProduceRequest
1118from kafka .producer .record_accumulator import RecordAccumulator , ProducerBatch
1219from kafka .producer .sender import Sender
20+ from kafka .producer .transaction_state import TransactionState
1321from kafka .record .memory_records import MemoryRecordsBuilder
1422from kafka .structs import TopicPartition
1523
@@ -20,8 +28,18 @@ def accumulator():
2028
2129
2230@pytest .fixture
23- def sender (client , accumulator , metrics , mocker ):
24- return Sender (client , client .cluster , accumulator , metrics = metrics )
31+ def sender (client , accumulator ):
32+ return Sender (client , client .cluster , accumulator )
33+
34+
35+ def producer_batch (topic = 'foo' , partition = 0 , magic = 2 ):
36+ tp = TopicPartition (topic , partition )
37+ records = MemoryRecordsBuilder (
38+ magic = magic , compression_type = 0 , batch_size = 100000 )
39+ batch = ProducerBatch (tp , records )
40+ batch .try_append (0 , None , b'msg' , [])
41+ batch .records .close ()
42+ return batch
2543
2644
2745@pytest .mark .parametrize (("api_version" , "produce_version" ), [
@@ -30,13 +48,169 @@ def sender(client, accumulator, metrics, mocker):
3048 ((0 , 9 ), 1 ),
3149 ((0 , 8 , 0 ), 0 )
3250])
33- def test_produce_request (sender , mocker , api_version , produce_version ):
51+ def test_produce_request (sender , api_version , produce_version ):
3452 sender ._client ._api_versions = BROKER_API_VERSIONS [api_version ]
35- tp = TopicPartition ('foo' , 0 )
3653 magic = KafkaProducer .max_usable_produce_magic (api_version )
37- records = MemoryRecordsBuilder (
38- magic = 1 , compression_type = 0 , batch_size = 100000 )
39- batch = ProducerBatch (tp , records )
40- records .close ()
54+ batch = producer_batch (magic = magic )
4155 produce_request = sender ._produce_request (0 , 0 , 0 , [batch ])
4256 assert isinstance (produce_request , ProduceRequest [produce_version ])
57+
58+
59+ @pytest .mark .parametrize (("api_version" , "produce_version" ), [
60+ ((2 , 1 ), 7 ),
61+ ])
62+ def test_create_produce_requests (sender , api_version , produce_version ):
63+ sender ._client ._api_versions = BROKER_API_VERSIONS [api_version ]
64+ tp = TopicPartition ('foo' , 0 )
65+ magic = KafkaProducer .max_usable_produce_magic (api_version )
66+ batches_by_node = collections .defaultdict (list )
67+ for node in range (3 ):
68+ for _ in range (5 ):
69+ batches_by_node [node ].append (producer_batch (magic = magic ))
70+ produce_requests_by_node = sender ._create_produce_requests (batches_by_node )
71+ assert len (produce_requests_by_node ) == 3
72+ for node in range (3 ):
73+ assert isinstance (produce_requests_by_node [node ], ProduceRequest [produce_version ])
74+
75+
76+ def test_complete_batch_success (sender ):
77+ batch = producer_batch ()
78+ assert not batch .produce_future .is_done
79+
80+ # No error, base_offset 0
81+ sender ._complete_batch (batch , None , 0 , timestamp_ms = 123 , log_start_offset = 456 )
82+ assert batch .is_done
83+ assert batch .produce_future .is_done
84+ assert batch .produce_future .succeeded ()
85+ assert batch .produce_future .value == (0 , 123 , 456 )
86+
87+
88+ def test_complete_batch_transaction (sender ):
89+ sender ._transaction_state = TransactionState ()
90+ batch = producer_batch ()
91+ assert sender ._transaction_state .sequence_number (batch .topic_partition ) == 0
92+ assert sender ._transaction_state .producer_id_and_epoch .producer_id == batch .producer_id
93+
94+ # No error, base_offset 0
95+ sender ._complete_batch (batch , None , 0 )
96+ assert batch .is_done
97+ assert sender ._transaction_state .sequence_number (batch .topic_partition ) == batch .record_count
98+
99+
100+ @pytest .mark .parametrize (("error" , "refresh_metadata" ), [
101+ (Errors .KafkaConnectionError , True ),
102+ (Errors .CorruptRecordError , False ),
103+ (Errors .UnknownTopicOrPartitionError , True ),
104+ (Errors .NotLeaderForPartitionError , True ),
105+ (Errors .MessageSizeTooLargeError , False ),
106+ (Errors .InvalidTopicError , False ),
107+ (Errors .RecordListTooLargeError , False ),
108+ (Errors .NotEnoughReplicasError , False ),
109+ (Errors .NotEnoughReplicasAfterAppendError , False ),
110+ (Errors .InvalidRequiredAcksError , False ),
111+ (Errors .TopicAuthorizationFailedError , False ),
112+ (Errors .UnsupportedForMessageFormatError , False ),
113+ (Errors .InvalidProducerEpochError , False ),
114+ (Errors .ClusterAuthorizationFailedError , False ),
115+ (Errors .TransactionalIdAuthorizationFailedError , False ),
116+ ])
117+ def test_complete_batch_error (sender , error , refresh_metadata ):
118+ sender ._client .cluster ._last_successful_refresh_ms = (time .time () - 10 ) * 1000
119+ sender ._client .cluster ._need_update = False
120+ assert sender ._client .cluster .ttl () > 0
121+ batch = producer_batch ()
122+ sender ._complete_batch (batch , error , - 1 )
123+ if refresh_metadata :
124+ assert sender ._client .cluster .ttl () == 0
125+ else :
126+ assert sender ._client .cluster .ttl () > 0
127+ assert batch .is_done
128+ assert batch .produce_future .failed ()
129+ assert isinstance (batch .produce_future .exception , error )
130+
131+
132+ @pytest .mark .parametrize (("error" , "retry" ), [
133+ (Errors .KafkaConnectionError , True ),
134+ (Errors .CorruptRecordError , False ),
135+ (Errors .UnknownTopicOrPartitionError , True ),
136+ (Errors .NotLeaderForPartitionError , True ),
137+ (Errors .MessageSizeTooLargeError , False ),
138+ (Errors .InvalidTopicError , False ),
139+ (Errors .RecordListTooLargeError , False ),
140+ (Errors .NotEnoughReplicasError , True ),
141+ (Errors .NotEnoughReplicasAfterAppendError , True ),
142+ (Errors .InvalidRequiredAcksError , False ),
143+ (Errors .TopicAuthorizationFailedError , False ),
144+ (Errors .UnsupportedForMessageFormatError , False ),
145+ (Errors .InvalidProducerEpochError , False ),
146+ (Errors .ClusterAuthorizationFailedError , False ),
147+ (Errors .TransactionalIdAuthorizationFailedError , False ),
148+ ])
149+ def test_complete_batch_retry (sender , accumulator , mocker , error , retry ):
150+ sender .config ['retries' ] = 1
151+ mocker .spy (sender , '_fail_batch' )
152+ mocker .patch .object (accumulator , 'reenqueue' )
153+ batch = producer_batch ()
154+ sender ._complete_batch (batch , error , - 1 )
155+ if retry :
156+ assert not batch .is_done
157+ accumulator .reenqueue .assert_called_with (batch )
158+ batch .attempts += 1 # normally handled by accumulator.reenqueue, but it's mocked
159+ sender ._complete_batch (batch , error , - 1 )
160+ assert batch .is_done
161+ assert isinstance (batch .produce_future .exception , error )
162+ else :
163+ assert batch .is_done
164+ assert isinstance (batch .produce_future .exception , error )
165+
166+
167+ def test_complete_batch_producer_id_changed_no_retry (sender , accumulator , mocker ):
168+ sender ._transaction_state = TransactionState ()
169+ sender .config ['retries' ] = 1
170+ mocker .spy (sender , '_fail_batch' )
171+ mocker .patch .object (accumulator , 'reenqueue' )
172+ error = Errors .NotLeaderForPartitionError
173+ batch = producer_batch ()
174+ sender ._complete_batch (batch , error , - 1 )
175+ assert not batch .is_done
176+ accumulator .reenqueue .assert_called_with (batch )
177+ batch .records ._producer_id = 123 # simulate different producer_id
178+ assert batch .producer_id != sender ._transaction_state .producer_id_and_epoch .producer_id
179+ sender ._complete_batch (batch , error , - 1 )
180+ assert batch .is_done
181+ assert isinstance (batch .produce_future .exception , error )
182+
183+
184+ def test_fail_batch (sender , accumulator , mocker ):
185+ sender ._transaction_state = TransactionState ()
186+ mocker .patch .object (TransactionState , 'reset_producer_id' )
187+ batch = producer_batch ()
188+ mocker .patch .object (batch , 'done' )
189+ assert sender ._transaction_state .producer_id_and_epoch .producer_id == batch .producer_id
190+ error = Exception ('error' )
191+ sender ._fail_batch (batch , base_offset = 0 , timestamp_ms = None , exception = error , log_start_offset = None )
192+ sender ._transaction_state .reset_producer_id .assert_called_once ()
193+ batch .done .assert_called_with (base_offset = 0 , timestamp_ms = None , exception = error , log_start_offset = None )
194+
195+
196+ def test_handle_produce_response ():
197+ pass
198+
199+
200+ def test_failed_produce (sender , mocker ):
201+ mocker .patch .object (sender , '_complete_batch' )
202+ mock_batches = ['foo' , 'bar' , 'fizzbuzz' ]
203+ sender ._failed_produce (mock_batches , 0 , 'error' )
204+ sender ._complete_batch .assert_has_calls ([
205+ call ('foo' , 'error' , - 1 ),
206+ call ('bar' , 'error' , - 1 ),
207+ call ('fizzbuzz' , 'error' , - 1 ),
208+ ])
209+
210+
211+ def test_maybe_wait_for_producer_id ():
212+ pass
213+
214+
215+ def test_run_once ():
216+ pass
0 commit comments