From 4f43804d42d521753ce0dee55941dd1e7f22f755 Mon Sep 17 00:00:00 2001 From: George Hickman Date: Mon, 1 Apr 2013 16:15:05 +0100 Subject: [PATCH] Move cache lookup code to a pluggable backend ETagCacheLookup is the first example implementation. --- rest_framework/cache_lookups.py | 42 +++++++++++++++++++++++++++++++ rest_framework/mixins.py | 15 ++++------- rest_framework/settings.py | 2 ++ rest_framework/views.py | 44 ++++++++++++++++++++++++++++----- 4 files changed, 87 insertions(+), 16 deletions(-) create mode 100644 rest_framework/cache_lookups.py diff --git a/rest_framework/cache_lookups.py b/rest_framework/cache_lookups.py new file mode 100644 index 000000000..4d51c426e --- /dev/null +++ b/rest_framework/cache_lookups.py @@ -0,0 +1,42 @@ +""" +Provides a set of pluggable cache policies. +""" +from django.core.cache import cache + + +class BaseCacheLookup(object): + def get_header(self, obj): + return {} + + def resource_unchanged(self, request): + """ + Return `False` if resource has changed, `True` otherwise. + """ + return False + + +class ETagCacheLookup(BaseCacheLookup): + """ + """ + etag_variable = 'etag' + + @staticmethod + def get_cache_key(cls, pk): + class_name = cls.__class__.__name__ + return 'etag-{}-{}'.format(class_name, pk) + + def get_etag(self, obj): + return getattr(obj, self.etag_variable) + + def get_header(self, obj): + key = self.get_cache_key(obj, 'pk') + etag = self.get_etag(obj) + cache.set(key, etag) + return {'ETag': etag} + + def resource_unchanged(self, request, key): + etag = cache.get(key) + header = request.META.get('HTTP_IF_NONE_MATCH') + if etag is not None and etag == header: + return True + return False diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index ada821754..a0fc51e35 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -8,7 +8,6 @@ from __future__ import unicode_literals from django.http import Http404 from rest_framework import status -from rest_framework.exceptions import PreconditionFailed from rest_framework.response import Response from rest_framework.request import clone_request @@ -101,11 +100,9 @@ class RetrieveModelMixin(object): queryset = self.get_queryset() filtered_queryset = self.filter_queryset(queryset) self.object = self.get_object(filtered_queryset) - headers = {} - if getattr(self, 'use_etags', False): - if self.get_etag(self.object) == request.META.get('HTTP_IF_NONE_MATCH'): - return Response(status=304) - headers.update({'ETag': self.get_etag(self.object)}) + + headers = self.get_cache_lookup_headers(self.object) + serializer = self.get_serializer(self.object) return Response(serializer.data, headers=headers) @@ -127,8 +124,7 @@ class UpdateModelMixin(object): created = True success_status_code = status.HTTP_201_CREATED else: - if getattr(self, 'use_etags', False) and self.object.etag != self.etag_header: - raise PreconditionFailed + self.cache_precondition_check(self.object, request) created = False success_status_code = status.HTTP_200_OK @@ -172,7 +168,6 @@ class DestroyModelMixin(object): """ def destroy(self, request, *args, **kwargs): obj = self.get_object() - if self.get_etag(obj) != self.etag_header: - raise PreconditionFailed + self.cache_precondition_check(obj, request) obj.delete() return Response(status=status.HTTP_204_NO_CONTENT) diff --git a/rest_framework/settings.py b/rest_framework/settings.py index eede0c5a0..be687b86a 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -47,6 +47,8 @@ DEFAULTS = { ), 'DEFAULT_THROTTLE_CLASSES': ( ), + 'DEFAULT_CACHE_LOOKUP_CLASSES': ( + ), 'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation', diff --git a/rest_framework/views.py b/rest_framework/views.py index 36f5a92fa..8e38d7e6a 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -60,6 +60,7 @@ class APIView(View): throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS + cache_lookup_classes = api_settings.DEFAULT_CACHE_LOOKUP_CLASSES @classmethod def as_view(cls, **initkwargs): @@ -241,6 +242,12 @@ class APIView(View): self._negotiator = self.content_negotiation_class() return self._negotiator + def get_cache_lookups(self): + """ + Instantiates and returns the list of cache lookups that this view requires. + """ + return [cache_lookup() for cache_lookup in self.cache_lookup_classes] + # API policy implementation methods def perform_content_negotiation(self, request, force=False): @@ -294,6 +301,33 @@ class APIView(View): if not throttle.allow_request(request, self): self.throttled(request, throttle.wait()) + def check_preemptive_cache(self, request): + for cache_lookup in self.get_cache_lookups(): + cache_key = cache_lookup.get_cache_key(self.model, self.kwargs['pk']) + if cache_lookup.resource_unchanged(request, cache_key): + return Response(status=304) + + def get_cache_lookup_headers(self, obj): + headers = {} + for cache_lookup in self.get_cache_lookups(): + headers.update(cache_lookup.get_header(obj)) + return headers + + def check_update_validity(self, request): + """ + """ + # TODO add setting to cover + # * raise IfMatchMissing when it's missing (if it's there, carry on) + # * continue regardless + if request.META.get('HTTP_IF_MATCH') is None: + raise exceptions.IfMatchMissing + + def cache_precondition_check(self, obj, request): + header = request.META.get('HTTP_IF_MATCH') + for cache_lookup in self.get_cache_lookups(): + if cache_lookup.get_etag(obj) != header: + raise exceptions.PreconditionFailed + # Dispatch methods def initialize_request(self, request, *args, **kargs): @@ -318,6 +352,7 @@ class APIView(View): self.perform_authentication(request) self.check_permissions(request) self.check_throttles(request) + self.check_preemptive_cache(request) # Perform content negotiation and store the accepted info on the request neg = self.perform_content_negotiation(request) @@ -398,10 +433,9 @@ class APIView(View): else: handler = self.http_method_not_allowed - if getattr(self, 'use_etags', False) and request.method.lower() in ('put', 'delete'): - self.etag_header = request.META.get('HTTP_IF_MATCH') - if self.etag_header is None: - raise exceptions.IfMatchMissing + if request.method.lower() in ('put', 'delete'): + # FIXME this method name isn't obvious + self.check_update_validity(request) response = handler(request, *args, **kwargs) @@ -418,5 +452,3 @@ class APIView(View): a less useful default implementation. """ return Response(self.metadata(request), status=status.HTTP_200_OK) - def get_etag(self, obj): - return getattr(obj, self.etag_var)