Improved base implementation

This commit is contained in:
Syrus Akbary 2016-08-13 17:37:57 -07:00
parent 0ffdd8d9ab
commit b19bca7f3b
8 changed files with 112 additions and 22 deletions

View File

@ -32,11 +32,13 @@ class InterfaceMeta(AbstractTypeMeta):
class Interface(six.with_metaclass(InterfaceMeta)):
resolve_type = None
@classmethod
def resolve_type(cls, root, args, info):
return type(root)
def __init__(self, *args, **kwargs):
raise Exception("An Interface cannot be intitialized")
# @classmethod
# def implements(cls, objecttype):
# pass
@classmethod
def implements(cls, objecttype):
pass

View File

@ -21,4 +21,4 @@ class JSONString(Scalar):
@staticmethod
def parse_value(value):
return json.dumps(value)
return json.loads(value)

View File

@ -1,3 +1,4 @@
from collections import OrderedDict
import six
from ..utils.is_base_type import is_base_type
@ -5,6 +6,7 @@ from .options import Options
from .abstracttype import AbstractTypeMeta
from .utils import get_fields_in_type, yank_fields_from_attrs, merge_fields_in_attrs
from .interface import Interface
class ObjectTypeMeta(AbstractTypeMeta):
@ -23,10 +25,22 @@ class ObjectTypeMeta(AbstractTypeMeta):
)
attrs = merge_fields_in_attrs(bases, attrs)
options.fields = get_fields_in_type(ObjectType, attrs)
yank_fields_from_attrs(attrs, options.fields)
options.local_fields = get_fields_in_type(ObjectType, attrs)
yank_fields_from_attrs(attrs, options.local_fields)
options.interface_fields = OrderedDict()
for interface in options.interfaces:
assert issubclass(interface, Interface), (
'All interfaces of {} must be a subclass of Interface. Received "{}".'
).format(name, interface)
options.interface_fields.update(interface._meta.fields)
options.fields = OrderedDict(options.interface_fields)
options.fields.update(options.local_fields)
return type.__new__(cls, name, bases, dict(attrs, _meta=options))
cls = type.__new__(cls, name, bases, dict(attrs, _meta=options))
for interface in options.interfaces:
interface.implements(cls)
return cls
def __str__(cls):
return cls._meta.name

View File

@ -62,6 +62,8 @@ class Schema(GraphQLSchema):
return self.get_graphql_type(self._subscription)
def get_graphql_type(self, _type):
if not _type:
return _type
if is_type(_type):
return _type
if is_graphene_type(_type):

View File

@ -71,6 +71,18 @@ def test_generate_interface_inherit_abstracttype():
assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field]
def test_generate_interface_inherit_interface():
class MyBaseInterface(Interface):
field1 = MyScalar()
class MyInterface(MyBaseInterface):
field2 = MyScalar()
assert MyInterface._meta.name == 'MyInterface'
assert MyInterface._meta.fields.keys() == ['field1', 'field2']
assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field]
def test_generate_interface_inherit_abstracttype_reversed():
class MyAbstractType(AbstractType):
field1 = MyScalar()

View File

@ -4,9 +4,10 @@ from ..field import Field
from ..objecttype import ObjectType
from ..unmountedtype import UnmountedType
from ..abstracttype import AbstractType
from ..interface import Interface
class MyType(object):
class MyType(Interface):
pass
@ -15,6 +16,17 @@ class Container(ObjectType):
field2 = Field(MyType)
class MyInterface(Interface):
ifield = Field(MyType)
class ContainerWithInterface(ObjectType):
class Meta:
interfaces = (MyInterface, )
field1 = Field(MyType)
field2 = Field(MyType)
class MyScalar(UnmountedType):
def get_type(self):
return MyType
@ -94,6 +106,10 @@ def test_parent_container_get_fields():
assert list(Container._meta.fields.keys()) == ['field1', 'field2']
def test_parent_container_interface_get_fields():
assert list(ContainerWithInterface._meta.fields.keys()) == ['ifield', 'field1', 'field2']
def test_objecttype_as_container_only_args():
container = Container("1", "2")
assert container.field1 == "1"

View File

