Improved registry

This commit is contained in:
Syrus Akbary 2017-02-18 14:02:31 -08:00
parent 0ec8d2c828
commit b06e33ddd7
9 changed files with 108 additions and 14 deletions

View File

@ -6,10 +6,15 @@ from graphene_django import DjangoConnectionField, DjangoObjectType
from graphene_django.utils import DJANGO_FILTER_INSTALLED from graphene_django.utils import DJANGO_FILTER_INSTALLED
from ...tests.models import Reporter from ...tests.models import Reporter
from ...registry import reset_global_registry
from ..middleware import DjangoDebugMiddleware from ..middleware import DjangoDebugMiddleware
from ..types import DjangoDebug from ..types import DjangoDebug
def setup_function(function):
reset_global_registry()
class context(object): class context(object):
pass pass

View File

@ -9,6 +9,7 @@ from graphene_django.forms import (GlobalIDFormField,
GlobalIDMultipleChoiceField) GlobalIDMultipleChoiceField)
from graphene_django.tests.models import Article, Pet, Reporter from graphene_django.tests.models import Article, Pet, Reporter
from graphene_django.utils import DJANGO_FILTER_INSTALLED from graphene_django.utils import DJANGO_FILTER_INSTALLED
from graphene_django.registry import Registry, reset_global_registry
pytestmark = [] pytestmark = []
@ -24,6 +25,8 @@ pytestmark.append(pytest.mark.django_db)
if DJANGO_FILTER_INSTALLED: if DJANGO_FILTER_INSTALLED:
reset_global_registry()
class ArticleNode(DjangoObjectType): class ArticleNode(DjangoObjectType):
class Meta: class Meta:
@ -47,6 +50,10 @@ if DJANGO_FILTER_INSTALLED:
# schema = Schema() # schema = Schema()
@pytest.fixture
def _registry():
return Registry()
def get_args(field): def get_args(field):
return field.args return field.args
@ -134,26 +141,28 @@ def test_filter_shortcut_filterset_extra_meta():
assert 'headline' not in field.filterset_class.get_fields() assert 'headline' not in field.filterset_class.get_fields()
def test_filter_filterset_information_on_meta(): def test_filter_filterset_information_on_meta(_registry):
class ReporterFilterNode(DjangoObjectType): class ReporterFilterNode(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node, )
filter_fields = ['first_name', 'articles'] filter_fields = ['first_name', 'articles']
registry = _registry
field = DjangoFilterConnectionField(ReporterFilterNode) field = DjangoFilterConnectionField(ReporterFilterNode)
assert_arguments(field, 'first_name', 'articles') assert_arguments(field, 'first_name', 'articles')
assert_not_orderable(field) assert_not_orderable(field)
def test_filter_filterset_information_on_meta_related(): def test_filter_filterset_information_on_meta_related(_registry):
class ReporterFilterNode(DjangoObjectType): class ReporterFilterNode(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node, )
filter_fields = ['first_name', 'articles'] filter_fields = ['first_name', 'articles']
registry = _registry
class ArticleFilterNode(DjangoObjectType): class ArticleFilterNode(DjangoObjectType):
@ -161,6 +170,7 @@ def test_filter_filterset_information_on_meta_related():
model = Article model = Article
interfaces = (Node, ) interfaces = (Node, )
filter_fields = ['headline', 'reporter'] filter_fields = ['headline', 'reporter']
registry = _registry
class Query(ObjectType): class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode) all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
@ -174,13 +184,14 @@ def test_filter_filterset_information_on_meta_related():
assert_not_orderable(articles_field) assert_not_orderable(articles_field)
def test_filter_filterset_related_results(): def test_filter_filterset_related_results(_registry):
class ReporterFilterNode(DjangoObjectType): class ReporterFilterNode(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node, )
filter_fields = ['first_name', 'articles'] filter_fields = ['first_name', 'articles']
registry = _registry
class ArticleFilterNode(DjangoObjectType): class ArticleFilterNode(DjangoObjectType):
@ -188,6 +199,7 @@ def test_filter_filterset_related_results():
interfaces = (Node, ) interfaces = (Node, )
model = Article model = Article
filter_fields = ['headline', 'reporter'] filter_fields = ['headline', 'reporter']
registry = _registry
class Query(ObjectType): class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode) all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
@ -315,7 +327,7 @@ def test_global_id_multiple_field_explicit_reverse():
assert multiple_filter.field_class == GlobalIDMultipleChoiceField assert multiple_filter.field_class == GlobalIDMultipleChoiceField
def test_filter_filterset_related_results(): def test_filter_filterset_related_results(_registry):
class ReporterFilterNode(DjangoObjectType): class ReporterFilterNode(DjangoObjectType):
class Meta: class Meta:
@ -324,6 +336,7 @@ def test_filter_filterset_related_results():
filter_fields = { filter_fields = {
'first_name': ['icontains'] 'first_name': ['icontains']
} }
registry = _registry
class Query(ObjectType): class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode) all_reporters = DjangoFilterConnectionField(ReporterFilterNode)

