Source code for faust.transport.drivers.aiokafka

"""Message transport using :pypi:`aiokafka`."""
import asyncio
import typing

from collections import deque
from typing import (
    Any,
    Awaitable,
    Callable,
    ClassVar,
    Iterable,
    List,
    Mapping,
    MutableMapping,
    Optional,
    Set,
    Tuple,
    Type,
    cast,
    no_type_check,
)

import aiokafka
import aiokafka.abc
import opentracing
from aiokafka.consumer.group_coordinator import OffsetCommitRequest
from aiokafka.errors import (
    CommitFailedError,
    ConsumerStoppedError,
    IllegalStateError,
    KafkaError,
)
from aiokafka.structs import (
    OffsetAndMetadata,
    TopicPartition as _TopicPartition,
)
from aiokafka.util import parse_kafka_version
from kafka.errors import (
    NotControllerError,
    TopicAlreadyExistsError as TopicExistsError,
    for_code,
)
from kafka.partitioner.default import DefaultPartitioner
from kafka.partitioner.hashed import murmur2
from kafka.protocol.metadata import MetadataRequest_v1
from mode import Service, get_logger
from mode.utils.futures import StampedeWrapper
from mode.utils.objects import cached_property
from mode.utils.times import Seconds, want_seconds
from mode.utils.typing import Deque
from opentracing.ext import tags
from yarl import URL

from faust.auth import (
    GSSAPICredentials,
    SASLCredentials,
    SSLCredentials,
)
from faust.exceptions import (
    ConsumerNotStarted,
    ImproperlyConfigured,
    NotReady,
    ProducerSendError,
)
from faust.transport import base
from faust.transport.consumer import (
    ConsumerThread,
    RecordMap,
    ThreadDelegateConsumer,
    ensure_TPset,
)
from faust.types import ConsumerMessage, HeadersArg, RecordMetadata, TP
from faust.types.auth import CredentialsT
from faust.types.transports import (
    ConsumerT,
    PartitionerT,
    ProducerT,
)
from faust.utils.kafka.protocol.admin import CreateTopicsRequest
from faust.utils.tracing import (
    noop_span,
    set_current_span,
    traced_from_parent_span,
)

__all__ = ['Consumer', 'Producer', 'Transport']

if not hasattr(aiokafka, '__robinhood__'):  # pragma: no cover
    raise RuntimeError(
        'Please install robinhood-aiokafka, not aiokafka')

logger = get_logger(__name__)

DEFAULT_GENERATION_ID = OffsetCommitRequest.DEFAULT_GENERATION_ID


def server_list(urls: List[URL], default_port: int) -> List[str]:
    """Convert list of urls to list of servers accepted by :pypi:`aiokafka`."""
    default_host = '127.0.0.1'
    return [f'{u.host or default_host}:{u.port or default_port}' for u in urls]


class ConsumerRebalanceListener(aiokafka.abc.ConsumerRebalanceListener):
    # kafka's ridiculous class based callback interface makes this hacky.

    def __init__(self, thread: ConsumerThread) -> None:
        self._thread: ConsumerThread = thread

    def on_partitions_revoked(
            self, revoked: Iterable[_TopicPartition]) -> Awaitable:
        """Call when partitions are being revoked."""
        thread = self._thread
        # XXX Must call app.on_rebalance_start as early as possible.
        # we call this in the sync method, this way when we know
        # that it will be called even if await never returns to the coroutine.
        thread.app.on_rebalance_start()

        # this way we should also get a warning if the coroutine
        # is never awaited.
        return thread.on_partitions_revoked(ensure_TPset(revoked))

    async def on_partitions_assigned(
            self, assigned: Iterable[_TopicPartition]) -> None:
        """Call when partitions are being assigned."""
        await self._thread.on_partitions_assigned(ensure_TPset(assigned))


