diff --git a/rest_framework/caches.py b/rest_framework/caches.py new file mode 100644 index 000000000..8ee96e197 --- /dev/null +++ b/rest_framework/caches.py @@ -0,0 +1,65 @@ +""" +Provides a set of pluggable cache policies. +""" +from django.core.cache import cache +from django.core.exceptions import ImproperlyConfigured +from rest_framework.exceptions import PreconditionFailed + + +class BaseCacheLookup(object): + def get_request_header(self): + raise ImproperlyConfigured('Subclass must implement `get_request_header`.') + + def get_update_header(self): + raise ImproperlyConfigured('Subclass must implement `get_update_header`.') + + def get_response_header(self, obj): + raise ImproperlyConfigured('Subclass must impelement `get_response_header`.') + + def precondition_check(self, obj, request): + raise ImproperlyConfigured('Subclass must implement `precondition_check`.') + + def resource_unchanged(self, request): + raise ImproperlyConfigured('Subclass must implement `resource_unchanged`.') + + +class ETagCacheLookup(BaseCacheLookup): + """ + """ + etag_variable = 'etag' + request_header = 'HTTP_IF_NONE_MATCH' + update_header = 'HTTP_IF_MATCH' + + @staticmethod + def get_cache_key(cls, pk): + try: + class_name = cls.__name__ # class + except AttributeError: + class_name = cls.__class__.__name__ # instance + return 'etag-{}-{}'.format(class_name, pk) + + def get_etag(self, obj): + return getattr(obj, self.etag_variable) + + def get_request_header(self): + return self.request_header + + def get_update_header(self): + return self.update_header + + def get_response_header(self, obj): + key = self.get_cache_key(obj, 'pk') + etag = self.get_etag(obj) + cache.set(key, etag) + return {'ETag': etag} + + def precondition_check(self, obj, request): + if self.get_etag(obj) != request.META.get(self.get_update_header()): + raise PreconditionFailed + + def resource_unchanged(self, request, key): + etag = cache.get(key) + header = request.META.get(self.get_request_header()) + if etag is not None and etag == header: + return True + return False diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 0c96ecdd5..9ceb64fe1 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -88,6 +88,14 @@ class Throttled(APIException): self.detail = detail or self.default_detail +class PreconditionFailed(APIException): + status_code = status.HTTP_412_PRECONDITION_FAILED + default_detail = 'Object has been updated since you retrieved it.' + + def __init__(self, detail=None): + self.detail = detail or self.default_detail + + class ConfigurationError(Exception): """ Indicates an internal server error. diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 7d9a6e654..b041f55ef 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -97,11 +97,18 @@ class RetrieveModelMixin(object): Should be mixed in with `SingleObjectAPIView`. """ def retrieve(self, request, *args, **kwargs): + cached_object = self.check_preemptive_cache(request) + if cached_object: + return cached_object + queryset = self.get_queryset() filtered_queryset = self.filter_queryset(queryset) self.object = self.get_object(filtered_queryset) + + headers = self.get_cache_lookup_response_headers(self.object) + serializer = self.get_serializer(self.object) - return Response(serializer.data) + return Response(serializer.data, headers=headers) class UpdateModelMixin(object): @@ -122,6 +129,7 @@ class UpdateModelMixin(object): save_kwargs = {'force_insert': True} success_status_code = status.HTTP_201_CREATED else: + self.cache_precondition_check(self.object, request) created = False save_kwargs = {'force_update': True} success_status_code = status.HTTP_200_OK @@ -166,5 +174,6 @@ class DestroyModelMixin(object): """ def destroy(self, request, *args, **kwargs): obj = self.get_object() + self.cache_precondition_check(obj, request) obj.delete() return Response(status=status.HTTP_204_NO_CONTENT) diff --git a/rest_framework/settings.py b/rest_framework/settings.py index eede0c5a0..be687b86a 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -47,6 +47,8 @@ DEFAULTS = { ), 'DEFAULT_THROTTLE_CLASSES': ( ), + 'DEFAULT_CACHE_LOOKUP_CLASSES': ( + ), 'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation', diff --git a/rest_framework/views.py b/rest_framework/views.py index 81cbdcbb2..59b15e998 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -60,6 +60,7 @@ class APIView(View): throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS + cache_lookup_classes = api_settings.DEFAULT_CACHE_LOOKUP_CLASSES @classmethod def as_view(cls, **initkwargs): @@ -241,6 +242,12 @@ class APIView(View): self._negotiator = self.content_negotiation_class() return self._negotiator + def get_cache_lookups(self): + """ + Instantiates and returns the list of cache lookups that this view requires. + """ + return [cache_lookup() for cache_lookup in self.cache_lookup_classes] + # API policy implementation methods def perform_content_negotiation(self, request, force=False): @@ -294,6 +301,22 @@ class APIView(View): if not throttle.allow_request(request, self): self.throttled(request, throttle.wait()) + def check_preemptive_cache(self, request): + for cache_lookup in self.get_cache_lookups(): + cache_key = cache_lookup.get_cache_key(self.model, self.kwargs['pk']) + if cache_lookup.resource_unchanged(request, cache_key): + return Response(status=304) + + def get_cache_lookup_response_headers(self, obj): + headers = {} + for cache_lookup in self.get_cache_lookups(): + headers.update(cache_lookup.get_response_header(obj)) + return headers + + def cache_precondition_check(self, obj, request): + for cache_lookup in self.get_cache_lookups(): + cache_lookup.precondition_check(obj, request) + # Dispatch methods def initialize_request(self, request, *args, **kargs):