diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index db41539fd..48fb2a392 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -16,6 +16,7 @@ from rest_framework.response import Response from rest_framework.settings import api_settings from rest_framework.utils import formatting from rest_framework.utils.field_mapping import ClassLookupDict +from rest_framework.utils.model_meta import _get_pk from rest_framework.views import APIView @@ -35,6 +36,11 @@ types_lookup = ClassLookupDict({ }) +def get_pk_name(model): + meta = model._meta.concrete_model._meta + return _get_pk(meta).name + + def as_query_fields(items): """ Take a list of Fields and plain strings. @@ -196,6 +202,9 @@ class SchemaGenerator(object): 'delete': 'destroy', } endpoint_inspector_cls = EndpointInspector + # 'pk' isn't great as an externally exposed name for an identifier, + # so by default we prefer to use the actual model field name for schemas. + coerce_pk = True def __init__(self, title=None, url=None, patterns=None, urlconf=None): assert coreapi, '`coreapi` must be installed for schema support.' @@ -230,6 +239,7 @@ class SchemaGenerator(object): links = OrderedDict() for path, method, callback in self.endpoints: view = self.create_view(callback, method, request) + path = self.coerce_path(path, method, view) if not self.should_include_view(path, method, view): continue link = self.get_link(path, method, view) @@ -280,6 +290,16 @@ class SchemaGenerator(object): return False return True + def coerce_path(self, path, method, view): + if not self.coerce_pk or '{pk}' not in path: + return path + model = getattr(getattr(view, 'queryset', None), 'model', None) + if model: + field_name = get_pk_name(model) + else: + field_name = 'id' + return path.replace('{pk}', '{%s}' % field_name) + # Methods for generating each individual `Link` instance... def get_link(self, path, method, view): diff --git a/tests/test_schemas.py b/tests/test_schemas.py index c8b40fbb4..1914994a7 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -103,10 +103,10 @@ class TestRouterGeneratedSchema(TestCase): ) }, 'retrieve': coreapi.Link( - url='/example/{pk}/', + url='/example/{id}/', action='get', fields=[ - coreapi.Field('pk', required=True, location='path') + coreapi.Field('id', required=True, location='path') ] ) } @@ -142,19 +142,19 @@ class TestRouterGeneratedSchema(TestCase): ] ), 'retrieve': coreapi.Link( - url='/example/{pk}/', + url='/example/{id}/', action='get', fields=[ - coreapi.Field('pk', required=True, location='path') + coreapi.Field('id', required=True, location='path') ] ), 'custom_action': coreapi.Link( - url='/example/{pk}/custom_action/', + url='/example/{id}/custom_action/', action='post', encoding='application/json', description='A description of custom action.', fields=[ - coreapi.Field('pk', required=True, location='path'), + coreapi.Field('id', required=True, location='path'), coreapi.Field('c', required=True, location='form', type='string'), coreapi.Field('d', required=False, location='form', type='string'), ] @@ -174,30 +174,30 @@ class TestRouterGeneratedSchema(TestCase): ) }, 'update': coreapi.Link( - url='/example/{pk}/', + url='/example/{id}/', action='put', encoding='application/json', fields=[ - coreapi.Field('pk', required=True, location='path'), + coreapi.Field('id', required=True, location='path'), coreapi.Field('a', required=True, location='form', type='string', description='A field description'), coreapi.Field('b', required=False, location='form', type='string') ] ), 'partial_update': coreapi.Link( - url='/example/{pk}/', + url='/example/{id}/', action='patch', encoding='application/json', fields=[ - coreapi.Field('pk', required=True, location='path'), + coreapi.Field('id', required=True, location='path'), coreapi.Field('a', required=False, location='form', type='string', description='A field description'), coreapi.Field('b', required=False, location='form', type='string') ] ), 'destroy': coreapi.Link( - url='/example/{pk}/', + url='/example/{id}/', action='delete', fields=[ - coreapi.Field('pk', required=True, location='path') + coreapi.Field('id', required=True, location='path') ] ) } @@ -254,18 +254,18 @@ class TestSchemaGenerator(TestCase): fields=[] ), 'retrieve': coreapi.Link( - url='/example/{pk}/', + url='/example/{id}/', action='get', fields=[ - coreapi.Field('pk', required=True, location='path') + coreapi.Field('id', required=True, location='path') ] ), 'sub': { 'list': coreapi.Link( - url='/example/{pk}/sub/', + url='/example/{id}/sub/', action='get', fields=[ - coreapi.Field('pk', required=True, location='path') + coreapi.Field('id', required=True, location='path') ] ) }