mirror of
				https://github.com/graphql-python/graphene.git
				synced 2025-10-31 16:07:27 +03:00 
			
		
		
		
	Improved base implementation
This commit is contained in:
		
							parent
							
								
									0ffdd8d9ab
								
							
						
					
					
						commit
						b19bca7f3b
					
				|  | @ -32,11 +32,13 @@ class InterfaceMeta(AbstractTypeMeta): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Interface(six.with_metaclass(InterfaceMeta)): | class Interface(six.with_metaclass(InterfaceMeta)): | ||||||
|     resolve_type = None |     @classmethod | ||||||
|  |     def resolve_type(cls, root, args, info): | ||||||
|  |         return type(root) | ||||||
| 
 | 
 | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, **kwargs): | ||||||
|         raise Exception("An Interface cannot be intitialized") |         raise Exception("An Interface cannot be intitialized") | ||||||
| 
 | 
 | ||||||
|     # @classmethod |     @classmethod | ||||||
|     # def implements(cls, objecttype): |     def implements(cls, objecttype): | ||||||
|     #     pass |         pass | ||||||
|  |  | ||||||
|  | @ -21,4 +21,4 @@ class JSONString(Scalar): | ||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def parse_value(value): |     def parse_value(value): | ||||||
|         return json.dumps(value) |         return json.loads(value) | ||||||
|  |  | ||||||
|  | @ -1,3 +1,4 @@ | ||||||
|  | from collections import OrderedDict | ||||||
| import six | import six | ||||||
| 
 | 
 | ||||||
| from ..utils.is_base_type import is_base_type | from ..utils.is_base_type import is_base_type | ||||||
|  | @ -5,6 +6,7 @@ from .options import Options | ||||||
| 
 | 
 | ||||||
| from .abstracttype import AbstractTypeMeta | from .abstracttype import AbstractTypeMeta | ||||||
| from .utils import get_fields_in_type, yank_fields_from_attrs, merge_fields_in_attrs | from .utils import get_fields_in_type, yank_fields_from_attrs, merge_fields_in_attrs | ||||||
|  | from .interface import Interface | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class ObjectTypeMeta(AbstractTypeMeta): | class ObjectTypeMeta(AbstractTypeMeta): | ||||||
|  | @ -23,10 +25,22 @@ class ObjectTypeMeta(AbstractTypeMeta): | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         attrs = merge_fields_in_attrs(bases, attrs) |         attrs = merge_fields_in_attrs(bases, attrs) | ||||||
|         options.fields = get_fields_in_type(ObjectType, attrs) |         options.local_fields = get_fields_in_type(ObjectType, attrs) | ||||||
|         yank_fields_from_attrs(attrs, options.fields) |         yank_fields_from_attrs(attrs, options.local_fields) | ||||||
|  |         options.interface_fields = OrderedDict() | ||||||
|  |         for interface in options.interfaces: | ||||||
|  |             assert issubclass(interface, Interface), ( | ||||||
|  |                 'All interfaces of {} must be a subclass of Interface. Received "{}".' | ||||||
|  |             ).format(name, interface) | ||||||
|  |             options.interface_fields.update(interface._meta.fields) | ||||||
|  |         options.fields = OrderedDict(options.interface_fields) | ||||||
|  |         options.fields.update(options.local_fields) | ||||||
| 
 | 
 | ||||||
|         return type.__new__(cls, name, bases, dict(attrs, _meta=options)) |         cls = type.__new__(cls, name, bases, dict(attrs, _meta=options)) | ||||||
|  |         for interface in options.interfaces: | ||||||
|  |             interface.implements(cls) | ||||||
|  | 
 | ||||||
|  |         return cls | ||||||
| 
 | 
 | ||||||
|     def __str__(cls): |     def __str__(cls): | ||||||
|         return cls._meta.name |         return cls._meta.name | ||||||
|  |  | ||||||
|  | @ -62,6 +62,8 @@ class Schema(GraphQLSchema): | ||||||
|         return self.get_graphql_type(self._subscription) |         return self.get_graphql_type(self._subscription) | ||||||
| 
 | 
 | ||||||
