Improved base implementation

This commit is contained in:
Syrus Akbary 2016-08-13 17:37:57 -07:00
parent 0ffdd8d9ab
commit b19bca7f3b
8 changed files with 112 additions and 22 deletions

View File

@ -32,11 +32,13 @@ class InterfaceMeta(AbstractTypeMeta):
class Interface(six.with_metaclass(InterfaceMeta)): class Interface(six.with_metaclass(InterfaceMeta)):
resolve_type = None @classmethod
def resolve_type(cls, root, args, info):
return type(root)
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
raise Exception("An Interface cannot be intitialized") raise Exception("An Interface cannot be intitialized")
# @classmethod @classmethod
# def implements(cls, objecttype): def implements(cls, objecttype):
# pass pass

View File

@ -21,4 +21,4 @@ class JSONString(Scalar):
@staticmethod @staticmethod
def parse_value(value): def parse_value(value):
return json.dumps(value) return json.loads(value)

View File

@ -1,3 +1,4 @@
from collections import OrderedDict
import six import six
from ..utils.is_base_type import is_base_type from ..utils.is_base_type import is_base_type
@ -5,6 +6,7 @@ from .options import Options
from .abstracttype import AbstractTypeMeta from .abstracttype import AbstractTypeMeta
from .utils import get_fields_in_type, yank_fields_from_attrs, merge_fields_in_attrs from .utils import get_fields_in_type, yank_fields_from_attrs, merge_fields_in_attrs
from .interface import Interface
class ObjectTypeMeta(AbstractTypeMeta): class ObjectTypeMeta(AbstractTypeMeta):
@ -23,10 +25,22 @@ class ObjectTypeMeta(AbstractTypeMeta):
) )
attrs = merge_fields_in_attrs(bases, attrs) attrs = merge_fields_in_attrs(bases, attrs)
options.fields = get_fields_in_type(ObjectType, attrs) options.local_fields = get_fields_in_type(ObjectType, attrs)
yank_fields_from_attrs(attrs, options.fields) yank_fields_from_attrs(attrs, options.local_fields)
options.interface_fields = OrderedDict()
for interface in options.interfaces:
assert issubclass(interface, Interface), (
'All interfaces of {} must be a subclass of Interface. Received "{}".'
).format(name, interface)
options.interface_fields.update(interface._meta.fields)
options.fields = OrderedDict(options.interface_fields)
options.fields.update(options.local_fields)
return type.__new__(cls, name, bases, dict(attrs, _meta=options)) cls = type.__new__(cls, name, bases, dict(attrs, _meta=options))
for interface in options.interfaces:
interface.implements(cls)
return cls
def __str__(cls): def __str__(cls):
return cls._meta.name return cls._meta.name

View File

@ -62,6 +62,8 @@ class Schema(GraphQLSchema):
return self.get_graphql_type(self._subscription) return self.get_graphql_type(self._subscription)
def get_graphql_type(self, _type): def get_graphql_type(self, _type):
if not _type:
return _type
if is_type(_type): if is_type(_type):
return _type return _type
if is_graphene_type(_type): if is_graphene_type(_type):

View File

@ -71,6 +71,18 @@ def test_generate_interface_inherit_abstracttype():
assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field] assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field]
def test_generate_interface_inherit_interface():
class MyBaseInterface(Interface):
field1 = MyScalar()
class MyInterface(MyBaseInterface):
field2 = MyScalar()
assert MyInterface._meta.name == 'MyInterface'
assert MyInterface._meta.fields.keys() == ['field1', 'field2']
assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field]
def test_generate_interface_inherit_abstracttype_reversed(): def test_generate_interface_inherit_abstracttype_reversed():
class MyAbstractType(AbstractType): class MyAbstractType(AbstractType):
field1 = MyScalar() field1 = MyScalar()

View File

