Improved Django integration with relations

This commit is contained in:
Syrus Akbary 2015-09-28 23:29:10 -07:00
parent d0285278ac
commit ac940b9309
11 changed files with 88 additions and 30 deletions

View File

@ -46,7 +46,11 @@ def _(field, cls):
@convert_django_field.register(models.ManyToOneRel) @convert_django_field.register(models.ManyToOneRel)
def _(field, cls): def _(field, cls):
return ListField(DjangoModelField(field.related_model)) schema = cls._meta.schema
model_field = DjangoModelField(field.related_model)
if issubclass(cls, schema.relay.Node):
return schema.relay.ConnectionField(model_field)
return ListField(model_field)
@convert_django_field.register(models.ForeignKey) @convert_django_field.register(models.ForeignKey)

View File

@ -10,7 +10,7 @@ def get_type_for_model(schema, model):
for _type in types: for _type in types:
type_model = getattr(_type._meta, 'model', None) type_model = getattr(_type._meta, 'model', None)
if model == type_model: if model == type_model:
return _type._meta.type return _type
class DjangoModelField(Field): class DjangoModelField(Field):
@ -20,4 +20,13 @@ class DjangoModelField(Field):
@cached_property @cached_property
def type(self): def type(self):
return get_type_for_model(self.schema, self.model) _type = self.get_object_type()
return _type and _type._meta.type
def get_object_type(self):
_type = get_type_for_model(self.schema, self.model)
if not _type and self.object_type._meta.only_fields:
# We will only raise the exception if the related field is specified in only_fields
raise Exception("Field %s (%s) model not mapped in current schema" % (self, self.model._meta.object_name))
return _type

View File

@ -5,6 +5,7 @@ from graphene.core.options import Options
VALID_ATTRS = ('model', 'only_fields') VALID_ATTRS = ('model', 'only_fields')
class DjangoOptions(Options): class DjangoOptions(Options):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.model = None self.model = None

View File

@ -7,6 +7,7 @@ from graphene.contrib.django.converter import convert_django_field
from graphene.relay import Node from graphene.relay import Node
def get_reverse_fields(model): def get_reverse_fields(model):
for name, attr in model.__dict__.items(): for name, attr in model.__dict__.items():
related = getattr(attr, 'related', None) related = getattr(attr, 'related', None)
@ -16,12 +17,11 @@ def get_reverse_fields(model):
class DjangoObjectTypeMeta(ObjectTypeMeta): class DjangoObjectTypeMeta(ObjectTypeMeta):
options_cls = DjangoOptions options_cls = DjangoOptions
def add_extra_fields(cls): def add_extra_fields(cls):
if not cls._meta.model: if not cls._meta.model:
return return
only_fields = cls._meta.only_fields only_fields = cls._meta.only_fields
# print cls._meta.model._meta._get_fields(forward=False, reverse=True, include_hidden=True)
reverse_fields = tuple(get_reverse_fields(cls._meta.model)) reverse_fields = tuple(get_reverse_fields(cls._meta.model))
for field in cls._meta.model._meta.fields + reverse_fields: for field in cls._meta.model._meta.fields + reverse_fields:
if only_fields and field.name not in only_fields: if only_fields and field.name not in only_fields:

View File

@ -48,6 +48,8 @@ class Field(object):
def get_object_type(self): def get_object_type(self):
field_type = self.field_type field_type = self.field_type
_is_class = inspect.isclass(field_type) _is_class = inspect.isclass(field_type)
if isinstance(field_type, Field):
return field_type.get_object_type()
if _is_class and issubclass(field_type, ObjectType): if _is_class and issubclass(field_type, ObjectType):
return field_type return field_type
elif isinstance(field_type, basestring): elif isinstance(field_type, basestring):

View File

@ -11,19 +11,11 @@ class Options(object):
self.local_fields = [] self.local_fields = []
self.interface = False self.interface = False
self.proxy = False self.proxy = False
self.schema = schema self.schema = schema or get_global_schema()
self.interfaces = [] self.interfaces = []
self.parents = [] self.parents = []
self.valid_attrs = DEFAULT_NAMES self.valid_attrs = DEFAULT_NAMES
# @property
# def schema(self):
# return self._schema or get_global_schema()
# @schema.setter
# def schema(self, schema):
# self._schema = schema
def contribute_to_class(self, cls, name): def contribute_to_class(self, cls, name):
cls._meta = self cls._meta = self
self.parent = cls self.parent = cls

View File

@ -13,8 +13,10 @@ from graphene.relay.relay import (
) )
from graphene.env import get_global_schema from graphene.env import get_global_schema
from graphene.relay.utils import setup
schema = get_global_schema() schema = get_global_schema()
setup(schema)
relay = schema.relay relay = schema.relay
Node, NodeField = relay.Node, relay.NodeField Node, NodeField = relay.Node, relay.NodeField

View File

