Source code for faust.transport.drivers.aiokafka

"""Message transport using :pypi:`aiokafka`."""
import asyncio
from typing import (
    Any,
    AsyncIterator,
    Awaitable,
    ClassVar,
    Dict,
    Iterable,
    Iterator,
    List,
    Mapping,
    MutableMapping,
    Optional,
    Set,
    Tuple,
    Type,
    Union,
    cast,
)

import aiokafka
import aiokafka.abc
from aiokafka.errors import (
    CommitFailedError,
    ConsumerStoppedError,
    IllegalStateError,
    KafkaError,
)
from aiokafka.structs import (
    ConsumerRecord,
    OffsetAndMetadata,
    TopicPartition as _TopicPartition,
)
from rhkafka.errors import (
    NotControllerError,
    TopicAlreadyExistsError as TopicExistsError,
    for_code,
)
from rhkafka.partitioner.default import DefaultPartitioner
from rhkafka.protocol.metadata import MetadataRequest_v1
from mode import Service, flight_recorder, get_logger
from mode.threads import MethodQueue, QueueServiceThread
from mode.utils.compat import OrderedDict
from mode.utils.futures import StampedeWrapper
from mode.utils.locks import Event
from mode.utils.times import Seconds, want_seconds
from yarl import URL

from faust.exceptions import ConsumerNotStarted, ProducerSendError
from faust.transport import base
from faust.types import AppT, ConsumerMessage, Message, RecordMetadata, TP
from faust.types.transports import ConsumerT, ProducerT
from faust.utils import terminal
from faust.utils.kafka.protocol.admin import CreateTopicsRequest

__all__ = ['Consumer', 'Producer', 'Transport']

if not hasattr(aiokafka, '__robinhood__'):
    raise RuntimeError(
        'Please install robinhood-aiokafka, not aiokafka')

# This is what we get from aiokafka getmany()
# A mapping of TP to buffer-list of records.
RecordMap = Mapping[TP, List[ConsumerRecord]]

# But we want to process records from topics in round-robin order.
# We convert records into a mapping from topic-name to "chain-of-buffers":
#   topic_index['topic-name'] = chain(all_topic_partition_buffers)
# This means we can get the next message available in any topic
# by doing: next(topic_index['topic_name'])
TopicIndexMap = MutableMapping[str, '_TopicBuffer']

_TPTypes = Union[TP, _TopicPartition]

logger = get_logger(__name__)


def server_list(urls: List[URL], default_port: int) -> List[str]:
    default_host = '127.0.0.1'
    return [f'{u.host or default_host}:{u.port or default_port}' for u in urls]


def _ensure_TP(tp: _TPTypes) -> TP:
    return tp if isinstance(tp, TP) else TP(tp.topic, tp.partition)


class _TopicBuffer(Iterator):
    _buffers: Dict[TP, Iterator[ConsumerRecord]]
    _it: Optional[Iterator[ConsumerRecord]]

    def __init__(self) -> None:
        # note: this is a regular dict, but ordered on Python 3.6
        # we use this alias to signify it must be ordered.
        self._buffers = OrderedDict()
        # getmany calls next(_TopicBuffer), and does not call iter(),
        # so the first call to next caches an iterator.
        self._it = None

    def add(self, tp: TP, buffer: List[ConsumerRecord]) -> None:
        assert tp not in self._buffers
        self._buffers[tp] = iter(buffer)

    def __iter__(self) -> Iterator[Tuple[TP, ConsumerRecord]]:
        buffers = self._buffers
        buffers_items = buffers.items
        buffers_remove = buffers.pop
        sentinel = object()
        to_remove: Set[TP] = set()
        mark_as_to_remove = to_remove.add
        while buffers:
            for tp in to_remove:
                buffers_remove(tp, None)
            for tp, buffer in buffers_items():
                item = next(buffer, sentinel)
                if item is sentinel:
                    mark_as_to_remove(tp)
                    continue
                yield tp, item

    def __next__(self) -> Tuple[TP, ConsumerRecord]:
        # Note: this method is not in normal iteration
        # as __iter__ returns generator.
        it = self._it
        if it is None:
            it = self._it = iter(self)
        return it.__next__()