@ -1,4 +1,5 @@
import inspect
from functools import partial
from collections import OrderedDict
from graphql.type.typemap import GraphQLTypeMap
@ -14,6 +15,8 @@ from .scalars import Scalar, String, Boolean, Int, Float, ID
from graphql import GraphQLString, GraphQLField, GraphQLList, GraphQLBoolean, GraphQLInt, GraphQLFloat, GraphQLID, GraphQLNonNull, GraphQLInputObjectField, GraphQLArgument
from graphql.type import GraphQLEnumValue
from ..utils.str_converters import to_camel_case
def is_graphene_type(_type):
if isinstance(_type, (List, NonNull)):
@ -22,13 +25,26 @@ def is_graphene_type(_type):
return True
def resolve_type(resolve_type_func, map, root, args, info):
_type = resolve_type_func(root, args, info)
# assert inspect.isclass(_type) and issubclass(_type, ObjectType), (
# 'Received incompatible type "{}".'.format(_type)
# )
if inspect.isclass(_type) and issubclass(_type, ObjectType):
graphql_type = map.get(_type._meta.name)
assert graphql_type and graphql_type.graphene_type == _type
return graphql_type
return _type
class TypeMap(GraphQLTypeMap):
@classmethod
def reducer(cls, map, type):
if not type:
return map
if inspect.isfunction(type):
type = type()
if is_graphene_type(type):
return cls.graphene_reducer(map, type)
return super(TypeMap, cls).reducer(map, type)
@ -112,10 +128,11 @@ class TypeMap(GraphQLTypeMap):
)
interfaces = []
for i in type._meta.interfaces:
map = cls.construct_interface(map, i)
map = cls.reducer(map, i)
interfaces.append(map[i._meta.name])
map[type._meta.name].interfaces = interfaces
map[type._meta.name]._provided_interfaces = interfaces
map[type._meta.name]._fields = cls.construct_fields_for_type(map, type)
# cls.reducer(map, map[type._meta.name])
return map
@classmethod
@ -126,9 +143,10 @@ class TypeMap(GraphQLTypeMap):
name=type._meta.name,
description=type._meta.description,
fields=None,
resolve_type=type.resolve_type,
resolve_type=partial(resolve_type, type.resolve_type, map),
)
map[type._meta.name]._fields = cls.construct_fields_for_type(map, type)
# cls.reducer(map, map[type._meta.name])
return map
@classmethod
@ -159,6 +177,14 @@ class TypeMap(GraphQLTypeMap):
map[type._meta.name].types = types
return map
@classmethod
def process_field_name(cls, name):
return to_camel_case(name)
@classmethod
def default_resolver(cls, attname, root, *_):
return getattr(root, attname, None)
@classmethod
def construct_fields_for_type(cls, map, type, is_input_type=False):
fields = OrderedDict()
@ -181,25 +207,42 @@ class TypeMap(GraphQLTypeMap):
description=arg.description,
default_value=arg.default_value
)
resolver = field.resolver
resolver_type = getattr(type, 'resolve_{}'.format(name), None)
if resolver_type:
resolver = resolver_type.__func__
_field = GraphQLField(
field_type,
args=args,
resolver=resolver,
resolver=field.resolver or cls.get_resolver_for_type(type, name),
deprecation_reason=field.deprecation_reason,
description=field.description
)
fields[name] = _field
processed_name = cls.process_field_name(name)
fields[processed_name] = _field
return fields
@classmethod
def get_resolver_for_type(cls, type, name):
if not issubclass(type, ObjectType):
return
resolver = getattr(type, 'resolve_{}'.format(name), None)
if not resolver:
# If we don't find the resolver in the ObjectType class, then try to
# find it in each of the interfaces
interface_resolver = None
for interface in type._meta.interfaces:
interface_resolver = getattr(interface, 'resolve_{}'.format(name), None)
if interface_resolver:
break
resolver = interface_resolver
# Only if is not decorated with classmethod
if resolver and not getattr(resolver, '__self__', True):
return resolver.__func__
return partial(cls.default_resolver, name)
@classmethod
def get_field_type(self, map, type):
if isinstance(type, List):
return GraphQLList(self.get_field_type(map, type.of_type))
if isinstance(type, NonNull):
return GraphQLNonNull(self.get_field_type(map, type.of_type))
if inspect.isfunction(type):
type = type()
return map.get(type._meta.name)

View File

@ -6,9 +6,10 @@ from .inputfield import InputField
def merge_fields_in_attrs(bases, attrs):
from ..types.abstracttype import AbstractType
from ..types import AbstractType, Interface
inherited_bases = (AbstractType, Interface)
for base in bases:
if base == AbstractType or not issubclass(base, AbstractType):
if base in inherited_bases or not issubclass(base, inherited_bases):
continue
for name, field in base._meta.fields.items():
if name in attrs: