Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions kafka/net/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,9 @@ def sasl_enabled(self):

async def _sasl_authenticate(self):
# Step 1: SaslHandshake to negotiate mechanism
version = self.broker_version_data.api_version(SaslHandshakeRequest, max_version=1)
request = SaslHandshakeRequest[version](self.config['sasl_mechanism'])
request = SaslHandshakeRequest(
mechanism=self.config['sasl_mechanism'],
max_version=1)
try:
response = await self._send_request(request)
except Exception as exc:
Expand All @@ -415,6 +416,7 @@ async def _sasl_authenticate(self):
return

# Step 2: SASL authentication exchange
version = response.API_VERSION
try:
mechanism = get_sasl_mechanism(self.config['sasl_mechanism'])(
host=self.transport.getPeer()[0], **self.config)
Expand All @@ -425,7 +427,7 @@ async def _sasl_authenticate(self):
while not mechanism.is_done():
token = mechanism.auth_bytes()
if version == 1:
auth_request = SaslAuthenticateRequest[0](token)
auth_request = SaslAuthenticateRequest(token, version=0)
else:
auth_request = SaslBytesRequest(token)

Expand Down
46 changes: 22 additions & 24 deletions kafka/producer/sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,34 +668,32 @@ def _produce_request(self, node_id, acks, timeout, batches):
Returns:
ProduceRequest (version depends on client api_versions)
"""
produce_records_by_partition = collections.defaultdict(dict)
max_version = 9
min_version = 0
Topic = ProduceRequest.TopicProduceData
Partition = Topic.PartitionProduceData
topic_data = collections.defaultdict(list)
for batch in batches:
topic = batch.topic_partition.topic
partition = batch.topic_partition.partition

buf = batch.records.buffer()
produce_records_by_partition[topic][partition] = buf
partition = Partition(
index=batch.topic_partition.partition,
records=batch.records.buffer(),
)
topic_data[topic].append(partition)

version = self._client.api_version(ProduceRequest, max_version=8)
topic_partition_data = [
(topic, list(partition_info.items()))
for topic, partition_info in produce_records_by_partition.items()]
transactional_id = self._transaction_manager.transactional_id if self._transaction_manager else None
if version >= 3:
return ProduceRequest[version](
transactional_id=transactional_id,
acks=acks,
timeout_ms=timeout,
topic_data=topic_partition_data,
)
else:
if transactional_id is not None:
log.warning('%s: Broker does not support ProduceRequest v3+, required for transactional_id', str(self))
return ProduceRequest[version](
acks=acks,
timeout_ms=timeout,
topic_data=topic_partition_data,
)
if transactional_id is not None:
min_version = 3

return ProduceRequest(
transactional_id=transactional_id,
acks=acks,
timeout_ms=timeout,
topic_data=[Topic(name=topic, partition_data=partitions)
for topic, partitions in topic_data.items()],
min_version=min_version,
max_version=max_version,
)

def wakeup(self):
"""Wake up the selector associated with this send thread."""
Expand Down
85 changes: 68 additions & 17 deletions test/producer/test_sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from kafka.cluster import ClusterMetadata
import kafka.errors as Errors
from kafka.protocol.broker_version_data import BrokerVersionData
from kafka.producer.kafka import KafkaProducer
from kafka.protocol.producer import ProduceRequest
from kafka.producer.future import FutureRecordMetadata
Expand Down Expand Up @@ -95,37 +94,89 @@ def transaction_manager(cluster):
metadata=cluster)


@pytest.mark.parametrize(("api_version", "produce_version"), [
def _capture(captured):
"""Build a respond_fn that records the negotiated wire api_version."""
def fn(api_key, api_version, correlation_id, request_bytes):
captured['api_version'] = api_version
# acks=1 expects a response; an empty topic list is valid at every version.
return ProduceResponse(throttle_time_ms=0, responses=[])
return fn


@pytest.mark.parametrize("broker, produce_version", [
((2, 1), 7),
((0, 10, 0), 2),
((0, 9), 1),
((0, 8, 0), 0)
])
def test_produce_request(sender, api_version, produce_version):
sender._client._manager.broker_version_data = BrokerVersionData(api_version)
magic = KafkaProducer.max_usable_produce_magic(api_version)
((0, 8, 0), 0),
], indirect=['broker'])
def test_produce_request_negotiates_wire_version(sender, broker, manager, produce_version):
"""``Sender._produce_request`` returns a ProduceRequest with no fixed
version; the connection negotiates the wire version against the broker's
api_versions table at send time. We verify by capturing the api_version
that arrives at the broker."""
# Bootstrap so cluster metadata knows about the MockBroker node.
manager.bootstrap(timeout_ms=5000)

magic = KafkaProducer.max_usable_produce_magic(broker.broker_version)
batch = producer_batch(magic=magic)
produce_request = sender._produce_request(0, 0, 0, [batch])
produce_request = sender._produce_request(0, 1, 0, [batch]) # acks=1
assert isinstance(produce_request, ProduceRequest)
assert produce_request.version == produce_version
# Version is not pinned at construction — that's the whole point.
assert produce_request.API_VERSION is None

captured = {}
broker.respond_fn(ProduceResponse, _capture(captured))

future = manager.send(produce_request, node_id=0)
manager.run(manager.wait_for, future, 5000)

@pytest.mark.parametrize(("api_version", "produce_version"), [
assert captured['api_version'] == produce_version


@pytest.mark.parametrize("broker, produce_version", [
((2, 1), 7),
])
def test_create_produce_requests(sender, api_version, produce_version):
sender._client._manager.broker_version_data = BrokerVersionData(api_version)
tp = TopicPartition('foo', 0)
magic = KafkaProducer.max_usable_produce_magic(api_version)
((0, 10, 0), 2),
((0, 9), 1),
((0, 8, 0), 0),
], indirect=['broker'])
def test_create_produce_requests_negotiates_wire_version(
sender, broker, manager, produce_version):
"""``_create_produce_requests`` builds one ProduceRequest per node;
each one negotiates independently against its broker's api_versions
table. We send each through the MockBroker (all routed to the single
MockBroker node via shared metadata) and assert each arrived at the
expected wire version."""
# Advertise three broker entries (all pointing at this single MockBroker)
# so ``manager.send(..., node_id=n)`` resolves for nodes 1 and 2 as well.
# Must happen *before* bootstrap so the metadata response carries them.
from kafka.protocol.metadata import MetadataResponse
Broker = MetadataResponse.MetadataResponseBroker
broker.set_metadata(brokers=[
Broker(node_id=n, host=broker.host, port=broker.port, rack=None)
for n in range(3)
])
manager.bootstrap(timeout_ms=5000)

magic = KafkaProducer.max_usable_produce_magic(broker.broker_version)
batches_by_node = collections.defaultdict(list)
for node in range(3):
for _ in range(5):
batches_by_node[node].append(producer_batch(magic=magic))
produce_requests_by_node = sender._create_produce_requests(batches_by_node)
assert len(produce_requests_by_node) == 3

for node in range(3):
assert isinstance(produce_requests_by_node[node], ProduceRequest)
assert produce_requests_by_node[node].version == produce_version
request = produce_requests_by_node[node]
assert isinstance(request, ProduceRequest)
assert request.API_VERSION is None

captured = {}
broker.respond_fn(ProduceResponse, _capture(captured))
future = manager.send(request, node_id=node)
manager.run(manager.wait_for, future, 5000)
assert captured['api_version'] == produce_version, (
'node %d: expected v%d got v%s'
% (node, produce_version, captured.get('api_version')))


def test_complete_batch_success(sender):
Expand Down
6 changes: 2 additions & 4 deletions test/test_mock_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,7 @@ def test_send_and_receive(self):
client.await_ready(node_id, timeout_ms=5000)

# Send a MetadataRequest directly via client.send
version = client.api_version(MetadataRequest, max_version=8)
request = MetadataRequest[version]()
request = MetadataRequest(max_version=9)
future = client.send(node_id, request)
_poll_for_future(client, future)

Expand All @@ -251,8 +250,7 @@ def test_fail_next_aborts_request(self):
error = Errors.KafkaConnectionError('simulated')
broker.fail_next(MetadataRequest, error=error)

version = client.api_version(MetadataRequest, max_version=8)
future = client.send(node_id, MetadataRequest[version]())
future = client.send(node_id, MetadataRequest(max_version=9))
_poll_for_future(client, future)

assert future.failed()
Expand Down
Loading