mirror of
				https://github.com/graphql-python/graphene.git
				synced 2025-10-31 07:57:26 +03:00 
			
		
		
		
	Improved base implementation
This commit is contained in:
		
							parent
							
								
									0ffdd8d9ab
								
							
						
					
					
						commit
						b19bca7f3b
					
				|  | @ -32,11 +32,13 @@ class InterfaceMeta(AbstractTypeMeta): | |||
| 
 | ||||
| 
 | ||||
| class Interface(six.with_metaclass(InterfaceMeta)): | ||||
|     resolve_type = None | ||||
|     @classmethod | ||||
|     def resolve_type(cls, root, args, info): | ||||
|         return type(root) | ||||
| 
 | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         raise Exception("An Interface cannot be intitialized") | ||||
| 
 | ||||
|     # @classmethod | ||||
|     # def implements(cls, objecttype): | ||||
|     #     pass | ||||
|     @classmethod | ||||
|     def implements(cls, objecttype): | ||||
|         pass | ||||
|  |  | |||
|  | @ -21,4 +21,4 @@ class JSONString(Scalar): | |||
| 
 | ||||
|     @staticmethod | ||||
|     def parse_value(value): | ||||
|         return json.dumps(value) | ||||
|         return json.loads(value) | ||||
|  |  | |||
|  | @ -1,3 +1,4 @@ | |||
| from collections import OrderedDict | ||||
| import six | ||||
| 
 | ||||
| from ..utils.is_base_type import is_base_type | ||||
|  | @ -5,6 +6,7 @@ from .options import Options | |||
| 
 | ||||
| from .abstracttype import AbstractTypeMeta | ||||
| from .utils import get_fields_in_type, yank_fields_from_attrs, merge_fields_in_attrs | ||||
| from .interface import Interface | ||||
| 
 | ||||
| 
 | ||||
| class ObjectTypeMeta(AbstractTypeMeta): | ||||
|  | @ -23,10 +25,22 @@ class ObjectTypeMeta(AbstractTypeMeta): | |||
|         ) | ||||
| 
 | ||||
|         attrs = merge_fields_in_attrs(bases, attrs) | ||||
|         options.fields = get_fields_in_type(ObjectType, attrs) | ||||
|         yank_fields_from_attrs(attrs, options.fields) | ||||
|         options.local_fields = get_fields_in_type(ObjectType, attrs) | ||||
|         yank_fields_from_attrs(attrs, options.local_fields) | ||||
|         options.interface_fields = OrderedDict() | ||||
|         for interface in options.interfaces: | ||||
|             assert issubclass(interface, Interface), ( | ||||
|                 'All interfaces of {} must be a subclass of Interface. Received "{}".' | ||||
|             ).format(name, interface) | ||||
|             options.interface_fields.update(interface._meta.fields) | ||||
|         options.fields = OrderedDict(options.interface_fields) | ||||
|         options.fields.update(options.local_fields) | ||||
| 
 | ||||
|         return type.__new__(cls, name, bases, dict(attrs, _meta=options)) | ||||
|         cls = type.__new__(cls, name, bases, dict(attrs, _meta=options)) | ||||
|         for interface in options.interfaces: | ||||
|             interface.implements(cls) | ||||
| 
 | ||||
|         return cls | ||||
| 
 | ||||
|     def __str__(cls): | ||||
|         return cls._meta.name | ||||
|  |  | |||
|  | @ -62,6 +62,8 @@ class Schema(GraphQLSchema): | |||
|         return self.get_graphql_type(self._subscription) | ||||
| 
 | ||||
|     def get_graphql_type(self, _type): | ||||
|         if not _type: | ||||
|             return _type | ||||
|         if is_type(_type): | ||||
|             return _type | ||||
|         if is_graphene_type(_type): | ||||
|  |  | |||
|  | @ -71,6 +71,18 @@ def test_generate_interface_inherit_abstracttype(): | |||
|     assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field] | ||||
| 
 | ||||
| 
 | ||||
| def test_generate_interface_inherit_interface(): | ||||
|     class MyBaseInterface(Interface): | ||||
|         field1 = MyScalar() | ||||
| 
 | ||||
|     class MyInterface(MyBaseInterface): | ||||
|         field2 = MyScalar() | ||||
| 
 | ||||
|     assert MyInterface._meta.name == 'MyInterface' | ||||
|     assert MyInterface._meta.fields.keys() == ['field1', 'field2'] | ||||
|     assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field] | ||||
| 
 | ||||
