66from collections .abc import Mapping
77from collections .abc import Sequence
88from collections .abc import Sized
9- from contextlib import AbstractContextManager
109from decimal import Decimal
1110import math
1211from numbers import Complex
1312import pprint
1413import re
1514import sys
16- from types import TracebackType
1715from typing import Any
18- from typing import cast
19- from typing import final
20- from typing import get_args
21- from typing import get_origin
2216from typing import overload
2317from typing import TYPE_CHECKING
2418from typing import TypeVar
2519
2620import _pytest ._code
2721from _pytest .outcomes import fail
22+ from _pytest .raises_group import BaseExcT_co_default
23+ from _pytest .raises_group import RaisesExc
2824
2925
3026if sys .version_info < (3 , 11 ):
31- from exceptiongroup import BaseExceptionGroup
32- from exceptiongroup import ExceptionGroup
27+ pass
3328
3429if TYPE_CHECKING :
3530 from numpy import ndarray
@@ -791,15 +786,29 @@ def _as_numpy_array(obj: object) -> ndarray | None:
791786
792787# builtin pytest.raises helper
793788
794- E = TypeVar ("E" , bound = BaseException )
789+ E = TypeVar ("E" , bound = BaseException , default = BaseException )
795790
796791
797792@overload
798793def raises (
799794 expected_exception : type [E ] | tuple [type [E ], ...],
800795 * ,
801796 match : str | re .Pattern [str ] | None = ...,
802- ) -> RaisesContext [E ]: ...
797+ check : Callable [[BaseExcT_co_default ], bool ] = ...,
798+ ) -> RaisesExc [E ]: ...
799+
800+
801+ @overload
802+ def raises (
803+ * ,
804+ match : str | re .Pattern [str ],
805+ # If exception_type is not provided, check() must do any typechecks itself.
806+ check : Callable [[BaseException ], bool ] = ...,
807+ ) -> RaisesExc [BaseException ]: ...
808+
809+
810+ @overload
811+ def raises (* , check : Callable [[BaseException ], bool ]) -> RaisesExc [BaseException ]: ...
803812
804813
805814@overload
@@ -812,8 +821,10 @@ def raises(
812821
813822
814823def raises (
815- expected_exception : type [E ] | tuple [type [E ], ...], * args : Any , ** kwargs : Any
816- ) -> RaisesContext [E ] | _pytest ._code .ExceptionInfo [E ]:
824+ expected_exception : type [E ] | tuple [type [E ], ...] | None = None ,
825+ * args : Any ,
826+ ** kwargs : Any ,
827+ ) -> RaisesExc [BaseException ] | _pytest ._code .ExceptionInfo [E ]:
817828 r"""Assert that a code block/function call raises an exception type, or one of its subclasses.
818829
819830 :param expected_exception:
@@ -960,117 +971,38 @@ def raises(
960971 """
961972 __tracebackhide__ = True
962973
974+ if not args :
975+ if set (kwargs ) - {"match" , "check" , "expected_exception" }:
976+ msg = "Unexpected keyword arguments passed to pytest.raises: "
977+ msg += ", " .join (sorted (kwargs ))
978+ msg += "\n Use context-manager form instead?"
979+ raise TypeError (msg )
980+
981+ if expected_exception is None :
982+ return RaisesExc (** kwargs )
983+ return RaisesExc (expected_exception , ** kwargs )
984+
963985 if not expected_exception :
964986 raise ValueError (
965987 f"Expected an exception type or a tuple of exception types, but got `{ expected_exception !r} `. "
966988 f"Raising exceptions is already understood as failing the test, so you don't need "
967989 f"any special code to say 'this should never raise an exception'."
968990 )
969-
970- expected_exceptions : tuple [type [E ], ...]
971- origin_exc : type [E ] | None = get_origin (expected_exception )
972- if isinstance (expected_exception , type ):
973- expected_exceptions = (expected_exception ,)
974- elif origin_exc and issubclass (origin_exc , BaseExceptionGroup ):
975- expected_exceptions = (cast (type [E ], expected_exception ),)
976- else :
977- expected_exceptions = expected_exception
978-
979- def validate_exc (exc : type [E ]) -> type [E ]:
980- __tracebackhide__ = True
981- origin_exc : type [E ] | None = get_origin (exc )
982- if origin_exc and issubclass (origin_exc , BaseExceptionGroup ):
983- exc_type = get_args (exc )[0 ]
984- if (
985- issubclass (origin_exc , ExceptionGroup ) and exc_type in (Exception , Any )
986- ) or (
987- issubclass (origin_exc , BaseExceptionGroup )
988- and exc_type in (BaseException , Any )
989- ):
990- return cast (type [E ], origin_exc )
991- else :
992- raise ValueError (
993- f"Only `ExceptionGroup[Exception]` or `BaseExceptionGroup[BaseExeption]` "
994- f"are accepted as generic types but got `{ exc } `. "
995- f"As `raises` will catch all instances of the specified group regardless of the "
996- f"generic argument specific nested exceptions has to be checked "
997- f"with `ExceptionInfo.group_contains()`"
998- )
999-
1000- elif not isinstance (exc , type ) or not issubclass (exc , BaseException ):
1001- msg = "expected exception must be a BaseException type, not {}" # type: ignore[unreachable]
1002- not_a = exc .__name__ if isinstance (exc , type ) else type (exc ).__name__
1003- raise TypeError (msg .format (not_a ))
1004- else :
1005- return exc
1006-
1007- expected_exceptions = tuple (validate_exc (exc ) for exc in expected_exceptions )
1008-
1009- message = f"DID NOT RAISE { expected_exception } "
1010-
1011- if not args :
1012- match : str | re .Pattern [str ] | None = kwargs .pop ("match" , None )
1013- if kwargs :
1014- msg = "Unexpected keyword arguments passed to pytest.raises: "
1015- msg += ", " .join (sorted (kwargs ))
1016- msg += "\n Use context-manager form instead?"
1017- raise TypeError (msg )
1018- return RaisesContext (expected_exceptions , message , match )
1019- else :
1020- func = args [0 ]
1021- if not callable (func ):
1022- raise TypeError (f"{ func !r} object (type: { type (func )} ) must be callable" )
1023- try :
1024- func (* args [1 :], ** kwargs )
1025- except expected_exceptions as e :
1026- return _pytest ._code .ExceptionInfo .from_exception (e )
1027- fail (message )
1028-
1029-
1030- # This doesn't work with mypy for now. Use fail.Exception instead.
1031- raises .Exception = fail .Exception # type: ignore
1032-
1033-
1034- @final
1035- class RaisesContext (AbstractContextManager [_pytest ._code .ExceptionInfo [E ]]):
1036- def __init__ (
1037- self ,
1038- expected_exception : type [E ] | tuple [type [E ], ...],
1039- message : str ,
1040- match_expr : str | re .Pattern [str ] | None = None ,
1041- ) -> None :
1042- self .expected_exception = expected_exception
1043- self .message = message
1044- self .match_expr = match_expr
1045- self .excinfo : _pytest ._code .ExceptionInfo [E ] | None = None
1046- if self .match_expr is not None :
1047- re_error = None
1048- try :
1049- re .compile (self .match_expr )
1050- except re .error as e :
1051- re_error = e
1052- if re_error is not None :
1053- fail (f"Invalid regex pattern provided to 'match': { re_error } " )
1054-
1055- def __enter__ (self ) -> _pytest ._code .ExceptionInfo [E ]:
1056- self .excinfo = _pytest ._code .ExceptionInfo .for_later ()
1057- return self .excinfo
1058-
1059- def __exit__ (
1060- self ,
1061- exc_type : type [BaseException ] | None ,
1062- exc_val : BaseException | None ,
1063- exc_tb : TracebackType | None ,
1064- ) -> bool :
1065- __tracebackhide__ = True
1066- if exc_type is None :
1067- fail (self .message )
1068- assert self .excinfo is not None
1069- if not issubclass (exc_type , self .expected_exception ):
1070- return False
1071- # Cast to narrow the exception type now that it's verified.
1072- exc_info = cast (tuple [type [E ], E , TracebackType ], (exc_type , exc_val , exc_tb ))
1073- self .excinfo .fill_unfilled (exc_info )
1074- if self .match_expr is not None :
1075- self .excinfo .match (self .match_expr )
1076- return True
991+ func = args [0 ]
992+ if not callable (func ):
993+ raise TypeError (f"{ func !r} object (type: { type (func )} ) must be callable" )
994+ with RaisesExc (expected_exception ) as excinfo :
995+ func (* args [1 :], ** kwargs )
996+ try :
997+ return excinfo
998+ finally :
999+ del excinfo
1000+
1001+
1002+ # note: RaisesExc/RaisesGroup uses fail() internally, so this alias
1003+ # indicates (to [internal] plugins?) that `pytest.raises` will
1004+ # raise `_pytest.outcomes.Failed`, where
1005+ # `outcomes.Failed is outcomes.fail.Exception is raises.Exception`
1006+ # note: this is *not* the same as `_pytest.main.Failed`
1007+ # note: mypy does not recognize this attribute
1008+ raises .Exception = fail .Exception # type: ignore[attr-defined]
0 commit comments