Skip to content

Commit 550b372

Browse files
author
Maxence
committed
Merge from upstream + allow list of drf_routers
1 parent fc9179f commit 550b372

File tree

9 files changed

+86
-83
lines changed

9 files changed

+86
-83
lines changed

demo/project/accounts/urls.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
1+
import django
12
from django.conf.urls import url
23
from project.accounts import views
34

5+
from rest_framework.routers import SimpleRouter
46

5-
urlpatterns = [
6-
url(r'^test/$', views.TestView.as_view(), name="test-view"),
7+
account_router = SimpleRouter()
8+
account_router.register('user-model-viewsets', views.UserModelViewset, base_name='account')
79

10+
account_urlpatterns = [
11+
url(r'^test/$', views.TestView.as_view(), name="test-view"),
812
url(r'^login/$', views.LoginView.as_view(), name="login"),
913
url(r'^register/$', views.UserRegistrationView.as_view(), name="register"),
1014
url(r'^reset-password/$', view=views.PasswordResetView.as_view(), name="reset-password"),
1115
url(r'^reset-password/confirm/$', views.PasswordResetConfirmView.as_view(), name="reset-password-confirm"),
12-
1316
url(r'^user/profile/$', views.UserProfileView.as_view(), name="profile"),
17+
] + account_router.urls
1418

15-
]
19+
# Django 1.9 Support for the app_name argument is deprecated
20+
# https://docs.djangoproject.com/en/1.9/ref/urls/#include
21+
django_version = django.VERSION
22+
if django.VERSION[:2] >= (1, 9, ):
23+
account_urlpatterns = (account_urlpatterns, 'accounts', )

demo/project/accounts/views.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from rest_framework.permissions import AllowAny
66
from rest_framework.response import Response
77
from rest_framework.views import APIView
8+
from rest_framework.viewsets import ModelViewSet
89
from project.accounts.models import User
910
from project.accounts.serializers import (
1011
UserRegistrationSerializer, UserProfileSerializer, ResetPasswordSerializer, CustomAuthTokenSerializer
@@ -80,3 +81,8 @@ def post(self, request, *args, **kwargs):
8081
if not serializer.is_valid():
8182
return Response({'errors': serializer.errors}, status=status.HTTP_400_BAD_REQUEST)
8283
return Response({"msg": "Password updated successfully."}, status=status.HTTP_200_OK)
84+
85+
86+
class UserModelViewset(ModelViewSet):
87+
queryset = User.objects.all()
88+
serializer_class = UserProfileSerializer

demo/project/organisations/urls.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,18 @@
22
from django.conf.urls import url
33
from project.organisations import views
44

5+
from rest_framework.routers import SimpleRouter
6+
from .views import OrganisationModelViewset
7+
8+
organisation_router = SimpleRouter()
9+
organisation_router.register('organisation-model-viewsets', OrganisationModelViewset, base_name='organisation')
510

611
organisations_urlpatterns = [
712
url(r'^create/$', view=views.CreateOrganisationView.as_view(), name="create"),
813
url(r'^(?P<slug>[\w-]+)/$', view=views.RetrieveOrganisationView.as_view(), name="organisation"),
914
url(r'^(?P<slug>[\w-]+)/members/$', view=views.OrganisationMembersView.as_view(), name="members"),
1015
url(r'^(?P<slug>[\w-]+)/leave/$', view=views.LeaveOrganisationView.as_view(), name="leave")
11-
]
16+
] + organisation_router.urls
1217