View File

@ -6,15 +6,16 @@ class Registry(object):
def register(self, cls): def register(self, cls):
from .types import DjangoObjectType from .types import DjangoObjectType
model = cls._meta.model
assert self._registry.get(model, cls) == cls, (
'Django Model "{}.{}" already associated with {}. '
'You can use a different registry for {} or skip the global Registry with "{}.Meta.skip_global_registry = True".'
).format(model._meta.app_label, model._meta.object_name, repr(self.get_type_for_model(cls._meta.model)), repr(cls), cls)
assert issubclass( assert issubclass(
cls, DjangoObjectType), 'Only DjangoObjectTypes can be registered, received "{}"'.format( cls, DjangoObjectType), 'Only DjangoObjectTypes can be registered, received "{}"'.format(
cls.__name__) cls.__name__)
assert cls._meta.registry == self, 'Registry for a Model have to match.' assert cls._meta.registry == self, 'Registry for a Model have to match.'
# assert self.get_type_for_model(cls._meta.model) == cls, ( self._registry[cls._meta.model] = cls
# 'Multiple DjangoObjectTypes registered for "{}"'.format(cls._meta.model)
# )
if not getattr(cls._meta, 'skip_registry', False):
self._registry[cls._meta.model] = cls
def get_type_for_model(self, model): def get_type_for_model(self, model):
return self._registry.get(model) return self._registry.get(model)

View File

@ -11,12 +11,13 @@ from graphene.types.json import JSONString
from ..compat import (ArrayField, HStoreField, JSONField, MissingType, from ..compat import (ArrayField, HStoreField, JSONField, MissingType,
RangeField, UUIDField, DurationField) RangeField, UUIDField, DurationField)
from ..converter import convert_django_field, convert_django_field_with_choices from ..converter import convert_django_field, convert_django_field_with_choices
from ..registry import Registry from ..registry import Registry, reset_global_registry
from ..types import DjangoObjectType from ..types import DjangoObjectType
from .models import Article, Film, FilmDetails, Reporter from .models import Article, Film, FilmDetails, Reporter
# from graphene.core.types.custom_scalars import DateTime, Time, JSONString def setup_function(function):
reset_global_registry()
def assert_conversion(django_field, graphene_field, *args, **kwargs): def assert_conversion(django_field, graphene_field, *args, **kwargs):

View File

@ -12,11 +12,16 @@ from ..utils import DJANGO_FILTER_INSTALLED
from ..compat import MissingType, RangeField from ..compat import MissingType, RangeField
from ..fields import DjangoConnectionField from ..fields import DjangoConnectionField
from ..types import DjangoObjectType from ..types import DjangoObjectType
from ..registry import reset_global_registry
from .models import Article, Reporter from .models import Article, Reporter
pytestmark = pytest.mark.django_db pytestmark = pytest.mark.django_db
def setup_function(function):
reset_global_registry()
def test_should_query_only_fields(): def test_should_query_only_fields():
with raises(Exception): with raises(Exception):
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):

