"""Streams."""
import asyncio
import reprlib
import typing
import weakref
from asyncio import CancelledError
from contextvars import ContextVar
from time import monotonic
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Callable,
Iterable,
Iterator,
List,
Mapping,
MutableSequence,
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
)
from mode import Seconds, Service, get_logger, want_seconds
from mode.utils.aiter import aenumerate, aiter
from mode.utils.compat import current_task
from mode.utils.futures import maybe_async, notify
from mode.utils.objects import cached_property
from mode.utils.types.trees import NodeT
from . import joins
from .exceptions import ImproperlyConfigured
from .types import AppT, ConsumerT, EventT, K, ModelArg, ModelT, TP, TopicT
from .types.joins import JoinT
from .types.models import FieldDescriptorT
from .types.streams import (
GroupByKeyArg,
JoinableT,
Processor,
StreamT,
T,
T_co,
T_contra,
)
from .types.topics import ChannelT
from .types.tuples import Message
__all__ = [
'Stream',
'current_event',
]
logger = get_logger(__name__)
if typing.TYPE_CHECKING: # pragma: no cover
_current_event: ContextVar[weakref.ReferenceType[EventT]]
_current_event = ContextVar('current_event')
[docs]def current_event() -> Optional[EventT]:
"""Return the event currently being processed, or None."""
eventref = _current_event.get(None) # type: ignore
return eventref() if eventref is not None else None
async def maybe_forward(value: Any, channel: ChannelT) -> Any:
if isinstance(value, EventT):
await value.forward(channel)
else:
await channel.send(value=value)
return value
class _LinkedListDirection(NamedTuple):
attr: str
getter: Callable[[StreamT], Optional[StreamT]]
_LinkedListDirectionFwd = _LinkedListDirection('_next', lambda n: n._next)
_LinkedListDirectionBwd = _LinkedListDirection('_prev', lambda n: n._prev)
[docs]class Stream(StreamT[T_co], Service):
"""A stream: async iterator processing events in channels/topics."""
logger = logger
mundane_level = 'debug'
_processors: MutableSequence[Processor]
_anext_started = False
_passive = False
_finalized = False
_passive_started: asyncio.Event
def __init__(self,
channel: AsyncIterator[T_co],
*,
app: AppT,
processors: Iterable[Processor[T]] = None,
combined: List[JoinableT] = None,
on_start: Callable = None,
join_strategy: JoinT = None,
beacon: NodeT = None,
concurrency_index: int = None,
prev: StreamT = None,
active_partitions: Set[TP] = None,
enable_acks: bool = True,
loop: asyncio.AbstractEventLoop = None) -> None:
Service.__init__(self, loop=loop, beacon=beacon)
self.app = app
self.channel = channel
self.outbox = self.app.FlowControlQueue(
maxsize=self.app.conf.stream_buffer_maxsize,
loop=self.loop,
clear_on_resume=True,
)
self._passive_started = asyncio.Event(loop=self.loop)
self.join_strategy = join_strategy
self.combined = combined if combined is not None else []
self.concurrency_index = concurrency_index
self._prev = prev
self.active_partitions = active_partitions
self.enable_acks = enable_acks
self._processors = list(processors) if processors else []
self._on_start = on_start
# attach beacon to channel, or if iterable attach to current task.
task = current_task(loop=self.loop)
if task is not None:
self.task_owner = task
# Generate message handler
self._on_stream_event_in = self.app.sensors.on_stream_event_in
self._on_stream_event_out = self.app.sensors.on_stream_event_out
self._on_message_out = self.app.sensors.on_message_out
[docs] def get_active_stream(self) -> StreamT:
"""Return the currently active stream.
A stream can be derived using ``Stream.group_by`` etc,
so if this stream was used to create another derived
stream, this function will return the stream being actively
consumed from. E.g. in the example::
>>> @app.agent()
... async def agent(a):
.. a = a
... b = a.group_by(Withdrawal.account_id)
... c = b.through('backup_topic')
... async for value in c:
... ...
The return value of ``a.get_active_stream()`` would be ``c``.
Notes:
The chain of streams that leads to the active stream
is decided by the :attr:`_next` attribute. To get
to the active stream we just traverse this linked-list::
>>> def get_active_stream(self):
... node = self
... while node._next:
... node = node._next
"""
return list(self._iter_ll_forwards())[-1]
[docs] def get_root_stream(self) -> StreamT:
return list(self._iter_ll_backwards())[-1]
def _iter_ll_forwards(self) -> Iterator[StreamT]:
return self._iter_ll(_LinkedListDirectionFwd)
def _iter_ll_backwards(self) -> Iterator[StreamT]:
return self._iter_ll(_LinkedListDirectionBwd)
def _iter_ll(self, dir_: _LinkedListDirection) -> Iterator[StreamT]:
node: Optional[StreamT] = self
seen: Set[StreamT] = set()
while node:
if node in seen:
raise RuntimeError(
'Loop in Stream.{dir_.attr}: Call support!')
seen.add(node)
yield node
node = dir_.getter(node)
async def _send_to_outbox(self, value: T_contra) -> None:
if self.outbox is not None:
await self.outbox.put(value)
[docs] def add_processor(self, processor: Processor[T]) -> None:
"""Add processor callback executed whenever a new event is received.
Processor functions can be async or non-async, must accept
a single argument, and should return the value, mutated or not.
For example a processor handling a stream of numbers may modify
the value::
def double(value: int) -> int:
return value * 2
stream.add_processor(double)
"""
self._processors.append(processor)
[docs] def info(self) -> Mapping[str, Any]:
"""Return stream settings as a dictionary."""
# used by e.g. .clone to reconstruct keyword arguments
# needed to create a clone of the stream.
return {
'app': self.app,
'channel': self.channel,
'processors': self._processors,
'on_start': self._on_start,
'loop': self.loop,
'combined': self.combined,
'beacon': self.beacon,
'concurrency_index': self.concurrency_index,
'prev': self._prev,
'active_partitions': self.active_partitions,
}
[docs] def clone(self, **kwargs: Any) -> StreamT:
"""Create a clone of this stream.
Notes:
If the cloned stream is supposed to "supercede" this stream,
like in ``group_by``/``through``/etc., you should use
:meth:`_chain` instead so `stream._next = cloned_stream`
is set and :meth:`get_active_stream` returns the cloned stream.
"""
return self.__class__(**{**self.info(), **kwargs})
def _chain(self, **kwargs: Any) -> StreamT:
assert not self._finalized
self._next = new_stream = self.clone(
on_start=self.maybe_start,
prev=self,
# move processors to active stream
processors=list(self._processors),
**kwargs,
)
# delete moved processors from self
self._processors.clear()
return new_stream
[docs] def noack(self) -> 'StreamT':
self._next = new_stream = self.clone(
enable_acks=False,
)
return new_stream
[docs] async def items(self) -> AsyncIterator[Tuple[K, T_co]]:
"""Iterate over the stream as ``key, value`` pairs.
Examples:
.. sourcecode:: python
@app.agent(topic)
async def mytask(stream):
async for key, value in stream.items():
print(key, value)
"""
async for event in self.events():
yield event.key, cast(T_co, event.value)
[docs] async def events(self) -> AsyncIterable[EventT]:
"""Iterate over the stream as events exclusively.
This means the stream must be iterating over a channel,
or at least an iterable of event objects.
"""
async for _ in self: # noqa: F841
if self.current_event is not None:
yield self.current_event
[docs] async def take(self, max_: int,
within: Seconds) -> AsyncIterable[Sequence[T_co]]:
"""Buffer n values at a time and yield a list of buffered values.
Arguments:
within: Timeout for when we give up waiting for another value,
and process the values we have.
Warning: If there's no timeout (i.e. `timeout=None`),
the agent is likely to stall and block buffered events for an
unreasonable length of time(!).
"""
buffer: List[T_co] = []
events: List[EventT] = []
buffer_add = buffer.append
event_add = events.append
buffer_size = buffer.__len__
buffer_full = asyncio.Event(loop=self.loop)
buffer_consumed = asyncio.Event(loop=self.loop)
timeout = want_seconds(within) if within else None
buffer_consuming: Optional[asyncio.Future] = None
channel_it = aiter(self.channel)
# We add this processor to populate the buffer, and the stream
# is passively consumed in the background (enable_passive below).
async def add_to_buffer(value: T) -> T:
# buffer_consuming is set when consuming buffer after timeout.
nonlocal buffer_consuming
if buffer_consuming is not None:
try:
await buffer_consuming
finally:
buffer_consuming = None
buffer_add(cast(T_co, value))
event = self.current_event
if event is not None:
event_add(event)
if buffer_size() >= max_:
# signal that the buffer is full and should be emptied.
buffer_full.set()
# strict wait for buffer to be consumed after buffer full.
# (if max_ is 1000, we are not allowed to return 1001 values.)
buffer_consumed.clear()
await self.wait(buffer_consumed)
return value
self.add_processor(add_to_buffer)
self._enable_passive(cast(ChannelT, channel_it))
while not self.should_stop:
# wait until buffer full, or timeout
await self.wait_for_stopped(buffer_full, timeout=timeout)
if buffer:
# make sure background thread does not add new times to
# budfer while we read.
buffer_consuming = self.loop.create_future()
try:
yield list(buffer)
finally:
buffer.clear()
for event in events:
await self.ack(event)
events.clear()
# allow writing to buffer again
notify(buffer_consuming)
buffer_full.clear()
buffer_consumed.set()
[docs] def enumerate(self, start: int = 0) -> AsyncIterable[Tuple[int, T_co]]:
"""Enumerate values received on this stream.
Unlike Python's built-in ``enumerate``, this works with
async generators.
"""
return aenumerate(self, start)
[docs] def through(self, channel: Union[str, ChannelT]) -> StreamT:
"""Forward values to in this stream to channel.
Send messages received on this stream to another channel,
and return a new stream that consumes from that channel.
Notes:
The messages are forwarded after any processors have been
applied.
Example:
.. sourcecode:: python
topic = app.topic('foo')
@app.agent(topic)
async def mytask(stream):
async for value in stream.through(app.topic('bar')):
# value was first received in topic 'foo',
# then forwarded and consumed from topic 'bar'
print(value)
"""
if self._finalized:
# if agent restart we reuse the same stream object
# which already have done the stream.through()
# so on iteration we set the finalized flag
# and make this through() a noop.
return self
if self.concurrency_index is not None:
raise ImproperlyConfigured(
'Agent with concurrency>1 cannot use stream.through!')
# ridiculous mypy
if isinstance(channel, str):
channelchannel = cast(ChannelT, self.derive_topic(channel))
else:
channelchannel = channel
channel_it = aiter(channelchannel)
if self._next is not None:
raise ImproperlyConfigured(
'Stream is already using group_by/through')
through = self._chain(channel=channel_it)
async def forward(value: T) -> T:
event = self.current_event
return await maybe_forward(event, channelchannel)
self.add_processor(forward)
self._enable_passive(cast(ChannelT, channel_it), declare=True)
return through
def _enable_passive(self, channel: ChannelT, *,
declare: bool = False) -> None:
if not self._passive:
self._passive = True
self.add_future(self._passive_drainer(channel, declare))
async def _passive_drainer(self, channel: ChannelT,
declare: bool = False) -> None:
try:
if declare:
await channel.maybe_declare()
self._passive_started.set()
try:
async for item in self: # noqa
...
except BaseException as exc:
# forward the exception to the final destination channel,
# e.g. in through/group_by/etc.
await channel.throw(exc)
finally:
self._channel_stop_iteration(channel)
self._passive = False
def _channel_stop_iteration(self, channel: Any) -> None:
try:
on_stop_iteration = channel.on_stop_iteration
except AttributeError:
pass
else:
on_stop_iteration()
[docs] def echo(self, *channels: Union[str, ChannelT]) -> StreamT:
"""Forward values to one or more channels.
Unlike :meth:`through`, we don't consume from these channels.
"""
_channels = [
self.derive_topic(c) if isinstance(c, str) else c for c in channels
]
async def echoing(value: T) -> T:
await asyncio.wait(
[maybe_forward(value, channel) for channel in _channels],
loop=self.loop,
return_when=asyncio.ALL_COMPLETED,
)
return value
self.add_processor(echoing)
return self
[docs] def group_by(self,
key: GroupByKeyArg,
*,
name: str = None,
topic: TopicT = None,
partitions: int = None) -> StreamT:
"""Create new stream that repartitions the stream using a new key.
Arguments:
key: The key argument decides how the new key is generated,
it can be a field descriptor, a callable, or an async
callable.
Note: The ``name`` argument must be provided if the key
argument is a callable.
name: Suffix to use for repartitioned topics.
This argument is required if `key` is a callable.
Examples:
Using a field descriptor to use a field in the event as the new
key:
.. sourcecode:: python
s = withdrawals_topic.stream()
# values in this stream are of type Withdrawal
async for event in s.group_by(Withdrawal.account_id):
...
Using an async callable to extract a new key:
.. sourcecode:: python
s = withdrawals_topic.stream()
async def get_key(withdrawal):
return await aiohttp.get(
f'http://e.com/resolve_account/{withdrawal.account_id}')
async for event in s.group_by(get_key):
...
Using a regular callable to extract a new key:
.. sourcecode:: python
s = withdrawals_topic.stream()
def get_key(withdrawal):
return withdrawal.account_id.upper()
async for event in s.group_by(get_key):
...
"""
if self._finalized:
# see note in self.through()
return self
channel: ChannelT
if self.concurrency_index is not None:
raise ImproperlyConfigured(
'Agent with concurrency>1 cannot use stream.group_by!')
if not name:
if isinstance(key, FieldDescriptorT):
name = cast(FieldDescriptorT, key).ident
else:
raise TypeError(
'group_by with callback must set name=topic_suffix')
if topic is not None:
channel = topic
else:
suffix = '-' + self.app.conf.id + '-' + name + '-repartition'
p = partitions if partitions else self.app.conf.topic_partitions
channel = cast(ChannelT, self.channel).derive(
suffix=suffix, partitions=p, internal=True)
format_key = self._format_key
channel_it = aiter(channel)
if self._next is not None:
raise ImproperlyConfigured('Stream already uses group_by/through')
grouped = self._chain(channel=channel_it)
async def repartition(value: T) -> T:
event = self.current_event
if event is None:
raise RuntimeError(
'Cannot repartition stream with non-topic channel')
new_key = await format_key(key, value)
await event.forward(channel, key=new_key)
return value
self.add_processor(repartition)
self._enable_passive(cast(ChannelT, channel_it), declare=True)
return grouped
async def _format_key(self, key: GroupByKeyArg, value: T_contra) -> str:
if isinstance(key, FieldDescriptorT):
return cast(FieldDescriptorT, key).getattr(cast(ModelT, value))
return await maybe_async(cast(Callable, key)(value))
[docs] def derive_topic(self,
name: str,
*,
key_type: ModelArg = None,
value_type: ModelArg = None,
prefix: str = '',
suffix: str = '') -> TopicT:
"""Create Topic description derived from the K/V type of this stream.
Arguments:
name: Topic name.
key_type: Specific key type to use for this topic.
If not set, the key type of this stream will be used.
value_type: Specific value type to use for this topic.
If not set, the value type of this stream will be used.
Raises:
ValueError: if the stream channel is not a topic.
"""
if isinstance(self.channel, TopicT):
return self.channel.derive_topic(
topics=[name],
key_type=key_type,
value_type=value_type,
prefix=prefix,
suffix=suffix,
)
raise ValueError('Cannot derive topic from non-topic channel.')
[docs] async def throw(self, exc: BaseException) -> None:
await cast(ChannelT, self.channel).throw(exc)
[docs] def combine(self, *nodes: JoinableT, **kwargs: Any) -> StreamT:
# A combined stream is composed of multiple streams that
# all share the same outbox.
# The resulting stream's `on_merge` callback can be used to
# process values from all the combined streams, and e.g.
# joins uses this to consolidate multiple values into one.
if self._finalized:
# see note in self.through()
return self
stream = self._chain(combined=self.combined + list(nodes))
for node in stream.combined:
node.contribute_to_stream(stream)
return stream
[docs] def contribute_to_stream(self, active: StreamT) -> None:
self.outbox = active.outbox
[docs] async def remove_from_stream(self, stream: StreamT) -> None:
await self.stop()
[docs] def join(self, *fields: FieldDescriptorT) -> StreamT:
return self._join(joins.RightJoin(stream=self, fields=fields))
[docs] def left_join(self, *fields: FieldDescriptorT) -> StreamT:
return self._join(joins.LeftJoin(stream=self, fields=fields))
[docs] def inner_join(self, *fields: FieldDescriptorT) -> StreamT:
return self._join(joins.InnerJoin(stream=self, fields=fields))
[docs] def outer_join(self, *fields: FieldDescriptorT) -> StreamT:
return self._join(joins.OuterJoin(stream=self, fields=fields))
def _join(self, join_strategy: JoinT) -> StreamT:
return self.clone(join_strategy=join_strategy)
[docs] async def on_merge(self, value: T = None) -> Optional[T]:
# TODO for joining streams
# The join strategy.process method can return None
# to eat the value, and on the next event create a merged
# event out of the previous event and new event.
join_strategy = self.join_strategy
if join_strategy:
value = await join_strategy.process(value)
return value
[docs] async def send(self, value: T_contra) -> None:
"""Send value into stream locally (bypasses topic)."""
if isinstance(self.channel, ChannelT):
await cast(ChannelT, self.channel).put(value)
else:
raise NotImplementedError(
'Cannot send to non-topic channel stream.')
[docs] async def on_start(self) -> None:
if self._on_start:
await self._on_start()
if self._passive:
await self._passive_started.wait()
[docs] async def stop(self) -> None:
# Stop all related streams (created by .through/.group_by/etc.)
for s in cast(Stream, self.get_root_stream())._iter_ll_forwards():
await Service.stop(s)
[docs] async def on_stop(self) -> None:
self._passive = False
self._passive_started.clear()
for table_or_stream in self.combined:
await table_or_stream.remove_from_stream(self)
def __iter__(self) -> Any:
return self
def __next__(self) -> Any:
raise NotImplementedError('Streams are asynchronous: use `async for`')
async def __aiter__(self) -> AsyncIterator:
self._finalized = True
loop = self.loop
await self.maybe_start()
on_merge = self.on_merge
on_stream_event_out = self._on_stream_event_out
on_message_out = self._on_message_out
# get from channel
channel = self.channel
if isinstance(channel, ChannelT):
chan_is_channel = True
chan = cast(ChannelT, self.channel)
chan_queue = chan.queue
chan_queue_empty = chan_queue.empty
chan_errors = chan_queue._errors
chan_quick_get = chan_queue.get_nowait
else:
chan_is_channel = False
chan_queue = None
chan_queue_empty = None
chan_errors = None
chan_quick_get = None
chan_slow_get = channel.__anext__
# Topic description -> processors
processors = self._processors
# Sensor: on_stream_event_in
on_stream_event_in = self._on_stream_event_in
# localize global variables
create_ref = weakref.ref
_maybe_async = maybe_async
event_cls = EventT
_current_event_contextvar = _current_event
ack_exceptions = self.app.conf.stream_ack_exceptions
ack_cancelled_tasks = self.app.conf.stream_ack_cancelled_tasks
consumer: ConsumerT = self.app.consumer
unacked: Set[Message] = consumer._unacked_messages
add_unacked: Callable[[Message], None] = unacked.add
acking_topics: Set[str] = self.app.topics._acking_topics
on_message_in = self.app.sensors.on_message_in
sleep = asyncio.sleep
try:
while not self.should_stop:
event = None
do_ack = self.enable_acks # set to False to not ack event.
# wait for next message
value: Any = None
# we iterate until on_merge gives value.
while value is None:
await sleep(0, loop=loop)
# get message from channel
# This inlines ThrowableQueue.get for performance:
# We selectively call `await Q.put`/`Q.put_nowait`,
# and prefer the latter if the queue is non-empty.
channel_value: Any
if chan_is_channel:
if chan_errors:
raise chan_errors.popleft()
if chan_queue_empty():
channel_value = await chan_slow_get()
else:
channel_value = chan_quick_get()
else:
# chan is an AsyncIterable
channel_value = await chan_slow_get()
if isinstance(channel_value, event_cls):
event = channel_value
message = event.message
topic = message.topic
tp = message.tp
offset = message.offset
if topic in acking_topics and not message.tracked:
message.tracked = True
# This inlines Consumer.track_message(message)
add_unacked(message)
on_message_in(message.tp, message.offset, message)
# XXX ugh this should be in the consumer somehow
if consumer._last_batch is None:
# set last_batch received timestamp if not
# already set. The commit livelock monitor
# uses this to check how long between
# receiving a message to we commit it
# (we reset _last_batch to None in .commit()).
consumer._last_batch = monotonic()
# call Sensors
on_stream_event_in(tp, offset, self, event)
# set task-local current_event
_current_event_contextvar.set(create_ref(event))
# set Stream._current_event
self.current_event = event
# Stream yields Event.value
value = event.value
else:
value = channel_value
self.current_event = None
# reduce using processors
for processor in processors:
value = await _maybe_async(processor(value))
value = await on_merge(value)
try:
yield value
except CancelledError:
if not ack_cancelled_tasks:
do_ack = False
raise
except Exception:
if not ack_exceptions:
do_ack = False
raise
except GeneratorExit:
raise # consumer did `break`
except BaseException:
# e.g. SystemExit/KeyboardInterrupt
if not ack_cancelled_tasks:
do_ack = False
raise
finally:
self.current_event = None
if do_ack and event is not None:
# This inlines self.ack
last_stream_to_ack = event.ack()
message = event.message
tp = event.message.tp
offset = event.message.offset
on_stream_event_out(tp, offset, self, event)
if last_stream_to_ack:
on_message_out(tp, offset, message)
except StopAsyncIteration:
# We are not allowed to propagate StopAsyncIteration in __aiter__
# (if we do, it'll be converted to RuntimeError by CPython).
# It can be raised when streaming over a list:
# async for value in app.stream([1, 2, 3, 4]):
# ...
# To support that, we just return here and that will stop
# the iteration.
return
finally:
self._channel_stop_iteration(channel)
async def __anext__(self) -> T: # pragma: no cover
...
[docs] async def ack(self, event: EventT) -> bool:
"""Ack event.
This will decrease the reference count of the event message by one,
and when the reference count reaches zero, the worker will
commit the offset so that the message will not be seen by a worker
again.
Arguments:
event: Event to ack.
"""
# WARNING: This function is duplicated in __aiter__
last_stream_to_ack = event.ack()
message = event.message
tp = message.tp
offset = message.offset
self._on_stream_event_out(tp, offset, self, event)
if last_stream_to_ack:
self._on_message_out(tp, offset, message)
return last_stream_to_ack
def __and__(self, other: Any) -> Any:
return self.combine(self, other)
def __copy__(self) -> Any:
return self.clone()
def _repr_info(self) -> str:
if self.combined:
return reprlib.repr(self.combined)
return reprlib.repr(self.channel)
@property
def label(self) -> str:
# used as textual description in graphs
return f'{type(self).__name__}: {self._repr_channel()}'
def _repr_channel(self) -> str:
return reprlib.repr(self.channel)
[docs] @cached_property
def shortlabel(self) -> str:
# used for shortlabel(stream), which is used by statsd to generate ids
# note: str(channel) returns topic name when it's a topic, so
# this will be:
# "Channel: <ANON>", for channel or
# "Topic: withdrawals", for a topic.
# statsd then uses that as part of the id.
return f'Stream: {self._human_channel()}'
def _human_channel(self) -> str:
if self.combined:
return '&'.join(s._human_channel() for s in self.combined)
return f'{type(self.channel).__name__}: {self.channel}'