@ -7,7 +7,7 @@ from graphql_relay.connection.connection import (
from graphene import signals from graphene import signals
from graphene.core.fields import NativeField from graphene.core.fields import NativeField
from graphene.relay.utils import get_relay from graphene.relay.utils import get_relay, setup
from graphene.relay.relay import Relay from graphene.relay.relay import Relay
@ -28,4 +28,4 @@ def object_type_created(object_type):
@signals.init_schema.connect @signals.init_schema.connect
def schema_created(schema): def schema_created(schema):
setattr(schema, 'relay', Relay(schema)) setup(schema)

View File

@ -1,3 +1,9 @@
def get_relay(schema): def get_relay(schema):
return getattr(schema, 'relay', None) return getattr(schema, 'relay', None)
def setup(schema):
from graphene.relay.relay import Relay
if not hasattr(schema, 'relay'):
return setattr(schema, 'relay', Relay(schema))
return schema

View File

@ -25,12 +25,31 @@ def test_should_raise_if_model_is_invalid():
assert 'not a Django model' in str(excinfo.value) assert 'not a Django model' in str(excinfo.value)
def test_should_raise_if_model_is_invalid():
with raises(Exception) as excinfo:
class ReporterTypeError(DjangoObjectType):
class Meta:
model = Reporter
only_fields = ('articles', )
schema = graphene.Schema(query=ReporterTypeError)
query = '''
query ReporterQuery {
articles
}
'''
result = schema.execute(query)
assert not result.errors
assert 'articles (Article) model not mapped in current schema' in str(excinfo.value)
def test_should_map_fields(): def test_should_map_fields():
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
class Query(graphene.ObjectType): class Query2(graphene.ObjectType):
reporter = graphene.Field(ReporterType) reporter = graphene.Field(ReporterType)
def resolve_reporter(self, *args, **kwargs): def resolve_reporter(self, *args, **kwargs):
@ -52,7 +71,7 @@ def test_should_map_fields():
'email': '' 'email': ''
} }
} }
Schema = graphene.Schema(query=Query) Schema = graphene.Schema(query=Query2)
result = Schema.execute(query) result = Schema.execute(query)
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
@ -74,15 +93,18 @@ def test_should_node():
def get_node(cls, id): def get_node(cls, id):
return ReporterNodeType(Reporter(id=2, first_name='Cookie Monster')) return ReporterNodeType(Reporter(id=2, first_name='Cookie Monster'))
def resolve_articles(self, *args, **kwargs):
return [ArticleNodeType(Article(headline='Hi!'))]
class ArticleNodeType(DjangoNode): class ArticleNodeType(DjangoNode):
class Meta: class Meta:
model = Article model = Article
@classmethod @classmethod
def get_node(cls, id): def get_node(cls, id):
return ArticleNodeType(None) return ArticleNodeType(Article(id=1, headline='Article node'))
class Query(graphene.ObjectType): class Query1(graphene.ObjectType):
node = relay.NodeField() node = relay.NodeField()
reporter = graphene.Field(ReporterNodeType) reporter = graphene.Field(ReporterNodeType)
@ -94,14 +116,24 @@ def test_should_node():
reporter { reporter {
id, id,
first_name, first_name,
articles {
edges {
node {
headline
}
}
}
last_name, last_name,
email email
} }
aCustomNode: node(id:"UmVwb3J0ZXJOb2RlVHlwZToy") { my_article: node(id:"QXJ0aWNsZU5vZGVUeXBlOjE=") {
id id
... on ReporterNodeType { ... on ReporterNodeType {
first_name first_name
} }
... on ArticleNodeType {
headline
}
} }
} }
''' '''
@ -110,14 +142,21 @@ def test_should_node():
'id': 'UmVwb3J0ZXJOb2RlVHlwZTox', 'id': 'UmVwb3J0ZXJOb2RlVHlwZTox',
'first_name': 'ABA', 'first_name': 'ABA',
'last_name': 'X', 'last_name': 'X',
'email': '' 'email': '',
'articles': {
'edges': [{
'node': {
'headline': 'Hi!'
}
}]
}, },
'aCustomNode': { },
'id': 'UmVwb3J0ZXJOb2RlVHlwZToy', 'my_article': {
'first_name': 'Cookie Monster' 'id': 'QXJ0aWNsZU5vZGVUeXBlOjE=',
'headline': 'Article node'
} }
} }
Schema = graphene.Schema(query=Query) Schema = graphene.Schema(query=Query1)
result = Schema.execute(query) result = Schema.execute(query)
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected

View File

@ -1,5 +1,5 @@
import graphene import graphene
from graphene import resolve_only_args, relay from graphene import resolve_only_args
from .data import ( from .data import (
getFaction, getFaction,
@ -8,6 +8,9 @@ from .data import (
getEmpire, getEmpire,
) )
schema = graphene.Schema(name='Starwars Relay Schema')
relay = schema.relay
class Ship(relay.Node): class Ship(relay.Node):
'''A ship in the Star Wars saga''' '''A ship in the Star Wars saga'''
name = graphene.StringField(description='The name of the ship.') name = graphene.StringField(description='The name of the ship.')
@ -31,7 +34,7 @@ class Faction(relay.Node):
return Faction(getFaction(id)) return Faction(getFaction(id))
class Query(graphene.ObjectType): class Query(schema.ObjectType):
rebels = graphene.Field(Faction) rebels = graphene.Field(Faction)
empire = graphene.Field(Faction) empire = graphene.Field(Faction)
node = relay.NodeField() node = relay.NodeField()
@ -45,4 +48,4 @@ class Query(graphene.ObjectType):
return Faction(getEmpire()) return Faction(getEmpire())
schema = graphene.Schema(query=Query, name='Starwars Relay Schema') schema.query = Query