rename DjangoRestFrameworkApi to Api; update Api to be more extendible; add some more tests

This commit is contained in:
Craig Blaszczyk 2011-12-22 11:52:23 +00:00
parent 3eb6fe5ad6
commit c213ae21a6
4 changed files with 192 additions and 25 deletions

View File

@ -2,14 +2,13 @@ __version__ = '0.2.3'
VERSION = __version__ # synonym
from djangorestframework.builtins import DjangoRestFrameworkApi
from django.utils.importlib import import_module
from djangorestframework.builtins import Api
import imp
__all__ = ('autodiscover','api', '__version__', 'VERSION')
api = DjangoRestFrameworkApi()
api = Api()
def autodiscover():
"""
@ -22,7 +21,7 @@ def autodiscover():
from django.utils.importlib import import_module
for app in settings.INSTALLED_APPS:
# Attempt to import the app's gargoyle module.
# Attempt to import the app's api module.
before_import_registry = copy.copy(api._registry)
try:
import_module('%s.api' % app)
@ -30,4 +29,5 @@ def autodiscover():
# Reset the model registry to the state before the last import as
# this import will have to reoccur on the next request and this
# could raise NotRegistered and AlreadyRegistered exceptions
# (see https://code.djangoproject.com/ticket/8245)
api._registry = before_import_registry

View File

@ -3,7 +3,7 @@ from collections import defaultdict
class ApiEntry(object):
"""
Hold information about a Resource in the api
Hold a list of urlpatterns for a given Resource in the API
"""
def __init__(self, resource, view, name, namespace=None):
@ -28,6 +28,12 @@ class ApiEntry(object):
)
)
elif issubclass(self.view, InstanceMixin):
# This regex pattern is intentionally designed to match primary
# keys which are integers, letters or both.
# An improvement would be to infer the right primary key regex from
# the model in the resource, to prevent matching non-numeric
# primary keys in the URL when the model can only have numeric
# primary keys
urlpatterns = patterns('',
url(r'^%s/(?P<pk>[0-9a-zA-Z]+)/$' % (namespaced_name),
self.view.as_view(resource=self.resource),
@ -40,13 +46,15 @@ class ApiEntry(object):
def urls(self):
return self.get_urls(), 'api', self.namespace
class DjangoRestFrameworkApi(object):
class Api(object):
app_name = 'api'
namespace = 'api'
api_entry_class = ApiEntry
def __init__(self, *args, **kwargs):
def __init__(self, api_entry_class=None):
self._registry = defaultdict(lambda: defaultdict(list))
super(DjangoRestFrameworkApi, self).__init__(*args, **kwargs)
if api_entry_class is not None:
self.api_entry_class = api_entry_class
def register(self, view, resource, namespace=None, name=None):
"""
@ -63,13 +71,18 @@ class DjangoRestFrameworkApi(object):
resource.api_name = name
api_entry = ApiEntry(resource, view, name, namespace)
api_entry = self.get_api_entry(
resource=resource, view=view, name=name, namespace=namespace
)
self._registry[namespace][name].append(api_entry)
@property
def urls(self):
return self.get_urls(), self.app_name, self.namespace
def get_api_entry(self, resource, view, name, namespace):
return self.api_entry_class(resource, view, name, namespace)
def get_urls(self):
"""
Return all of the urls for this API

View File

@ -1,38 +1,50 @@
from django.test.testcases import TestCase
from djangorestframework.builtins import DjangoRestFrameworkApi
from djangorestframework.builtins import Api, ApiEntry
from djangorestframework.resources import Resource, ModelResource
from djangorestframework.tests.models import Company, Employee
from django.conf.urls.defaults import patterns, url, include
from djangorestframework.views import ListOrCreateModelView, InstanceModelView,\
ListModelView
from django.core.urlresolvers import reverse
from django.core.urlresolvers import reverse, NoReverseMatch
import random
import string
from mock import Mock
__all__ = ('ApiTestCase',)
__all__ = ('ApiTestCase', 'ApiEntryTestCase')
class CompanyResource(ModelResource):
model = Company
class EmployeeResource(ModelResource):
model = Employee
class UrlConfModule(object):
def __init__(self, api):
self.api = api
class DummyUrlConfModule(object):
def __init__(self, object_with_urls):
self._object_with_urls = object_with_urls
def _get_urlpatterns(self):
return patterns('',
url(r'^', include(self.api.urls)),
)
urlpatterns = property(_get_urlpatterns)
@property
def urlpatterns(self):
urlpatterns = patterns('',
url(r'^', include(self._object_with_urls.urls)),
)
return urlpatterns
class CustomApiEntry(ApiEntry):
def __init__(self, *args, **kwargs):
super(CustomApiEntry, self).__init__(*args, **kwargs)
self.name = 'custom'
class ApiTestCase(TestCase):
def setUp(self):
self.api = DjangoRestFrameworkApi()
self.urlconfmodule = UrlConfModule(self.api)
self.api = Api()
self.urlconfmodule = DummyUrlConfModule(self.api)
def test_list_view(self):
# Check that the URL gets registered
@ -82,4 +94,145 @@ class ApiTestCase(TestCase):
reverse(
'api:abcdef:company_change', urlconf=self.urlconfmodule,
kwargs={'pk':company.id},
)
)
def test_custom_api_entry_class_1(self):
"""
Ensure that an Api object which has a custom `api_entry_class` passed
to the constructor
"""
self.api = Api(api_entry_class=CustomApiEntry)
self.urlconfmodule = DummyUrlConfModule(self.api)
# Check that the URL gets registered
self.api.register(ListModelView, CompanyResource)
reverse('api:custom', urlconf=self.urlconfmodule)
self.assertRaises(
NoReverseMatch, reverse, 'api:company', urlconf=self.urlconfmodule
)
def test_custom_api_entry_class_2(self):
"""
Ensure that an Api object which has a custom `api_entry_class` assigned
to it uses it
"""
self.api = Api()
self.api.api_entry_class = CustomApiEntry
self.urlconfmodule = DummyUrlConfModule(self.api)
# Check that the URL gets registered
self.api.register(ListModelView, CompanyResource)
reverse('api:custom', urlconf=self.urlconfmodule)
self.assertRaises(
NoReverseMatch, reverse, 'api:company', urlconf=self.urlconfmodule
)
class ApiEntryTestCase(TestCase):
"""
Test the ApiEntry class
"""
def test_with_different_name(self):
"""
Ensure that the passed in name is used in the returned URL
"""
name = ''.join(random.choice(string.letters) for i in xrange(10))
api_entry = ApiEntry(
resource=CompanyResource, view=ListModelView, name=name
)
urls = api_entry.get_urls()
self.assertEqual(len(urls), 1)
self.assert_(urls[0].resolve('%s/' % (name)) is not None)
def test_list_model_view(self):
"""
Ensure that using a ListModelView returns only a url all objects
"""
api_entry = ApiEntry(
resource=CompanyResource, view=ListModelView, name='company'
)
urls = api_entry.get_urls()
self.assertEqual(len(urls), 1)
self.assert_(urls[0].resolve('company/') is not None)
self.assert_(urls[0].resolve('company/10/') is None)
self.assert_(urls[0].resolve('company/dasdsad/') is None)
def test_reverse_by_name_list_model_view(self):
"""
Ensure the created ListModelView URL patterns can be reversed by name
"""
api_entry = ApiEntry(
resource=CompanyResource, view=ListModelView, name='company'
)
# Setup the dummy urlconf module
urlconfmodule = DummyUrlConfModule(api_entry)
# Check that the URL gets registered with a name
reverse('company', urlconf=urlconfmodule)
def test_reverse_by_name_isntance_model_view(self):
"""
Ensure the created ListModelView URL patterns can be reversed by name
"""
api_entry = ApiEntry(
resource=CompanyResource, view=InstanceModelView, name='company'
)
# Setup the dummy urlconf module
urlconfmodule = DummyUrlConfModule(api_entry)
# Check that the URL gets registered with a name
reverse('company_change', kwargs={'pk': '10'}, urlconf=urlconfmodule)
reverse('company_change', kwargs={'pk': 'aaaaa'}, urlconf=urlconfmodule)
def test_instance_model_view(self):
"""
Ensure that using an InstanceModelView returns a url which requires a
primary key
"""
api_entry = ApiEntry(
resource=CompanyResource, view=InstanceModelView, name='company'
)
urls = api_entry.get_urls()
self.assertEqual(len(urls), 1)
self.assert_(urls[0].resolve('company/') is None)
self.assert_(urls[0].resolve('company/10/') is not None)
self.assert_(urls[0].resolve('company/abcde/') is not None)
def test_namespaced_names(self):
"""
Ensure that when a namespace gets passed into the ApiEntry, it is
reflected in the returned URL
"""
namespace = ''.join(random.choice(string.letters) for i in xrange(10))
# test list model view
api_entry = ApiEntry(
resource=CompanyResource, view=ListModelView, name='company',
namespace=namespace
)
urls = api_entry.get_urls()
self.assertEqual(len(urls), 1)
self.assert_(urls[0].resolve('%s/company/' % (namespace)) is not None)
self.assert_(urls[0].resolve('%s/company/10/' % (namespace)) is None)
self.assert_(
urls[0].resolve('%s/company/dasdsad/' % (namespace)) is None
)
# Test instance model view
api_entry = ApiEntry(
resource=CompanyResource, view=InstanceModelView, name='company',
namespace=namespace
)
urls = api_entry.get_urls()
self.assertEqual(len(urls), 1)
self.assert_(urls[0].resolve('%s/company/' % (namespace)) is None)
self.assert_(
urls[0].resolve('%s/company/10/' % (namespace)) is not None
)
self.assert_(
urls[0].resolve('%s/company/abcde/' % (namespace)) is not None
)

View File

@ -4,3 +4,4 @@
Django==1.2.4
wsgiref==0.1.2
coverage==3.4
mock