mirror of
				https://github.com/graphql-python/graphene.git
				synced 2025-11-04 09:57:41 +03:00 
			
		
		
		
	Improved Django integration with relations
This commit is contained in:
		
							parent
							
								
									d0285278ac
								
							
						
					
					
						commit
						ac940b9309
					
				| 
						 | 
					@ -46,7 +46,11 @@ def _(field, cls):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@convert_django_field.register(models.ManyToOneRel)
 | 
					@convert_django_field.register(models.ManyToOneRel)
 | 
				
			||||||
def _(field, cls):
 | 
					def _(field, cls):
 | 
				
			||||||
    return ListField(DjangoModelField(field.related_model))
 | 
					    schema = cls._meta.schema
 | 
				
			||||||
 | 
					    model_field = DjangoModelField(field.related_model)
 | 
				
			||||||
 | 
					    if issubclass(cls, schema.relay.Node):
 | 
				
			||||||
 | 
					        return schema.relay.ConnectionField(model_field)
 | 
				
			||||||
 | 
					    return ListField(model_field)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@convert_django_field.register(models.ForeignKey)
 | 
					@convert_django_field.register(models.ForeignKey)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -10,7 +10,7 @@ def get_type_for_model(schema, model):
 | 
				
			||||||
    for _type in types:
 | 
					    for _type in types:
 | 
				
			||||||
        type_model = getattr(_type._meta, 'model', None)
 | 
					        type_model = getattr(_type._meta, 'model', None)
 | 
				
			||||||
        if model == type_model:
 | 
					        if model == type_model:
 | 
				
			||||||
            return _type._meta.type
 | 
					            return _type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class DjangoModelField(Field):
 | 
					class DjangoModelField(Field):
 | 
				
			||||||
| 
						 | 
					@ -20,4 +20,13 @@ class DjangoModelField(Field):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @cached_property
 | 
					    @cached_property
 | 
				
			||||||
    def type(self):
 | 
					    def type(self):
 | 
				
			||||||
        return get_type_for_model(self.schema, self.model)
 | 
					        _type = self.get_object_type()
 | 
				
			||||||
 | 
					        return _type and _type._meta.type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_object_type(self):
 | 
				
			||||||
 | 
					        _type = get_type_for_model(self.schema, self.model)
 | 
				
			||||||
 | 
					        if not _type and self.object_type._meta.only_fields:
 | 
				
			||||||
 | 
					            # We will only raise the exception if the related field is specified in only_fields
 | 
				
			||||||
 | 
					            raise Exception("Field %s (%s) model not mapped in current schema" % (self, self.model._meta.object_name))
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        return _type
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -5,6 +5,7 @@ from graphene.core.options import Options
 | 
				
			||||||
 | 
					
 | 
				
			||||||
VALID_ATTRS = ('model', 'only_fields')
 | 
					VALID_ATTRS = ('model', 'only_fields')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class DjangoOptions(Options):
 | 
					class DjangoOptions(Options):
 | 
				
			||||||
    def __init__(self, *args, **kwargs):
 | 
					    def __init__(self, *args, **kwargs):
 | 
				
			||||||
        self.model = None
 | 
					        self.model = None
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -7,6 +7,7 @@ from graphene.contrib.django.converter import convert_django_field
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from graphene.relay import Node
 | 
					from graphene.relay import Node
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_reverse_fields(model):
 | 
					def get_reverse_fields(model):
 | 
				
			||||||
    for name, attr in model.__dict__.items():
 | 
					    for name, attr in model.__dict__.items():
 | 
				
			||||||
        related = getattr(attr, 'related', None)
 | 
					        related = getattr(attr, 'related', None)
 | 
				
			||||||