|     def get_graphql_type(self, _type): |     def get_graphql_type(self, _type): | ||||||
|  |         if not _type: | ||||||
|  |             return _type | ||||||
|         if is_type(_type): |         if is_type(_type): | ||||||
|             return _type |             return _type | ||||||
|         if is_graphene_type(_type): |         if is_graphene_type(_type): | ||||||
|  |  | ||||||
|  | @ -71,6 +71,18 @@ def test_generate_interface_inherit_abstracttype(): | ||||||
|     assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field] |     assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def test_generate_interface_inherit_interface(): | ||||||
|  |     class MyBaseInterface(Interface): | ||||||
|  |         field1 = MyScalar() | ||||||
|  | 
 | ||||||
|  |     class MyInterface(MyBaseInterface): | ||||||
|  |         field2 = MyScalar() | ||||||
|  | 
 | ||||||
|  |     assert MyInterface._meta.name == 'MyInterface' | ||||||
|  |     assert MyInterface._meta.fields.keys() == ['field1', 'field2'] | ||||||
|  |     assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def test_generate_interface_inherit_abstracttype_reversed(): | def test_generate_interface_inherit_abstracttype_reversed(): | ||||||
|     class MyAbstractType(AbstractType): |     class MyAbstractType(AbstractType): | ||||||
|         field1 = MyScalar() |         field1 = MyScalar() | ||||||
|  |  | ||||||
|  | @ -4,9 +4,10 @@ from ..field import Field | ||||||
| from ..objecttype import ObjectType | from ..objecttype import ObjectType | ||||||
| from ..unmountedtype import UnmountedType | from ..unmountedtype import UnmountedType | ||||||
| from ..abstracttype import AbstractType | from ..abstracttype import AbstractType | ||||||
|  | from ..interface import Interface | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class MyType(object): | class MyType(Interface): | ||||||
|     pass |     pass | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -15,6 +16,17 @@ class Container(ObjectType): | ||||||
|     field2 = Field(MyType) |     field2 = Field(MyType) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class MyInterface(Interface): | ||||||
|  |     ifield = Field(MyType) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ContainerWithInterface(ObjectType): | ||||||
|  |     class Meta: | ||||||
|  |         interfaces = (MyInterface, ) | ||||||
|  |     field1 = Field(MyType) | ||||||
|  |     field2 = Field(MyType) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class MyScalar(UnmountedType): | class MyScalar(UnmountedType): | ||||||
|     def get_type(self): |     def get_type(self): | ||||||
|         return MyType |         return MyType | ||||||
|  | @ -94,6 +106,10 @@ def test_parent_container_get_fields(): | ||||||
|     assert list(Container._meta.fields.keys()) == ['field1', 'field2'] |     assert list(Container._meta.fields.keys()) == ['field1', 'field2'] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def test_parent_container_interface_get_fields(): | ||||||
|  |     assert list(ContainerWithInterface._meta.fields.keys()) == ['ifield', 'field1', 'field2'] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def test_objecttype_as_container_only_args(): | def test_objecttype_as_container_only_args(): | ||||||
|     container = Container("1", "2") |     container = Container("1", "2") | ||||||
|     assert container.field1 == "1" |     assert container.field1 == "1" | ||||||
|  |  | ||||||
|  | @ -1,4 +1,5 @@ | ||||||
| import inspect | import inspect | ||||||
|  | from functools import partial | ||||||
| from collections import OrderedDict | from collections import OrderedDict | ||||||
| 
 | 
 | ||||||
| from graphql.type.typemap import GraphQLTypeMap | from graphql.type.typemap import GraphQLTypeMap | ||||||
|  | @ -14,6 +15,8 @@ from .scalars import Scalar, String, Boolean, Int, Float, ID | ||||||
| from graphql import GraphQLString, GraphQLField, GraphQLList, GraphQLBoolean, GraphQLInt, GraphQLFloat, GraphQLID, GraphQLNonNull, GraphQLInputObjectField, GraphQLArgument | from graphql import GraphQLString, GraphQLField, GraphQLList, GraphQLBoolean, GraphQLInt, GraphQLFloat, GraphQLID, GraphQLNonNull, GraphQLInputObjectField, GraphQLArgument | ||||||
| from graphql.type import GraphQLEnumValue | from graphql.type import GraphQLEnumValue | ||||||
| 
 | 
 | ||||||
