Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import mypy.errorcodes as codes
from mypy import message_registry
from mypy.nodes import DictExpr, IntExpr, StrExpr, UnaryExpr
from mypy.nodes import DictExpr, IntExpr, StrExpr, TypeInfo, UnaryExpr
from mypy.plugin import (
AttributeContext,
ClassDefContext,
Expand Down Expand Up @@ -47,7 +47,12 @@
dataclass_tag_callback,
replace_function_sig_callback,
)
from mypy.plugins.enums import enum_member_callback, enum_name_callback, enum_value_callback
from mypy.plugins.enums import (
enum_member_callback,
enum_name_callback,
enum_new_callback,
enum_value_callback,
)
from mypy.plugins.functools import (
functools_total_ordering_maker_callback,
functools_total_ordering_makers,
Expand Down Expand Up @@ -104,6 +109,12 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type]
return partial_new_callback
elif fullname == "enum.member":
return enum_member_callback
elif (
(st := self.lookup_fully_qualified(fullname))
and isinstance(st.node, TypeInfo)
and getattr(st.node, "is_enum", False)
):
return enum_new_callback
return None

def get_function_signature_hook(
Expand Down
61 changes: 61 additions & 0 deletions mypy/plugins/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
LiteralType,
ProperType,
Type,
TypeVarType,
UnionType,
get_proper_type,
is_named_instance,
)
Expand Down Expand Up @@ -297,3 +299,62 @@ def _extract_underlying_field_name(typ: Type) -> str | None:
# as a string.
assert isinstance(underlying_literal.value, str)
return underlying_literal.value


def enum_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
"""This plugin refines the return type of `__new__`, ensuring reconstructed
Enums are idempotent.

By default, mypy will infer that `Foo(Foo.x)` is of type `Foo`. This plugin
ensures types are not loosened, meaning with this plugin enabled
`Foo(Foo.x)` is of type `Literal[Foo.x]?`.

This means with this plugin:
```
reveal_type(Foo(Foo.x)) # mypy reveals Literal[Foo.x]?
```

This plugin works by adjusting the return type of `__new__` to be the given
argument type, if and only if `__new__` comes from `enum.Enum`.

This plugin supports arguments that are Final, Literial, Union of Literials
and generic TypeVars.
"""
base_ret = ctx.default_return_type
enum_inst = get_proper_type(base_ret)
if not isinstance(enum_inst, Instance):
return base_ret

info: TypeInfo = enum_inst.type
if not info.is_enum:
return base_ret

if _implements_new(info):
return base_ret

if not ctx.args or not ctx.args[0] or not ctx.arg_types or not ctx.arg_types[0]:
return base_ret

arg0_t = get_proper_type(ctx.arg_types[0][0])

if isinstance(arg0_t, Instance) and arg0_t.type is info:
return arg0_t
elif isinstance(arg0_t, LiteralType) and arg0_t.fallback.type is info:
return arg0_t
elif isinstance(arg0_t, UnionType):

def is_memeber(given_t: ProperType) -> bool:
return (isinstance(given_t, Instance) and given_t.type is info) or (
isinstance(given_t, LiteralType) and given_t.fallback.type is info
)

items = [get_proper_type(it) for it in arg0_t.items]
if items and all(is_memeber(item) for item in items):
return arg0_t
elif (isinstance(arg0_t, TypeVarType)) and isinstance(
upperbound_t := get_proper_type(arg0_t.upper_bound), Instance
):
if upperbound_t.type is info:
return arg0_t

return base_ret
Loading
Loading