class ConsumerRebalanceListener(aiokafka.abc.ConsumerRebalanceListener):
    # kafka's ridiculous class based callback interface makes this hacky.

    def __init__(self, thread: 'ConsumerThread') -> None:
        self._thread: 'ConsumerThread' = thread

    async def on_partitions_revoked(
            self, revoked: Iterable[_TopicPartition]) -> None:
        await self._thread.on_partitions_revoked(revoked)

    async def on_partitions_assigned(
            self, assigned: Iterable[_TopicPartition]) -> None:
        await self._thread.on_partitions_assigned(assigned)


[docs]class Consumer(base.Consumer): """Kafka consumer using :pypi:`aiokafka`.""" logger = logger RebalanceListener: ClassVar[Type[ConsumerRebalanceListener]] RebalanceListener = ConsumerRebalanceListener _thread: 'ConsumerThread' _active_partitions: Optional[Set[_TopicPartition]] _paused_partitions: Set[_TopicPartition] fetch_timeout: float = 10.0 consumer_stopped_errors: ClassVar[Tuple[Type[BaseException], ...]] = ( ConsumerStoppedError, ) flow_active: bool = True can_resume_flow: Event #: Main thread method queue. #: The consumer is running in a separate thread, and so we send #: requests to it via a queue. #: Sometimes the thread needs to call code owned by the main thread, #: such as App.on_partitions_revoked, and in that case the thread #: uses this method queue. _method_queue: MethodQueue
[docs] def on_init(self) -> None: self._active_partitions = None self._paused_partitions = set() self.can_resume_flow = Event() self._method_queue = MethodQueue(loop=self.loop, beacon=self.beacon) self._thread = ConsumerThread( self, loop=self.loop, beacon=self.beacon, )
[docs] async def on_restart(self) -> None: self.on_init()
def _get_active_partitions(self) -> Set[_TopicPartition]: tps = self._active_partitions if tps is None: # need aiokafka._TopicPartition, not faust.TP return self._set_active_tps(self.assignment()) return tps def _set_active_tps(self, tps: Set[_TopicPartition]) -> Set[_TopicPartition]: tps = self._active_partitions = set(tps) # copy! tps.difference_update(self._paused_partitions) return tps def _create_consumer( self, loop: asyncio.AbstractEventLoop) -> aiokafka.AIOKafkaConsumer: transport = cast(Transport, self.transport) if self.app.client_only: return self._create_client_consumer(transport, loop=loop) else: return self._create_worker_consumer(transport, loop=loop) def _create_worker_consumer( self, transport: 'Transport', loop: asyncio.AbstractEventLoop) -> aiokafka.AIOKafkaConsumer: conf = self.app.conf self._assignor = self.app.assignor return aiokafka.AIOKafkaConsumer( loop=loop, client_id=conf.broker_client_id, group_id=conf.id, bootstrap_servers=server_list( transport.url, transport.default_port), partition_assignment_strategy=[self._assignor], enable_auto_commit=False, auto_offset_reset='earliest', max_poll_records=conf.broker_max_poll_records, max_partition_fetch_bytes=conf.consumer_max_fetch_size, fetch_max_wait_ms=1500, request_timeout_ms=int(conf.broker_request_timeout * 1000.0), check_crcs=conf.broker_check_crcs, session_timeout_ms=int(conf.broker_session_timeout * 1000.0), heartbeat_interval_ms=int(conf.broker_heartbeat_interval * 1000.0), security_protocol="SSL" if conf.ssl_context else "PLAINTEXT", ssl_context=conf.ssl_context, ) def _create_client_consumer( self, transport: 'Transport', loop: asyncio.AbstractEventLoop) -> aiokafka.AIOKafkaConsumer: conf = self.app.conf return aiokafka.AIOKafkaConsumer( loop=loop, client_id=conf.broker_client_id, bootstrap_servers=server_list( transport.url, transport.default_port), request_timeout_ms=int(conf.broker_request_timeout * 1000.0), enable_auto_commit=True, max_poll_records=conf.broker_max_poll_records, auto_offset_reset='earliest', check_crcs=conf.broker_check_crcs, security_protocol="SSL" if conf.ssl_context else "PLAINTEXT", ssl_context=conf.ssl_context, )
[docs] async def create_topic(self, topic: str, partitions: int, replication: int, *, config: Mapping[str, Any] = None, timeout: Seconds = 30.0, retention: Seconds = None, compacting: bool = None, deleting: bool = None, ensure_created: bool = False) -> None: await self._thread.create_topic( topic, partitions, replication, config=config, timeout=int(want_seconds(timeout) * 1000.0), retention=int(want_seconds(retention) * 1000.0), compacting=compacting, deleting=deleting, ensure_created=ensure_created, )
[docs] async def on_start(self) -> None: await self.add_runtime_dependency(self._method_queue) await self.add_runtime_dependency(self._thread)
[docs] async def threadsafe_partitions_revoked( self, receiver_loop: asyncio.AbstractEventLoop, revoked: Set[TP]) -> None: promise = await self._method_queue.call( receiver_loop.create_future(), self.on_partitions_revoked, revoked, ) # wait for main-thread to finish processing request await promise
[docs] async def threadsafe_partitions_assigned( self, receiver_loop: asyncio.AbstractEventLoop, assigned: Set[TP]) -> None: promise = await self._method_queue.call( receiver_loop.create_future(), self.on_partitions_assigned, assigned, ) # wait for main-thread to finish processing request await promise
[docs] async def subscribe(self, topics: Iterable[str]) -> None: await self._thread.subscribe(topics=topics)
[docs] async def getmany(self, timeout: float) -> AsyncIterator[Tuple[TP, Message]]: if not self.flow_active: await self.wait(self.can_resume_flow) # Implementation for the Fetcher service. active_partitions = self._get_active_partitions() _next = next records: RecordMap = {} if active_partitions: # Fetch records only if active partitions to avoid the risk of # fetching all partitions in the beginning when none of the # partitions is paused/resumed. records = await self._thread.getmany( active_partitions, timeout=timeout, ) else: # We should still release to the event loop await self.sleep(1) if self.should_stop: return create_message = ConsumerMessage # localize # records' contain mapping from TP to list of messages. # if there are two agents, consuming from topics t1 and t2, # normal order of iteration would be to process each # tp in the dict: # for tp. messages in records.items(): # for message in messages: # yield tp, message # # The problem with this, is if we have prefetched 16k records # for one partition, the other partitions won't even start processing # before those 16k records are completed. # # So we try round-robin between the tps instead: # # iterators: Dict[TP, Iterator] = { # tp: iter(messages) # for tp, messages in records.items() # } # while iterators: # for tp, messages in iterators.items(): # yield tp, next(messages) # # remove from iterators if empty. # # The problem with this implementation is that # the records mapping is ordered by TP, so records.keys() # will look like this: # # TP(topic='bar', partition=0) # TP(topic='bar', partition=1) # TP(topic='bar', partition=2) # TP(topic='bar', partition=3) # TP(topic='foo', partition=0) # TP(topic='foo', partition=1) # TP(topic='foo', partition=2) # TP(topic='foo', partition=3) # # If there are 100 partitions for each topic, # it will process 100 items in the first topic, then 100 items # in the other topic, but even worse if partition counts # vary greatly, t1 has 1000 partitions and t2 # has 1 partition, then t2 will end up being starved most of the time. # # We solve this by going round-robin through each topic. topic_index = self._records_to_topic_index(records, active_partitions) to_remove: Set[str] = set() sentinel = object() while topic_index: if not self.flow_active: break for topic in to_remove: topic_index.pop(topic, None) for topic, messages in topic_index.items(): if not self.flow_active: break item = _next(messages, sentinel) if item is sentinel: # this topic is now empty, # but we cannot remove from dict while iterating over it, # so move that to the outer loop. to_remove.add(topic) continue tp, record = item # type: ignore if tp in active_partitions: highwater_mark = self._thread.highwater(tp) self.app.monitor.track_tp_end_offset(tp, highwater_mark) # convert timestamp to seconds from int milliseconds. timestamp: Optional[int] = record.timestamp timestamp_s: float = cast(float, None) if timestamp is not None: timestamp_s = timestamp / 1000.0 yield tp, create_message( record.topic, record.partition, record.offset, timestamp_s, record.timestamp_type, record.key, record.value, record.checksum, record.serialized_key_size, record.serialized_value_size, tp, )
def _records_to_topic_index( self, records: RecordMap, active_partitions: Set[_TopicPartition]) -> TopicIndexMap: topic_index: TopicIndexMap = {} for tp, messages in records.items(): try: entry = topic_index[tp.topic] except KeyError: entry = topic_index[tp.topic] = _TopicBuffer() entry.add(tp, messages) return topic_index def _new_topicpartition(self, topic: str, partition: int) -> TP: return cast(TP, _TopicPartition(topic, partition)) def _new_offsetandmetadata(self, offset: int, meta: str) -> Any: return OffsetAndMetadata(offset, meta)
[docs] async def on_stop(self) -> None: await super().on_stop() # wait_empty await self.commit() transport = cast(Transport, self.transport) transport._topic_waiters.clear()
[docs] async def perform_seek(self) -> None: read_offset = self._read_offset _committed_offsets = await self._thread.seek_to_committed() committed_offsets = { _ensure_TP(tp): offset for tp, offset in _committed_offsets.items() if offset is not None } read_offset.update({ tp: offset if offset else None for tp, offset in committed_offsets.items() }) self._committed_offset.update(committed_offsets)
async def _commit(self, offsets: Mapping[TP, Tuple[int, str]]) -> bool: table = terminal.logtable( [(str(tp), str(offset), meta) for tp, (offset, meta) in offsets.items()], title='Commit Offsets', headers=['TP', 'Offset', 'Metadata'], ) self.log.dev('COMMITTING OFFSETS:\n%s', table) try: assignment = self.assignment() commitable: Dict[TP, OffsetAndMetadata] = {} revoked: Dict[TP, OffsetAndMetadata] = {} commitable_offsets: Dict[TP, int] = {} for tp, (offset, meta) in offsets.items(): offset_and_metadata = self._new_offsetandmetadata(offset, meta) if tp in assignment: commitable_offsets[tp] = offset commitable[tp] = offset_and_metadata else: revoked[tp] = offset_and_metadata if revoked: self.log.info( 'Discarded commit for revoked partitions that ' 'will be eventually processed again: %r', revoked, ) if not commitable: return False with flight_recorder(self.log, timeout=300.0) as on_timeout: on_timeout.info('+aiokafka_consumer.commit()') await self._thread.commit(commitable) on_timeout.info('-aiokafka._consumer.commit()') self._committed_offset.update(commitable_offsets) self.app.monitor.on_tp_commit(commitable_offsets) self._last_batch = None return True except CommitFailedError as exc: if 'already rebalanced' in str(exc): return False self.log.exception(f'Committing raised exception: %r', exc) await self.crash(exc) return False except IllegalStateError as exc: self.log.exception(f'Got exception: {exc}\n' f'Current assignment: {self.assignment()}') await self.crash(exc) return False
[docs] def stop_flow(self) -> None: self.flow_active = False self.can_resume_flow.clear()
[docs] def resume_flow(self) -> None: self.flow_active = True self.can_resume_flow.set()
[docs] def pause_partitions(self, tps: Iterable[TP]) -> None: tpset = set(tps) self._get_active_partitions().difference_update(tpset) self._paused_partitions.update(tpset)
[docs] def resume_partitions(self, tps: Iterable[TP]) -> None: tpset = set(tps) self._get_active_partitions().update(tps) self._paused_partitions.difference_update(tpset)
[docs] async def position(self, tp: TP) -> Optional[int]: return await self._thread.position(tp)
[docs] async def seek_wait(self, partitions: Mapping[TP, int]) -> None: return await self._thread.seek_wait(partitions)
async def _seek_to_beginning(self, *partitions: TP) -> None: self.log.dev('SEEK TO BEGINNING: %r', partitions) self._read_offset.update((_ensure_TP(tp), None) for tp in partitions) await self._thread.seek_to_beginning(*( self._new_topicpartition(tp.topic, tp.partition) for tp in partitions ))
[docs] async def seek(self, partition: TP, offset: int) -> None: self.log.dev('SEEK %r -> %r', partition, offset) # reset livelock detection self._last_batch = None # set new read offset so we will reread messages self._read_offset[_ensure_TP(partition)] = offset if offset else None self._thread.seek(partition, offset)
[docs] def assignment(self) -> Set[TP]: return self._thread.assignment()
[docs] def highwater(self, tp: TP) -> int: return self._thread.highwater(tp)
[docs] def topic_partitions(self, topic: str) -> Optional[int]: return self._thread.topic_partitions(topic)
[docs] async def earliest_offsets(self, *partitions: TP) -> MutableMapping[TP, int]: return await self._thread.earliest_offsets(*partitions)
[docs] async def highwaters(self, *partitions: TP) -> MutableMapping[TP, int]: return await self._thread.highwaters(*partitions)
[docs] def close(self) -> None: self._thread.close()
class ConsumerThread(QueueServiceThread): app: AppT consumer: Consumer _consumer: Optional[aiokafka.AIOKafkaConsumer] = None def __init__(self, consumer: Consumer, **kwargs: Any) -> None: self.consumer = consumer transport = self.consumer.transport self.app = transport.app self._rebalance_listener = consumer.RebalanceListener(self) super().__init__(**kwargs) async def on_start(self) -> None: self._consumer = self.consumer._create_consumer(loop=self.thread_loop) await self._consumer.start() def close(self) -> None: if self._consumer is not None: self._consumer.set_close() self._consumer._coordinator.set_close() async def on_partitions_revoked( self, revoked: Iterable[_TopicPartition]) -> None: self.consumer.app.on_rebalance_start() # see comment in on_partitions_assigned consumer = self.consumer _revoked = cast(Set[TP], set(revoked)) # remove revoked partitions from active + paused tps. if consumer._active_partitions is not None: consumer._active_partitions.difference_update(_revoked) consumer._paused_partitions.difference_update(_revoked) # start callback chain of assigned callbacks. await consumer.threadsafe_partitions_revoked( self.thread_loop, _revoked) async def on_partitions_assigned( self, assigned: Iterable[_TopicPartition]) -> None: # have to cast to Consumer since ConsumerT interface does not # have this attribute (mypy currently thinks a Callable instance # variable is an instance method). Furthermore we have to cast # the Kafka TopicPartition namedtuples to our description, # that way they are typed and decoupled from the actual client # implementation. consumer = self.consumer _assigned = set(assigned) # remove recently revoked tps from set of paused tps. consumer._paused_partitions.intersection_update(_assigned) # cache set of assigned partitions cast(Set[TP], consumer._set_active_tps(_assigned)) # start callback chain of assigned callbacks. # need to copy set at this point, since we cannot have # the callbacks mutate our active list. consumer._last_batch = None await consumer.threadsafe_partitions_assigned( self.thread_loop, _assigned) async def subscribe(self, topics: Iterable[str]) -> None: # XXX pattern does not work :/ await self.call_thread( self._ensure_consumer().subscribe, topics=set(topics), listener=self._rebalance_listener, ) async def seek_to_committed(self) -> Mapping[TP, int]: return await self.call_thread( self._ensure_consumer().seek_to_committed) async def commit(self, tps: Any) -> Any: return await self.call_thread( self._ensure_consumer().commit, tps) async def position(self, tp: TP) -> Optional[int]: return await self.call_thread( self._ensure_consumer().position, tp) async def seek_to_beginning(self, *partitions: _TopicPartition) -> None: await self.call_thread( self._ensure_consumer().seek_to_beginning, *partitions) async def seek_wait(self, partitions: Mapping[TP, int]) -> None: consumer = self._ensure_consumer() await self.call_thread(self._seek_wait, consumer, partitions) async def _seek_wait(self, consumer: Consumer, partitions: Mapping[TP, int]) -> None: for tp, offset in partitions.items(): self.log.dev('SEEK %r -> %r', tp, offset) consumer.seek(tp, offset) await asyncio.gather(*[ consumer.position(tp) for tp in partitions ]) def seek(self, partition: TP, offset: int) -> None: self._ensure_consumer().seek(partition, offset) def assignment(self) -> Set[TP]: return cast(Set[TP], self._ensure_consumer().assignment()) def highwater(self, tp: TP) -> int: return self._ensure_consumer().highwater(tp) def topic_partitions(self, topic: str) -> Optional[int]: if self._consumer is not None: return self._consumer._coordinator._metadata_snapshot.get(topic) return None async def earliest_offsets(self, *partitions: TP) -> MutableMapping[TP, int]: return await self.call_thread( self._ensure_consumer().beginning_offsets, partitions) async def highwaters(self, *partitions: TP) -> MutableMapping[TP, int]: return await self.call_thread( self._ensure_consumer().end_offsets, partitions) def _ensure_consumer(self) -> aiokafka.AIOKafkaConsumer: if self._consumer is None: raise ConsumerNotStarted('Consumer thread not yet started') return self._consumer async def getmany(self, active_partitions: Set[_TopicPartition], timeout: float) -> RecordMap: # Implementation for the Fetcher service. _consumer = self._ensure_consumer() fetcher = _consumer._fetcher if _consumer._closed or fetcher._closed: raise ConsumerStoppedError() return await self.call_thread( fetcher.fetched_records, active_partitions, timeout=timeout, ) async def create_topic(self, topic: str, partitions: int, replication: int, *, config: Mapping[str, Any] = None, timeout: Seconds = 30.0, retention: Seconds = None, compacting: bool = None, deleting: bool = None, ensure_created: bool = False) -> None: consumer = self.consumer transport = cast(Transport, consumer.transport) _consumer = self._ensure_consumer() await self.call_thread( transport._create_topic, consumer, _consumer._client, topic, partitions, replication, config=config, timeout=int(want_seconds(timeout) * 1000.0), retention=int(want_seconds(retention) * 1000.0), compacting=compacting, deleting=deleting, ensure_created=ensure_created, )
[docs]class Producer(base.Producer): """Kafka producer using :pypi:`aiokafka`.""" logger = logger _producer: aiokafka.AIOKafkaProducer
[docs] def on_init(self) -> None: transport = cast(Transport, self.transport) self._producer = aiokafka.AIOKafkaProducer( loop=self.loop, bootstrap_servers=server_list( transport.url, transport.default_port), client_id=self.client_id, acks=self.acks, linger_ms=self.linger_ms, max_batch_size=self.max_batch_size, max_request_size=self.max_request_size, compression_type=self.compression_type, on_irrecoverable_error=self._on_irrecoverable_error, security_protocol='SSL' if self.ssl_context else 'PLAINTEXT', ssl_context=self.ssl_context, partitioner=self.partitioner or DefaultPartitioner(), request_timeout_ms=int(self.request_timeout * 1000), )
async def _on_irrecoverable_error(self, exc: BaseException) -> None: consumer = self.transport.app.consumer if consumer is not None: await consumer.crash(exc) await self.crash(exc)
[docs] async def on_restart(self) -> None: self.on_init()
[docs] async def create_topic(self, topic: str, partitions: int, replication: int, *, config: Mapping[str, Any] = None, timeout: Seconds = 20.0, retention: Seconds = None, compacting: bool = None, deleting: bool = None, ensure_created: bool = False) -> None: _retention = (int(want_seconds(retention) * 1000.0) if retention else None) await cast(Transport, self.transport)._create_topic( self, self._producer.client, topic, partitions, replication, config=config, timeout=int(want_seconds(timeout) * 1000.0), retention=_retention, compacting=compacting, deleting=deleting, ensure_created=ensure_created, )
[docs] async def on_start(self) -> None: self.beacon.add(self._producer) self._last_batch = None await self._producer.start()
[docs] async def on_stop(self) -> None: cast(Transport, self.transport)._topic_waiters.clear() self._last_batch = None await self._producer.stop()
[docs] async def send(self, topic: str, key: Optional[bytes], value: Optional[bytes], partition: Optional[int]) -> Awaitable[RecordMetadata]: try: return cast(Awaitable[RecordMetadata], await self._producer.send( topic, value, key=key, partition=partition)) except KafkaError as exc: raise ProducerSendError(f'Error while sending: {exc!r}') from exc
[docs] async def send_and_wait(self, topic: str, key: Optional[bytes], value: Optional[bytes], partition: Optional[int]) -> RecordMetadata: fut = await self.send(topic, key=key, value=value, partition=partition) return await fut
[docs] async def flush(self) -> None: await self._producer.flush()
[docs] def key_partition(self, topic: str, key: bytes) -> TP: partition = self._producer._partition( topic, partition=None, key=None, value=None, serialized_key=key, serialized_value=None, ) return TP(topic, partition)
[docs]class Transport(base.Transport): """Kafka transport using :pypi:`aiokafka`.""" Consumer: ClassVar[Type[ConsumerT]] = Consumer Producer: ClassVar[Type[ProducerT]] = Producer default_port = 9092 driver_version = f'aiokafka={aiokafka.__version__}' _topic_waiters: MutableMapping[str, StampedeWrapper] def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._topic_waiters = {} def _topic_config(self, retention: int = None, compacting: bool = None, deleting: bool = None) -> MutableMapping[str, Any]: config: MutableMapping[str, Any] = {} cleanup_flags: Set[str] = set() if compacting: cleanup_flags |= {'compact'} if deleting: cleanup_flags |= {'delete'} if cleanup_flags: config['cleanup.policy'] = ','.join(sorted(cleanup_flags)) if retention: config['retention.ms'] = retention return config async def _create_topic(self, owner: Service, client: aiokafka.AIOKafkaClient, topic: str, partitions: int, replication: int, **kwargs: Any) -> None: assert topic is not None try: wrap = self._topic_waiters[topic] except KeyError: wrap = self._topic_waiters[topic] = StampedeWrapper( self._really_create_topic, owner, client, topic, partitions, replication, loop=self.loop, **kwargs) try: await wrap() except Exception: self._topic_waiters.pop(topic, None) raise async def _get_controller_node(self, owner: Service, client: aiokafka.AIOKafkaClient, timeout: int = 30000) -> Optional[int]: nodes = [broker.nodeId for broker in client.cluster.brokers()] for node_id in nodes: if node_id is None: raise RuntimeError('Not connected to Kafka Broker') request = MetadataRequest_v1([]) wait_result = await owner.wait( client.send(node_id, request), timeout=timeout, ) if wait_result.stopped: owner.log.info(f'Shutting down - skipping creation.') return None response = wait_result.result return response.controller_id raise Exception(f'Controller node not found') async def _really_create_topic(self, owner: Service, client: aiokafka.AIOKafkaClient, topic: str, partitions: int, replication: int, *, config: Mapping[str, Any] = None, timeout: int = 30000, retention: int = None, compacting: bool = None, deleting: bool = None, ensure_created: bool = False) -> None: owner.log.info(f'Creating topic {topic}') if topic in client.cluster.topics(): owner.log.debug(f'Topic {topic} exists, skipping creation.') return protocol_version = 1 extra_configs = config or {} config = self._topic_config(retention, compacting, deleting) config.update(extra_configs) controller_node = await self._get_controller_node(owner, client, timeout=timeout) owner.log.info(f'Found controller: {controller_node}') if controller_node is None: if owner.should_stop: owner.log.info(f'Shutting down hence controller not found') return else: raise Exception(f'Controller node is None') request = CreateTopicsRequest[protocol_version]( [(topic, partitions, replication, [], list(config.items()))], timeout, False, ) wait_result = await owner.wait( client.send(controller_node, request), timeout=timeout, ) if wait_result.stopped: owner.log.info(f'Shutting down - skipping creation.') return response = wait_result.result assert len(response.topic_error_codes), 'single topic' _, code, reason = response.topic_error_codes[0] if code != 0: if not ensure_created and code == TopicExistsError.errno: owner.log.debug( f'Topic {topic} exists, skipping creation.') return elif code == NotControllerError.errno: raise RuntimeError(f'Invalid controller: {controller_node}') else: raise for_code(code)( f'Cannot create topic: {topic} ({code}): {reason}') else: owner.log.info(f'Topic {topic} created.') return