|  | from ..utils.str_converters import to_camel_case | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| def is_graphene_type(_type): | def is_graphene_type(_type): | ||||||
|     if isinstance(_type, (List, NonNull)): |     if isinstance(_type, (List, NonNull)): | ||||||
|  | @ -22,13 +25,26 @@ def is_graphene_type(_type): | ||||||
|         return True |         return True | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def resolve_type(resolve_type_func, map, root, args, info): | ||||||
|  |     _type = resolve_type_func(root, args, info) | ||||||
|  |     # assert inspect.isclass(_type) and issubclass(_type, ObjectType), ( | ||||||
|  |     #     'Received incompatible type "{}".'.format(_type) | ||||||
|  |     # ) | ||||||
|  |     if inspect.isclass(_type) and issubclass(_type, ObjectType): | ||||||
|  |         graphql_type = map.get(_type._meta.name) | ||||||
|  |         assert graphql_type and graphql_type.graphene_type == _type | ||||||
|  |         return graphql_type | ||||||
|  |     return _type | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class TypeMap(GraphQLTypeMap): | class TypeMap(GraphQLTypeMap): | ||||||
| 
 | 
 | ||||||
|     @classmethod |     @classmethod | ||||||
|     def reducer(cls, map, type): |     def reducer(cls, map, type): | ||||||
|         if not type: |         if not type: | ||||||
|             return map |             return map | ||||||
| 
 |         if inspect.isfunction(type): | ||||||
|  |             type = type() | ||||||
|         if is_graphene_type(type): |         if is_graphene_type(type): | ||||||
|             return cls.graphene_reducer(map, type) |             return cls.graphene_reducer(map, type) | ||||||
|         return super(TypeMap, cls).reducer(map, type) |         return super(TypeMap, cls).reducer(map, type) | ||||||
|  | @ -112,10 +128,11 @@ class TypeMap(GraphQLTypeMap): | ||||||
|         ) |         ) | ||||||
|         interfaces = [] |         interfaces = [] | ||||||
|         for i in type._meta.interfaces: |         for i in type._meta.interfaces: | ||||||
|             map = cls.construct_interface(map, i) |             map = cls.reducer(map, i) | ||||||
|             interfaces.append(map[i._meta.name]) |             interfaces.append(map[i._meta.name]) | ||||||
|         map[type._meta.name].interfaces = interfaces |         map[type._meta.name]._provided_interfaces = interfaces | ||||||
|         map[type._meta.name]._fields = cls.construct_fields_for_type(map, type) |         map[type._meta.name]._fields = cls.construct_fields_for_type(map, type) | ||||||
|  |         # cls.reducer(map, map[type._meta.name]) | ||||||
|         return map |         return map | ||||||
| 
 | 
 | ||||||
|     @classmethod |     @classmethod | ||||||
|  | @ -126,9 +143,10 @@ class TypeMap(GraphQLTypeMap): | ||||||
|             name=type._meta.name, |             name=type._meta.name, | ||||||
|             description=type._meta.description, |             description=type._meta.description, | ||||||
|             fields=None, |             fields=None, | ||||||
|             resolve_type=type.resolve_type, |             resolve_type=partial(resolve_type, type.resolve_type, map), | ||||||
|         ) |         ) | ||||||
|         map[type._meta.name]._fields = cls.construct_fields_for_type(map, type) |         map[type._meta.name]._fields = cls.construct_fields_for_type(map, type) | ||||||
|  |         # cls.reducer(map, map[type._meta.name]) | ||||||
|         return map |         return map | ||||||
| 
 | 
 | ||||||
|     @classmethod |     @classmethod | ||||||
|  | @ -159,6 +177,14 @@ class TypeMap(GraphQLTypeMap): | ||||||
|         map[type._meta.name].types = types |         map[type._meta.name].types = types | ||||||
|         return map |         return map | ||||||
| 
 | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def process_field_name(cls, name): | ||||||
|  |         return to_camel_case(name) | ||||||
|  | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def default_resolver(cls, attname, root, *_): | ||||||
|  |         return getattr(root, attname, None) | ||||||
|  | 
 | ||||||