| 
 | ||||
| def test_generate_interface_inherit_abstracttype_reversed(): | ||||
|     class MyAbstractType(AbstractType): | ||||
|         field1 = MyScalar() | ||||
|  |  | |||
|  | @ -4,9 +4,10 @@ from ..field import Field | |||
| from ..objecttype import ObjectType | ||||
| from ..unmountedtype import UnmountedType | ||||
| from ..abstracttype import AbstractType | ||||
| from ..interface import Interface | ||||
| 
 | ||||
| 
 | ||||
| class MyType(object): | ||||
| class MyType(Interface): | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
|  | @ -15,6 +16,17 @@ class Container(ObjectType): | |||
|     field2 = Field(MyType) | ||||
| 
 | ||||
| 
 | ||||
| class MyInterface(Interface): | ||||
|     ifield = Field(MyType) | ||||
| 
 | ||||
| 
 | ||||
| class ContainerWithInterface(ObjectType): | ||||
|     class Meta: | ||||
|         interfaces = (MyInterface, ) | ||||
|     field1 = Field(MyType) | ||||
|     field2 = Field(MyType) | ||||
| 
 | ||||
| 
 | ||||
| class MyScalar(UnmountedType): | ||||
|     def get_type(self): | ||||
|         return MyType | ||||
|  | @ -94,6 +106,10 @@ def test_parent_container_get_fields(): | |||
|     assert list(Container._meta.fields.keys()) == ['field1', 'field2'] | ||||
| 
 | ||||
| 
 | ||||
| def test_parent_container_interface_get_fields(): | ||||
|     assert list(ContainerWithInterface._meta.fields.keys()) == ['ifield', 'field1', 'field2'] | ||||
| 
 | ||||
| 
 | ||||
| def test_objecttype_as_container_only_args(): | ||||
|     container = Container("1", "2") | ||||
|     assert container.field1 == "1" | ||||
|  |  | |||
|  | @ -1,4 +1,5 @@ | |||
| import inspect | ||||
| from functools import partial | ||||
| from collections import OrderedDict | ||||
| 
 | ||||
| from graphql.type.typemap import GraphQLTypeMap | ||||
|  | @ -14,6 +15,8 @@ from .scalars import Scalar, String, Boolean, Int, Float, ID | |||
| from graphql import GraphQLString, GraphQLField, GraphQLList, GraphQLBoolean, GraphQLInt, GraphQLFloat, GraphQLID, GraphQLNonNull, GraphQLInputObjectField, GraphQLArgument | ||||
| from graphql.type import GraphQLEnumValue | ||||
| 
 | ||||
| from ..utils.str_converters import to_camel_case | ||||
| 
 | ||||
| 
 | ||||
| def is_graphene_type(_type): | ||||
|     if isinstance(_type, (List, NonNull)): | ||||
|  | @ -22,13 +25,26 @@ def is_graphene_type(_type): | |||
|         return True | ||||
| 
 | ||||
| 
 | ||||
| def resolve_type(resolve_type_func, map, root, args, info): | ||||
|     _type = resolve_type_func(root, args, info) | ||||
|     # assert inspect.isclass(_type) and issubclass(_type, ObjectType), ( | ||||
|     #     'Received incompatible type "{}".'.format(_type) | ||||
|     # ) | ||||
|     if inspect.isclass(_type) and issubclass(_type, ObjectType): | ||||
|         graphql_type = map.get(_type._meta.name) | ||||
|         assert graphql_type and graphql_type.graphene_type == _type | ||||
|         return graphql_type | ||||
|     return _type | ||||
| 
 | ||||
| 
 | ||||
| class TypeMap(GraphQLTypeMap): | ||||
| 
 | ||||
|     @classmethod | ||||
|     def reducer(cls, map, type): | ||||
|         if not type: | ||||
|             return map | ||||
| 
 | ||||
|         if inspect.isfunction(type): | ||||
|             type = type() | ||||
|         if is_graphene_type(type): | ||||
|             return cls.graphene_reducer(map, type) | ||||
|         return super(TypeMap, cls).reducer(map, type) | ||||
|  | @ -112,10 +128,11 @@ class TypeMap(GraphQLTypeMap): | |||
|         ) | ||||
|         interfaces = [] | ||||
|         for i in type._meta.interfaces: | ||||
|             map = cls.construct_interface(map, i) | ||||
|             map = cls.reducer(map, i) | ||||
|             interfaces.append(map[i._meta.name]) | ||||
|         map[type._meta.name].interfaces = interfaces | ||||
|         map[type._meta.name]._provided_interfaces = interfaces | ||||
|         map[type._meta.name]._fields = cls.construct_fields_for_type(map, type) | ||||
|         # cls.reducer(map, map[type._meta.name]) | ||||
|         return map | ||||
| 
 | ||||
