Coerce schema 'pk' in path to actual field name

This commit is contained in:
Tom Christie 2016-10-06 16:22:03 +01:00
parent 7edee804aa
commit b44ab76d2c
2 changed files with 36 additions and 16 deletions

View File

@ -16,6 +16,7 @@ from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.utils import formatting
from rest_framework.utils.field_mapping import ClassLookupDict
from rest_framework.utils.model_meta import _get_pk
from rest_framework.views import APIView
@ -35,6 +36,11 @@ types_lookup = ClassLookupDict({
})
def get_pk_name(model):
meta = model._meta.concrete_model._meta
return _get_pk(meta).name
def as_query_fields(items):
"""
Take a list of Fields and plain strings.
@ -196,6 +202,9 @@ class SchemaGenerator(object):
'delete': 'destroy',
}
endpoint_inspector_cls = EndpointInspector
# 'pk' isn't great as an externally exposed name for an identifier,
# so by default we prefer to use the actual model field name for schemas.
coerce_pk = True
def __init__(self, title=None, url=None, patterns=None, urlconf=None):
assert coreapi, '`coreapi` must be installed for schema support.'
@ -230,6 +239,7 @@ class SchemaGenerator(object):
links = OrderedDict()
for path, method, callback in self.endpoints:
view = self.create_view(callback, method, request)
path = self.coerce_path(path, method, view)
if not self.should_include_view(path, method, view):
continue
link = self.get_link(path, method, view)
@ -280,6 +290,16 @@ class SchemaGenerator(object):
return False
return True
def coerce_path(self, path, method, view):
if not self.coerce_pk or '{pk}' not in path:
return path
model = getattr(getattr(view, 'queryset', None), 'model', None)
if model:
field_name = get_pk_name(model)
else:
field_name = 'id'
return path.replace('{pk}', '{%s}' % field_name)
# Methods for generating each individual `Link` instance...
def get_link(self, path, method, view):

View File

@ -103,10 +103,10 @@ class TestRouterGeneratedSchema(TestCase):
)
},
'retrieve': coreapi.Link(
url='/example/{pk}/',
url='/example/{id}/',
action='get',
fields=[
coreapi.Field('pk', required=True, location='path')
coreapi.Field('id', required=True, location='path')
]
)
}
@ -142,19 +142,19 @@ class TestRouterGeneratedSchema(TestCase):
]
),
'retrieve': coreapi.Link(
url='/example/{pk}/',
url='/example/{id}/',
action='get',
fields=[
coreapi.Field('pk', required=True, location='path')
coreapi.Field('id', required=True, location='path')
]
),
'custom_action': coreapi.Link(
url='/example/{pk}/custom_action/',
url='/example/{id}/custom_action/',
action='post',
encoding='application/json',
description='A description of custom action.',
fields=[
coreapi.Field('pk', required=True, location='path'),
coreapi.Field('id', required=True, location='path'),
coreapi.Field('c', required=True, location='form', type='string'),
coreapi.Field('d', required=False, location='form', type='string'),
]
@ -174,30 +174,30 @@ class TestRouterGeneratedSchema(TestCase):
)
},
'update': coreapi.Link(
url='/example/{pk}/',
url='/example/{id}/',
action='put',
encoding='application/json',
fields=[
coreapi.Field('pk', required=True, location='path'),
coreapi.Field('id', required=True, location='path'),
coreapi.Field('a', required=True, location='form', type='string', description='A field description'),
coreapi.Field('b', required=False, location='form', type='string')
]
),
'partial_update': coreapi.Link(
url='/example/{pk}/',
url='/example/{id}/',
action='patch',
encoding='application/json',
fields=[
coreapi.Field('pk', required=True, location='path'),
coreapi.Field('id', required=True, location='path'),
coreapi.Field('a', required=False, location='form', type='string', description='A field description'),
coreapi.Field('b', required=False, location='form', type='string')
]
),
'destroy': coreapi.Link(
url='/example/{pk}/',
url='/example/{id}/',
action='delete',
fields=[
coreapi.Field('pk', required=True, location='path')
coreapi.Field('id', required=True, location='path')
]
)
}
@ -254,18 +254,18 @@ class TestSchemaGenerator(TestCase):
fields=[]
),
'retrieve': coreapi.Link(
url='/example/{pk}/',
url='/example/{id}/',
action='get',
fields=[
coreapi.Field('pk', required=True, location='path')
coreapi.Field('id', required=True, location='path')
]
),
'sub': {
'list': coreapi.Link(
url='/example/{pk}/sub/',
url='/example/{id}/sub/',
action='get',
fields=[
coreapi.Field('pk', required=True, location='path')
coreapi.Field('id', required=True, location='path')
]
)
}