@ -4,9 +4,10 @@ from ..field import Field
from ..objecttype import ObjectType from ..objecttype import ObjectType
from ..unmountedtype import UnmountedType from ..unmountedtype import UnmountedType
from ..abstracttype import AbstractType from ..abstracttype import AbstractType
from ..interface import Interface
class MyType(object): class MyType(Interface):
pass pass
@ -15,6 +16,17 @@ class Container(ObjectType):
field2 = Field(MyType) field2 = Field(MyType)
class MyInterface(Interface):
ifield = Field(MyType)
class ContainerWithInterface(ObjectType):
class Meta:
interfaces = (MyInterface, )
field1 = Field(MyType)
field2 = Field(MyType)
class MyScalar(UnmountedType): class MyScalar(UnmountedType):
def get_type(self): def get_type(self):
return MyType return MyType
@ -94,6 +106,10 @@ def test_parent_container_get_fields():
assert list(Container._meta.fields.keys()) == ['field1', 'field2'] assert list(Container._meta.fields.keys()) == ['field1', 'field2']
def test_parent_container_interface_get_fields():
assert list(ContainerWithInterface._meta.fields.keys()) == ['ifield', 'field1', 'field2']
def test_objecttype_as_container_only_args(): def test_objecttype_as_container_only_args():
container = Container("1", "2") container = Container("1", "2")
assert container.field1 == "1" assert container.field1 == "1"

View File

