"""Consumer - fetching messages and managing consumer state.
The Consumer is responsible for:
- Holds reference to the transport that created it
- ... and the app via ``self.transport.app``.
- Has a callback that usually points back to ``Conductor.on_message``.
- Receives messages and calls the callback for every message received.
- Keeps track of the message and its acked/unacked status.
- The Conductor forwards the message to all Streams that subscribes
to the topic the message was sent to.
+ Messages are reference counted, and the Conductor increases
the reference count to the number of subscribed streams.
+ ``Stream.__aiter__`` is set up in a way such that when what is
iterating over the stream is finished with the message, a
finally: block will decrease the reference count by one.
+ When the reference count for a message hits zero, the stream will
call ``Consumer.ack(message)``, which will mark that topic +
partition + offset combination as "committable"
+ If all the streams share the same key_type/value_type,
the conductor will only deserialize the payload once.
- Commits the offset at an interval
+ The Consumer has a background thread that periodically commits the
offset.
- If the consumer marked an offset as committable this thread
will advance the committed offset.
+ To find the offset that it can safely advance to the commit thread
will traverse the _acked mapping of TP to list of acked offsets, by
finding a range of consecutive acked offsets (see note in
_new_offset).
"""
import abc
import asyncio
import gc
import typing
from collections import defaultdict
from time import monotonic
from typing import (
Any,
AsyncIterator,
Awaitable,
ClassVar,
Dict,
Iterable,
Iterator,
List,
Mapping,
MutableMapping,
MutableSet,
NamedTuple,
Optional,
Set,
Tuple,
Type,
cast,
)
from weakref import WeakSet
from mode import Service, ServiceT, flight_recorder, get_logger
from mode.threads import MethodQueue, QueueServiceThread
from mode.utils.futures import notify
from mode.utils.locks import Event
from mode.utils.text import pluralize
from mode.utils.times import Seconds
from mode.utils.typing import Counter
from faust.exceptions import ProducerSendError
from faust.types import AppT, ConsumerMessage, Message, RecordMetadata, TP
from faust.types.core import HeadersArg
from faust.types.transports import (
ConsumerCallback,
ConsumerT,
PartitionsAssignedCallback,
PartitionsRevokedCallback,
ProducerT,
TPorTopicSet,
TransactionManagerT,
TransportT,
)
from faust.utils import terminal
from faust.utils.functional import consecutive_numbers
from faust.utils.tracing import traced_from_parent_span
if typing.TYPE_CHECKING: # pragma: no cover
from faust.app import App as _App
else:
class _App: ... # noqa: E701
__all__ = ['Consumer', 'Fetcher']
# These flags are used for Service.diag, tracking what the consumer
# service is currently doing.
CONSUMER_FETCHING = 'FETCHING'
CONSUMER_PARTITIONS_REVOKED = 'PARTITIONS_REVOKED'
CONSUMER_PARTITIONS_ASSIGNED = 'PARTITIONS_ASSIGNED'
CONSUMER_COMMITTING = 'COMMITTING'
CONSUMER_SEEKING = 'SEEKING'
CONSUMER_WAIT_EMPTY = 'WAIT_EMPTY'
logger = get_logger(__name__)
RecordMap = Mapping[TP, List[Any]]
class TopicPartitionGroup(NamedTuple):
"""Tuple of ``(topic, partition, group)``."""
topic: str
partition: int
group: int
def ensure_TP(tp: Any) -> TP:
"""Convert aiokafka ``TopicPartition`` to Faust ``TP``."""
return tp if isinstance(tp, TP) else TP(tp.topic, tp.partition)
def ensure_TPset(tps: Iterable[Any]) -> Set[TP]:
"""Convert set of aiokafka ``TopicPartition`` to Faust ``TP``."""
return {ensure_TP(tp) for tp in tps}
[docs]class Fetcher(Service):
"""Service fetching messages from Kafka."""
app: AppT
logger = logger
_drainer: Optional[asyncio.Future] = None
def __init__(self, app: AppT, **kwargs: Any) -> None:
self.app = app
super().__init__(**kwargs)
[docs] async def on_stop(self) -> None:
"""Call when the fetcher is stopping."""
if self._drainer is not None and not self._drainer.done():
self._drainer.cancel()
while True:
try:
await asyncio.wait_for(self._drainer, timeout=1.0)
except StopIteration:
# Task is cancelled right before coro stops.
break
except asyncio.CancelledError:
break
except asyncio.TimeoutError:
self.log.warning('Fetcher is ignoring cancel or slow :(')
else:
break
@Service.task
async def _fetcher(self) -> None:
try:
consumer = cast(Consumer, self.app.consumer)
self._drainer = asyncio.ensure_future(
consumer._drain_messages(self),
loop=self.loop,
)
await self._drainer
except asyncio.CancelledError:
pass
finally:
self.set_shutdown()
class TransactionManager(Service, TransactionManagerT):
"""Manage producer transactions."""
app: AppT
transactional_id_format = '{group_id}-{tpg.group}-{tpg.partition}'
def __init__(self, transport: TransportT,
*,
consumer: 'ConsumerT',
producer: 'ProducerT',
**kwargs: Any) -> None:
self.transport = transport
self.app = self.transport.app
self.consumer = consumer
self.producer = producer
super().__init__(**kwargs)
async def flush(self) -> None:
"""Wait for producer to transmit all pending messages."""
await self.producer.flush()
async def on_partitions_revoked(self, revoked: Set[TP]) -> None:
"""Call when the cluster is rebalancing and partitions are revoked."""
await traced_from_parent_span()(self.flush)()
async def on_rebalance(self,
assigned: Set[TP],
revoked: Set[TP],
newly_assigned: Set[TP]) -> None:
"""Call when the cluster is rebalancing."""
T = traced_from_parent_span()
# Stop producers for revoked partitions.
revoked_tids = list(sorted(self._tps_to_transactional_ids(revoked)))
if revoked_tids:
self.log.info(
'Stopping %r transactional %s for %r revoked %s...',
len(revoked_tids),
pluralize(len(revoked_tids), 'producer'),
len(revoked),
pluralize(len(revoked), 'partition'))
await T(self._stop_transactions, tids=revoked_tids)(revoked_tids)
# Start produers for assigned partitions
assigned_tids = list(sorted(self._tps_to_transactional_ids(assigned)))
if assigned_tids:
self.log.info(
'Starting %r transactional %s for %r assigned %s...',
len(assigned_tids),
pluralize(len(assigned_tids), 'producer'),
len(assigned),
pluralize(len(assigned), 'partition'))
await T(self._start_transactions,
tids=assigned_tids)(assigned_tids)
async def _stop_transactions(self, tids: Iterable[str]) -> None:
T = traced_from_parent_span()
producer = self.producer
for transactional_id in tids:
await T(producer.stop_transaction)(transactional_id)
async def _start_transactions(self, tids: Iterable[str]) -> None:
T = traced_from_parent_span()
producer = self.producer
for transactional_id in tids:
await T(producer.maybe_begin_transaction)(transactional_id)
def _tps_to_transactional_ids(self, tps: Set[TP]) -> Set[str]:
return {
self.transactional_id_format.format(
tpg=tpg,
group_id=self.app.conf.id,
)
for tpg in self._tps_to_active_tpgs(tps)
}
def _tps_to_active_tpgs(self, tps: Set[TP]) -> Set[TopicPartitionGroup]:
assignor = self.app.assignor
return {
TopicPartitionGroup(
tp.topic,
tp.partition,
assignor.group_for_topic(tp.topic),
)
for tp in tps
if not assignor.is_standby(tp)
}
async def send(self, topic: str, key: Optional[bytes],
value: Optional[bytes],
partition: Optional[int],
timestamp: Optional[float],
headers: Optional[HeadersArg],
*,
transactional_id: str = None) -> Awaitable[RecordMetadata]:
"""Schedule message to be sent by producer."""
group = transactional_id = None
p = self.consumer.key_partition(topic, key, partition)
if p is not None:
group = self.app.assignor.group_for_topic(topic)
transactional_id = f'{self.app.conf.id}-{group}-{p}'
return await self.producer.send(
topic, key, value, p, timestamp, headers,
transactional_id=transactional_id,
)
async def send_and_wait(self, topic: str, key: Optional[bytes],
value: Optional[bytes],
partition: Optional[int],
timestamp: Optional[float],
headers: Optional[HeadersArg],
*,
transactional_id: str = None) -> RecordMetadata:
"""Send message and wait for it to be transmitted."""
fut = await self.send(topic, key, value, partition, timestamp, headers)
return await fut
async def commit(self, offsets: Mapping[TP, int],
start_new_transaction: bool = True) -> bool:
"""Commit offsets for partitions."""
producer = self.producer
group_id = self.app.conf.id
by_transactional_id: MutableMapping[str, MutableMapping[TP, int]]
by_transactional_id = defaultdict(dict)
for tp, offset in offsets.items():
group = self.app.assignor.group_for_topic(tp.topic)
transactional_id = f'{group_id}-{group}-{tp.partition}'
by_transactional_id[transactional_id][tp] = offset
if by_transactional_id:
await producer.commit_transactions(
by_transactional_id, group_id,
start_new_transaction=start_new_transaction,
)
return True
def key_partition(self, topic: str, key: bytes) -> TP:
raise NotImplementedError()
async def create_topic(self,
topic: str,
partitions: int,
replication: int,
*,
config: Mapping[str, Any] = None,
timeout: Seconds = 30.0,
retention: Seconds = None,
compacting: bool = None,
deleting: bool = None,
ensure_created: bool = False) -> None:
"""Create/declare topic on server."""
return await self.producer.create_topic(
topic, partitions, replication,
config=config,
timeout=timeout,
retention=retention,
compacting=compacting,
deleting=deleting,
ensure_created=ensure_created,
)
def supports_headers(self) -> bool:
"""Return :const:`True` if the Kafka server supports headers."""
return self.producer.supports_headers()
[docs]class Consumer(Service, ConsumerT):
"""Base Consumer."""
app: AppT
logger = logger
#: Tuple of exception types that may be raised when the
#: underlying consumer driver is stopped.
consumer_stopped_errors: ClassVar[Tuple[Type[BaseException], ...]] = ()
# Mapping of TP to list of gap in offsets.
_gap: MutableMapping[TP, List[int]]
# Mapping of TP to list of acked offsets.
_acked: MutableMapping[TP, List[int]]
#: Fast lookup to see if tp+offset was acked.
_acked_index: MutableMapping[TP, Set[int]]
#: Keeps track of the currently read offset in each TP
_read_offset: MutableMapping[TP, Optional[int]]
#: Keeps track of the currently committed offset in each TP.
_committed_offset: MutableMapping[TP, Optional[int]]
#: The consumer.wait_empty() method will set this to be notified
#: when something acks a message.
_waiting_for_ack: Optional[asyncio.Future] = None
#: Used by .commit to ensure only one thread is comitting at a time.
#: Other thread starting to commit while a commit is already active,
#: will wait for the original request to finish, and do nothing.
_commit_fut: Optional[asyncio.Future] = None
#: Set of unacked messages: that is messages that we started processing
#: and that we MUST attempt to complete processing of, before
#: shutting down or resuming a rebalance.
_unacked_messages: MutableSet[Message]
#: Time of last record batch received.
#: Set only when not set, and reset by commit() so actually
#: tracks how long it ago it was since we received a record that
#: was never committed.
_last_batch: Counter[TP]
#: Time of when the consumer was started.
_time_start: float
# How often to poll and track log end offsets.
_end_offset_monitor_interval: float
_commit_every: Optional[int]
_n_acked: int = 0
_active_partitions: Optional[Set[TP]]
_paused_partitions: Set[TP]
flow_active: bool = True
can_resume_flow: Event
def __init__(self,
transport: TransportT,
callback: ConsumerCallback,
on_partitions_revoked: PartitionsRevokedCallback,
on_partitions_assigned: PartitionsAssignedCallback,
*,
commit_interval: float = None,
commit_livelock_soft_timeout: float = None,
loop: asyncio.AbstractEventLoop = None,
**kwargs: Any) -> None:
assert callback is not None
self.transport = transport
self.app = self.transport.app
self.in_transaction = self.app.in_transaction
self.callback = callback
self._on_message_in = self.app.sensors.on_message_in
self._on_partitions_revoked = on_partitions_revoked
self._on_partitions_assigned = on_partitions_assigned
self._commit_every = self.app.conf.broker_commit_every
self.scheduler = self.app.conf.ConsumerScheduler()
self.commit_interval = (
commit_interval or self.app.conf.broker_commit_interval)
self.commit_livelock_soft_timeout = (
commit_livelock_soft_timeout or
self.app.conf.broker_commit_livelock_soft_timeout)
self._gap = defaultdict(list)
self._acked = defaultdict(list)
self._acked_index = defaultdict(set)
self._read_offset = defaultdict(lambda: None)
self._committed_offset = defaultdict(lambda: None)
self._unacked_messages = WeakSet()
self._waiting_for_ack = None
self._time_start = monotonic()
self._last_batch = Counter()
self._end_offset_monitor_interval = self.commit_interval * 2
self.randomly_assigned_topics = set()
self.can_resume_flow = Event()
self._reset_state()
super().__init__(loop=loop or self.transport.loop, **kwargs)
self.transactions = self.transport.create_transaction_manager(
consumer=self,
producer=self.app.producer,
beacon=self.beacon,
loop=self.loop,
)
[docs] def on_init_dependencies(self) -> Iterable[ServiceT]:
"""Return list of services this consumer depends on."""
# We start the TransactionManager only if
# processing_guarantee='exactly_once'
if self.in_transaction:
return [self.transactions]
return []
def _reset_state(self) -> None:
self._active_partitions = None
self._paused_partitions = set()
self.can_resume_flow.clear()
self.flow_active = True
self._last_batch.clear()
self._time_start = monotonic()
[docs] async def on_restart(self) -> None:
"""Call when the consumer is restarted."""
self._reset_state()
self.on_init()
def _get_active_partitions(self) -> Set[TP]:
tps = self._active_partitions
if tps is None:
return self._set_active_tps(self.assignment())
assert all(isinstance(x, TP) for x in tps)
return tps
def _set_active_tps(self, tps: Set[TP]) -> Set[TP]:
xtps = self._active_partitions = ensure_TPset(tps) # copy
xtps.difference_update(self._paused_partitions)
return xtps
@abc.abstractmethod
async def _commit(
self,
offsets: Mapping[TP, int]) -> bool: # pragma: no cover
...
[docs] @abc.abstractmethod
async def seek_to_committed(self) -> Mapping[TP, int]:
"""Seek all partitions to their committed offsets."""
...
[docs] async def seek(self, partition: TP, offset: int) -> None:
"""Seek partition to specific offset."""
self.log.dev('SEEK %r -> %r', partition, offset)
# reset livelock detection
self._last_batch.pop(partition, None)
await self._seek(partition, offset)
# set new read offset so we will reread messages
self._read_offset[ensure_TP(partition)] = offset if offset else None
@abc.abstractmethod
async def _seek(self, partition: TP, offset: int) -> None:
...
[docs] def stop_flow(self) -> None:
"""Block consumer from processing any more messages."""
self.flow_active = False
self.can_resume_flow.clear()
[docs] def resume_flow(self) -> None:
"""Allow consumer to process messages."""
self.flow_active = True
self.can_resume_flow.set()
[docs] def pause_partitions(self, tps: Iterable[TP]) -> None:
"""Pause fetching from partitions."""
tpset = ensure_TPset(tps)
self._get_active_partitions().difference_update(tpset)
self._paused_partitions.update(tpset)
[docs] def resume_partitions(self, tps: Iterable[TP]) -> None:
"""Resume fetching from partitions."""
tpset = ensure_TPset(tps)
self._get_active_partitions().update(tps)
self._paused_partitions.difference_update(tpset)
@abc.abstractmethod
def _new_topicpartition(
self, topic: str, partition: int) -> TP: # pragma: no cover
...
def _is_changelog_tp(self, tp: TP) -> bool:
return tp.topic in self.app.tables.changelog_topics
[docs] @Service.transitions_to(CONSUMER_PARTITIONS_REVOKED)
async def on_partitions_revoked(self, revoked: Set[TP]) -> None:
"""Call during rebalancing when partitions are being revoked."""
# NOTE:
# The ConsumerRebalanceListener is responsible for calling
# app.on_rebalance_start(), and this must have happened
# before we get to this point (see aiokafka implementation).
span = self.app._start_span_from_rebalancing('on_partitions_revoked')
T = traced_from_parent_span(span)
with span:
# see comment in on_partitions_assigned
# remove revoked partitions from active + paused tps.
if self._active_partitions is not None:
self._active_partitions.difference_update(revoked)
self._paused_partitions.difference_update(revoked)
await T(self._on_partitions_revoked, partitions=revoked)(
revoked)
[docs] @Service.transitions_to(CONSUMER_PARTITIONS_ASSIGNED)
async def on_partitions_assigned(self, assigned: Set[TP]) -> None:
"""Call during rebalancing when partitions are being assigned."""
span = self.app._start_span_from_rebalancing('on_partitions_assigned')
T = traced_from_parent_span(span)
with span:
# remove recently revoked tps from set of paused tps.
self._paused_partitions.intersection_update(assigned)
# cache set of assigned partitions
self._set_active_tps(assigned)
# start callback chain of assigned callbacks.
# need to copy set at this point, since we cannot have
# the callbacks mutate our active list.
self._last_batch.clear()
await T(self._on_partitions_assigned, partitions=assigned)(
assigned)
self.app.on_rebalance_return()
@abc.abstractmethod
async def _getmany(self,
active_partitions: Optional[Set[TP]],
timeout: float) -> RecordMap:
...
[docs] async def getmany(self,
timeout: float) -> AsyncIterator[Tuple[TP, Message]]:
"""Fetch batch of messages from server."""
# records' contain mapping from TP to list of messages.
# if there are two agents, consuming from topics t1 and t2,
# normal order of iteration would be to process each
# tp in the dict:
# for tp. messages in records.items():
# for message in messages:
# yield tp, message
#
# The problem with this, is if we have prefetched 16k records
# for one partition, the other partitions won't even start processing
# before those 16k records are completed.
#
# So we try round-robin between the tps instead:
#
# iterators: Dict[TP, Iterator] = {
# tp: iter(messages)
# for tp, messages in records.items()
# }
# while iterators:
# for tp, messages in iterators.items():
# yield tp, next(messages)
# # remove from iterators if empty.
#
# The problem with this implementation is that
# the records mapping is ordered by TP, so records.keys()
# will look like this:
#
# TP(topic='bar', partition=0)
# TP(topic='bar', partition=1)
# TP(topic='bar', partition=2)
# TP(topic='bar', partition=3)
# TP(topic='foo', partition=0)
# TP(topic='foo', partition=1)
# TP(topic='foo', partition=2)
# TP(topic='foo', partition=3)
#
# If there are 100 partitions for each topic,
# it will process 100 items in the first topic, then 100 items
# in the other topic, but even worse if partition counts
# vary greatly, t1 has 1000 partitions and t2
# has 1 partition, then t2 will end up being starved most of the time.
#
# We solve this by going round-robin through each topic.
records, active_partitions = await self._wait_next_records(timeout)
if records is None or self.should_stop:
return
records_it = self.scheduler.iterate(records)
to_message = self._to_message # localize
if self.flow_active:
for tp, record in records_it:
if not self.flow_active:
break
if active_partitions is None or tp in active_partitions:
highwater_mark = self.highwater(tp)
self.app.monitor.track_tp_end_offset(tp, highwater_mark)
# convert timestamp to seconds from int milliseconds.
yield tp, to_message(tp, record)
async def _wait_next_records(
self, timeout: float) -> Tuple[Optional[RecordMap],
Optional[Set[TP]]]:
if not self.flow_active:
await self.wait(self.can_resume_flow)
# Implementation for the Fetcher service.
is_client_only = self.app.client_only
active_partitions: Optional[Set[TP]]
if is_client_only:
active_partitions = None
else:
active_partitions = self._get_active_partitions()
records: RecordMap = {}
if is_client_only or active_partitions:
# Fetch records only if active partitions to avoid the risk of
# fetching all partitions in the beginning when none of the
# partitions is paused/resumed.
records = await self._getmany(
active_partitions=active_partitions,
timeout=timeout,
)
else:
# We should still release to the event loop
await self.sleep(1)
return records, active_partitions
@abc.abstractmethod
def _to_message(self, tp: TP, record: Any) -> ConsumerMessage:
...
[docs] def track_message(self, message: Message) -> None:
"""Track message and mark it as pending ack."""
# add to set of pending messages that must be acked for graceful
# shutdown. This is called by transport.Conductor,
# before delivering messages to streams.
self._unacked_messages.add(message)
# call sensors
self._on_message_in(message.tp, message.offset, message)
[docs] def ack(self, message: Message) -> bool:
"""Mark message as being acknowledged by stream."""
if not message.acked:
message.acked = True
tp = message.tp
offset = message.offset
if self.app.topics.acks_enabled_for(message.topic):
committed = self._committed_offset[tp]
try:
if committed is None or offset > committed:
acked_index = self._acked_index[tp]
if offset not in acked_index:
self._unacked_messages.discard(message)
acked_index.add(offset)
acked_for_tp = self._acked[tp]
acked_for_tp.append(offset)
self._n_acked += 1
return True
finally:
notify(self._waiting_for_ack)
return False
async def _wait_for_ack(self, timeout: float) -> None:
# arm future so that `ack()` can wake us up
self._waiting_for_ack = asyncio.Future(loop=self.loop)
try:
# wait for `ack()` to wake us up
await asyncio.wait_for(
self._waiting_for_ack, loop=self.loop, timeout=1)
except (asyncio.TimeoutError,
asyncio.CancelledError): # pragma: no cover
pass
finally:
self._waiting_for_ack = None
[docs] @Service.transitions_to(CONSUMER_WAIT_EMPTY)
async def wait_empty(self) -> None:
"""Wait for all messages that started processing to be acked."""
wait_count = 0
T = traced_from_parent_span()
while not self.should_stop and self._unacked_messages:
wait_count += 1
if not wait_count % 10: # pragma: no cover
remaining = [(m.refcount, m) for m in self._unacked_messages]
self.log.warning('wait_empty: Waiting for %r tasks', remaining)
self.log.dev('STILL WAITING FOR ALL STREAMS TO FINISH')
self.log.dev('WAITING FOR %r EVENTS', len(self._unacked_messages))
gc.collect()
await T(self.commit)()
if not self._unacked_messages:
break
await T(self._wait_for_ack)(timeout=1)
self.log.dev('COMMITTING AGAIN AFTER STREAMS DONE')
await T(self.commit_and_end_transactions)()
[docs] async def commit_and_end_transactions(self) -> None:
"""Commit all safe offsets and end transaction."""
await self.commit(start_new_transaction=False)
[docs] async def on_stop(self) -> None:
"""Call when consumer is stopping."""
if self.app.conf.stream_wait_empty:
await self.wait_empty()
else:
await self.commit_and_end_transactions()
self._last_batch.clear()
@Service.task
async def _commit_handler(self) -> None:
interval = self.commit_interval
await self.sleep(interval)
async for sleep_time in self.itertimer(interval, name='commit'):
await self.commit()
@Service.task
async def _commit_livelock_detector(self) -> None: # pragma: no cover
soft_timeout = self.commit_livelock_soft_timeout
interval: float = self.commit_interval * 2.5
acks_enabled_for = self.app.topics.acks_enabled_for
await self.sleep(interval)
async for sleep_time in self.itertimer(interval, name='livelock'):
for tp, last_batch_time in self._last_batch.items():
if last_batch_time and acks_enabled_for(tp.topic):
s_since_batch = monotonic() - last_batch_time
if s_since_batch > soft_timeout:
self.log.warning(
'Possible livelock: '
'COMMIT OFFSET NOT ADVANCING FOR %r', tp)
[docs] async def commit(self, topics: TPorTopicSet = None,
start_new_transaction: bool = True) -> bool:
"""Maybe commit the offset for all or specific topics.
Arguments:
topics: Set containing topics and/or TopicPartitions to commit.
"""
if self.app.client_only:
# client only cannot commit as consumer does not have group_id
return False
if await self.maybe_wait_for_commit_to_finish():
# original commit finished, return False as we did not commit
return False
self._commit_fut = asyncio.Future(loop=self.loop)
try:
return await self.force_commit(
topics,
start_new_transaction=start_new_transaction,
)
finally:
# set commit_fut to None so that next call will commit.
fut, self._commit_fut = self._commit_fut, None
# notify followers that the commit is done.
notify(fut)
[docs] async def maybe_wait_for_commit_to_finish(self) -> bool:
"""Wait for any existing commit operation to finish."""
# Only one coroutine allowed to commit at a time,
# and other coroutines should wait for the original commit to finish
# then do nothing.
if self._commit_fut is not None:
# something is already committing so wait for that future.
try:
await self._commit_fut
except asyncio.CancelledError:
# if future is cancelled we have to start new commit
pass
else:
return True
return False
[docs] @Service.transitions_to(CONSUMER_COMMITTING)
async def force_commit(self,
topics: TPorTopicSet = None,
start_new_transaction: bool = True) -> bool:
"""Force offset commit."""
sensor_state = self.app.sensors.on_commit_initiated(self)
# Go over the ack list in each topic/partition
commit_tps = list(self._filter_tps_with_pending_acks(topics))
did_commit = await self._commit_tps(
commit_tps, start_new_transaction=start_new_transaction)
self.app.sensors.on_commit_completed(self, sensor_state)
return did_commit
async def _commit_tps(self,
tps: Iterable[TP],
start_new_transaction: bool) -> bool:
commit_offsets = self._filter_committable_offsets(tps)
if commit_offsets:
try:
# send all messages attached to the new offset
await self._handle_attached(commit_offsets)
except ProducerSendError as exc:
await self.crash(exc)
else:
return await self._commit_offsets(
commit_offsets,
start_new_transaction=start_new_transaction)
return False
def _filter_committable_offsets(self, tps: Iterable[TP]) -> Dict[TP, int]:
commit_offsets = {}
for tp in tps:
# Find the latest offset we can commit in this tp
offset = self._new_offset(tp)
# check if we can commit to this offset
if offset is not None and self._should_commit(tp, offset):
commit_offsets[tp] = offset
return commit_offsets
async def _handle_attached(self, commit_offsets: Mapping[TP, int]) -> None:
for tp, offset in commit_offsets.items():
app = cast(_App, self.app)
attachments = app._attachments
producer = app.producer
# Start publishing the messages and return a list of pending
# futures.
pending = await attachments.publish_for_tp_offset(tp, offset)
# then we wait for either
# 1) all the attached messages to be published, or
# 2) the producer crashing
#
# If the producer crashes we will not be able to send any messages
# and it only crashes when there's an irrecoverable error.
#
# If we cannot commit it means the events will be processed again,
# so conforms to at-least-once semantics.
if pending:
await producer.wait_many(pending)
async def _commit_offsets(self, offsets: Mapping[TP, int],
start_new_transaction: bool = True) -> bool:
table = terminal.logtable(
[(str(tp), str(offset))
for tp, offset in offsets.items()],
title='Commit Offsets',
headers=['TP', 'Offset'],
)
self.log.dev('COMMITTING OFFSETS:\n%s', table)
assignment = self.assignment()
committable_offsets: Dict[TP, int] = {}
revoked: Dict[TP, int] = {}
for tp, offset in offsets.items():
if tp in assignment:
committable_offsets[tp] = offset
else:
revoked[tp] = offset
if revoked:
self.log.info(
'Discarded commit for revoked partitions that '
'will be eventually processed again: %r',
revoked,
)
if not committable_offsets:
return False
with flight_recorder(self.log, timeout=300.0) as on_timeout:
did_commit = False
on_timeout.info('+consumer.commit()')
if self.in_transaction:
did_commit = await self.transactions.commit(
committable_offsets,
start_new_transaction=start_new_transaction,
)
else:
did_commit = await self._commit(committable_offsets)
on_timeout.info('-consumer.commit()')
if did_commit:
on_timeout.info('+tables.on_commit')
self.app.tables.on_commit(committable_offsets)
on_timeout.info('-tables.on_commit')
self._committed_offset.update(committable_offsets)
self.app.monitor.on_tp_commit(committable_offsets)
for tp in offsets:
self._last_batch.pop(tp, None)
return did_commit
def _filter_tps_with_pending_acks(
self, topics: TPorTopicSet = None) -> Iterator[TP]:
return (tp for tp in self._acked
if topics is None or tp in topics or tp.topic in topics)
def _should_commit(self, tp: TP, offset: int) -> bool:
committed = self._committed_offset[tp]
return committed is None or bool(offset) and offset > committed
def _new_offset(self, tp: TP) -> Optional[int]:
# get the new offset for this tp, by going through
# its list of acked messages.
acked = self._acked[tp]
# We iterate over it until we find a gap
# then return the offset before that.
# For example if acked[tp] is:
# 1 2 3 4 5 6 7 8 9
# the return value will be: 9
# If acked[tp] is:
# 34 35 36 40 41 42 43 44
# ^--- gap
# the return value will be: 36
if acked:
max_offset = max(acked)
gap_for_tp = self._gap[tp]
if gap_for_tp:
gap_index = next((i for i, x in enumerate(gap_for_tp)
if x > max_offset), len(gap_for_tp))
gaps = gap_for_tp[:gap_index]
acked.extend(gaps)
gap_for_tp[:gap_index] = []
acked.sort()
# Note: acked is always kept sorted.
# find first list of consecutive numbers
batch = next(consecutive_numbers(acked))
# remove them from the list to clean up.
acked[:len(batch) - 1] = []
self._acked_index[tp].difference_update(batch)
# return the highest commit offset
return batch[-1]
return None
[docs] async def on_task_error(self, exc: BaseException) -> None:
"""Call when processing a message failed."""
await self.commit()
def _add_gap(self, tp: TP, offset_from: int, offset_to: int) -> None:
committed = self._committed_offset[tp]
gap_for_tp = self._gap[tp]
for offset in range(offset_from, offset_to):
if committed is None or offset > committed:
gap_for_tp.append(offset)
async def _drain_messages(
self, fetcher: ServiceT) -> None: # pragma: no cover
# This is the background thread started by Fetcher, used to
# constantly read messages using Consumer.getmany.
# It takes Fetcher as argument, because we must be able to
# stop it using `await Fetcher.stop()`.
callback = self.callback
getmany = self.getmany
consumer_should_stop = self._stopped.is_set
fetcher_should_stop = fetcher._stopped.is_set
get_read_offset = self._read_offset.__getitem__
set_read_offset = self._read_offset.__setitem__
get_commit_offset = self._committed_offset.__getitem__
flag_consumer_fetching = CONSUMER_FETCHING
set_flag = self.diag.set_flag
unset_flag = self.diag.unset_flag
commit_every = self._commit_every
acks_enabled_for = self.app.topics.acks_enabled_for
try:
while not (consumer_should_stop() or fetcher_should_stop()):
set_flag(flag_consumer_fetching)
ait = cast(AsyncIterator, getmany(timeout=5.0))
last_batch = self._last_batch
# Sleeping because sometimes getmany is called in a loop
# never releasing to the event loop
await self.sleep(0)
if not self.should_stop:
async for tp, message in ait:
offset = message.offset
r_offset = get_read_offset(tp)
committed_offset = get_commit_offset(tp)
if committed_offset != r_offset:
last_batch[tp] = monotonic()
if r_offset is None or offset > r_offset:
gap = offset - (r_offset or 0)
# We have a gap in income messages
if gap > 1 and r_offset:
acks_enabled = acks_enabled_for(message.topic)
if acks_enabled:
self._add_gap(tp, r_offset + 1, offset)
if commit_every is not None:
if self._n_acked >= commit_every:
self._n_acked = 0
await self.commit()
await callback(message)
set_read_offset(tp, offset)
else:
self.log.dev('DROPPED MESSAGE ROFF %r: k=%r v=%r',
offset, message.key, message.value)
unset_flag(flag_consumer_fetching)
except self.consumer_stopped_errors:
if self.transport.app.should_stop:
# we're already stopping so ignore
self.log.info('Broker stopped consumer, shutting down...')
return
raise
except asyncio.CancelledError:
if self.transport.app.should_stop:
# we're already stopping so ignore
self.log.info('Consumer shutting down for user cancel.')
return
raise
except Exception as exc:
self.log.exception('Drain messages raised: %r', exc)
raise
finally:
unset_flag(flag_consumer_fetching)
[docs] def close(self) -> None:
"""Close consumer for graceful shutdown."""
...
@property
def unacked(self) -> Set[Message]:
"""Return the set of currently unacknowledged messages."""
return cast(Set[Message], self._unacked_messages)
class ConsumerThread(QueueServiceThread):
"""Consumer running in a dedicated thread."""
app: AppT
consumer: 'ThreadDelegateConsumer'
transport: TransportT
def __init__(self, consumer: ConsumerT, **kwargs: Any) -> None:
self.consumer = consumer
self.transport = self.consumer.transport
self.app = self.transport.app
super().__init__(**kwargs)
@abc.abstractmethod
async def subscribe(self, topics: Iterable[str]) -> None:
"""Reset subscription (requires rebalance)."""
...
@abc.abstractmethod
async def seek_to_committed(self) -> Mapping[TP, int]:
"""Seek all partitions to their committed offsets."""
...
@abc.abstractmethod
async def commit(self, tps: Mapping[TP, int]) -> bool:
"""Commit offsets in topic partitions."""
...
@abc.abstractmethod
async def position(self, tp: TP) -> Optional[int]:
"""Return the current offset for partition."""
...
@abc.abstractmethod
async def seek_to_beginning(self, *partitions: TP) -> None:
"""Seek to the earliest offsets available for partitions."""
...
@abc.abstractmethod
async def seek_wait(self, partitions: Mapping[TP, int]) -> None:
"""Seek partitions to specific offsets and wait."""
...
@abc.abstractmethod
def seek(self, partition: TP, offset: int) -> None:
"""Seek partition to specific offset."""
...
@abc.abstractmethod
def assignment(self) -> Set[TP]:
"""Return the current assignment."""
...
@abc.abstractmethod
def highwater(self, tp: TP) -> int:
"""Return the last available offset in partition."""
...
@abc.abstractmethod
def topic_partitions(self, topic: str) -> Optional[int]:
"""Return number of configured partitions for topic by name."""
...
@abc.abstractmethod
async def earliest_offsets(self, *partitions: TP) -> Mapping[TP, int]:
"""Return the earliest available offset for list of partitions."""
...
@abc.abstractmethod
async def highwaters(self, *partitions: TP) -> Mapping[TP, int]:
"""Return the last available offset for list of partitions."""
...
@abc.abstractmethod
async def getmany(self,
active_partitions: Optional[Set[TP]],
timeout: float) -> RecordMap:
"""Fetch batch of messages from server."""
...
@abc.abstractmethod
async def create_topic(self,
topic: str,
partitions: int,
replication: int,
*,
config: Mapping[str, Any] = None,
timeout: Seconds = 30.0,
retention: Seconds = None,
compacting: bool = None,
deleting: bool = None,
ensure_created: bool = False) -> None:
"""Create/declare topic on server."""
...
async def on_partitions_revoked(
self, revoked: Set[TP]) -> None:
"""Call on rebalance when partitions are being revoked."""
await self.consumer.threadsafe_partitions_revoked(
self.thread_loop, revoked)
async def on_partitions_assigned(
self, assigned: Set[TP]) -> None:
"""Call on rebalance when partitions are being assigned."""
await self.consumer.threadsafe_partitions_assigned(
self.thread_loop, assigned)
@abc.abstractmethod
def key_partition(self,
topic: str,
key: Optional[bytes],
partition: int = None) -> Optional[int]:
"""Hash key to determine partition number."""
...
class ThreadDelegateConsumer(Consumer):
_thread: ConsumerThread
#: Main thread method queue.
#: The consumer is running in a separate thread, and so we send
#: requests to it via a queue.
#: Sometimes the thread needs to call code owned by the main thread,
#: such as App.on_partitions_revoked, and in that case the thread
#: uses this method queue.
_method_queue: MethodQueue
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._method_queue = MethodQueue(loop=self.loop, beacon=self.beacon)
self.add_dependency(self._method_queue)
self._thread = self._new_consumer_thread()
self.add_dependency(self._thread)
@abc.abstractmethod
def _new_consumer_thread(self) -> ConsumerThread:
...
async def threadsafe_partitions_revoked(
self,
receiver_loop: asyncio.AbstractEventLoop,
revoked: Set[TP]) -> None:
"""Call rebalancing callback in a thread-safe manner."""
promise = await self._method_queue.call(
receiver_loop.create_future(),
self.on_partitions_revoked,
revoked,
)
# wait for main-thread to finish processing request
await promise
async def threadsafe_partitions_assigned(
self,
receiver_loop: asyncio.AbstractEventLoop,
assigned: Set[TP]) -> None:
"""Call rebalancing callback in a thread-safe manner."""
promise = await self._method_queue.call(
receiver_loop.create_future(),
self.on_partitions_assigned,
assigned,
)
# wait for main-thread to finish processing request
await promise
async def _getmany(self,
active_partitions: Optional[Set[TP]],
timeout: float) -> RecordMap:
return await self._thread.getmany(active_partitions, timeout)
async def subscribe(self, topics: Iterable[str]) -> None:
"""Reset subscription (requires rebalance)."""
await self._thread.subscribe(topics=topics)
async def seek_to_committed(self) -> Mapping[TP, int]:
"""Seek all partitions to the committed offset."""
return await self._thread.seek_to_committed()
async def position(self, tp: TP) -> Optional[int]:
"""Return the current position for partition."""
return await self._thread.position(tp)
async def seek_wait(self, partitions: Mapping[TP, int]) -> None:
"""Seek partitions to specific offsets and wait."""
return await self._thread.seek_wait(partitions)
async def _seek(self, partition: TP, offset: int) -> None:
self._thread.seek(partition, offset)
def assignment(self) -> Set[TP]:
"""Return the current assignment."""
return self._thread.assignment()
def highwater(self, tp: TP) -> int:
"""Return the last available offset for specific partition."""
return self._thread.highwater(tp)
def topic_partitions(self, topic: str) -> Optional[int]:
"""Return the number of partitions configured for topic by name."""
return self._thread.topic_partitions(topic)
async def earliest_offsets(self, *partitions: TP) -> Mapping[TP, int]:
"""Return the earliest offsets for a list of partitions."""
return await self._thread.earliest_offsets(*partitions)
async def highwaters(self, *partitions: TP) -> Mapping[TP, int]:
"""Return the last offset for a list of partitions."""
return await self._thread.highwaters(*partitions)
async def _commit(self, offsets: Mapping[TP, int]) -> bool:
return await self._thread.commit(offsets)
def close(self) -> None:
"""Close consumer for graceful shutdown."""
self._thread.close()
def key_partition(self,
topic: str,
key: Optional[bytes],
partition: int = None) -> Optional[int]:
"""Hash key to determine partition number."""
return self._thread.key_partition(topic, key, partition=partition)