| 
						 | 
					@ -16,12 +17,11 @@ def get_reverse_fields(model):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class DjangoObjectTypeMeta(ObjectTypeMeta):
 | 
					class DjangoObjectTypeMeta(ObjectTypeMeta):
 | 
				
			||||||
    options_cls = DjangoOptions
 | 
					    options_cls = DjangoOptions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def add_extra_fields(cls):
 | 
					    def add_extra_fields(cls):
 | 
				
			||||||
        if not cls._meta.model:
 | 
					        if not cls._meta.model:
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
 | 
					 | 
				
			||||||
        only_fields = cls._meta.only_fields
 | 
					        only_fields = cls._meta.only_fields
 | 
				
			||||||
        # print cls._meta.model._meta._get_fields(forward=False, reverse=True, include_hidden=True)
 | 
					 | 
				
			||||||
        reverse_fields = tuple(get_reverse_fields(cls._meta.model))
 | 
					        reverse_fields = tuple(get_reverse_fields(cls._meta.model))
 | 
				
			||||||
        for field in cls._meta.model._meta.fields + reverse_fields:
 | 
					        for field in cls._meta.model._meta.fields + reverse_fields:
 | 
				
			||||||
            if only_fields and field.name not in only_fields:
 | 
					            if only_fields and field.name not in only_fields:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -48,6 +48,8 @@ class Field(object):
 | 
				
			||||||
    def get_object_type(self):
 | 
					    def get_object_type(self):
 | 
				
			||||||
        field_type = self.field_type
 | 
					        field_type = self.field_type
 | 
				
			||||||
        _is_class = inspect.isclass(field_type)
 | 
					        _is_class = inspect.isclass(field_type)
 | 
				
			||||||
 | 
					        if isinstance(field_type, Field):
 | 
				
			||||||
 | 
					            return field_type.get_object_type()
 | 
				
			||||||
        if _is_class and issubclass(field_type, ObjectType):
 | 
					        if _is_class and issubclass(field_type, ObjectType):
 | 
				
			||||||
            return field_type
 | 
					            return field_type
 | 
				
			||||||
        elif isinstance(field_type, basestring):
 | 
					        elif isinstance(field_type, basestring):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -11,18 +11,10 @@ class Options(object):
 | 
				
			||||||
        self.local_fields = []
 | 
					        self.local_fields = []
 | 
				
			||||||
        self.interface = False
 | 
					        self.interface = False
 | 
				
			||||||
        self.proxy = False
 | 
					        self.proxy = False
 | 
				
			||||||
        self.schema = schema
 | 
					        self.schema = schema or get_global_schema()
 | 
				
			||||||
        self.interfaces = []
 | 
					        self.interfaces = []
 | 
				
			||||||
        self.parents = []
 | 
					        self.parents = []
 | 
				
			||||||
        self.valid_attrs = DEFAULT_NAMES
 | 
					        self.valid_attrs = DEFAULT_NAMES
 | 
				
			||||||
 | 
					 | 
				
			||||||
    # @property
 | 
					 | 
				
			||||||
    # def schema(self):
 | 
					 | 
				
			||||||
    #     return self._schema or get_global_schema()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # @schema.setter
 | 
					 | 
				
			||||||
    # def schema(self, schema):
 | 
					 | 
				
			||||||
    #     self._schema = schema
 | 
					 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    def contribute_to_class(self, cls, name):
 | 
					    def contribute_to_class(self, cls, name):
 | 
				
			||||||
        cls._meta = self
 | 
					        cls._meta = self
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -13,8 +13,10 @@ from graphene.relay.relay import (
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from graphene.env import get_global_schema
 | 
					from graphene.env import get_global_schema
 | 
				
			||||||
 | 
					from graphene.relay.utils import setup
 | 
				
			||||||
 | 
					
 | 
				
			||||||
schema = get_global_schema()
 | 
					schema = get_global_schema()
 | 
				
			||||||
 | 
					setup(schema)
 | 
				
			||||||
relay = schema.relay
 | 
					relay = schema.relay
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Node, NodeField = relay.Node, relay.NodeField
 | 
					Node, NodeField = relay.Node, relay.NodeField
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -7,7 +7,7 @@ from graphql_relay.connection.connection import (
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from graphene import signals
 | 
					from graphene import signals
 | 
				
			||||||
from graphene.core.fields import NativeField
 | 
					from graphene.core.fields import NativeField
 | 
				
			||||||
from graphene.relay.utils import get_relay
 | 
					from graphene.relay.utils import get_relay, setup
 | 
				
			||||||
from graphene.relay.relay import Relay
 | 
					from graphene.relay.relay import Relay
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -28,4 +28,4 @@ def object_type_created(object_type):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@signals.init_schema.connect
 | 
					@signals.init_schema.connect
 | 
				
			||||||
def schema_created(schema):
 | 
					def schema_created(schema):
 | 
				
			||||||
    setattr(schema, 'relay', Relay(schema))
 | 
					    setup(schema)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,3 +1,9 @@
 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_relay(schema):
 | 
					def get_relay(schema):
 | 
				
			||||||
    return getattr(schema, 'relay', None)
 | 
					    return getattr(schema, 'relay', None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def setup(schema):
 | 
				
			||||||
 | 
						from graphene.relay.relay import Relay
 | 
				
			||||||
 | 
						if not hasattr(schema, 'relay'):
 | 
				
			||||||
 | 
							return setattr(schema, 'relay', Relay(schema))
 | 
				
			||||||
 | 
						return schema
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -25,12 +25,31 @@ def test_should_raise_if_model_is_invalid():
 | 
				
			||||||
    assert 'not a Django model' in str(excinfo.value)
 | 
					    assert 'not a Django model' in str(excinfo.value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_should_raise_if_model_is_invalid():
 | 
				
			||||||
 | 
					    with raises(Exception) as excinfo:
 | 
				
			||||||
 | 
					        class ReporterTypeError(DjangoObjectType):
 | 
				
			||||||
 | 
					            class Meta:
 | 
				
			||||||
 | 
					                model = Reporter
 | 
				
			||||||
 | 
					                only_fields = ('articles', )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        schema = graphene.Schema(query=ReporterTypeError)
 | 
				
			||||||
 | 
					        query = '''
 | 
				
			||||||
 | 
					            query ReporterQuery {
 | 
				
			||||||
 | 
					              articles
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        '''
 | 
				
			||||||
 | 
					        result = schema.execute(query)
 | 
				
			||||||
 | 
					        assert not result.errors
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert 'articles (Article) model not mapped in current schema' in str(excinfo.value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_should_map_fields():
 | 
					def test_should_map_fields():
 | 
				
			||||||
    class ReporterType(DjangoObjectType):
 | 
					    class ReporterType(DjangoObjectType):
 | 
				
			||||||
        class Meta:
 | 
					        class Meta:
 | 
				
			||||||
            model = Reporter
 | 
					            model = Reporter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    class Query(graphene.ObjectType):
 | 
					    class Query2(graphene.ObjectType):
 | 
				
			||||||
        reporter = graphene.Field(ReporterType)
 | 
					        reporter = graphene.Field(ReporterType)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        def resolve_reporter(self, *args, **kwargs):
 | 
					        def resolve_reporter(self, *args, **kwargs):
 | 
				
			||||||
| 
						 | 
					@ -52,7 +71,7 @@ def test_should_map_fields():
 | 
				
			||||||
            'email': ''
 | 
					            'email': ''
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    Schema = graphene.Schema(query=Query)
 | 
					    Schema = graphene.Schema(query=Query2)
 | 
				
			||||||
    result = Schema.execute(query)
 | 
					    result = Schema.execute(query)
 | 
				
			||||||
    assert not result.errors
 | 
					    assert not result.errors
 | 
				
			||||||
    assert result.data == expected
 | 
					    assert result.data == expected
 | 
				
			||||||
| 
						 | 
					@ -74,15 +93,18 @@ def test_should_node():
 | 
				
			||||||
        def get_node(cls, id):
 | 
					        def get_node(cls, id):
 | 
				
			||||||
            return ReporterNodeType(Reporter(id=2, first_name='Cookie Monster'))
 | 
					            return ReporterNodeType(Reporter(id=2, first_name='Cookie Monster'))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def resolve_articles(self, *args, **kwargs):
 | 
				
			||||||
 | 
					            return [ArticleNodeType(Article(headline='Hi!'))]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    class ArticleNodeType(DjangoNode):
 | 
					    class ArticleNodeType(DjangoNode):
 | 
				
			||||||
        class Meta:
 | 
					        class Meta:
 | 
				
			||||||
            model = Article
 | 
					            model = Article
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @classmethod
 | 
					        @classmethod
 | 
				
			||||||
        def get_node(cls, id):
 | 
					        def get_node(cls, id):
 | 
				
			||||||
            return ArticleNodeType(None)
 | 
					            return ArticleNodeType(Article(id=1, headline='Article node'))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    class Query(graphene.ObjectType):
 | 
					    class Query1(graphene.ObjectType):
 | 
				
			||||||
        node = relay.NodeField()
 | 
					        node = relay.NodeField()
 | 
				
			||||||
        reporter = graphene.Field(ReporterNodeType)
 | 
					        reporter = graphene.Field(ReporterNodeType)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -94,14 +116,24 @@ def test_should_node():
 | 
				
			||||||
          reporter {
 | 
					          reporter {
 | 
				
			||||||
            id,
 | 
					            id,
 | 
				
			||||||
            first_name,
 | 
					            first_name,
 | 
				
			||||||
 | 
					            articles {
 | 
				
			||||||
 | 
					              edges {
 | 
				
			||||||
 | 
					                node {
 | 
				
			||||||
 | 
					                  headline
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					              }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
            last_name,
 | 
					            last_name,
 | 
				
			||||||
            email
 | 
					            email
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
          aCustomNode: node(id:"UmVwb3J0ZXJOb2RlVHlwZToy") {
 | 
					          my_article: node(id:"QXJ0aWNsZU5vZGVUeXBlOjE=") {
 | 
				
			||||||
            id
 | 
					            id
 | 
				
			||||||
            ... on ReporterNodeType {
 | 
					            ... on ReporterNodeType {
 | 
				
			||||||
                first_name
 | 
					                first_name
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					            ... on ArticleNodeType {
 | 
				
			||||||
 | 
					                headline
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    '''
 | 
					    '''
 | 
				
			||||||
| 
						 | 
					@ -110,14 +142,21 @@ def test_should_node():
 | 
				
			||||||
            'id': 'UmVwb3J0ZXJOb2RlVHlwZTox',
 | 
					            'id': 'UmVwb3J0ZXJOb2RlVHlwZTox',
 | 
				
			||||||
            'first_name': 'ABA',
 | 
					            'first_name': 'ABA',
 | 
				
			||||||
            'last_name': 'X',
 | 
					            'last_name': 'X',
 | 
				
			||||||
            'email': ''
 | 
					            'email': '',
 | 
				
			||||||
 | 
					            'articles': {
 | 
				
			||||||
 | 
					              'edges': [{
 | 
				
			||||||
 | 
					                'node': {
 | 
				
			||||||
 | 
					                  'headline': 'Hi!'
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					              }]
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
        },
 | 
					        },
 | 
				
			||||||
        'aCustomNode': {
 | 
					        'my_article': {
 | 
				
			||||||
            'id': 'UmVwb3J0ZXJOb2RlVHlwZToy',
 | 
					            'id': 'QXJ0aWNsZU5vZGVUeXBlOjE=',
 | 
				
			||||||
            'first_name': 'Cookie Monster'
 | 
					            'headline': 'Article node'
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    Schema = graphene.Schema(query=Query)
 | 
					    Schema = graphene.Schema(query=Query1)
 | 
				
			||||||
    result = Schema.execute(query)
 | 
					    result = Schema.execute(query)
 | 
				
			||||||
    assert not result.errors
 | 
					    assert not result.errors
 | 
				
			||||||
    assert result.data == expected
 | 
					    assert result.data == expected
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,5 +1,5 @@
 | 
				
			||||||
import graphene
 | 
					import graphene
 | 
				
			||||||
from graphene import resolve_only_args, relay
 | 
					from graphene import resolve_only_args
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .data import (
 | 
					from .data import (
 | 
				
			||||||
    getFaction,
 | 
					    getFaction,
 | 
				
			||||||
| 
						 | 
					@ -8,6 +8,9 @@ from .data import (
 | 
				
			||||||
    getEmpire,
 | 
					    getEmpire,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					schema = graphene.Schema(name='Starwars Relay Schema')
 | 
				
			||||||
 | 
					relay = schema.relay
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Ship(relay.Node):
 | 
					class Ship(relay.Node):
 | 
				
			||||||
    '''A ship in the Star Wars saga'''
 | 
					    '''A ship in the Star Wars saga'''
 | 
				
			||||||
    name = graphene.StringField(description='The name of the ship.')
 | 
					    name = graphene.StringField(description='The name of the ship.')
 | 
				
			||||||
| 
						 | 
					@ -31,7 +34,7 @@ class Faction(relay.Node):
 | 
				
			||||||
        return Faction(getFaction(id))
 | 
					        return Faction(getFaction(id))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Query(graphene.ObjectType):
 | 
					class Query(schema.ObjectType):
 | 
				
			||||||
    rebels = graphene.Field(Faction)
 | 
					    rebels = graphene.Field(Faction)
 | 
				
			||||||
    empire = graphene.Field(Faction)
 | 
					    empire = graphene.Field(Faction)
 | 
				
			||||||
    node = relay.NodeField()
 | 
					    node = relay.NodeField()
 | 
				
			||||||
| 
						 | 
					@ -45,4 +48,4 @@ class Query(graphene.ObjectType):
 | 
				
			||||||
        return Faction(getEmpire())
 | 
					        return Faction(getEmpire())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
schema = graphene.Schema(query=Query, name='Starwars Relay Schema')
 | 
					schema.query = Query
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user