@ -1,4 +1,5 @@
import inspect import inspect
from functools import partial
from collections import OrderedDict from collections import OrderedDict
from graphql.type.typemap import GraphQLTypeMap from graphql.type.typemap import GraphQLTypeMap
@ -14,6 +15,8 @@ from .scalars import Scalar, String, Boolean, Int, Float, ID
from graphql import GraphQLString, GraphQLField, GraphQLList, GraphQLBoolean, GraphQLInt, GraphQLFloat, GraphQLID, GraphQLNonNull, GraphQLInputObjectField, GraphQLArgument from graphql import GraphQLString, GraphQLField, GraphQLList, GraphQLBoolean, GraphQLInt, GraphQLFloat, GraphQLID, GraphQLNonNull, GraphQLInputObjectField, GraphQLArgument
from graphql.type import GraphQLEnumValue from graphql.type import GraphQLEnumValue
from ..utils.str_converters import to_camel_case
def is_graphene_type(_type): def is_graphene_type(_type):
if isinstance(_type, (List, NonNull)): if isinstance(_type, (List, NonNull)):
@ -22,13 +25,26 @@ def is_graphene_type(_type):
return True return True
def resolve_type(resolve_type_func, map, root, args, info):
_type = resolve_type_func(root, args, info)
# assert inspect.isclass(_type) and issubclass(_type, ObjectType), (
# 'Received incompatible type "{}".'.format(_type)
# )
if inspect.isclass(_type) and issubclass(_type, ObjectType):
graphql_type = map.get(_type._meta.name)
assert graphql_type and graphql_type.graphene_type == _type
return graphql_type
return _type
class TypeMap(GraphQLTypeMap): class TypeMap(GraphQLTypeMap):
@classmethod @classmethod
def reducer(cls, map, type): def reducer(cls, map, type):
if not type: if not type:
return map return map
if inspect.isfunction(type):
type = type()
if is_graphene_type(type): if is_graphene_type(type):
return cls.graphene_reducer(map, type) return cls.graphene_reducer(map, type)
return super(TypeMap, cls).reducer(map, type) return super(TypeMap, cls).reducer(map, type)
@ -112,10 +128,11 @@ class TypeMap(GraphQLTypeMap):
) )
interfaces = [] interfaces = []
for i in type._meta.interfaces: for i in type._meta.interfaces:
map = cls.construct_interface(map, i) map = cls.reducer(map, i)
interfaces.append(map[i._meta.name]) interfaces.append(map[i._meta.name])
map[type._meta.name].interfaces = interfaces map[type._meta.name]._provided_interfaces = interfaces
map[type._meta.name]._fields = cls.construct_fields_for_type(map, type) map[type._meta.name]._fields = cls.construct_fields_for_type(map, type)
# cls.reducer(map, map[type._meta.name])
return map return map
@classmethod @classmethod
@ -126,9 +143,10 @@ class TypeMap(GraphQLTypeMap):
name=type._meta.name, name=type._meta.name,
description=type._meta.description, description=type._meta.description,
fields=None, fields=None,
resolve_type=type.resolve_type, resolve_type=partial(resolve_type, type.resolve_type, map),
) )
map[type._meta.name]._fields = cls.construct_fields_for_type(map, type) map[type._meta.name]._fields = cls.construct_fields_for_type(map, type)
# cls.reducer(map, map[type._meta.name])
return map return map
@classmethod @classmethod
@ -159,6 +177,14 @@ class TypeMap(GraphQLTypeMap):
map[type._meta.name].types = types map[type._meta.name].types = types
return map return map
@classmethod
def process_field_name(cls, name):
return to_camel_case(name)
@classmethod
def default_resolver(cls, attname, root, *_):
return getattr(root, attname, None)
@classmethod @classmethod
def construct_fields_for_type(cls, map, type, is_input_type=False): def construct_fields_for_type(cls, map, type, is_input_type=False):
fields = OrderedDict() fields = OrderedDict()
@ -181,25 +207,42 @@ class TypeMap(GraphQLTypeMap):
description=arg.description, description=arg.description,
default_value=arg.default_value default_value=arg.default_value
) )
resolver = field.resolver
resolver_type = getattr(type, 'resolve_{}'.format(name), None)
if resolver_type:
resolver = resolver_type.__func__
_field = GraphQLField( _field = GraphQLField(
field_type, field_type,
args=args, args=args,
resolver=resolver, resolver=field.resolver or cls.get_resolver_for_type(type, name),
deprecation_reason=field.deprecation_reason, deprecation_reason=field.deprecation_reason,
description=field.description description=field.description
) )
fields[name] = _field processed_name = cls.process_field_name(name)
fields[processed_name] = _field
return fields return fields
@classmethod
def get_resolver_for_type(cls, type, name):
if not issubclass(type, ObjectType):
return
resolver = getattr(type, 'resolve_{}'.format(name), None)
if not resolver:
# If we don't find the resolver in the ObjectType class, then try to
# find it in each of the interfaces
interface_resolver = None
for interface in type._meta.interfaces:
interface_resolver = getattr(interface, 'resolve_{}'.format(name), None)
if interface_resolver:
break
resolver = interface_resolver
# Only if is not decorated with classmethod
if resolver and not getattr(resolver, '__self__', True):
return resolver.__func__
return partial(cls.default_resolver, name)
@classmethod @classmethod
def get_field_type(self, map, type): def get_field_type(self, map, type):
if isinstance(type, List): if isinstance(type, List):
return GraphQLList(self.get_field_type(map, type.of_type)) return GraphQLList(self.get_field_type(map, type.of_type))
if isinstance(type, NonNull): if isinstance(type, NonNull):
return GraphQLNonNull(self.get_field_type(map, type.of_type)) return GraphQLNonNull(self.get_field_type(map, type.of_type))
if inspect.isfunction(type):
type = type()
return map.get(type._meta.name) return map.get(type._meta.name)

View File

@ -6,9 +6,10 @@ from .inputfield import InputField
def merge_fields_in_attrs(bases, attrs): def merge_fields_in_attrs(bases, attrs):
from ..types.abstracttype import AbstractType from ..types import AbstractType, Interface
inherited_bases = (AbstractType, Interface)
for base in bases: for base in bases:
if base == AbstractType or not issubclass(base, AbstractType): if base in inherited_bases or not issubclass(base, inherited_bases):
continue continue
for name, field in base._meta.fields.items(): for name, field in base._meta.fields.items():
if name in attrs: if name in attrs: