rename DjangoRestFrameworkApi to Api; update Api to be more extendible; add some more tests

This commit is contained in:
Craig Blaszczyk 2011-12-22 11:52:23 +00:00
parent 3eb6fe5ad6
commit c213ae21a6
4 changed files with 192 additions and 25 deletions

View File

@ -2,14 +2,13 @@ __version__ = '0.2.3'
VERSION = __version__ # synonym VERSION = __version__ # synonym
from djangorestframework.builtins import DjangoRestFrameworkApi from djangorestframework.builtins import Api
from django.utils.importlib import import_module
import imp import imp
__all__ = ('autodiscover','api', '__version__', 'VERSION') __all__ = ('autodiscover','api', '__version__', 'VERSION')
api = DjangoRestFrameworkApi() api = Api()
def autodiscover(): def autodiscover():
""" """
@ -22,7 +21,7 @@ def autodiscover():
from django.utils.importlib import import_module from django.utils.importlib import import_module
for app in settings.INSTALLED_APPS: for app in settings.INSTALLED_APPS:
# Attempt to import the app's gargoyle module. # Attempt to import the app's api module.
before_import_registry = copy.copy(api._registry) before_import_registry = copy.copy(api._registry)
try: try:
import_module('%s.api' % app) import_module('%s.api' % app)
@ -30,4 +29,5 @@ def autodiscover():
# Reset the model registry to the state before the last import as # Reset the model registry to the state before the last import as
# this import will have to reoccur on the next request and this # this import will have to reoccur on the next request and this
# could raise NotRegistered and AlreadyRegistered exceptions # could raise NotRegistered and AlreadyRegistered exceptions
# (see https://code.djangoproject.com/ticket/8245)
api._registry = before_import_registry api._registry = before_import_registry

View File

