mirror of
https://github.com/graphql-python/graphene.git
synced 2024-11-26 11:33:55 +03:00
Improved base implementation
This commit is contained in:
parent
0ffdd8d9ab
commit
b19bca7f3b
|
@ -32,11 +32,13 @@ class InterfaceMeta(AbstractTypeMeta):
|
||||||
|
|
||||||
|
|
||||||
class Interface(six.with_metaclass(InterfaceMeta)):
|
class Interface(six.with_metaclass(InterfaceMeta)):
|
||||||
resolve_type = None
|
@classmethod
|
||||||
|
def resolve_type(cls, root, args, info):
|
||||||
|
return type(root)
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
raise Exception("An Interface cannot be intitialized")
|
raise Exception("An Interface cannot be intitialized")
|
||||||
|
|
||||||
# @classmethod
|
@classmethod
|
||||||
# def implements(cls, objecttype):
|
def implements(cls, objecttype):
|
||||||
# pass
|
pass
|
||||||
|
|
|
@ -21,4 +21,4 @@ class JSONString(Scalar):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_value(value):
|
def parse_value(value):
|
||||||
return json.dumps(value)
|
return json.loads(value)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from collections import OrderedDict
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from ..utils.is_base_type import is_base_type
|
from ..utils.is_base_type import is_base_type
|
||||||
|
@ -5,6 +6,7 @@ from .options import Options
|
||||||
|
|
||||||
from .abstracttype import AbstractTypeMeta
|
from .abstracttype import AbstractTypeMeta
|
||||||
from .utils import get_fields_in_type, yank_fields_from_attrs, merge_fields_in_attrs
|
from .utils import get_fields_in_type, yank_fields_from_attrs, merge_fields_in_attrs
|
||||||
|
from .interface import Interface
|
||||||
|
|
||||||
|
|
||||||
class ObjectTypeMeta(AbstractTypeMeta):
|
class ObjectTypeMeta(AbstractTypeMeta):
|
||||||
|
@ -23,10 +25,22 @@ class ObjectTypeMeta(AbstractTypeMeta):
|
||||||
)
|
)
|
||||||
|
|
||||||
attrs = merge_fields_in_attrs(bases, attrs)
|
attrs = merge_fields_in_attrs(bases, attrs)
|
||||||
options.fields = get_fields_in_type(ObjectType, attrs)
|
options.local_fields = get_fields_in_type(ObjectType, attrs)
|
||||||
yank_fields_from_attrs(attrs, options.fields)
|
yank_fields_from_attrs(attrs, options.local_fields)
|
||||||
|
options.interface_fields = OrderedDict()
|
||||||
|
for interface in options.interfaces:
|
||||||
|
assert issubclass(interface, Interface), (
|
||||||
|
'All interfaces of {} must be a subclass of Interface. Received "{}".'
|
||||||
|
).format(name, interface)
|
||||||
|
options.interface_fields.update(interface._meta.fields)
|
||||||
|
options.fields = OrderedDict(options.interface_fields)
|
||||||
|
options.fields.update(options.local_fields)
|
||||||
|
|
||||||
return type.__new__(cls, name, bases, dict(attrs, _meta=options))
|
cls = type.__new__(cls, name, bases, dict(attrs, _meta=options))
|
||||||
|
for interface in options.interfaces:
|
||||||
|
interface.implements(cls)
|
||||||
|
|
||||||
|
return cls
|
||||||
|
|
||||||
def __str__(cls):
|
def __str__(cls):
|
||||||
return cls._meta.name
|
return cls._meta.name
|
||||||
|
|
|
@ -62,6 +62,8 @@ class Schema(GraphQLSchema):
|
||||||
return self.get_graphql_type(self._subscription)
|
return self.get_graphql_type(self._subscription)
|
||||||
|
|
||||||
def get_graphql_type(self, _type):
|
def get_graphql_type(self, _type):
|
||||||
|
if not _type:
|
||||||
|
return _type
|
||||||
if is_type(_type):
|
if is_type(_type):
|
||||||
return _type
|
return _type
|
||||||
if is_graphene_type(_type):
|
if is_graphene_type(_type):
|
||||||
|
|
|
@ -71,6 +71,18 @@ def test_generate_interface_inherit_abstracttype():
|
||||||
assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field]
|
assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field]
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_interface_inherit_interface():
|
||||||
|
class MyBaseInterface(Interface):
|
||||||
|
field1 = MyScalar()
|
||||||
|
|
||||||
|
class MyInterface(MyBaseInterface):
|
||||||
|
field2 = MyScalar()
|
||||||
|
|
||||||
|
assert MyInterface._meta.name == 'MyInterface'
|
||||||
|
assert MyInterface._meta.fields.keys() == ['field1', 'field2']
|
||||||
|
assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field]
|
||||||
|
|
||||||
|
|
||||||
def test_generate_interface_inherit_abstracttype_reversed():
|
def test_generate_interface_inherit_abstracttype_reversed():
|
||||||
class MyAbstractType(AbstractType):
|
class MyAbstractType(AbstractType):
|
||||||
field1 = MyScalar()
|
field1 = MyScalar()
|
||||||
|
|
|
@ -4,9 +4,10 @@ from ..field import Field
|
||||||
from ..objecttype import ObjectType
|
from ..objecttype import ObjectType
|
||||||
from ..unmountedtype import UnmountedType
|
from ..unmountedtype import UnmountedType
|
||||||
from ..abstracttype import AbstractType
|
from ..abstracttype import AbstractType
|
||||||
|
from ..interface import Interface
|
||||||
|
|
||||||
|
|
||||||
class MyType(object):
|
class MyType(Interface):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -15,6 +16,17 @@ class Container(ObjectType):
|
||||||
field2 = Field(MyType)
|
field2 = Field(MyType)
|
||||||
|
|
||||||
|
|
||||||
|
class MyInterface(Interface):
|
||||||
|
ifield = Field(MyType)
|
||||||
|
|
||||||
|
|
||||||
|
class ContainerWithInterface(ObjectType):
|
||||||
|
class Meta:
|
||||||
|
interfaces = (MyInterface, )
|
||||||
|
field1 = Field(MyType)
|
||||||
|
field2 = Field(MyType)
|
||||||
|
|
||||||
|
|
||||||
class MyScalar(UnmountedType):
|
class MyScalar(UnmountedType):
|
||||||
def get_type(self):
|
def get_type(self):
|
||||||
return MyType
|
return MyType
|
||||||
|
@ -94,6 +106,10 @@ def test_parent_container_get_fields():
|
||||||
assert list(Container._meta.fields.keys()) == ['field1', 'field2']
|
assert list(Container._meta.fields.keys()) == ['field1', 'field2']
|
||||||
|
|
||||||
|
|
||||||
|
def test_parent_container_interface_get_fields():
|
||||||
|
assert list(ContainerWithInterface._meta.fields.keys()) == ['ifield', 'field1', 'field2']
|
||||||
|
|
||||||
|
|
||||||
def test_objecttype_as_container_only_args():
|
def test_objecttype_as_container_only_args():
|
||||||
container = Container("1", "2")
|
container = Container("1", "2")
|
||||||
assert container.field1 == "1"
|
assert container.field1 == "1"
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import inspect
|
import inspect
|
||||||
|
from functools import partial
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
from graphql.type.typemap import GraphQLTypeMap
|
from graphql.type.typemap import GraphQLTypeMap
|
||||||
|
@ -14,6 +15,8 @@ from .scalars import Scalar, String, Boolean, Int, Float, ID
|
||||||
from graphql import GraphQLString, GraphQLField, GraphQLList, GraphQLBoolean, GraphQLInt, GraphQLFloat, GraphQLID, GraphQLNonNull, GraphQLInputObjectField, GraphQLArgument
|
from graphql import GraphQLString, GraphQLField, GraphQLList, GraphQLBoolean, GraphQLInt, GraphQLFloat, GraphQLID, GraphQLNonNull, GraphQLInputObjectField, GraphQLArgument
|
||||||
from graphql.type import GraphQLEnumValue
|
from graphql.type import GraphQLEnumValue
|
||||||
|
|
||||||
|
from ..utils.str_converters import to_camel_case
|
||||||
|
|
||||||
|
|
||||||
def is_graphene_type(_type):
|
def is_graphene_type(_type):
|
||||||
if isinstance(_type, (List, NonNull)):
|
if isinstance(_type, (List, NonNull)):
|
||||||
|
@ -22,13 +25,26 @@ def is_graphene_type(_type):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_type(resolve_type_func, map, root, args, info):
|
||||||
|
_type = resolve_type_func(root, args, info)
|
||||||
|
# assert inspect.isclass(_type) and issubclass(_type, ObjectType), (
|
||||||
|
# 'Received incompatible type "{}".'.format(_type)
|
||||||
|
# )
|
||||||
|
if inspect.isclass(_type) and issubclass(_type, ObjectType):
|
||||||
|
graphql_type = map.get(_type._meta.name)
|
||||||
|
assert graphql_type and graphql_type.graphene_type == _type
|
||||||
|
return graphql_type
|
||||||
|
return _type
|
||||||
|
|
||||||
|
|
||||||
class TypeMap(GraphQLTypeMap):
|
class TypeMap(GraphQLTypeMap):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def reducer(cls, map, type):
|
def reducer(cls, map, type):
|
||||||
if not type:
|
if not type:
|
||||||
return map
|
return map
|
||||||
|
if inspect.isfunction(type):
|
||||||
|
type = type()
|
||||||
if is_graphene_type(type):
|
if is_graphene_type(type):
|
||||||
return cls.graphene_reducer(map, type)
|
return cls.graphene_reducer(map, type)
|
||||||
return super(TypeMap, cls).reducer(map, type)
|
return super(TypeMap, cls).reducer(map, type)
|
||||||
|
@ -112,10 +128,11 @@ class TypeMap(GraphQLTypeMap):
|
||||||
)
|
)
|
||||||
interfaces = []
|
interfaces = []
|
||||||
for i in type._meta.interfaces:
|
for i in type._meta.interfaces:
|
||||||
map = cls.construct_interface(map, i)
|
map = cls.reducer(map, i)
|
||||||
interfaces.append(map[i._meta.name])
|
interfaces.append(map[i._meta.name])
|
||||||
map[type._meta.name].interfaces = interfaces
|
map[type._meta.name]._provided_interfaces = interfaces
|
||||||
map[type._meta.name]._fields = cls.construct_fields_for_type(map, type)
|
map[type._meta.name]._fields = cls.construct_fields_for_type(map, type)
|
||||||
|
# cls.reducer(map, map[type._meta.name])
|
||||||
return map
|
return map
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -126,9 +143,10 @@ class TypeMap(GraphQLTypeMap):
|
||||||
name=type._meta.name,
|
name=type._meta.name,
|
||||||
description=type._meta.description,
|
description=type._meta.description,
|
||||||
fields=None,
|
fields=None,
|
||||||
resolve_type=type.resolve_type,
|
resolve_type=partial(resolve_type, type.resolve_type, map),
|
||||||
)
|
)
|
||||||
map[type._meta.name]._fields = cls.construct_fields_for_type(map, type)
|
map[type._meta.name]._fields = cls.construct_fields_for_type(map, type)
|
||||||
|
# cls.reducer(map, map[type._meta.name])
|
||||||
return map
|
return map
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -159,6 +177,14 @@ class TypeMap(GraphQLTypeMap):
|
||||||
map[type._meta.name].types = types
|
map[type._meta.name].types = types
|
||||||
return map
|
return map
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def process_field_name(cls, name):
|
||||||
|
return to_camel_case(name)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default_resolver(cls, attname, root, *_):
|
||||||
|
return getattr(root, attname, None)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def construct_fields_for_type(cls, map, type, is_input_type=False):
|
def construct_fields_for_type(cls, map, type, is_input_type=False):
|
||||||
fields = OrderedDict()
|
fields = OrderedDict()
|
||||||
|
@ -181,25 +207,42 @@ class TypeMap(GraphQLTypeMap):
|
||||||
description=arg.description,
|
description=arg.description,
|
||||||
default_value=arg.default_value
|
default_value=arg.default_value
|
||||||
)
|
)
|
||||||
resolver = field.resolver
|
|
||||||
resolver_type = getattr(type, 'resolve_{}'.format(name), None)
|
|
||||||
if resolver_type:
|
|
||||||
resolver = resolver_type.__func__
|
|
||||||
|
|
||||||
_field = GraphQLField(
|
_field = GraphQLField(
|
||||||
field_type,
|
field_type,
|
||||||
args=args,
|
args=args,
|
||||||
resolver=resolver,
|
resolver=field.resolver or cls.get_resolver_for_type(type, name),
|
||||||
deprecation_reason=field.deprecation_reason,
|
deprecation_reason=field.deprecation_reason,
|
||||||
description=field.description
|
description=field.description
|
||||||
)
|
)
|
||||||
fields[name] = _field
|
processed_name = cls.process_field_name(name)
|
||||||
|
fields[processed_name] = _field
|
||||||
return fields
|
return fields
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_resolver_for_type(cls, type, name):
|
||||||
|
if not issubclass(type, ObjectType):
|
||||||
|
return
|
||||||
|
resolver = getattr(type, 'resolve_{}'.format(name), None)
|
||||||
|
if not resolver:
|
||||||
|
# If we don't find the resolver in the ObjectType class, then try to
|
||||||
|
# find it in each of the interfaces
|
||||||
|
interface_resolver = None
|
||||||
|
for interface in type._meta.interfaces:
|
||||||
|
interface_resolver = getattr(interface, 'resolve_{}'.format(name), None)
|
||||||
|
if interface_resolver:
|
||||||
|
break
|
||||||
|
resolver = interface_resolver
|
||||||
|
# Only if is not decorated with classmethod
|
||||||
|
if resolver and not getattr(resolver, '__self__', True):
|
||||||
|
return resolver.__func__
|
||||||
|
return partial(cls.default_resolver, name)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_field_type(self, map, type):
|
def get_field_type(self, map, type):
|
||||||
if isinstance(type, List):
|
if isinstance(type, List):
|
||||||
return GraphQLList(self.get_field_type(map, type.of_type))
|
return GraphQLList(self.get_field_type(map, type.of_type))
|
||||||
if isinstance(type, NonNull):
|
if isinstance(type, NonNull):
|
||||||
return GraphQLNonNull(self.get_field_type(map, type.of_type))
|
return GraphQLNonNull(self.get_field_type(map, type.of_type))
|
||||||
|
if inspect.isfunction(type):
|
||||||
|
type = type()
|
||||||
return map.get(type._meta.name)
|
return map.get(type._meta.name)
|
||||||
|
|
|
@ -6,9 +6,10 @@ from .inputfield import InputField
|
||||||
|
|
||||||
|
|
||||||
def merge_fields_in_attrs(bases, attrs):
|
def merge_fields_in_attrs(bases, attrs):
|
||||||
from ..types.abstracttype import AbstractType
|
from ..types import AbstractType, Interface
|
||||||
|
inherited_bases = (AbstractType, Interface)
|
||||||
for base in bases:
|
for base in bases:
|
||||||
if base == AbstractType or not issubclass(base, AbstractType):
|
if base in inherited_bases or not issubclass(base, inherited_bases):
|
||||||
continue
|
continue
|
||||||
for name, field in base._meta.fields.items():
|
for name, field in base._meta.fields.items():
|
||||||
if name in attrs:
|
if name in attrs:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user