|     @classmethod | ||||
|  | @ -126,9 +143,10 @@ class TypeMap(GraphQLTypeMap): | |||
|             name=type._meta.name, | ||||
|             description=type._meta.description, | ||||
|             fields=None, | ||||
|             resolve_type=type.resolve_type, | ||||
|             resolve_type=partial(resolve_type, type.resolve_type, map), | ||||
|         ) | ||||
|         map[type._meta.name]._fields = cls.construct_fields_for_type(map, type) | ||||
|         # cls.reducer(map, map[type._meta.name]) | ||||
|         return map | ||||
| 
 | ||||
|     @classmethod | ||||
|  | @ -159,6 +177,14 @@ class TypeMap(GraphQLTypeMap): | |||
|         map[type._meta.name].types = types | ||||
|         return map | ||||
| 
 | ||||
|     @classmethod | ||||
|     def process_field_name(cls, name): | ||||
|         return to_camel_case(name) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def default_resolver(cls, attname, root, *_): | ||||
|         return getattr(root, attname, None) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def construct_fields_for_type(cls, map, type, is_input_type=False): | ||||
|         fields = OrderedDict() | ||||
|  | @ -181,25 +207,42 @@ class TypeMap(GraphQLTypeMap): | |||
|                         description=arg.description, | ||||
|                         default_value=arg.default_value | ||||
|                     ) | ||||
|                 resolver = field.resolver | ||||
|                 resolver_type = getattr(type, 'resolve_{}'.format(name), None) | ||||
|                 if resolver_type: | ||||
|                     resolver = resolver_type.__func__ | ||||
| 
 | ||||
|                 _field = GraphQLField( | ||||
|                     field_type, | ||||
|                     args=args, | ||||
|                     resolver=resolver, | ||||
|                     resolver=field.resolver or cls.get_resolver_for_type(type, name), | ||||
|                     deprecation_reason=field.deprecation_reason, | ||||
|                     description=field.description | ||||
|                 ) | ||||
|             fields[name] = _field | ||||
|             processed_name = cls.process_field_name(name) | ||||
|             fields[processed_name] = _field | ||||
|         return fields | ||||
| 
 | ||||
|     @classmethod | ||||
|     def get_resolver_for_type(cls, type, name): | ||||
|         if not issubclass(type, ObjectType): | ||||
|             return | ||||
|         resolver = getattr(type, 'resolve_{}'.format(name), None) | ||||
|         if not resolver: | ||||
|             # If we don't find the resolver in the ObjectType class, then try to | ||||
|             # find it in each of the interfaces | ||||
|             interface_resolver = None | ||||
|             for interface in type._meta.interfaces: | ||||
|                 interface_resolver = getattr(interface, 'resolve_{}'.format(name), None) | ||||
|                 if interface_resolver: | ||||
|                     break | ||||
|             resolver = interface_resolver | ||||
|         # Only if is not decorated with classmethod | ||||
|         if resolver and not getattr(resolver, '__self__', True): | ||||
|             return resolver.__func__ | ||||
|         return partial(cls.default_resolver, name) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def get_field_type(self, map, type): | ||||
|         if isinstance(type, List): | ||||
|             return GraphQLList(self.get_field_type(map, type.of_type)) | ||||
|         if isinstance(type, NonNull): | ||||
|             return GraphQLNonNull(self.get_field_type(map, type.of_type)) | ||||
|         if inspect.isfunction(type): | ||||
|             type = type() | ||||
|         return map.get(type._meta.name) | ||||
|  |  | |||
|  | @ -6,9 +6,10 @@ from .inputfield import InputField | |||
| 
 | ||||
| 
 | ||||
| def merge_fields_in_attrs(bases, attrs): | ||||
|     from ..types.abstracttype import AbstractType | ||||
|     from ..types import AbstractType, Interface | ||||
|     inherited_bases = (AbstractType, Interface) | ||||
|     for base in bases: | ||||
|         if base == AbstractType or not issubclass(base, AbstractType): | ||||
|         if base in inherited_bases or not issubclass(base, inherited_bases): | ||||
|             continue | ||||
|         for name, field in base._meta.fields.items(): | ||||
|             if name in attrs: | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user