| 
1 | 1 | import pickle  | 
2 |  | -from typing import Any, Dict, Optional, TypeVar, Union  | 
 | 2 | +import sys  | 
 | 3 | +from contextlib import asynccontextmanager  | 
 | 4 | +from typing import (  | 
 | 5 | + TYPE_CHECKING,  | 
 | 6 | + Any,  | 
 | 7 | + AsyncIterator,  | 
 | 8 | + Dict,  | 
 | 9 | + List,  | 
 | 10 | + Optional,  | 
 | 11 | + Tuple,  | 
 | 12 | + TypeVar,  | 
 | 13 | + Union,  | 
 | 14 | +)  | 
3 | 15 | 
 
  | 
4 |  | -from redis.asyncio import BlockingConnectionPool, Redis  | 
 | 16 | +from redis.asyncio import BlockingConnectionPool, Redis, Sentinel  | 
5 | 17 | from redis.asyncio.cluster import RedisCluster  | 
6 | 18 | from taskiq import AsyncResultBackend  | 
7 | 19 | from taskiq.abc.result_backend import TaskiqResult  | 
 | 20 | +from taskiq.abc.serializer import TaskiqSerializer  | 
8 | 21 | 
 
  | 
9 | 22 | from taskiq_redis.exceptions import (  | 
10 | 23 |  DuplicateExpireTimeSelectedError,  | 
11 | 24 |  ExpireTimeMustBeMoreThanZeroError,  | 
12 | 25 |  ResultIsMissingError,  | 
13 | 26 | )  | 
 | 27 | +from taskiq_redis.serializer import PickleSerializer  | 
 | 28 | + | 
 | 29 | +if sys.version_info >= (3, 10):  | 
 | 30 | + from typing import TypeAlias  | 
 | 31 | +else:  | 
 | 32 | + from typing_extensions import TypeAlias  | 
 | 33 | + | 
 | 34 | +if TYPE_CHECKING:  | 
 | 35 | + _Redis: TypeAlias = Redis[bytes]  | 
 | 36 | +else:  | 
 | 37 | + _Redis: TypeAlias = Redis  | 
14 | 38 | 
 
  | 
15 | 39 | _ReturnType = TypeVar("_ReturnType")  | 
16 | 40 | 
 
  | 
