From ac940b93090e75d884390b35eacdfd729d860584 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Mon, 28 Sep 2015 23:29:10 -0700 Subject: [PATCH] Improved Django integration with relations --- graphene/contrib/django/converter.py | 6 ++- graphene/contrib/django/fields.py | 13 +++++- graphene/contrib/django/options.py | 1 + graphene/contrib/django/types.py | 4 +- graphene/core/fields.py | 2 + graphene/core/options.py | 10 +---- graphene/relay/__init__.py | 2 + graphene/relay/connections.py | 4 +- graphene/relay/utils.py | 8 +++- tests/contrib_django/test_schema.py | 59 +++++++++++++++++++++++----- tests/starwars_relay/schema.py | 9 +++-- 11 files changed, 88 insertions(+), 30 deletions(-) diff --git a/graphene/contrib/django/converter.py b/graphene/contrib/django/converter.py index 6611d356..663cd11c 100644 --- a/graphene/contrib/django/converter.py +++ b/graphene/contrib/django/converter.py @@ -46,7 +46,11 @@ def _(field, cls): @convert_django_field.register(models.ManyToOneRel) def _(field, cls): - return ListField(DjangoModelField(field.related_model)) + schema = cls._meta.schema + model_field = DjangoModelField(field.related_model) + if issubclass(cls, schema.relay.Node): + return schema.relay.ConnectionField(model_field) + return ListField(model_field) @convert_django_field.register(models.ForeignKey) diff --git a/graphene/contrib/django/fields.py b/graphene/contrib/django/fields.py index 683f245a..992614c1 100644 --- a/graphene/contrib/django/fields.py +++ b/graphene/contrib/django/fields.py @@ -10,7 +10,7 @@ def get_type_for_model(schema, model): for _type in types: type_model = getattr(_type._meta, 'model', None) if model == type_model: - return _type._meta.type + return _type class DjangoModelField(Field): @@ -20,4 +20,13 @@ class DjangoModelField(Field): @cached_property def type(self): - return get_type_for_model(self.schema, self.model) + _type = self.get_object_type() + return _type and _type._meta.type + + def get_object_type(self): + _type = get_type_for_model(self.schema, self.model) + if not _type and self.object_type._meta.only_fields: + # We will only raise the exception if the related field is specified in only_fields + raise Exception("Field %s (%s) model not mapped in current schema" % (self, self.model._meta.object_name)) + + return _type diff --git a/graphene/contrib/django/options.py b/graphene/contrib/django/options.py index 381914f0..4560080a 100644 --- a/graphene/contrib/django/options.py +++ b/graphene/contrib/django/options.py @@ -5,6 +5,7 @@ from graphene.core.options import Options VALID_ATTRS = ('model', 'only_fields') + class DjangoOptions(Options): def __init__(self, *args, **kwargs): self.model = None diff --git a/graphene/contrib/django/types.py b/graphene/contrib/django/types.py index dbde3aa9..cd6831be 100644 --- a/graphene/contrib/django/types.py +++ b/graphene/contrib/django/types.py @@ -7,6 +7,7 @@ from graphene.contrib.django.converter import convert_django_field from graphene.relay import Node + def get_reverse_fields(model): for name, attr in model.__dict__.items(): related = getattr(attr, 'related', None) @@ -16,12 +17,11 @@ def get_reverse_fields(model): class DjangoObjectTypeMeta(ObjectTypeMeta): options_cls = DjangoOptions + def add_extra_fields(cls): if not cls._meta.model: return - only_fields = cls._meta.only_fields - # print cls._meta.model._meta._get_fields(forward=False, reverse=True, include_hidden=True) reverse_fields = tuple(get_reverse_fields(cls._meta.model)) for field in cls._meta.model._meta.fields + reverse_fields: if only_fields and field.name not in only_fields: diff --git a/graphene/core/fields.py b/graphene/core/fields.py index 254c2f02..bd27aaaf 100644 --- a/graphene/core/fields.py +++ b/graphene/core/fields.py @@ -48,6 +48,8 @@ class Field(object): def get_object_type(self): field_type = self.field_type _is_class = inspect.isclass(field_type) + if isinstance(field_type, Field): + return field_type.get_object_type() if _is_class and issubclass(field_type, ObjectType): return field_type elif isinstance(field_type, basestring): diff --git a/graphene/core/options.py b/graphene/core/options.py index 101c8a22..6799e517 100644 --- a/graphene/core/options.py +++ b/graphene/core/options.py @@ -11,18 +11,10 @@ class Options(object): self.local_fields = [] self.interface = False self.proxy = False - self.schema = schema + self.schema = schema or get_global_schema() self.interfaces = [] self.parents = [] self.valid_attrs = DEFAULT_NAMES - - # @property - # def schema(self): - # return self._schema or get_global_schema() - - # @schema.setter - # def schema(self, schema): - # self._schema = schema def contribute_to_class(self, cls, name): cls._meta = self diff --git a/graphene/relay/__init__.py b/graphene/relay/__init__.py index 4e353804..76020a69 100644 --- a/graphene/relay/__init__.py +++ b/graphene/relay/__init__.py @@ -13,8 +13,10 @@ from graphene.relay.relay import ( ) from graphene.env import get_global_schema +from graphene.relay.utils import setup schema = get_global_schema() +setup(schema) relay = schema.relay Node, NodeField = relay.Node, relay.NodeField diff --git a/graphene/relay/connections.py b/graphene/relay/connections.py index 83d73af5..db241bb8 100644 --- a/graphene/relay/connections.py +++ b/graphene/relay/connections.py @@ -7,7 +7,7 @@ from graphql_relay.connection.connection import ( from graphene import signals from graphene.core.fields import NativeField -from graphene.relay.utils import get_relay +from graphene.relay.utils import get_relay, setup from graphene.relay.relay import Relay @@ -28,4 +28,4 @@ def object_type_created(object_type): @signals.init_schema.connect def schema_created(schema): - setattr(schema, 'relay', Relay(schema)) + setup(schema) diff --git a/graphene/relay/utils.py b/graphene/relay/utils.py index cd23632d..2974fbab 100644 --- a/graphene/relay/utils.py +++ b/graphene/relay/utils.py @@ -1,3 +1,9 @@ - def get_relay(schema): return getattr(schema, 'relay', None) + + +def setup(schema): + from graphene.relay.relay import Relay + if not hasattr(schema, 'relay'): + return setattr(schema, 'relay', Relay(schema)) + return schema diff --git a/tests/contrib_django/test_schema.py b/tests/contrib_django/test_schema.py index 5dbf093d..34617ddf 100644 --- a/tests/contrib_django/test_schema.py +++ b/tests/contrib_django/test_schema.py @@ -25,12 +25,31 @@ def test_should_raise_if_model_is_invalid(): assert 'not a Django model' in str(excinfo.value) +def test_should_raise_if_model_is_invalid(): + with raises(Exception) as excinfo: + class ReporterTypeError(DjangoObjectType): + class Meta: + model = Reporter + only_fields = ('articles', ) + + schema = graphene.Schema(query=ReporterTypeError) + query = ''' + query ReporterQuery { + articles + } + ''' + result = schema.execute(query) + assert not result.errors + + assert 'articles (Article) model not mapped in current schema' in str(excinfo.value) + + def test_should_map_fields(): class ReporterType(DjangoObjectType): class Meta: model = Reporter - class Query(graphene.ObjectType): + class Query2(graphene.ObjectType): reporter = graphene.Field(ReporterType) def resolve_reporter(self, *args, **kwargs): @@ -52,7 +71,7 @@ def test_should_map_fields(): 'email': '' } } - Schema = graphene.Schema(query=Query) + Schema = graphene.Schema(query=Query2) result = Schema.execute(query) assert not result.errors assert result.data == expected @@ -74,15 +93,18 @@ def test_should_node(): def get_node(cls, id): return ReporterNodeType(Reporter(id=2, first_name='Cookie Monster')) + def resolve_articles(self, *args, **kwargs): + return [ArticleNodeType(Article(headline='Hi!'))] + class ArticleNodeType(DjangoNode): class Meta: model = Article @classmethod def get_node(cls, id): - return ArticleNodeType(None) + return ArticleNodeType(Article(id=1, headline='Article node')) - class Query(graphene.ObjectType): + class Query1(graphene.ObjectType): node = relay.NodeField() reporter = graphene.Field(ReporterNodeType) @@ -94,14 +116,24 @@ def test_should_node(): reporter { id, first_name, + articles { + edges { + node { + headline + } + } + } last_name, email } - aCustomNode: node(id:"UmVwb3J0ZXJOb2RlVHlwZToy") { + my_article: node(id:"QXJ0aWNsZU5vZGVUeXBlOjE=") { id ... on ReporterNodeType { first_name } + ... on ArticleNodeType { + headline + } } } ''' @@ -110,14 +142,21 @@ def test_should_node(): 'id': 'UmVwb3J0ZXJOb2RlVHlwZTox', 'first_name': 'ABA', 'last_name': 'X', - 'email': '' + 'email': '', + 'articles': { + 'edges': [{ + 'node': { + 'headline': 'Hi!' + } + }] + }, }, - 'aCustomNode': { - 'id': 'UmVwb3J0ZXJOb2RlVHlwZToy', - 'first_name': 'Cookie Monster' + 'my_article': { + 'id': 'QXJ0aWNsZU5vZGVUeXBlOjE=', + 'headline': 'Article node' } } - Schema = graphene.Schema(query=Query) + Schema = graphene.Schema(query=Query1) result = Schema.execute(query) assert not result.errors assert result.data == expected diff --git a/tests/starwars_relay/schema.py b/tests/starwars_relay/schema.py index 8731f712..4bcdb87d 100644 --- a/tests/starwars_relay/schema.py +++ b/tests/starwars_relay/schema.py @@ -1,5 +1,5 @@ import graphene -from graphene import resolve_only_args, relay +from graphene import resolve_only_args from .data import ( getFaction, @@ -8,6 +8,9 @@ from .data import ( getEmpire, ) +schema = graphene.Schema(name='Starwars Relay Schema') +relay = schema.relay + class Ship(relay.Node): '''A ship in the Star Wars saga''' name = graphene.StringField(description='The name of the ship.') @@ -31,7 +34,7 @@ class Faction(relay.Node): return Faction(getFaction(id)) -class Query(graphene.ObjectType): +class Query(schema.ObjectType): rebels = graphene.Field(Faction) empire = graphene.Field(Faction) node = relay.NodeField() @@ -45,4 +48,4 @@ class Query(graphene.ObjectType): return Faction(getEmpire()) -schema = graphene.Schema(query=Query, name='Starwars Relay Schema') +schema.query = Query