[docs]class Consumer(ThreadDelegateConsumer): """Kafka consumer using :pypi:`aiokafka`.""" logger = logger RebalanceListener: ClassVar[Type[ConsumerRebalanceListener]] RebalanceListener = ConsumerRebalanceListener consumer_stopped_errors: ClassVar[Tuple[Type[BaseException], ...]] = ( ConsumerStoppedError, ) def _new_consumer_thread(self) -> ConsumerThread: return AIOKafkaConsumerThread(self, loop=self.loop, beacon=self.beacon)
[docs] async def create_topic(self, topic: str, partitions: int, replication: int, *, config: Mapping[str, Any] = None, timeout: Seconds = 30.0, retention: Seconds = None, compacting: bool = None, deleting: bool = None, ensure_created: bool = False) -> None: """Create/declare topic on server.""" await self._thread.create_topic( topic, partitions, replication, config=config, timeout=timeout, retention=retention, compacting=compacting, deleting=deleting, ensure_created=ensure_created, )
def _new_topicpartition(self, topic: str, partition: int) -> TP: return cast(TP, _TopicPartition(topic, partition)) def _to_message(self, tp: TP, record: Any) -> ConsumerMessage: timestamp: Optional[int] = record.timestamp timestamp_s: float = cast(float, None) if timestamp is not None: timestamp_s = timestamp / 1000.0 return ConsumerMessage( record.topic, record.partition, record.offset, timestamp_s, record.timestamp_type, record.headers, record.key, record.value, record.checksum, record.serialized_key_size, record.serialized_value_size, tp, )
[docs] async def on_stop(self) -> None: """Call when consumer is stopping.""" await super().on_stop() transport = cast(Transport, self.transport) transport._topic_waiters.clear()
class AIOKafkaConsumerThread(ConsumerThread): _consumer: Optional[aiokafka.AIOKafkaConsumer] = None _pending_rebalancing_spans: Deque[opentracing.Span] def __post_init__(self) -> None: self._partitioner: PartitionerT = ( self.app.conf.producer_partitioner or DefaultPartitioner()) self._rebalance_listener = self.consumer.RebalanceListener(self) self._pending_rebalancing_spans = deque() async def on_start(self) -> None: """Call when consumer starts.""" self._consumer = self._create_consumer(loop=self.thread_loop) await self._consumer.start() async def on_thread_stop(self) -> None: """Call when consumer thread is stopping.""" if self._consumer is not None: await self._consumer.stop() def _create_consumer( self, loop: asyncio.AbstractEventLoop) -> aiokafka.AIOKafkaConsumer: transport = cast(Transport, self.transport) if self.app.client_only: return self._create_client_consumer(transport, loop=loop) else: return self._create_worker_consumer(transport, loop=loop) def _create_worker_consumer( self, transport: 'Transport', loop: asyncio.AbstractEventLoop) -> aiokafka.AIOKafkaConsumer: isolation_level: str = 'read_uncommitted' conf = self.app.conf if self.consumer.in_transaction: isolation_level = 'read_committed' self._assignor = self.app.assignor auth_settings = credentials_to_aiokafka_auth( conf.broker_credentials, conf.ssl_context) max_poll_interval = conf.broker_max_poll_interval or 0 request_timeout = conf.broker_request_timeout session_timeout = conf.broker_session_timeout if session_timeout > request_timeout: raise ImproperlyConfigured( f'Setting broker_session_timeout={session_timeout} ' f'cannot be greater than ' f'broker_request_timeout={request_timeout}') return aiokafka.AIOKafkaConsumer( loop=loop, client_id=conf.broker_client_id, group_id=conf.id, bootstrap_servers=server_list( transport.url, transport.default_port), partition_assignment_strategy=[self._assignor], enable_auto_commit=False, auto_offset_reset=conf.consumer_auto_offset_reset, max_poll_records=conf.broker_max_poll_records, max_poll_interval_ms=int(max_poll_interval * 1000.0), max_partition_fetch_bytes=conf.consumer_max_fetch_size, fetch_max_wait_ms=1500, request_timeout_ms=int(request_timeout * 1000.0), check_crcs=conf.broker_check_crcs, session_timeout_ms=int(session_timeout * 1000.0), heartbeat_interval_ms=int(conf.broker_heartbeat_interval * 1000.0), isolation_level=isolation_level, traced_from_parent_span=self.traced_from_parent_span, start_rebalancing_span=self.start_rebalancing_span, start_coordinator_span=self.start_coordinator_span, on_generation_id_known=self.on_generation_id_known, flush_spans=self.flush_spans, **auth_settings, ) def _create_client_consumer( self, transport: 'Transport', loop: asyncio.AbstractEventLoop) -> aiokafka.AIOKafkaConsumer: conf = self.app.conf auth_settings = credentials_to_aiokafka_auth( conf.broker_credentials, conf.ssl_context) max_poll_interval = conf.broker_max_poll_interval or 0 return aiokafka.AIOKafkaConsumer( loop=loop, client_id=conf.broker_client_id, bootstrap_servers=server_list( transport.url, transport.default_port), request_timeout_ms=int(conf.broker_request_timeout * 1000.0), enable_auto_commit=True, max_poll_records=conf.broker_max_poll_records, max_poll_interval_ms=int(max_poll_interval * 1000.0), auto_offset_reset=conf.consumer_auto_offset_reset, check_crcs=conf.broker_check_crcs, **auth_settings, ) @cached_property def trace_category(self) -> str: return f'{self.app.conf.name}-_aiokafka' def start_rebalancing_span(self) -> opentracing.Span: return self._start_span('rebalancing', lazy=True) def start_coordinator_span(self) -> opentracing.Span: return self._start_span('coordinator') def _start_span(self, name: str, *, lazy: bool = False) -> opentracing.Span: tracer = self.app.tracer if tracer is not None: span = tracer.get_tracer(self.trace_category).start_span( operation_name=name, ) span.set_tag(tags.SAMPLING_PRIORITY, 1) self.app._span_add_default_tags(span) set_current_span(span) if lazy: self._transform_span_lazy(span) return span else: return noop_span() @no_type_check def _transform_span_lazy(self, span: opentracing.Span) -> None: # XXX slow consumer = self if typing.TYPE_CHECKING: # MyPy completely disallows the statements below # claiming it is an illegal dynamic baseclass. # We know mypy, but do it anyway :D pass else: cls = span.__class__ class LazySpan(cls): def finish() -> None: consumer._span_finish(span) span._real_finish, span.finish = span.finish, LazySpan.finish def _span_finish(self, span: opentracing.Span) -> None: assert self._consumer is not None if self._consumer._coordinator.generation == DEFAULT_GENERATION_ID: self._on_span_generation_pending(span) else: self._on_span_generation_known(span) def _on_span_generation_pending(self, span: opentracing.Span) -> None: self._pending_rebalancing_spans.append(span) def _on_span_generation_known(self, span: opentracing.Span) -> None: if self._consumer: coordinator = self._consumer._coordinator coordinator_id = coordinator.coordinator_id app_id = self.app.conf.id generation = coordinator.generation member_id = coordinator.member_id try: op_name = span.operation_name set_tag = span.set_tag except AttributeError: # pragma: no cover pass # not a real span else: trace_id_str = f'reb-{app_id}-{generation}' trace_id = murmur2(trace_id_str.encode()) span.context.trace_id = trace_id if op_name.endswith('.REPLACE_WITH_MEMBER_ID'): span.set_operation_name(f'rebalancing node {member_id}') set_tag('kafka_generation', generation) set_tag('kafka_member_id', member_id) set_tag('kafka_coordinator_id', coordinator_id) self.app._span_add_default_tags(span) span._real_finish() def _on_span_cancelled_early(self, span: opentracing.Span) -> None: op_name = span.operation_name span.set_operation_name(f'{op_name} (CANCELLED)') span._real_finish() def traced_from_parent_span(self, parent_span: opentracing.Span, lazy: bool = False, **extra_context: Any) -> Callable: return traced_from_parent_span( parent_span, callback=self._transform_span_lazy if lazy else None, **extra_context) def flush_spans(self) -> None: while self._pending_rebalancing_spans: span = self._pending_rebalancing_spans.popleft() self._on_span_cancelled_early(span) def on_generation_id_known(self) -> None: while self._pending_rebalancing_spans: span = self._pending_rebalancing_spans.popleft() self._on_span_generation_known(span) def close(self) -> None: """Close consumer for graceful shutdown.""" if self._consumer is not None: self._consumer.set_close() self._consumer._coordinator.set_close() async def subscribe(self, topics: Iterable[str]) -> None: """Reset subscription (requires rebalance).""" # XXX pattern does not work :/ await self.call_thread( self._ensure_consumer().subscribe, topics=set(topics), listener=self._rebalance_listener, ) async def seek_to_committed(self) -> Mapping[TP, int]: """Seek partitions to the last committed offset.""" return await self.call_thread( self._ensure_consumer().seek_to_committed) async def commit(self, offsets: Mapping[TP, int]) -> bool: """Commit topic offsets.""" return await self.call_thread(self._commit, offsets) async def _commit(self, offsets: Mapping[TP, int]) -> bool: consumer = self._ensure_consumer() try: aiokafka_offsets = { tp: OffsetAndMetadata(offset, '') for tp, offset in offsets.items() } await consumer.commit(aiokafka_offsets) except CommitFailedError as exc: if 'already rebalanced' in str(exc): return False self.log.exception(f'Committing raised exception: %r', exc) await self.crash(exc) return False except IllegalStateError as exc: self.log.exception( 'Got exception: %r\nCurrent assignment: %r', exc, self.assignment()) await self.crash(exc) return False return True async def position(self, tp: TP) -> Optional[int]: """Return the current position for topic partition.""" return await self.call_thread( self._ensure_consumer().position, tp) async def seek_to_beginning(self, *partitions: _TopicPartition) -> None: """Seek list of offsets to the first available offset.""" await self.call_thread( self._ensure_consumer().seek_to_beginning, *partitions) async def seek_wait(self, partitions: Mapping[TP, int]) -> None: """Seek partitions to specific offset and wait for operation.""" consumer = self._ensure_consumer() await self.call_thread(self._seek_wait, consumer, partitions) async def _seek_wait(self, consumer: Consumer, partitions: Mapping[TP, int]) -> None: for tp, offset in partitions.items(): self.log.dev('SEEK %r -> %r', tp, offset) consumer.seek(tp, offset) if offset > 0: self.consumer._read_offset[tp] = offset await asyncio.gather(*[ consumer.position(tp) for tp in partitions ]) def seek(self, partition: TP, offset: int) -> None: """Seek partition to specific offset.""" self._ensure_consumer().seek(partition, offset) def assignment(self) -> Set[TP]: """Return the current assignment.""" return ensure_TPset(self._ensure_consumer().assignment()) def highwater(self, tp: TP) -> int: """Return the last offset in a specific partition.""" if self.consumer.in_transaction: return self._ensure_consumer().last_stable_offset(tp) else: return self._ensure_consumer().highwater(tp) def topic_partitions(self, topic: str) -> Optional[int]: """Return the number of partitions configured for topic by name.""" if self._consumer is not None: return self._consumer._coordinator._metadata_snapshot.get(topic) return None async def earliest_offsets(self, *partitions: TP) -> Mapping[TP, int]: """Return the earliest offsets for a list of partitions.""" return await self.call_thread( self._ensure_consumer().beginning_offsets, partitions) async def highwaters(self, *partitions: TP) -> Mapping[TP, int]: """Return the last offsets for a list of partitions.""" return await self.call_thread(self._highwaters, partitions) async def _highwaters(self, partitions: List[TP]) -> Mapping[TP, int]: consumer = self._ensure_consumer() if self.consumer.in_transaction: return { tp: consumer.last_stable_offset(tp) for tp in partitions } else: return cast(Mapping[TP, int], await consumer.end_offsets(partitions)) def _ensure_consumer(self) -> aiokafka.AIOKafkaConsumer: if self._consumer is None: raise ConsumerNotStarted('Consumer thread not yet started') return self._consumer async def getmany(self, active_partitions: Optional[Set[TP]], timeout: float) -> RecordMap: """Fetch batch of messages from server.""" # Implementation for the Fetcher service. _consumer = self._ensure_consumer() # NOTE: Since we are enqueing the fetch request, # we need to check when dequeued that we are not in a rebalancing # state at that point to return early, or we # will create a deadlock (fetch request starts after flow stopped) return await self.call_thread( self._fetch_records, _consumer, active_partitions, timeout=timeout, max_records=_consumer._max_poll_records, ) async def _fetch_records(self, consumer: aiokafka.AIOKafkaConsumer, active_partitions: Set[TP], timeout: float = None, max_records: int = None) -> RecordMap: if not self.consumer.flow_active: return {} fetcher = consumer._fetcher if consumer._closed or fetcher._closed: raise ConsumerStoppedError() with fetcher._subscriptions.fetch_context(): return await fetcher.fetched_records( active_partitions, timeout=timeout, max_records=max_records, ) async def create_topic(self, topic: str, partitions: int, replication: int, *, config: Mapping[str, Any] = None, timeout: Seconds = 30.0, retention: Seconds = None, compacting: bool = None, deleting: bool = None, ensure_created: bool = False) -> None: """Create/declare topic on server.""" transport = cast(Transport, self.consumer.transport) _consumer = self._ensure_consumer() _retention = (int(want_seconds(retention) * 1000.0) if retention else None) await self.call_thread( transport._create_topic, self, _consumer._client, topic, partitions, replication, config=config, timeout=int(want_seconds(timeout) * 1000.0), retention=_retention, compacting=compacting, deleting=deleting, ensure_created=ensure_created, ) def key_partition(self, topic: str, key: Optional[bytes], partition: int = None) -> Optional[int]: """Hash key to determine partition destination.""" consumer = self._ensure_consumer() metadata = consumer._client.cluster partitions_for_topic = metadata.partitions_for_topic(topic) if partitions_for_topic is None: return None if partition is not None: assert partition >= 0 assert partition in partitions_for_topic, \ 'Unrecognized partition' return partition all_partitions = list(partitions_for_topic) available = list(metadata.available_partitions_for_topic(topic)) return self._partitioner(key, all_partitions, available)
[docs]class Producer(base.Producer): """Kafka producer using :pypi:`aiokafka`.""" logger = logger allow_headers: bool = True _producer: Optional[aiokafka.AIOKafkaProducer] = None def __post_init__(self) -> None: self._send_on_produce_message = self.app.on_produce_message.send if self.partitioner is None: self.partitioner = DefaultPartitioner() if self._api_version != 'auto': wanted_api_version = parse_kafka_version(self._api_version) if wanted_api_version < (0, 11): self.allow_headers = False def _settings_default(self) -> Mapping[str, Any]: transport = cast(Transport, self.transport) return { 'bootstrap_servers': server_list( transport.url, transport.default_port), 'client_id': self.client_id, 'acks': self.acks, 'linger_ms': self.linger_ms, 'max_batch_size': self.max_batch_size, 'max_request_size': self.max_request_size, 'compression_type': self.compression_type, 'on_irrecoverable_error': self._on_irrecoverable_error, 'security_protocol': 'SSL' if self.ssl_context else 'PLAINTEXT', 'partitioner': self.partitioner, 'request_timeout_ms': int(self.request_timeout * 1000), 'api_version': self._api_version, } def _settings_auth(self) -> Mapping[str, Any]: return credentials_to_aiokafka_auth( self.credentials, self.ssl_context)
[docs] async def begin_transaction(self, transactional_id: str) -> None: """Begin transaction by id.""" await self._ensure_producer().begin_transaction(transactional_id)
[docs] async def commit_transaction(self, transactional_id: str) -> None: """Commit transaction by id.""" await self._ensure_producer().commit_transaction(transactional_id)
[docs] async def abort_transaction(self, transactional_id: str) -> None: """Abort and rollback transaction by id.""" await self._ensure_producer().abort_transaction(transactional_id)
[docs] async def stop_transaction(self, transactional_id: str) -> None: """Stop transaction by id.""" await self._ensure_producer().stop_transaction(transactional_id)
[docs] async def maybe_begin_transaction(self, transactional_id: str) -> None: """Begin transaction (if one does not already exist).""" await self._ensure_producer().maybe_begin_transaction(transactional_id)
[docs] async def commit_transactions( self, tid_to_offset_map: Mapping[str, Mapping[TP, int]], group_id: str, start_new_transaction: bool = True) -> None: """Commit transactions.""" await self._ensure_producer().commit( tid_to_offset_map, group_id, start_new_transaction=start_new_transaction, )
def _settings_extra(self) -> Mapping[str, Any]: if self.app.in_transaction: return {'acks': 'all'} return {} def _new_producer(self) -> aiokafka.AIOKafkaProducer: return self._producer_type( loop=self.loop, **{**self._settings_default(), **self._settings_auth(), **self._settings_extra()}, ) @property def _producer_type(self) -> Type[aiokafka.BaseProducer]: if self.app.in_transaction: return aiokafka.MultiTXNProducer return aiokafka.AIOKafkaProducer async def _on_irrecoverable_error(self, exc: BaseException) -> None: consumer = self.transport.app.consumer if consumer is not None: # pragma: no cover # coverage executes this line, but does not mark as covered. await consumer.crash(exc) await self.crash(exc)
[docs] async def create_topic(self, topic: str, partitions: int, replication: int, *, config: Mapping[str, Any] = None, timeout: Seconds = 20.0, retention: Seconds = None, compacting: bool = None, deleting: bool = None, ensure_created: bool = False) -> None: """Create/declare topic on server.""" _retention = (int(want_seconds(retention) * 1000.0) if retention else None) producer = self._ensure_producer() await cast(Transport, self.transport)._create_topic( self, producer.client, topic, partitions, replication, config=config, timeout=int(want_seconds(timeout) * 1000.0), retention=_retention, compacting=compacting, deleting=deleting, ensure_created=ensure_created, )
def _ensure_producer(self) -> aiokafka.BaseProducer: if self._producer is None: raise NotReady('Producer service not yet started') return self._producer
[docs] async def on_start(self) -> None: """Call when producer starts.""" await super().on_start() producer = self._producer = self._new_producer() self.beacon.add(producer) await producer.start()
[docs] async def on_stop(self) -> None: """Call when producer stops.""" await super().on_stop() cast(Transport, self.transport)._topic_waiters.clear() producer, self._producer = self._producer, None if producer is not None: await producer.stop()
[docs] async def send(self, topic: str, key: Optional[bytes], value: Optional[bytes], partition: Optional[int], timestamp: Optional[float], headers: Optional[HeadersArg], *, transactional_id: str = None) -> Awaitable[RecordMetadata]: """Schedule message to be transmitted by producer.""" producer = self._ensure_producer() if headers is not None: if isinstance(headers, Mapping): headers = list(headers.items()) self._send_on_produce_message( key=key, value=value, partition=partition, timestamp=timestamp, headers=headers, ) if headers is not None and not self.allow_headers: headers = None timestamp_ms = int(timestamp * 1000.0) if timestamp else timestamp try: return cast(Awaitable[RecordMetadata], await producer.send( topic, value, key=key, partition=partition, timestamp_ms=timestamp_ms, headers=headers, transactional_id=transactional_id, )) except KafkaError as exc: raise ProducerSendError(f'Error while sending: {exc!r}') from exc
[docs] async def send_and_wait(self, topic: str, key: Optional[bytes], value: Optional[bytes], partition: Optional[int], timestamp: Optional[float], headers: Optional[HeadersArg], *, transactional_id: str = None) -> RecordMetadata: """Send message and wait for it to be transmitted.""" fut = await self.send( topic, key=key, value=value, partition=partition, timestamp=timestamp, headers=headers, transactional_id=transactional_id, ) return await fut
[docs] async def flush(self) -> None: """Wait for producer to finish transmitting all buffered messages.""" await self.buffer.flush() if self._producer is not None: await self._producer.flush()
[docs] def key_partition(self, topic: str, key: bytes) -> TP: """Hash key to determine partition destination.""" producer = self._ensure_producer() partition = producer._partition( topic, partition=None, key=None, value=None, serialized_key=key, serialized_value=None, ) return TP(topic, partition)
[docs] def supports_headers(self) -> bool: """Return :const:`True` if message headers are supported.""" producer = self._ensure_producer() client = producer.client if client is None: raise NotReady('Producer client not yet connected') return client.api_version >= (0, 11)
[docs]class Transport(base.Transport): """Kafka transport using :pypi:`aiokafka`.""" Consumer: ClassVar[Type[ConsumerT]] Consumer = Consumer Producer: ClassVar[Type[ProducerT]] Producer = Producer default_port = 9092 driver_version = f'aiokafka={aiokafka.__version__}' _topic_waiters: MutableMapping[str, StampedeWrapper] def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._topic_waiters = {} def _topic_config(self, retention: int = None, compacting: bool = None, deleting: bool = None) -> MutableMapping[str, Any]: config: MutableMapping[str, Any] = {} cleanup_flags: Set[str] = set() if compacting: cleanup_flags |= {'compact'} if deleting: cleanup_flags |= {'delete'} if cleanup_flags: config['cleanup.policy'] = ','.join(sorted(cleanup_flags)) if retention: config['retention.ms'] = retention return config async def _create_topic(self, owner: Service, client: aiokafka.AIOKafkaClient, topic: str, partitions: int, replication: int, **kwargs: Any) -> None: assert topic is not None try: wrap = self._topic_waiters[topic] except KeyError: wrap = self._topic_waiters[topic] = StampedeWrapper( self._really_create_topic, owner, client, topic, partitions, replication, loop=asyncio.get_event_loop(), **kwargs) try: await wrap() except Exception: self._topic_waiters.pop(topic, None) raise async def _get_controller_node( self, owner: Service, client: aiokafka.AIOKafkaClient, timeout: int = 30000) -> Optional[int]: # pragma: no cover nodes = [broker.nodeId for broker in client.cluster.brokers()] for node_id in nodes: if node_id is None: raise NotReady('Not connected to Kafka Broker') request = MetadataRequest_v1([]) wait_result = await owner.wait( client.send(node_id, request), timeout=timeout, ) if wait_result.stopped: owner.log.info(f'Shutting down - skipping creation.') return None response = wait_result.result return response.controller_id raise Exception(f'Controller node not found') async def _really_create_topic( self, owner: Service, client: aiokafka.AIOKafkaClient, topic: str, partitions: int, replication: int, *, config: Mapping[str, Any] = None, timeout: int = 30000, retention: int = None, compacting: bool = None, deleting: bool = None, ensure_created: bool = False) -> None: # pragma: no cover owner.log.info('Creating topic %r', topic) if topic in client.cluster.topics(): owner.log.debug('Topic %r exists, skipping creation.', topic) return protocol_version = 1 extra_configs = config or {} config = self._topic_config(retention, compacting, deleting) config.update(extra_configs) controller_node = await self._get_controller_node(owner, client, timeout=timeout) owner.log.debug('Found controller: %r', controller_node) if controller_node is None: if owner.should_stop: owner.log.info('Shutting down hence controller not found') return else: raise Exception('Controller node is None') request = CreateTopicsRequest[protocol_version]( [(topic, partitions, replication, [], list(config.items()))], timeout, False, ) wait_result = await owner.wait( client.send(controller_node, request), timeout=timeout, ) if wait_result.stopped: owner.log.debug(f'Shutting down - skipping creation.') return response = wait_result.result assert len(response.topic_error_codes), 'single topic' _, code, reason = response.topic_error_codes[0] if code != 0: if not ensure_created and code == TopicExistsError.errno: owner.log.debug( 'Topic %r exists, skipping creation.', topic) return elif code == NotControllerError.errno: raise RuntimeError(f'Invalid controller: {controller_node}') else: raise for_code(code)( f'Cannot create topic: {topic} ({code}): {reason}') else: owner.log.info('Topic %r created.', topic) return
def credentials_to_aiokafka_auth(credentials: CredentialsT = None, ssl_context: Any = None) -> Mapping: if credentials is not None: if isinstance(credentials, SSLCredentials): return { 'security_protocol': credentials.protocol.value, 'ssl_context': credentials.context, } elif isinstance(credentials, SASLCredentials): return { 'security_protocol': credentials.protocol.value, 'sasl_mechanism': credentials.mechanism.value, 'sasl_plain_username': credentials.username, 'sasl_plain_password': credentials.password, 'ssl_context': credentials.ssl_context, } elif isinstance(credentials, GSSAPICredentials): return { 'security_protocol': credentials.protocol.value, 'sasl_mechanism': credentials.mechanism.value, 'sasl_kerberos_service_name': credentials.kerberos_service_name, 'sasl_kerberos_domain_name': credentials.kerberos_domain_name, 'ssl_context': credentials.ssl_context, } else: raise ImproperlyConfigured( f'aiokafka does not support {credentials}') elif ssl_context is not None: return { 'security_protocol': 'SSL', 'ssl_context': ssl_context, } else: return {'security_protocol': 'PLAINTEXT'}