DEV Community

Wu Haotian
Wu Haotian

Posted on • Originally published at blog.whtsky.me on

Decorator type gymnastics in Python

You need to type hints your decorator

Say you have a simple decorator for adding logging before calling a function.

import logging def add_log(f): def wrapper(*args, **kwargs): logging.info("Called f!") return f(*args, **kwargs) return wrapper @add_log def two_sum(a, b): return a + b 
Enter fullscreen mode Exit fullscreen mode

One day you decide to add type hints for this module -- it's easy to add type hints for two_sum

def two_sum(a: int, b: int) -> int: return a + b 
Enter fullscreen mode Exit fullscreen mode

But you need to add type hints for your decorator (add_log in this case) too, or you'll get Anyed wrapped function. mypy's reveal_type can be used for verifying this.

import logging def add_log(f): def wrapper(*args, **kwargs): # type: ignore  logging.info("Called f!") return f(*args, **kwargs) return wrapper def two_sum(a: int, b: int) -> int: return a + b reveal_type(two_sum) # Revealed type is 'def (a: builtins.int, b: builtins.int) -> builtins.int' reveal_type(add_log(two_sum)) # Revealed type is 'Any' 
Enter fullscreen mode Exit fullscreen mode

Simple type hints for simple decorators

Let's adding type hints for this simple decorator. For a simple decorator which doesn't modify the functions' arguments and return (like add_log above ), TypeVar should do the job pretty well.

from typing import TypeVar, Callable, cast import logging TCallable = TypeVar("TCallable", bound=Callable) def add_log(f: TCallable) -> TCallable: def wrapper(*args, **kwargs): # type: ignore  logging.info("Called f!") return f(*args, **kwargs) return cast(TCallable, wrapper) @add_log def two_sum(a: int, b: int) -> int: return a + b reveal_type(two_sum) # Revealed type is 'def (a: builtins.int, b: builtins.int) -> builtins.int' reveal_type(add_log(two_sum)) # Revealed type is 'def (a: builtins.int, b: builtins.int) -> builtins.int' 
Enter fullscreen mode Exit fullscreen mode

Hard type hints for hard decorators

But, what if you want to modify the arguments and/or return value? Well, there's no easy way to type arguments, but at least you can type return value correctly

from typing import TypeVar, Callable, Awaitable, cast R = TypeVar('R') def sync_to_async(f: Callable[..., R]) -> Callable[..., Awaitable[R]]: async def wrapper(*args, **kwargs): # type: ignore  return f(*args, **kwargs) return cast(Callable[..., Awaitable[R]], wrapper) @sync_to_async def two_sum(a: int, b: int) -> int: return a + b reveal_type(two_sum) # Revealed type is 'def (*Any, **Any) -> typing.Awaitable[builtins.int*]' reveal_type(two_sum(2, "3")) # Revealed type is 'typing.Awaitable[builtins.int*] 
Enter fullscreen mode Exit fullscreen mode

The dark way for typing arguments & return values

You can do some type gymnastics.. by generating numbers of TypeVars and using overload.

from typing import TypeVar, Callable, Awaitable, overload A = TypeVar('A') B = TypeVar('B') C = TypeVar('C') D = TypeVar('D') E = TypeVar('E') RV = TypeVar('RV') @overload def sync_to_async(f: Callable[[A], RV]) -> Callable[[A], Awaitable[RV]]: ... @overload def sync_to_async(f: Callable[[A, B], RV]) -> Callable[[A, B], Awaitable[RV]]: ... @overload def sync_to_async(f: Callable[[A, B, C], RV]) -> Callable[[A, B, C], Awaitable[RV]]: ... @overload def sync_to_async(f: Callable[[A, B, C, D], RV]) -> Callable[[A, B, C, D], Awaitable[RV]]: ... @overload def sync_to_async(f: Callable[[A, B, C, D, E], RV]) -> Callable[[A, B, C, D, E], Awaitable[RV]]: ... def sync_to_async(f): async def wrapper(*args, **kwargs): return f(*args, **kwargs) return wrapper @sync_to_async def two_sum(a: int, b: int) -> int: return a + b @sync_to_async def do_log(content: str) -> None: print(content) reveal_type(two_sum) # Revealed type is 'def (builtins.int*, builtins.int*) -> typing.Awaitable[builtins.int*]' reveal_type(do_log) # Revealed type is 'def (builtins.str*) -> typing.Awaitable[None]' 
Enter fullscreen mode Exit fullscreen mode

If you insist to go this way, here's the code snippet I used for generating code above:

gymnastics = 5 for i in range(gymnastics): char = chr(i + 65) print(f"{char} = TypeVar('{char}')") print("RV = TypeVar('RV')") for i in range(gymnastics): chars = [chr(n + 65) for n in range(i + 1)] args = ", ".join(chars) print(f"""@overload def sync_to_async(f: Callable[[{args}], RV]) -> Callable[[{args}], Awaitable[RV]]: ...""") 
Enter fullscreen mode Exit fullscreen mode

The Future: PEP-612

PEP-612 defines ParamSpec and Concatenate. They can make type hinting decorators pretty easy:

from typing import Concatenate, ParamSpec P = ParamSpec('P') R = TypeVar('R') def with_context(f: Callable[Concatenate[Context, P], R]) -> Callable[P, R]: def inner(*args: P.args, **kwargs: P.kwargs) -> R: return f(context, *args, **kwargs) return inner @with_context def request(context: Context) -> int: return 42 
Enter fullscreen mode Exit fullscreen mode

The sad thing is PEP-612 is not widely supported, as of now mypy does not fully support it.

Fin

May the type be with you.

Top comments (0)