|     @classmethod |     @classmethod | ||||||
|     def construct_fields_for_type(cls, map, type, is_input_type=False): |     def construct_fields_for_type(cls, map, type, is_input_type=False): | ||||||
|         fields = OrderedDict() |         fields = OrderedDict() | ||||||
|  | @ -181,25 +207,42 @@ class TypeMap(GraphQLTypeMap): | ||||||
|                         description=arg.description, |                         description=arg.description, | ||||||
|                         default_value=arg.default_value |                         default_value=arg.default_value | ||||||
|                     ) |                     ) | ||||||
|                 resolver = field.resolver |  | ||||||
|                 resolver_type = getattr(type, 'resolve_{}'.format(name), None) |  | ||||||
|                 if resolver_type: |  | ||||||
|                     resolver = resolver_type.__func__ |  | ||||||
| 
 |  | ||||||
|                 _field = GraphQLField( |                 _field = GraphQLField( | ||||||
|                     field_type, |                     field_type, | ||||||
|                     args=args, |                     args=args, | ||||||
|                     resolver=resolver, |                     resolver=field.resolver or cls.get_resolver_for_type(type, name), | ||||||
|                     deprecation_reason=field.deprecation_reason, |                     deprecation_reason=field.deprecation_reason, | ||||||
|                     description=field.description |                     description=field.description | ||||||
|                 ) |                 ) | ||||||
|             fields[name] = _field |             processed_name = cls.process_field_name(name) | ||||||
|  |             fields[processed_name] = _field | ||||||
|         return fields |         return fields | ||||||
| 
 | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def get_resolver_for_type(cls, type, name): | ||||||
|  |         if not issubclass(type, ObjectType): | ||||||
|  |             return | ||||||
|  |         resolver = getattr(type, 'resolve_{}'.format(name), None) | ||||||
|  |         if not resolver: | ||||||
|  |             # If we don't find the resolver in the ObjectType class, then try to | ||||||
|  |             # find it in each of the interfaces | ||||||
|  |             interface_resolver = None | ||||||
|  |             for interface in type._meta.interfaces: | ||||||
|  |                 interface_resolver = getattr(interface, 'resolve_{}'.format(name), None) | ||||||
|  |                 if interface_resolver: | ||||||
|  |                     break | ||||||
|  |             resolver = interface_resolver | ||||||
|  |         # Only if is not decorated with classmethod | ||||||
|  |         if resolver and not getattr(resolver, '__self__', True): | ||||||
|  |             return resolver.__func__ | ||||||
|  |         return partial(cls.default_resolver, name) | ||||||
|  | 
 | ||||||
|     @classmethod |     @classmethod | ||||||
|     def get_field_type(self, map, type): |     def get_field_type(self, map, type): | ||||||
|         if isinstance(type, List): |         if isinstance(type, List): | ||||||
|             return GraphQLList(self.get_field_type(map, type.of_type)) |             return GraphQLList(self.get_field_type(map, type.of_type)) | ||||||
|         if isinstance(type, NonNull): |         if isinstance(type, NonNull): | ||||||
|             return GraphQLNonNull(self.get_field_type(map, type.of_type)) |             return GraphQLNonNull(self.get_field_type(map, type.of_type)) | ||||||
|  |         if inspect.isfunction(type): | ||||||
|  |             type = type() | ||||||
|         return map.get(type._meta.name) |         return map.get(type._meta.name) | ||||||
|  |  | ||||||
|  | @ -6,9 +6,10 @@ from .inputfield import InputField | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def merge_fields_in_attrs(bases, attrs): | def merge_fields_in_attrs(bases, attrs): | ||||||
|     from ..types.abstracttype import AbstractType |     from ..types import AbstractType, Interface | ||||||
|  |     inherited_bases = (AbstractType, Interface) | ||||||
|     for base in bases: |     for base in bases: | ||||||
|         if base == AbstractType or not issubclass(base, AbstractType): |         if base in inherited_bases or not issubclass(base, inherited_bases): | ||||||
|             continue |             continue | ||||||
|         for name, field in base._meta.fields.items(): |         for name, field in base._meta.fields.items(): | ||||||
|             if name in attrs: |             if name in attrs: | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user