"""Redis cache backend."""
import socket
import typing

from enum import Enum
from typing import Any, Mapping, Optional, Type, Union

from mode.utils.compat import want_bytes
from mode.utils.objects import cached_property
from yarl import URL

from faust.exceptions import ImproperlyConfigured
from faust.types import AppT
from . import base

    import aredis
    import aredis.exceptions
except ImportError:  # pragma: no cover
    aredis = None  # noqa

if typing.TYPE_CHECKING:  # pragma: no cover
    from aredis import StrictRedis as _RedisClientT
    class _RedisClientT: ...  # noqa

[docs]class RedisScheme(Enum): """Types of Redis configurations.""" SINGLE_NODE = 'redis' CLUSTER = 'rediscluster'
[docs]class CacheBackend(base.CacheBackend): """Backend for cache operations using Redis.""" connect_timeout: Optional[float] stream_timeout: Optional[float] max_connections: Optional[int] max_connections_per_node: Optional[int] _client: Optional[_RedisClientT] = None _client_by_scheme: Mapping[str, Type[_RedisClientT]] if aredis is None: # pragma: no cover ... else: operational_errors = ( socket.error, IOError, OSError, aredis.exceptions.ConnectionError, aredis.exceptions.TimeoutError, ) invalidating_errors = ( aredis.exceptions.DataError, aredis.exceptions.InvalidResponse, aredis.exceptions.ResponseError, ) irrecoverable_errors = ( aredis.exceptions.AuthenticationError, ) def __init__(self, app: AppT, url: Union[URL, str], *, connect_timeout: float = None, stream_timeout: float = None, max_connections: int = None, max_connections_per_node: int = None, **kwargs: Any) -> None: super().__init__(app, url, **kwargs) self.connect_timeout = connect_timeout self.stream_timeout = stream_timeout self.max_connections = max_connections self.max_connections_per_node = max_connections_per_node self._client_by_scheme = self._init_schemes() def _init_schemes(self) -> Mapping[str, Type[_RedisClientT]]: if aredis is None: # pragma: no cover return {} else: return { RedisScheme.SINGLE_NODE.value: aredis.StrictRedis, RedisScheme.CLUSTER.value: aredis.StrictRedisCluster, } async def _get(self, key: str) -> Optional[bytes]: value: Optional[bytes] = await self.client.get(key) if value is not None: return want_bytes(value) return None async def _set(self, key: str, value: bytes, timeout: float) -> None: await self.client.setex(key, int(timeout), value) async def _delete(self, key: str) -> None: await self.client.delete(key)
[docs] async def on_start(self) -> None: """Call when Redis backend starts.""" if aredis is None: raise ImproperlyConfigured( 'Redis cache backend requires `pip install aredis`') await self.connect()
[docs] async def connect(self) -> None: """Connect to Redis/Redis Cluster server.""" if self._client is None: self._client = self._new_client() await
def _new_client(self) -> _RedisClientT: return self._client_from_url_and_query(self.url, **self.url.query) def _client_from_url_and_query( self, url: URL, *, connect_timeout: str = None, stream_timeout: str = None, max_connections: str = None, max_connections_per_node: str = None, **kwargs: Any) -> _RedisClientT: Client = self._client_by_scheme[url.scheme] return Client(**self._prepare_client_kwargs( url,, port=url.port, db=self._db_from_path(url.path), password=url.password, connect_timeout=self._float_from_str( connect_timeout, self.connect_timeout), stream_timeout=self._float_from_str( stream_timeout, self.stream_timeout), max_connections=self._int_from_str( max_connections, self.max_connections), max_connections_per_node=self._int_from_str( max_connections_per_node, self.max_connections_per_node), skip_full_coverage_check=True, )) def _prepare_client_kwargs(self, url: URL, **kwargs: Any) -> Mapping: if url.scheme == RedisScheme.CLUSTER.value: return self._as_cluster_kwargs(**kwargs) return kwargs def _as_cluster_kwargs(self, db: str = None, **kwargs: Any) -> Mapping: # Redis Cluster does not support db as argument. return kwargs def _int_from_str(self, val: str = None, default: int = None) -> Optional[int]: return int(val) if val else default def _float_from_str(self, val: str = None, default: float = None) -> Optional[float]: return float(val) if val else default def _db_from_path(self, path: str) -> int: if not path or path == '/': return 0 # default db try: return int(path.strip('/')) except ValueError: raise ValueError( f'Database is int between 0 and limit - 1, not {path!r}')
[docs] @cached_property def client(self) -> _RedisClientT: """Return Redis client instance.""" if self._client is None: raise RuntimeError('Cache backend not started') return self._client