1318
members_urlpatterns = [
1419
url(r'^(?P<slug>[\w-]+)/members/$', view=views.OrganisationMembersView.as_view(), name="members"),
@@ -18,5 +23,5 @@
1823
# https://docs.djangoproject.com/en/1.9/ref/urls/#include
1924
django_version = django.VERSION
2025
if django.VERSION[:2] >= (1, 9, ):
21-
organisations_urlpatterns = (organisations_urlpatterns, 'organisations_app', )
22-
members_urlpatterns = (members_urlpatterns, 'organisations_app', )
26+
organisations_urlpatterns = (organisations_urlpatterns, 'organisations', )
27+
members_urlpatterns = (members_urlpatterns, 'organisations', )

demo/project/organisations/views.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from rest_framework import generics, status
22
from rest_framework.response import Response
3+
from rest_framework.viewsets import ModelViewSet
34
from project.organisations.models import Organisation, Membership
45
from project.organisations.serializers import (
56
CreateOrganisationSerializer, OrganisationMembersSerializer, RetrieveOrganisationSerializer
@@ -34,3 +35,8 @@ def delete(self, request, *args, **kwargs):
3435
instance = self.get_object()
3536
self.perform_destroy(instance)
3637
return Response(status=status.HTTP_204_NO_CONTENT)
38+
39+
40+
class OrganisationModelViewset(ModelViewSet):
41+
queryset = Organisation.objects.all()
42+
serializer_class = OrganisationMembersSerializer

demo/project/urls.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,35 @@
1616
import django
1717
from django.conf.urls import include, url
1818
from django.contrib import admin
19-
from .organisations.urls import organisations_urlpatterns, members_urlpatterns
19+
from rest_framework_docs.views import DRFDocsView
20+
from .accounts.urls import account_urlpatterns, account_router
21+
from .organisations.urls import organisations_urlpatterns, members_urlpatterns, organisation_router
2022

2123
urlpatterns = [
2224
url(r'^admin/', include(admin.site.urls)),
23-
url(r'^docs/', include('rest_framework_docs.urls')),
24-
25-
# API
26-
url(r'^accounts/', view=include('project.accounts.urls', namespace='accounts')),
2725
]
2826

2927
# Django 1.9 Support for the app_name argument is deprecated
3028
# https://docs.djangoproject.com/en/1.9/ref/urls/#include
3129
django_version = django.VERSION
3230
if django.VERSION[:2] >= (1, 9, ):
3331
urlpatterns.extend([
32+
url(r'^accounts/', view=include(account_urlpatterns, namespace='accounts')),
3433
url(r'^organisations/', view=include(organisations_urlpatterns, namespace='organisations')),
3534
url(r'^members/', view=include(members_urlpatterns, namespace='members')),
3635
])
3736
else:
3837
urlpatterns.extend([
38+
url(r'^accounts/', view=include(account_urlpatterns, namespace='accounts', app_name='account_app')),
3939
url(r'^organisations/', view=include(organisations_urlpatterns, namespace='organisations', app_name='organisations_app')),
4040
url(r'^members/', view=include(members_urlpatterns, namespace='members', app_name='organisations_app')),
4141
])
42+
43+
44+
from tests.views import LoginView
45+
routers = [account_router, organisation_router]
46+
urlpatterns.extend([
47+
url(r'^docs/(?P<filter_param>[\w-]+)/$', DRFDocsView.as_view(drf_router=routers), name='drfdocs-filter'),
48+
url(r'^docs/$', DRFDocsView.as_view(drf_router=routers), name='drfdocs'),
49+
url(r'^another-login/$', LoginView.as_view(), name="login"),
50+
])

rest_framework_docs/api_docs.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,12 @@ def __init__(self, drf_router=None, filter_param=None):
3030
def get_all_view_names(self, urlpatterns, parent_pattern=None, filter_param=None):
3131
for pattern in urlpatterns:
3232
if isinstance(pattern, RegexURLResolver) and (not filter_param or filter_param in [pattern.app_name, pattern.namespace]):
33-
parent_pattern = None if pattern._regex == "^" else pattern
34-
self.get_all_view_names(urlpatterns=pattern.url_patterns, parent_pattern=pattern, filter_param=filter_param)
33+
# parent_pattern = None if pattern._regex == "^" else pattern
34+
self.get_all_view_names(urlpatterns=pattern.url_patterns, parent_pattern=None if pattern._regex == "^" else pattern, filter_param=filter_param)
3535
elif isinstance(pattern, RegexURLPattern) and self._is_drf_view(pattern) and not self._is_format_endpoint(pattern):
36-
api_endpoint = ApiEndpoint(pattern, parent_pattern, self.drf_router)
37-
self.endpoints.append(api_endpoint)
36+
if not filter_param or parent_pattern:
37+
api_endpoint = ApiEndpoint(pattern, parent_pattern, self.drf_router)
38+
self.endpoints.append(api_endpoint)
3839

3940
def _is_drf_view(self, pattern):
4041
"""

rest_framework_docs/api_endpoint.py

Lines changed: 12 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,17 @@
22
import inspect
33
from django.contrib.admindocs.views import simplify_regex
44
from django.utils.encoding import force_str
5-
from rest_framework.serializers import BaseSerializer
6-
75
from rest_framework import serializers
86
from rest_framework.viewsets import ModelViewSet
97
from rest_framework_docs import SERIALIZER_FIELDS
108

119

12-
METHODS_ORDER = ['GET', 'POST', 'PUT', 'PATCH', 'DELETE', 'OPTIONS']
13-
VIEWSET_METHODS = {
14-
'List': ['get', 'post'],
15-
'Instance': ['get', 'put', 'patch', 'delete'],
16-
}
17-
18-
1910
class ApiEndpoint(object):
2011

2112
def __init__(self, pattern, parent_pattern=None, drf_router=None):
22-
self.drf_router = drf_router
13+
self.drf_router = drf_router or []
14+
if not isinstance(self.drf_router, (list, tuple)):
15+
self.drf_router = [self.drf_router]
2316
self.pattern = pattern
2417
self.callback = pattern.callback
2518
self.docstring = self.__get_docstring__()
@@ -40,9 +33,8 @@ def __init__(self, pattern, parent_pattern=None, drf_router=None):
4033
self.serializer_class = self.__get_serializer_class__()
4134
if self.serializer_class:
4235
self.serializer = self.__get_serializer__()
43-
self.fields = self.__get_serializer_fields__(self.serializer)
36+
self.fields = self.__get_serializer_fields__()
4437
self.fields_json = self.__get_serializer_fields_json__()
45-
4638
self.permissions = self.__get_permissions_class__()
4739

4840
def __get_path__(self, parent_pattern):
@@ -51,38 +43,28 @@ def __get_path__(self, parent_pattern):
5143
return simplify_regex(self.pattern.regex.pattern)
5244

5345
def __get_allowed_methods__(self):
54-
callback_cls = self.callback.cls
55-
56-
def is_method_allowed(method_name):
57-
return hasattr(callback_cls, method_name) or (
58-
issubclass(callback_cls, ModelViewSet) and
59-
method_name in VIEWSET_METHODS.get(self.callback.suffix, []))
60-
61-
return sorted(
62-
[force_str(name).upper() for name in callback_cls.http_method_names if is_method_allowed(name)],
63-
key=lambda e: METHODS_ORDER.index(e))
6446

6547
viewset_methods = []
66-
if self.drf_router:
67-
for prefix, viewset, basename in self.drf_router.registry:
48+
for router in self.drf_router:
49+
for prefix, viewset, basename in router.registry:
6850
if self.callback.cls != viewset:
6951
continue
7052

71-
lookup = self.drf_router.get_lookup_regex(viewset)
72-
routes = self.drf_router.get_routes(viewset)
53+
lookup = router.get_lookup_regex(viewset)
54+
routes = router.get_routes(viewset)
7355

7456
for route in routes:
7557

7658
# Only actions which actually exist on the viewset will be bound
77-
mapping = self.drf_router.get_method_map(viewset, route.mapping)
59+
mapping = router.get_method_map(viewset, route.mapping)
7860
if not mapping:
7961
continue
8062

8163
# Build the url pattern
8264
regex = route.url.format(
8365
prefix=prefix,
8466
lookup=lookup,
85-
trailing_slash=self.drf_router.trailing_slash
67+
trailing_slash=router.trailing_slash
8668
)
8769
if self.pattern.regex.pattern == regex:
8870
funcs, viewset_methods = zip(
@@ -92,7 +74,8 @@ def is_method_allowed(method_name):
9274
if len(set(funcs)) == 1:
9375
self.docstring = inspect.getdoc(getattr(self.callback.cls, funcs[0]))
9476

95-
view_methods = [force_str(m).upper() for m in self.callback.cls.http_method_names if hasattr(self.callback.cls, m)]
77+
view_methods = [force_str(m).upper() for m in self.callback.cls.http_method_names if
78+
hasattr(self.callback.cls, m)]
9679
return viewset_methods + view_methods
9780

9881
def __get_docstring__(self):
@@ -102,22 +85,6 @@ def __get_permissions_class__(self):
10285
for perm_class in self.pattern.callback.cls.permission_classes:
10386
return perm_class.__name__
10487

105-
def __get_serializer_fields__(self):
106-
fields = []
107-
serializer = None
108-
109-
if hasattr(self.callback.cls, 'serializer_class'):
110-
serializer = self.callback.cls.serializer_class
111-
112-
elif hasattr(self.callback.cls, 'get_serializer_class'):
113-
serializer = self.callback.cls.get_serializer_class(self.pattern.callback.cls())
114-
115-
if hasattr(serializer, 'get_fields'):
116-
try:
117-
fields = self.__get_fields__(serializer)
118-
except KeyError as e:
119-
self.errors = e
120-
fields = []
12188
def __get_serializer__(self):
12289
try:
12390
return self.serializer_class()

tests/tests.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ def test_index_view_with_endpoints(self):
2828
response = self.client.get(reverse('drfdocs'))
2929

3030
self.assertEqual(response.status_code, 200)
31-
self.assertEqual(len(response.context["endpoints"]), 15)
31+
self.assertEqual(len(response.context["endpoints"]), 16)
3232

3333
# Test the login view
34-
endpoint = response.context["endpoints"][1]
34+
endpoint = response.context["endpoints"][4]
3535
self.assertEqual(endpoint.name_parent, "accounts")
3636
self.assertEqual(endpoint.allowed_methods, ['POST', 'OPTIONS'])
3737
self.assertEqual(endpoint.path, "/accounts/login/")
@@ -41,7 +41,7 @@ def test_index_view_with_endpoints(self):
4141
self.assertTrue(endpoint.fields[0]["required"])
4242

4343
# The view "OrganisationErroredView" (organisations/(?P<slug>[\w-]+)/errored/) should contain an error.
44-
endpoint = response.context["endpoints"][8]
44+
endpoint = response.context["endpoints"][12]
4545
self.assertEqual(str(endpoint.errors), "'test_value'")
4646

4747
def test_index_search_with_endpoints(self):
@@ -81,10 +81,10 @@ def test_index_view_with_existent_namespace(self):
8181
# Test 'organisations' namespace
8282
response = self.client.get(reverse('drfdocs-filter', args=['organisations']))
8383
self.assertEqual(response.status_code, 200)
84-
self.assertEqual(len(response.context["endpoints"]), 3)
84+
self.assertEqual(len(response.context["endpoints"]), 5)
8585

8686
# The view "OrganisationErroredView" (organisations/(?P<slug>[\w-]+)/errored/) should contain an error.
87-
self.assertEqual(str(response.context["endpoints"][0].errors), "'test_value'")
87+
self.assertEqual(str(response.context["endpoints"][1].errors), "'test_value'")
8888

8989
# Test 'members' namespace
9090
response = self.client.get(reverse('drfdocs-filter', args=['members']))
@@ -107,9 +107,9 @@ def test_index_view_with_existent_app_name(self):
107107
# Test 'organisations_app' app_name
108108
response = self.client.get(reverse('drfdocs-filter', args=['organisations_app']))
109109
self.assertEqual(response.status_code, 200)
110-
self.assertEqual(len(response.context["endpoints"]), 4)
110+
self.assertEqual(len(response.context["endpoints"]), 6)
111111
parents_name = [e.name_parent for e in response.context["endpoints"]]
112-
self.assertEquals(parents_name.count('organisations'), 3)
112+
self.assertEquals(parents_name.count('organisations'), 5)
113113
self.assertEquals(parents_name.count('members'), 1)
114114

115115
def test_index_search_with_existent_app_name(self):
@@ -118,7 +118,7 @@ def test_index_search_with_existent_app_name(self):
118118
self.assertEqual(response.status_code, 200)
119119
self.assertEqual(len(response.context["endpoints"]), 1)
120120
self.assertEqual(response.context["endpoints"][0].path, "/organisations/create/")
121-
self.assertEqual(len(response.context["endpoints"][0].fields), 2)
121+
self.assertEqual(len(response.context["endpoints"][0].fields), 3)
122122

123123
def test_index_view_with_non_existent_namespace_or_app_name(self):
124124
"""
@@ -130,14 +130,10 @@ def test_index_view_with_non_existent_namespace_or_app_name(self):
130130

131131
def test_model_viewset(self):
132132
response = self.client.get(reverse('drfdocs'))
133-
134133
self.assertEqual(response.status_code, 200)
135-
136-
self.assertEqual(response.context["endpoints"][10].path, '/organisations/<slug>/')
137-
self.assertEqual(response.context['endpoints'][6].fields[2]['to_many_relation'], True)
138-
self.assertEqual(response.context["endpoints"][11].path, '/organisation-model-viewsets/')
139-
self.assertEqual(response.context["endpoints"][12].path, '/organisation-model-viewsets/<pk>/')
140-
self.assertEqual(response.context["endpoints"][11].allowed_methods, ['GET', 'POST', 'OPTIONS'])
141-
self.assertEqual(response.context["endpoints"][12].allowed_methods, ['GET', 'PUT', 'PATCH', 'DELETE', 'OPTIONS'])
142-
self.assertEqual(response.context["endpoints"][13].allowed_methods, ['POST', 'OPTIONS'])
143-
self.assertEqual(response.context["endpoints"][13].docstring, 'This is a test.')
134+
self.assertEqual(response.context["endpoints"][1].path, '/organisation-model-viewsets/')
135+
self.assertEqual(response.context["endpoints"][2].path, '/organisation-model-viewsets/<pk>/')
136+
self.assertEqual(response.context["endpoints"][1].allowed_methods, ['GET', 'POST', 'OPTIONS'])
137+
self.assertEqual(response.context["endpoints"][2].allowed_methods, ['GET', 'PUT', 'PATCH', 'DELETE', 'OPTIONS'])
138+
self.assertEqual(response.context["endpoints"][4].allowed_methods, ['POST', 'OPTIONS'])
139+
self.assertEqual(response.context["endpoints"][3].docstring, 'This is a test.')

tests/urls.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from rest_framework_docs.views import DRFDocsView
88
from tests import views
99

10+
1011
accounts_urls = [
1112
url(r'^login/$', views.LoginView.as_view(), name="login"),
1213
url(r'^login2/$', views.LoginWithSerilaizerClassView.as_view(), name="login2"),
@@ -36,12 +37,14 @@
3637

3738
urlpatterns = [
3839
url(r'^admin/', include(admin.site.urls)),
39-
url(r'^docs/', DRFDocsView.as_view(drf_router=router), name='drfdocs'),
40+
41+
# url(r'^docs/', include('rest_framework_docs.urls')),
42+
url(r'^docs/(?P<filter_param>[\w-]+)/$', DRFDocsView.as_view(drf_router=router), name='drfdocs-filter'),
43+
url(r'^docs/$', DRFDocsView.as_view(drf_router=router), name='drfdocs'),
4044

4145
# API
42-
url(r'^accounts/', view=include(accounts_urls, namespace="accounts")),
43-
url(r'^accounts/', view=include(accounts_urls, namespace='accounts')),
44-
url(r'^organisations/', view=include(organisations_urls, namespace='organisations')),
46+
# url(r'^accounts/', view=include(accounts_urls, namespace="accounts")),
47+
# url(r'^organisations/', view=include(organisations_urls, namespace='organisations')),
4548
url(r'^', include(router.urls)),
4649

4750
# Endpoints without parents/namespaces
@@ -55,11 +58,13 @@
5558
organisations_urls = (organisations_urls, 'organisations_app', )
5659
members_urls = (members_urls, 'organisations_app', )
5760
urlpatterns.extend([
61+
url(r'^accounts/', view=include(accounts_urls, namespace="accounts")),
5862
url(r'^organisations/', view=include(organisations_urls, namespace='organisations')),
5963
url(r'^members/', view=include(members_urls, namespace='members')),
6064
])
6165
else:
6266
urlpatterns.extend([
67+
url(r'^accounts/', view=include(accounts_urls, namespace="accounts", app_name='accounts_app')),
6368
url(r'^organisations/', view=include(organisations_urls, namespace='organisations', app_name='organisations_app')),
6469
url(r'^members/', view=include(members_urls, namespace='members', app_name='organisations_app')),
6570
])

0 commit comments

Comments
 (0)