View File

@ -0,0 +1,56 @@
from pytest import raises
from ..registry import Registry, get_global_registry, reset_global_registry
from ..types import DjangoObjectType
from .models import Reporter as ReporterModel
def setup_function(function):
reset_global_registry()
def test_registry_basic():
global_registry = get_global_registry()
class Reporter(DjangoObjectType):
'''Reporter description'''
class Meta:
model = ReporterModel
assert Reporter._meta.registry == global_registry
assert global_registry.get_type_for_model(ReporterModel) == Reporter
def test_registry_multiple_types():
class Reporter(DjangoObjectType):
'''Reporter description'''
class Meta:
model = ReporterModel
with raises(Exception) as exc_info:
class Reporter2(DjangoObjectType):
'''Reporter2 description'''
class Meta:
model = ReporterModel
assert str(exc_info.value) == (
'Django Model "tests.Reporter" already associated with <class \'graphene_django.tests.test_registry.Reporter\'>. '
'You can use a different registry for <class \'graphene_django.tests.test_registry.Reporter2\'> '
'or skip the global Registry with "Reporter2.Meta.skip_global_registry = True".'
)
def test_registry_multiple_types_dont_collision_if_skip_global_registry():
class Reporter(DjangoObjectType):
'''Reporter description'''
class Meta:
model = ReporterModel
class Reporter2(DjangoObjectType):
'''Reporter2 description'''
class Meta:
model = ReporterModel
skip_global_registry = True
assert Reporter._meta.registry != Reporter2._meta.registry
assert Reporter2._meta.registry != get_global_registry()

View File

@ -1,10 +1,14 @@
from py.test import raises from py.test import raises
from ..registry import Registry from ..registry import Registry, reset_global_registry
from ..types import DjangoObjectType from ..types import DjangoObjectType
from .models import Reporter from .models import Reporter
def setup_function(function):
reset_global_registry()
def test_should_raise_if_no_model(): def test_should_raise_if_no_model():
with raises(Exception) as excinfo: with raises(Exception) as excinfo:
class Character1(DjangoObjectType): class Character1(DjangoObjectType):

View File

@ -3,11 +3,12 @@ from mock import patch
from graphene import Interface, ObjectType, Schema from graphene import Interface, ObjectType, Schema
from graphene.relay import Node from graphene.relay import Node
from ..registry import reset_global_registry from ..registry import Registry, reset_global_registry
from ..types import DjangoObjectType from ..types import DjangoObjectType
from .models import Article as ArticleModel from .models import Article as ArticleModel
from .models import Reporter as ReporterModel from .models import Reporter as ReporterModel
reset_global_registry() reset_global_registry()

View File

@ -58,7 +58,7 @@ class DjangoObjectTypeMeta(ObjectTypeMeta):
only_fields=(), only_fields=(),
exclude_fields=(), exclude_fields=(),
interfaces=(), interfaces=(),
skip_registry=False, skip_global_registry=False,
registry=None registry=None
) )
if DJANGO_FILTER_INSTALLED: if DJANGO_FILTER_INSTALLED:
@ -72,6 +72,14 @@ class DjangoObjectTypeMeta(ObjectTypeMeta):
attrs.pop('Meta', None), attrs.pop('Meta', None),
**defaults **defaults
) )
# If the DjangoObjectType wants to skip the registry
# we will automatically create one, so the model is isolated
# there.
if options.skip_global_registry:
assert not options.registry, (
"The attribute skip_global_registry requires have an empty registry in {}.Meta"
).format(name)
options.registry = Registry()
if not options.registry: if not options.registry:
options.registry = get_global_registry() options.registry = get_global_registry()
assert isinstance(options.registry, Registry), ( assert isinstance(options.registry, Registry), (