@ -3,7 +3,7 @@ from collections import defaultdict
class ApiEntry(object): class ApiEntry(object):
""" """
Hold information about a Resource in the api Hold a list of urlpatterns for a given Resource in the API
""" """
def __init__(self, resource, view, name, namespace=None): def __init__(self, resource, view, name, namespace=None):
@ -28,6 +28,12 @@ class ApiEntry(object):
) )
) )
elif issubclass(self.view, InstanceMixin): elif issubclass(self.view, InstanceMixin):
# This regex pattern is intentionally designed to match primary
# keys which are integers, letters or both.
# An improvement would be to infer the right primary key regex from
# the model in the resource, to prevent matching non-numeric
# primary keys in the URL when the model can only have numeric
# primary keys
urlpatterns = patterns('', urlpatterns = patterns('',
url(r'^%s/(?P<pk>[0-9a-zA-Z]+)/$' % (namespaced_name), url(r'^%s/(?P<pk>[0-9a-zA-Z]+)/$' % (namespaced_name),
self.view.as_view(resource=self.resource), self.view.as_view(resource=self.resource),
@ -40,13 +46,15 @@ class ApiEntry(object):
def urls(self): def urls(self):
return self.get_urls(), 'api', self.namespace return self.get_urls(), 'api', self.namespace
class DjangoRestFrameworkApi(object): class Api(object):
app_name = 'api' app_name = 'api'
namespace = 'api' namespace = 'api'
api_entry_class = ApiEntry
def __init__(self, *args, **kwargs): def __init__(self, api_entry_class=None):
self._registry = defaultdict(lambda: defaultdict(list)) self._registry = defaultdict(lambda: defaultdict(list))
super(DjangoRestFrameworkApi, self).__init__(*args, **kwargs) if api_entry_class is not None:
self.api_entry_class = api_entry_class
def register(self, view, resource, namespace=None, name=None): def register(self, view, resource, namespace=None, name=None):
""" """
@ -63,13 +71,18 @@ class DjangoRestFrameworkApi(object):
resource.api_name = name resource.api_name = name
api_entry = ApiEntry(resource, view, name, namespace) api_entry = self.get_api_entry(
resource=resource, view=view, name=name, namespace=namespace
)
self._registry[namespace][name].append(api_entry) self._registry[namespace][name].append(api_entry)
@property @property
def urls(self): def urls(self):
return self.get_urls(), self.app_name, self.namespace return self.get_urls(), self.app_name, self.namespace
def get_api_entry(self, resource, view, name, namespace):
return self.api_entry_class(resource, view, name, namespace)
def get_urls(self): def get_urls(self):
""" """
Return all of the urls for this API Return all of the urls for this API

View File

@ -1,38 +1,50 @@
from django.test.testcases import TestCase from django.test.testcases import TestCase
from djangorestframework.builtins import DjangoRestFrameworkApi from djangorestframework.builtins import Api, ApiEntry
from djangorestframework.resources import Resource, ModelResource from djangorestframework.resources import Resource, ModelResource
from djangorestframework.tests.models import Company, Employee from djangorestframework.tests.models import Company, Employee
from django.conf.urls.defaults import patterns, url, include from django.conf.urls.defaults import patterns, url, include
from djangorestframework.views import ListOrCreateModelView, InstanceModelView,\ from djangorestframework.views import ListOrCreateModelView, InstanceModelView,\
ListModelView ListModelView
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse, NoReverseMatch
import random
import string
from mock import Mock
__all__ = ('ApiTestCase',) __all__ = ('ApiTestCase', 'ApiEntryTestCase')
class CompanyResource(ModelResource): class CompanyResource(ModelResource):
model = Company model = Company
class EmployeeResource(ModelResource): class EmployeeResource(ModelResource):
model = Employee model = Employee
class UrlConfModule(object):
def __init__(self, api): class DummyUrlConfModule(object):
self.api = api
def __init__(self, object_with_urls):
self._object_with_urls = object_with_urls
def _get_urlpatterns(self): @property
return patterns('', def urlpatterns(self):
url(r'^', include(self.api.urls)), urlpatterns = patterns('',
) url(r'^', include(self._object_with_urls.urls)),
)
urlpatterns = property(_get_urlpatterns) return urlpatterns
class CustomApiEntry(ApiEntry):
def __init__(self, *args, **kwargs):
super(CustomApiEntry, self).__init__(*args, **kwargs)
self.name = 'custom'
class ApiTestCase(TestCase): class ApiTestCase(TestCase):
def setUp(self): def setUp(self):
self.api = DjangoRestFrameworkApi() self.api = Api()
self.urlconfmodule = UrlConfModule(self.api) self.urlconfmodule = DummyUrlConfModule(self.api)
def test_list_view(self): def test_list_view(self):
# Check that the URL gets registered # Check that the URL gets registered
@ -82,4 +94,145 @@ class ApiTestCase(TestCase):
reverse( reverse(
'api:abcdef:company_change', urlconf=self.urlconfmodule, 'api:abcdef:company_change', urlconf=self.urlconfmodule,
kwargs={'pk':company.id}, kwargs={'pk':company.id},
) )
def test_custom_api_entry_class_1(self):
"""
Ensure that an Api object which has a custom `api_entry_class` passed
to the constructor
"""
self.api = Api(api_entry_class=CustomApiEntry)
self.urlconfmodule = DummyUrlConfModule(self.api)
# Check that the URL gets registered
self.api.register(ListModelView, CompanyResource)
reverse('api:custom', urlconf=self.urlconfmodule)
self.assertRaises(
NoReverseMatch, reverse, 'api:company', urlconf=self.urlconfmodule
)
def test_custom_api_entry_class_2(self):
"""
Ensure that an Api object which has a custom `api_entry_class` assigned
to it uses it
"""
self.api = Api()
self.api.api_entry_class = CustomApiEntry
self.urlconfmodule = DummyUrlConfModule(self.api)
# Check that the URL gets registered
self.api.register(ListModelView, CompanyResource)
reverse('api:custom', urlconf=self.urlconfmodule)
self.assertRaises(
NoReverseMatch, reverse, 'api:company', urlconf=self.urlconfmodule
)
class ApiEntryTestCase(TestCase):
"""
Test the ApiEntry class
"""
def test_with_different_name(self):
"""
Ensure that the passed in name is used in the returned URL
"""
name = ''.join(random.choice(string.letters) for i in xrange(10))
api_entry = ApiEntry(
resource=CompanyResource, view=ListModelView, name=name
)
urls = api_entry.get_urls()
self.assertEqual(len(urls), 1)
self.assert_(urls[0].resolve('%s/' % (name)) is not None)
def test_list_model_view(self):
"""
Ensure that using a ListModelView returns only a url all objects
"""
api_entry = ApiEntry(
resource=CompanyResource, view=ListModelView, name='company'
)
urls = api_entry.get_urls()
self.assertEqual(len(urls), 1)
self.assert_(urls[0].resolve('company/') is not None)
self.assert_(urls[0].resolve('company/10/') is None)
self.assert_(urls[0].resolve('company/dasdsad/') is None)
def test_reverse_by_name_list_model_view(self):
"""
Ensure the created ListModelView URL patterns can be reversed by name
"""
api_entry = ApiEntry(
resource=CompanyResource, view=ListModelView, name='company'
)
# Setup the dummy urlconf module
urlconfmodule = DummyUrlConfModule(api_entry)
# Check that the URL gets registered with a name
reverse('company', urlconf=urlconfmodule)
def test_reverse_by_name_isntance_model_view(self):
"""
Ensure the created ListModelView URL patterns can be reversed by name
"""
api_entry = ApiEntry(
resource=CompanyResource, view=InstanceModelView, name='company'
)
# Setup the dummy urlconf module
urlconfmodule = DummyUrlConfModule(api_entry)
# Check that the URL gets registered with a name
reverse('company_change', kwargs={'pk': '10'}, urlconf=urlconfmodule)
reverse('company_change', kwargs={'pk': 'aaaaa'}, urlconf=urlconfmodule)
def test_instance_model_view(self):
"""
Ensure that using an InstanceModelView returns a url which requires a
primary key
"""
api_entry = ApiEntry(
resource=CompanyResource, view=InstanceModelView, name='company'
)
urls = api_entry.get_urls()
self.assertEqual(len(urls), 1)
self.assert_(urls[0].resolve('company/') is None)
self.assert_(urls[0].resolve('company/10/') is not None)
self.assert_(urls[0].resolve('company/abcde/') is not None)
def test_namespaced_names(self):
"""
Ensure that when a namespace gets passed into the ApiEntry, it is
reflected in the returned URL
"""
namespace = ''.join(random.choice(string.letters) for i in xrange(10))
# test list model view
api_entry = ApiEntry(
resource=CompanyResource, view=ListModelView, name='company',
namespace=namespace
)
urls = api_entry.get_urls()
self.assertEqual(len(urls), 1)
self.assert_(urls[0].resolve('%s/company/' % (namespace)) is not None)
self.assert_(urls[0].resolve('%s/company/10/' % (namespace)) is None)
self.assert_(
urls[0].resolve('%s/company/dasdsad/' % (namespace)) is None
)
# Test instance model view
api_entry = ApiEntry(
resource=CompanyResource, view=InstanceModelView, name='company',
namespace=namespace
)
urls = api_entry.get_urls()
self.assertEqual(len(urls), 1)
self.assert_(urls[0].resolve('%s/company/' % (namespace)) is None)
self.assert_(
urls[0].resolve('%s/company/10/' % (namespace)) is not None
)
self.assert_(
urls[0].resolve('%s/company/abcde/' % (namespace)) is not None
)

View File

@ -4,3 +4,4 @@
Django==1.2.4 Django==1.2.4
wsgiref==0.1.2 wsgiref==0.1.2
coverage==3.4 coverage==3.4
mock