From 311209760db9d8548133bf34e9394662678574fc Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Sun, 27 Sep 2015 20:19:33 -0700 Subject: [PATCH] Refactored code allowing multiple schemas at the same time. --- graphene/__init__.py | 2 + graphene/core/fields.py | 2 +- graphene/core/options.py | 2 +- graphene/core/schema.py | 24 ++------ graphene/core/types.py | 20 ++++++- graphene/relay/__init__.py | 17 +++++- graphene/relay/connections.py | 35 ++++++++++++ graphene/relay/fields.py | 30 ++++++++++ graphene/relay/nodes.py | 35 +++++++----- graphene/relay/relay.py | 57 ++++--------------- graphene/relay/utils.py | 3 + graphene/signals.py | 1 + tests/core/test_types.py | 1 + tests/relay/test_relay.py | 1 + tests/starwars_relay/schema.py | 15 +---- tests/starwars_relay/test_connections.py | 4 +- .../test_objectidentification.py | 12 ++-- 17 files changed, 153 insertions(+), 108 deletions(-) create mode 100644 graphene/relay/connections.py create mode 100644 graphene/relay/fields.py create mode 100644 graphene/relay/utils.py diff --git a/graphene/__init__.py b/graphene/__init__.py index 4fd823ce..61996539 100644 --- a/graphene/__init__.py +++ b/graphene/__init__.py @@ -34,3 +34,5 @@ from graphene.core.fields import ( from graphene.decorators import ( resolve_only_args ) + +import graphene.relay diff --git a/graphene/core/fields.py b/graphene/core/fields.py index 430ec9f9..383138d6 100644 --- a/graphene/core/fields.py +++ b/graphene/core/fields.py @@ -10,7 +10,6 @@ from graphql.core.type import ( GraphQLArgument, ) from graphene.utils import cached_property -from graphene.core.types import ObjectType class Field(object): def __init__(self, field_type, resolve=None, null=True, args=None, description='', **extra_args): @@ -45,6 +44,7 @@ class Field(object): return resolve_fn(instance, args, info) def get_object_type(self): + from graphene.core.types import ObjectType field_type = self.field_type _is_class = inspect.isclass(field_type) if _is_class and issubclass(field_type, ObjectType): diff --git a/graphene/core/options.py b/graphene/core/options.py index e5eac894..5d0263f0 100644 --- a/graphene/core/options.py +++ b/graphene/core/options.py @@ -11,7 +11,7 @@ class Options(object): self.local_fields = [] self.interface = False self.proxy = False - self.schema = schema or get_global_schema() + self.schema = schema self.interfaces = [] self.parents = [] diff --git a/graphene/core/schema.py b/graphene/core/schema.py index 4a5519af..132a4760 100644 --- a/graphene/core/schema.py +++ b/graphene/core/schema.py @@ -4,7 +4,7 @@ from graphql.core.type import ( ) from graphene import signals from graphene.utils import cached_property -# from graphene.relay.nodes import create_node_definitions + class Schema(object): _query = None @@ -14,27 +14,15 @@ class Schema(object): self.query = query self.name = name self._types = {} - + signals.init_schema.send(self) + def __repr__(self): - return '' % str(self.name) - - # @cachedproperty - # def node_definitions(self): - # return [object, object] - # # from graphene.relay import create_node_definitions - # # return create_node_definitions(schema=self) - - # @property - # def Node(self): - # return self.node_definitions[0] - - # @property - # def NodeField(self): - # return self.node_definitions[1] + return '' % (str(self.name), hash(self)) @property def query(self): return self._query + @query.setter def query(self, query): if not query: @@ -69,5 +57,3 @@ def object_type_created(object_type): schema = object_type._meta.schema if schema: schema.register_type(object_type) - -from graphene.env import get_global_schema diff --git a/graphene/core/types.py b/graphene/core/types.py index f020a72a..978fd41e 100644 --- a/graphene/core/types.py +++ b/graphene/core/types.py @@ -49,8 +49,8 @@ class ObjectTypeMeta(type): # Things without _meta aren't functional models, so they're # uninteresting parents. continue - if base._meta.schema != new_class._meta.schema: - raise Exception('The parent schema is not the same') + # if base._meta.schema != new_class._meta.schema: + # raise Exception('The parent schema is not the same') parent_fields = base._meta.local_fields # Check for clashes between locally declared fields and those @@ -135,3 +135,19 @@ class Interface(ObjectType): class Meta: interface = True proxy = True + +@signals.init_schema.connect +def add_types_to_schema(schema): + own_schema = schema + class _Interface(Interface): + class Meta: + schema = own_schema + proxy = True + + class _ObjectType(ObjectType): + class Meta: + schema = own_schema + proxy = True + + setattr(own_schema, 'Interface', _Interface) + setattr(own_schema, 'ObjectType', _ObjectType) diff --git a/graphene/relay/__init__.py b/graphene/relay/__init__.py index b7a71ce5..4e353804 100644 --- a/graphene/relay/__init__.py +++ b/graphene/relay/__init__.py @@ -2,4 +2,19 @@ from graphene.relay.nodes import ( create_node_definitions ) -from graphene.relay.relay import * +from graphene.relay.fields import ( + ConnectionField, +) + +import graphene.relay.connections + +from graphene.relay.relay import ( + Relay +) + +from graphene.env import get_global_schema + +schema = get_global_schema() +relay = schema.relay + +Node, NodeField = relay.Node, relay.NodeField diff --git a/graphene/relay/connections.py b/graphene/relay/connections.py new file mode 100644 index 00000000..24f2a61d --- /dev/null +++ b/graphene/relay/connections.py @@ -0,0 +1,35 @@ +import collections + +from graphql_relay.node.node import ( + globalIdField +) +from graphql_relay.connection.connection import ( + connectionDefinitions +) + +from graphene import signals + +from graphene.core.fields import NativeField +from graphene.relay.utils import get_relay +from graphene.relay.relay import Relay + + +@signals.class_prepared.connect +def object_type_created(object_type): + relay = get_relay(object_type._meta.schema) + if relay and issubclass(object_type, relay.Node): + type_name = object_type._meta.type_name + # def getId(*args, **kwargs): + # print '**GET ID', args, kwargs + # return 2 + field = NativeField(globalIdField(type_name)) + object_type.add_to_class('id', field) + assert hasattr(object_type, 'get_node'), 'get_node classmethod not found in %s Node' % type_name + + connection = connectionDefinitions(type_name, object_type._meta.type).connectionType + object_type.add_to_class('connection', connection) + + +@signals.init_schema.connect +def schema_created(schema): + setattr(schema, 'relay', Relay(schema)) diff --git a/graphene/relay/fields.py b/graphene/relay/fields.py new file mode 100644 index 00000000..acdeb2f5 --- /dev/null +++ b/graphene/relay/fields.py @@ -0,0 +1,30 @@ +import collections + +from graphql_relay.connection.arrayconnection import ( + connectionFromArray +) +from graphql_relay.connection.connection import ( + connectionArgs +) +from graphene.core.fields import Field +from graphene.utils import cached_property +from graphene.relay.utils import get_relay + + +class ConnectionField(Field): + def __init__(self, field_type, resolve=None, description=''): + super(ConnectionField, self).__init__(field_type, resolve=resolve, + args=connectionArgs, description=description) + + def resolve(self, instance, args, info): + resolved = super(ConnectionField, self).resolve(instance, args, info) + if resolved: + assert isinstance(resolved, collections.Iterable), 'Resolved value from the connection field have to be iterable' + return connectionFromArray(resolved, args) + + @cached_property + def type(self): + object_type = self.get_object_type() + relay = get_relay(object_type._meta.schema) + assert issubclass(object_type, relay.Node), 'Only nodes have connections.' + return object_type.connection diff --git a/graphene/relay/nodes.py b/graphene/relay/nodes.py index 81223a52..643d148f 100644 --- a/graphene/relay/nodes.py +++ b/graphene/relay/nodes.py @@ -2,27 +2,32 @@ from graphql_relay.node.node import ( nodeDefinitions, fromGlobalId ) +from graphene.env import get_global_schema +from graphene.core.types import Interface +from graphene.core.fields import Field, NativeField -def create_node_definitions(getNode=None, getNodeType=None, schema=None): - from graphene.core.types import Interface - from graphene.core.fields import Field, NativeField - if not getNode: - def getNode(globalId, *args): - from graphene.env import get_global_schema - _schema = schema or get_global_schema() - resolvedGlobalId = fromGlobalId(globalId) - _type, _id = resolvedGlobalId.type, resolvedGlobalId.id - object_type = _schema.get_type(_type) - return object_type.get_node(_id) - if not getNodeType: - def getNodeType(obj): - return obj._meta.type +def getSchemaNode(schema=None): + def getNode(globalId, *args): + _schema = schema or get_global_schema() + resolvedGlobalId = fromGlobalId(globalId) + _type, _id = resolvedGlobalId.type, resolvedGlobalId.id + object_type = schema.get_type(_type) + return object_type.get_node(_id) + return getNode + +def getNodeType(obj): + return obj._meta.type + + +def create_node_definitions(getNode=None, getNodeType=getNodeType, schema=None): + getNode = getNode or getSchemaNode(schema) _nodeDefinitions = nodeDefinitions(getNode, getNodeType) + _Interface = getattr(schema,'Interface', Interface) - class Node(Interface): + class Node(_Interface): @classmethod def get_graphql_type(cls): if cls is Node: diff --git a/graphene/relay/relay.py b/graphene/relay/relay.py index 2df13e81..2ff4eb9b 100644 --- a/graphene/relay/relay.py +++ b/graphene/relay/relay.py @@ -1,51 +1,14 @@ -import collections - -from graphene import signals -from graphene.utils import cached_property - -from graphql_relay.node.node import ( - globalIdField +from graphene.relay.nodes import ( + create_node_definitions ) -from graphql_relay.connection.arrayconnection import ( - connectionFromArray + +from graphene.relay.fields import ( + ConnectionField, ) -from graphql_relay.connection.connection import ( - connectionArgs, - connectionDefinitions -) -from graphene.relay.nodes import create_node_definitions -from graphene.core.fields import Field, NativeField - -Node, NodeField = create_node_definitions() - -class ConnectionField(Field): - def __init__(self, field_type, resolve=None, description=''): - super(ConnectionField, self).__init__(field_type, resolve=resolve, - args=connectionArgs, description=description) - - def resolve(self, instance, args, info): - resolved = super(ConnectionField, self).resolve(instance, args, info) - if resolved: - assert isinstance(resolved, collections.Iterable), 'Resolved value from the connection field have to be iterable' - return connectionFromArray(resolved, args) - - @cached_property - def type(self): - object_type = self.get_object_type() - assert issubclass(object_type, Node), 'Only nodes have connections.' - return object_type.connection -@signals.class_prepared.connect -def object_type_created(object_type): - if issubclass(object_type, Node): - type_name = object_type._meta.type_name - # def getId(*args, **kwargs): - # print '**GET ID', args, kwargs - # return 2 - field = NativeField(globalIdField(type_name)) - object_type.add_to_class('id', field) - assert hasattr(object_type, 'get_node'), 'get_node classmethod not found in %s Node' % type_name - - connection = connectionDefinitions(type_name, object_type._meta.type).connectionType - object_type.add_to_class('connection', connection) +class Relay(object): + def __init__(self, schema): + self.schema = schema + self.Node, self.NodeField = create_node_definitions(schema=self.schema) + self.ConnectionField = ConnectionField diff --git a/graphene/relay/utils.py b/graphene/relay/utils.py new file mode 100644 index 00000000..cd23632d --- /dev/null +++ b/graphene/relay/utils.py @@ -0,0 +1,3 @@ + +def get_relay(schema): + return getattr(schema, 'relay', None) diff --git a/graphene/signals.py b/graphene/signals.py index 954d02a1..ccb9ef0f 100644 --- a/graphene/signals.py +++ b/graphene/signals.py @@ -1,5 +1,6 @@ from blinker import Signal +init_schema = Signal() class_prepared = Signal() pre_init = Signal() post_init = Signal() diff --git a/tests/core/test_types.py b/tests/core/test_types.py index 1408caa2..0a83398a 100644 --- a/tests/core/test_types.py +++ b/tests/core/test_types.py @@ -20,6 +20,7 @@ class Character(Interface): name = StringField() class Meta: type_name = 'core.Character' + class Human(Character): '''Human description''' friends = StringField() diff --git a/tests/relay/test_relay.py b/tests/relay/test_relay.py index 2bcdc836..7314c096 100644 --- a/tests/relay/test_relay.py +++ b/tests/relay/test_relay.py @@ -4,6 +4,7 @@ import graphene from graphene import relay schema = graphene.Schema() +relay = schema.relay class OtherNode(relay.Node): name = graphene.StringField() diff --git a/tests/starwars_relay/schema.py b/tests/starwars_relay/schema.py index f1a12050..fbbfc49c 100644 --- a/tests/starwars_relay/schema.py +++ b/tests/starwars_relay/schema.py @@ -8,8 +8,6 @@ from .data import ( getEmpire, ) -schema = graphene.Schema() - class Ship(relay.Node): '''A ship in the Star Wars saga''' name = graphene.StringField(description='The name of the ship.') @@ -20,8 +18,6 @@ class Ship(relay.Node): if ship: return Ship(ship) - # class Meta: - # schema = schema class Faction(relay.Node): '''A faction in the Star Wars saga''' @@ -38,9 +34,6 @@ class Faction(relay.Node): if faction: return Faction(faction) - # class Meta: - # schema = schema - class Query(graphene.ObjectType): rebels = graphene.Field(Faction) @@ -56,10 +49,4 @@ class Query(graphene.ObjectType): return Faction(getEmpire()) - # class Meta: - # schema = schema - -print '*CACA', schema._types - -schema.query = Query -Schema = schema +schema = graphene.Schema(query=Query, name='Starwars Relay Schema') diff --git a/tests/starwars_relay/test_connections.py b/tests/starwars_relay/test_connections.py index 1bac0a4d..7d385514 100644 --- a/tests/starwars_relay/test_connections.py +++ b/tests/starwars_relay/test_connections.py @@ -1,7 +1,7 @@ from pytest import raises from graphql.core import graphql -from .schema import Schema +from .schema import schema def test_correct_fetch_first_ship_rebels(): query = ''' @@ -32,6 +32,6 @@ def test_correct_fetch_first_ship_rebels(): } } } - result = Schema.execute(query) + result = schema.execute(query) assert result.errors == None assert result.data == expected diff --git a/tests/starwars_relay/test_objectidentification.py b/tests/starwars_relay/test_objectidentification.py index 85050b6f..1c4a0ba2 100644 --- a/tests/starwars_relay/test_objectidentification.py +++ b/tests/starwars_relay/test_objectidentification.py @@ -1,7 +1,7 @@ from pytest import raises from graphql.core import graphql -from .schema import Schema +from .schema import schema def test_correctly_fetches_id_name_rebels(): query = ''' @@ -18,7 +18,7 @@ def test_correctly_fetches_id_name_rebels(): 'name': 'Alliance to Restore the Republic' } } - result = Schema.execute(query) + result = schema.execute(query) assert result.errors == None assert result.data == expected @@ -39,7 +39,7 @@ def test_correctly_refetches_rebels(): 'name': 'Alliance to Restore the Republic' } } - result = Schema.execute(query) + result = schema.execute(query) assert result.errors == None assert result.data == expected @@ -58,7 +58,7 @@ def test_correctly_fetches_id_name_empire(): 'name': 'Galactic Empire' } } - result = Schema.execute(query) + result = schema.execute(query) assert result.errors == None assert result.data == expected @@ -79,7 +79,7 @@ def test_correctly_refetches_empire(): 'name': 'Galactic Empire' } } - result = Schema.execute(query) + result = schema.execute(query) assert result.errors == None assert result.data == expected @@ -100,6 +100,6 @@ def test_correctly_refetches_xwing(): 'name': 'X-Wing' } } - result = Schema.execute(query) + result = schema.execute(query) assert result.errors == None assert result.data == expected