"""Agent replies: waiting for replies, sending them, etc."""
import asyncio
from collections import defaultdict
from typing import (
Any,
AsyncIterator,
MutableMapping,
MutableSet,
NamedTuple,
Optional,
)
from weakref import WeakSet
from mode import Service
from faust.types import AppT, ChannelT, TopicT
from .models import ReqRepResponse
__all__ = ['ReplyPromise', 'BarrierState', 'ReplyConsumer']
class ReplyTuple(NamedTuple):
correlation_id: str
value: Any
[docs]class ReplyPromise(asyncio.Future):
"""Reply promise can be :keyword:`await`-ed to wait until result ready."""
reply_to: str
correlation_id: str
def __init__(self, reply_to: str, correlation_id: str,
**kwargs: Any) -> None:
self.reply_to = reply_to
self.correlation_id = correlation_id
super().__init__(**kwargs)
[docs] def fulfill(self, correlation_id: str, value: Any) -> None:
"""Fulfill promise: a reply was received."""
# If it wasn't for BarrierState we would just use .set_result()
# directly, but BarrierState.fulfill requires the correlation_id
# to be sent with it. That way it can mark that part of the map
# operation as completed.
assert correlation_id == self.correlation_id
self.set_result(value)
[docs]class BarrierState(asyncio.Future):
"""State of pending/complete barrier.
A barrier is a synchronization primitive that will wait until
a group of coroutines have completed.
"""
reply_to: str
#: This is the size while the messages are being sent.
#: (it's a tentative total, added to until the total is finalized).
size: int = 0
#: This is the actual total when all messages have been sent.
#: It's set by :meth:`finalize`.
total: int = 0
#: The number of results we have received.
fulfilled: int = 0
#: Internal queue where results are added to.
_results: asyncio.Queue
#: Set of pending replies that this barrier is composed of.
pending: MutableSet[ReplyPromise]
def __init__(self, reply_to: str, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.reply_to = reply_to
self.pending = set()
loop: asyncio.AbstractEventLoop = self._loop
self._results = asyncio.Queue(maxsize=1000, loop=loop)
[docs] def add(self, p: ReplyPromise) -> None:
"""Add promise to barrier.
Note:
You can only add promises before the barrier is finalized
using :meth:`finalize`.
"""
self.pending.add(p)
self.size += 1
[docs] def finalize(self) -> None:
"""Finalize this barrier.
After finalization you can not grow or shrink the size
of the barrier.
"""
self.total = self.size
# The barrier may have been filled up already at this point,
if self.fulfilled >= self.total:
self.set_result(True)
self._results.put_nowait(None) # always wake-up .iterate()
[docs] def fulfill(self, correlation_id: str, value: Any) -> None:
"""Fulfill one of the promises in this barrier.
Once all promises in this barrier is fulfilled, the barrier
will be ready.
"""
# ReplyConsumer calls this whenever a new reply is received.
self._results.put_nowait(ReplyTuple(correlation_id, value))
self.fulfilled += 1
if self.total:
if self.fulfilled >= self.total:
self.set_result(True)
self._results.put_nowait(None) # always wake-up .iterate()
[docs] def get_nowait(self) -> ReplyTuple:
"""Return next reply, or raise :exc:`asyncio.QueueEmpty`."""
for _ in range(10): # remove sentinels
value = self._results.get_nowait()
if value is not None:
return value
raise asyncio.QueueEmpty()
[docs] async def iterate(self) -> AsyncIterator[ReplyTuple]:
"""Iterate over results as they arrive."""
get = self._results.get
get_nowait = self._results.get_nowait
is_done = self.done
while not is_done():
value = await get()
if value is not None:
yield value
while 1:
try:
value = get_nowait()
except asyncio.QueueEmpty:
break
else:
if value is not None:
yield value
[docs]class ReplyConsumer(Service):
"""Consumer responsible for redelegation of replies received."""
_waiting: MutableMapping[str, MutableSet[ReplyPromise]]
_fetchers: MutableMapping[str, Optional[asyncio.Future]]
def __init__(self, app: AppT, **kwargs: Any) -> None:
self.app = app
self._waiting = defaultdict(WeakSet)
self._fetchers = {}
super().__init__(**kwargs)
[docs] async def on_start(self) -> None:
"""Call when reply consumer starts."""
if self.app.conf.reply_create_topic:
await self._start_fetcher(self.app.conf.reply_to)
[docs] async def add(self, correlation_id: str, promise: ReplyPromise) -> None:
"""Register promise to start tracking when it arrives."""
reply_topic = promise.reply_to
if reply_topic not in self._fetchers:
await self._start_fetcher(reply_topic)
self._waiting[correlation_id].add(promise)
async def _start_fetcher(self, topic_name: str) -> None:
if topic_name not in self._fetchers:
# set the key as a lock, so it doesn't happen twice
self._fetchers[topic_name] = None
# declare the topic
topic = self._reply_topic(topic_name)
await topic.maybe_declare()
await self.sleep(3.0)
# then create the future
self._fetchers[topic_name] = self.add_future(
self._drain_replies(topic))
async def _drain_replies(self, channel: ChannelT) -> None:
async for reply in channel.stream():
for promise in self._waiting[reply.correlation_id]:
promise.fulfill(reply.correlation_id, reply.value)
def _reply_topic(self, topic: str) -> TopicT:
return self.app.topic(
topic,
partitions=1,
replicas=0,
deleting=True,
retention=self.app.conf.reply_expires,
value_type=ReqRepResponse,
)