Skip to content

Commit 3f3f934

Browse files
authored
fixed request context dependency (#144)
* fixed request context dependency * fixed failing tests * 5.3.9
1 parent 414fb96 commit 3f3f934

File tree

2 files changed

+39
-24
lines changed

2 files changed

+39
-24
lines changed

ninja_jwt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""Django Ninja JWT - JSON Web Token for Django-Ninja"""
22

3-
__version__ = "5.3.7"
3+
__version__ = "5.3.9"

ninja_jwt/schema.py

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from ninja.schema import DjangoGetter
1212
from ninja_extra import service_resolver
1313
from ninja_extra.context import RouteContext
14-
from pydantic import ConfigDict, model_validator
14+
from pydantic import ConfigDict, ValidationInfo, model_validator
15+
from pydantic.main import BaseModel
1516

1617
import ninja_jwt.exceptions as exceptions
1718
from ninja_jwt.utils import token_error
@@ -28,14 +29,21 @@
2829

2930

3031
class SchemaInputService:
31-
def __init__(self, values: SCHEMA_INPUT, model_config: ConfigDict) -> None:
32+
def __init__(
33+
self,
34+
values: SCHEMA_INPUT,
35+
model_config: ConfigDict,
36+
request: Optional[HttpRequest] = None,
37+
) -> None:
3238
self.model_config = model_config
3339
self.values = values
3440

41+
self._request: Optional[HttpRequest] = request
42+
3543
def get_request(self) -> HttpRequest:
36-
if self.model_config.get("extra") == "forbid":
44+
if self.model_config.get("extra") == "forbid" and self._request is None:
3745
return service_resolver(RouteContext).request
38-
return self.values._context.get("request")
46+
return self._request
3947

4048
def get_values(self) -> Dict:
4149
if self.model_config.get("extra") == "forbid":
@@ -75,7 +83,7 @@ def check_user_authentication_rule(self) -> None:
7583
)
7684

7785
@classmethod
78-
def validate_values(cls, request: HttpRequest, values: Dict) -> Dict:
86+
def validate_values(cls, values: Dict) -> Dict:
7987
if user_name_field not in values and "password" not in values:
8088
raise exceptions.ValidationError(
8189
{
@@ -92,16 +100,16 @@ def validate_values(cls, request: HttpRequest, values: Dict) -> Dict:
92100
if not values.get("password"):
93101
raise exceptions.ValidationError({"password": "password is required"})
94102

95-
_user = authenticate(request, **values)
96-
cls._user = _user
103+
return values
104+
105+
def authenticate(self, request: HttpRequest, credentials: Dict) -> None:
106+
self._user = authenticate(request, **credentials)
97107

98-
if not (_user is not None and _user.is_active):
108+
if not (self._user is not None and self._user.is_active):
99109
raise exceptions.AuthenticationFailed(
100-
cls._default_error_messages["no_active_account"]
110+
self._default_error_messages["no_active_account"]
101111
)
102112

103-
return values
104-
105113
def output_schema(self) -> Schema:
106114
warnings.warn(
107115
"output_schema() is deprecated in favor of to_response_schema()",
@@ -119,36 +127,45 @@ def get_token(cls, user: AbstractUser) -> Dict:
119127

120128
class TokenObtainInputSchemaBase(ModelSchema, TokenInputSchemaMixin):
121129
class Config:
122-
# extra = "allow"
130+
# extra = "forbid"
123131
model = get_user_model()
124132
model_fields = ["password", user_name_field]
125-
extra = "forbid"
126133

127134
@model_validator(mode="before")
128135
def validate_inputs(cls, values: SCHEMA_INPUT) -> dict:
129136
schema_input = SchemaInputService(values, cls.model_config)
130137
input_values = schema_input.get_values()
131-
request = schema_input.get_request()
132138

133139
if isinstance(input_values, dict):
134-
values.update(cls.validate_values(request=request, values=input_values))
135-
return values
140+
cls.validate_values(values=input_values)
136141
return values
137142

138143
@model_validator(mode="after")
139-
def post_validate(cls, values: Dict) -> dict:
140-
return cls.post_validate_schema(values)
144+
def post_validate(
145+
cls, values: "TokenObtainInputSchemaBase", info: ValidationInfo
146+
) -> BaseModel:
147+
schema_input = SchemaInputService(
148+
values.model_dump(), cls.model_config, info.context.get("request")
149+
)
150+
151+
credentials = schema_input.get_values()
152+
request = schema_input.get_request()
153+
154+
values.authenticate(request, credentials)
155+
cls.post_validate_schema(values)
156+
157+
return values
141158

142159
@classmethod
143-
def post_validate_schema(cls, values: Dict) -> dict:
160+
def post_validate_schema(cls, values: "TokenObtainInputSchemaBase") -> None:
144161
"""
145162
This is a post validate process which is common for any token generating schema.
146163
:param values:
147164
:return:
148165
"""
149166
# get_token can return values that wants to apply to `OutputSchema`
150167

151-
data = cls.get_token(cls._user)
168+
data = cls.get_token(values._user)
152169

153170
if not isinstance(data, dict):
154171
raise Exception("`get_token` must return a `typing.Dict` type.")
@@ -158,9 +175,7 @@ def post_validate_schema(cls, values: Dict) -> dict:
158175
values.__dict__.update(token_data=data)
159176

160177
if api_settings.UPDATE_LAST_LOGIN:
161-
update_last_login(None, cls._user)
162-
163-
return values
178+
update_last_login(None, values._user)
164179

165180
def get_response_schema_init_kwargs(self) -> dict:
166181
return dict(

0 commit comments

Comments
 (0)