|
| 1 | +import json |
1 | 2 | from typing import Any |
| 3 | +from typing import Type |
2 | 4 | from typing import TYPE_CHECKING |
3 | 5 |
|
4 | 6 | from flask import current_app |
| 7 | +from flask import Flask |
5 | 8 |
|
6 | 9 | from flask_jwt_extended.exceptions import RevokedTokenError |
7 | 10 | from flask_jwt_extended.exceptions import UserClaimsVerificationError |
8 | 11 | from flask_jwt_extended.exceptions import WrongTokenError |
9 | 12 |
|
| 13 | +try: |
| 14 | + from flask.json.provider import DefaultJSONProvider |
| 15 | + |
| 16 | + HAS_JSON_PROVIDER = True |
| 17 | +except ModuleNotFoundError: # pragma: no cover |
| 18 | + # The flask.json.provider module was added in Flask 2.2. |
| 19 | + # Further details are handled in get_json_encoder. |
| 20 | + HAS_JSON_PROVIDER = False |
| 21 | + |
| 22 | + |
10 | 23 | if TYPE_CHECKING: # pragma: no cover |
11 | 24 | from flask_jwt_extended import JWTManager |
12 | 25 |
|
@@ -51,3 +64,31 @@ def custom_verification_for_token(jwt_header: dict, jwt_data: dict) -> None: |
51 | 64 | if not jwt_manager._token_verification_callback(jwt_header, jwt_data): |
52 | 65 | error_msg = "User claims verification failed" |
53 | 66 | raise UserClaimsVerificationError(error_msg, jwt_header, jwt_data) |
| 67 | + |
| 68 | + |
| 69 | +def get_json_encoder(app: Flask) -> Type[json.JSONEncoder]: |
| 70 | + """Get the JSON Encoder for the provided flask app |
| 71 | +
|
| 72 | + Starting with flask version 2.2 the flask application provides a |
| 73 | + interface to register a custom JSON Encoder/Decoder under the json_provider_class. |
| 74 | + As this interface is not compatible with the standard JSONEncoder, the `default` |
| 75 | + method of the class is wrapped. |
| 76 | +
|
| 77 | + Lookup Order: |
| 78 | + - app.json_encoder - For Flask < 2.2 |
| 79 | + - app.json_provider_class.default |
| 80 | + - flask.json.provider.DefaultJSONProvider.default |
| 81 | +
|
| 82 | + """ |
| 83 | + if not HAS_JSON_PROVIDER: # pragma: no cover |
| 84 | + return app.json_encoder |
| 85 | + |
| 86 | + # If the registered JSON provider does not implement a default classmethod |
| 87 | + # use the method defined by the DefaultJSONProvider |
| 88 | + default = getattr(app.json_provider_class, "default", DefaultJSONProvider.default) |
| 89 | + |
| 90 | + class JSONEncoder(json.JSONEncoder): |
| 91 | + def default(self, o: Any) -> Any: |
| 92 | + return default(o) # pragma: no cover |
| 93 | + |
| 94 | + return JSONEncoder |
0 commit comments