diff --git a/rest_framework/schemas/generators.py b/rest_framework/schemas/generators.py index 629f92b0d..3d46251b9 100644 --- a/rest_framework/schemas/generators.py +++ b/rest_framework/schemas/generators.py @@ -365,9 +365,7 @@ class SchemaGenerator(object): """ Given a callback, return an actual view instance. """ - view = callback.cls() - for attr, val in getattr(callback, 'initkwargs', {}).items(): - setattr(view, attr, val) + view = callback.cls(**getattr(callback, 'initkwargs', {})) view.args = () view.kwargs = {} view.format_kwarg = None diff --git a/rest_framework/schemas/inspectors.py b/rest_framework/schemas/inspectors.py index 89a1fc93a..5c9659a57 100644 --- a/rest_framework/schemas/inspectors.py +++ b/rest_framework/schemas/inspectors.py @@ -7,6 +7,7 @@ See schemas.__init__.py for package overview. import re import warnings from collections import OrderedDict +from weakref import WeakKeyDictionary from django.db import models from django.utils.encoding import force_text, smart_text @@ -128,6 +129,10 @@ class ViewInspector(object): Provide subclass for per-view schema generation """ + + def __init__(self): + self.instance_schemas = WeakKeyDictionary() + def __get__(self, instance, owner): """ Enables `ViewInspector` as a Python _Descriptor_. @@ -144,9 +149,16 @@ class ViewInspector(object): See: https://docs.python.org/3/howto/descriptor.html for info on descriptor usage. """ + if instance in self.instance_schemas: + return self.instance_schemas[instance] + self.view = instance return self + def __set__(self, instance, other): + self.instance_schemas[instance] = other + other.view = instance + @property def view(self): """View property.""" @@ -189,6 +201,7 @@ class AutoSchema(ViewInspector): * `manual_fields`: list of `coreapi.Field` instances that will be added to auto-generated fields, overwriting on `Field.name` """ + super(AutoSchema, self).__init__() if manual_fields is None: manual_fields = [] self._manual_fields = manual_fields @@ -455,6 +468,7 @@ class ManualSchema(ViewInspector): * `fields`: list of `coreapi.Field` instances. * `descripton`: String description for view. Optional. """ + super(ManualSchema, self).__init__() assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances" self._fields = fields self._description = description @@ -474,9 +488,13 @@ class ManualSchema(ViewInspector): ) -class DefaultSchema(object): +class DefaultSchema(ViewInspector): """Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting""" def __get__(self, instance, owner): + result = super(DefaultSchema, self).__get__(instance, owner) + if not isinstance(result, DefaultSchema): + return result + inspector_class = api_settings.DEFAULT_SCHEMA_CLASS assert issubclass(inspector_class, ViewInspector), "DEFAULT_SCHEMA_CLASS must be set to a ViewInspector (usually an AutoSchema) subclass" inspector = inspector_class()