@@ -267,3 +291,142 @@ async def get_result(  | 
267 | 291 |  taskiq_result.log = None  | 
268 | 292 | 
 
  | 
269 | 293 |  return taskiq_result  | 
 | 294 | + | 
 | 295 | + | 
 | 296 | +class RedisAsyncSentinelResultBackend(AsyncResultBackend[_ReturnType]):  | 
 | 297 | + """Async result based on redis sentinel."""  | 
 | 298 | + | 
 | 299 | + def __init__(  | 
 | 300 | + self,  | 
 | 301 | + sentinels: List[Tuple[str, int]],  | 
 | 302 | + master_name: str,  | 
 | 303 | + keep_results: bool = True,  | 
 | 304 | + result_ex_time: Optional[int] = None,  | 
 | 305 | + result_px_time: Optional[int] = None,  | 
 | 306 | + min_other_sentinels: int = 0,  | 
 | 307 | + sentinel_kwargs: Optional[Any] = None,  | 
 | 308 | + serializer: Optional[TaskiqSerializer] = None,  | 
 | 309 | + **connection_kwargs: Any,  | 
 | 310 | + ) -> None:  | 
 | 311 | + """  | 
 | 312 | + Constructs a new result backend.  | 
 | 313 | +
  | 
 | 314 | + :param sentinels: list of sentinel host and ports pairs.  | 
 | 315 | + :param master_name: sentinel master name.  | 
 | 316 | + :param keep_results: flag to not remove results from Redis after reading.  | 
 | 317 | + :param result_ex_time: expire time in seconds for result.  | 
 | 318 | + :param result_px_time: expire time in milliseconds for result.  | 
 | 319 | + :param max_connection_pool_size: maximum number of connections in pool.  | 
 | 320 | + :param connection_kwargs: additional arguments for redis BlockingConnectionPool.  | 
 | 321 | +
  | 
 | 322 | + :raises DuplicateExpireTimeSelectedError: if result_ex_time  | 
 | 323 | + and result_px_time are selected.  | 
 | 324 | + :raises ExpireTimeMustBeMoreThanZeroError: if result_ex_time  | 
 | 325 | + and result_px_time are equal zero.  | 
 | 326 | + """  | 
 | 327 | + self.sentinel = Sentinel(  | 
 | 328 | + sentinels=sentinels,  | 
 | 329 | + min_other_sentinels=min_other_sentinels,  | 
 | 330 | + sentinel_kwargs=sentinel_kwargs,  | 
 | 331 | + **connection_kwargs,  | 
 | 332 | + )  | 
 | 333 | + self.master_name = master_name  | 
 | 334 | + if serializer is None:  | 
 | 335 | + serializer = PickleSerializer()  | 
 | 336 | + self.serializer = serializer  | 
 | 337 | + self.keep_results = keep_results  | 
 | 338 | + self.result_ex_time = result_ex_time  | 
 | 339 | + self.result_px_time = result_px_time  | 
 | 340 | + | 
 | 341 | + unavailable_conditions = any(  | 
 | 342 | + (  | 
 | 343 | + self.result_ex_time is not None and self.result_ex_time <= 0,  | 
 | 344 | + self.result_px_time is not None and self.result_px_time <= 0,  | 
 | 345 | + ),  | 
 | 346 | + )  | 
 | 347 | + if unavailable_conditions:  | 
 | 348 | + raise ExpireTimeMustBeMoreThanZeroError(  | 
 | 349 | + "You must select one expire time param and it must be more than zero.",  | 
 | 350 | + )  | 
 | 351 | + | 
 | 352 | + if self.result_ex_time and self.result_px_time:  | 
 | 353 | + raise DuplicateExpireTimeSelectedError(  | 
 | 354 | + "Choose either result_ex_time or result_px_time.",  | 
 | 355 | + )  | 
 | 356 | + | 
 | 357 | + @asynccontextmanager  | 
 | 358 | + async def _acquire_master_conn(self) -> AsyncIterator[_Redis]:  | 
 | 359 | + async with self.sentinel.master_for(self.master_name) as redis_conn:  | 
 | 360 | + yield redis_conn  | 
 | 361 | + | 
 | 362 | + async def set_result(  | 
 | 363 | + self,  | 
 | 364 | + task_id: str,  | 
 | 365 | + result: TaskiqResult[_ReturnType],  | 
 | 366 | + ) -> None:  | 
 | 367 | + """  | 
 | 368 | + Sets task result in redis.  | 
 | 369 | +
  | 
 | 370 | + Dumps TaskiqResult instance into the bytes and writes  | 
 | 371 | + it to redis.  | 
 | 372 | +
  | 
 | 373 | + :param task_id: ID of the task.  | 
 | 374 | + :param result: TaskiqResult instance.  | 
 | 375 | + """  | 
 | 376 | + redis_set_params: Dict[str, Union[str, bytes, int]] = {  | 
 | 377 | + "name": task_id,  | 
 | 378 | + "value": self.serializer.dumpb(result),  | 
 | 379 | + }  | 
 | 380 | + if self.result_ex_time:  | 
 | 381 | + redis_set_params["ex"] = self.result_ex_time  | 
 | 382 | + elif self.result_px_time:  | 
 | 383 | + redis_set_params["px"] = self.result_px_time  | 
 | 384 | + | 
 | 385 | + async with self._acquire_master_conn() as redis:  | 
 | 386 | + await redis.set(**redis_set_params) # type: ignore  | 
 | 387 | + | 
 | 388 | + async def is_result_ready(self, task_id: str) -> bool:  | 
 | 389 | + """  | 
 | 390 | + Returns whether the result is ready.  | 
 | 391 | +
  | 
 | 392 | + :param task_id: ID of the task.  | 
 | 393 | +
  | 
 | 394 | + :returns: True if the result is ready else False.  | 
 | 395 | + """  | 
 | 396 | + async with self._acquire_master_conn() as redis:  | 
 | 397 | + return bool(await redis.exists(task_id))  | 
 | 398 | + | 
 | 399 | + async def get_result(  | 
 | 400 | + self,  | 
 | 401 | + task_id: str,  | 
 | 402 | + with_logs: bool = False,  | 
 | 403 | + ) -> TaskiqResult[_ReturnType]:  | 
 | 404 | + """  | 
 | 405 | + Gets result from the task.  | 
 | 406 | +
  | 
 | 407 | + :param task_id: task's id.  | 
 | 408 | + :param with_logs: if True it will download task's logs.  | 
 | 409 | + :raises ResultIsMissingError: if there is no result when trying to get it.  | 
 | 410 | + :return: task's return value.  | 
 | 411 | + """  | 
 | 412 | + async with self._acquire_master_conn() as redis:  | 
 | 413 | + if self.keep_results:  | 
 | 414 | + result_value = await redis.get(  | 
 | 415 | + name=task_id,  | 
 | 416 | + )  | 
 | 417 | + else:  | 
 | 418 | + result_value = await redis.getdel(  | 
 | 419 | + name=task_id,  | 
 | 420 | + )  | 
 | 421 | + | 
 | 422 | + if result_value is None:  | 
 | 423 | + raise ResultIsMissingError  | 
 | 424 | + | 
 | 425 | + taskiq_result: TaskiqResult[_ReturnType] = pickle.loads( # noqa: S301  | 
 | 426 | + result_value,  | 
 | 427 | + )  | 
 | 428 | + | 
 | 429 | + if not with_logs:  | 
 | 430 | + taskiq_result.log = None  | 
 | 431 | + | 
 | 432 | + return taskiq_result  | 
0 commit comments