mirror of
				https://github.com/graphql-python/graphene.git
				synced 2025-11-04 09:57:41 +03:00 
			
		
		
		
	Improved fields mounting
This commit is contained in:
		
							parent
							
								
									9f655d9416
								
							
						
					
					
						commit
						25e967200b
					
				| 
						 | 
				
			
			@ -1,6 +1,8 @@
 | 
			
		|||
from collections import OrderedDict
 | 
			
		||||
import inspect
 | 
			
		||||
import copy
 | 
			
		||||
from itertools import chain
 | 
			
		||||
from functools import partial
 | 
			
		||||
 | 
			
		||||
from graphql.utils.assert_valid_name import assert_valid_name
 | 
			
		||||
from graphql.type.definition import GraphQLObjectType
 | 
			
		||||
| 
						 | 
				
			
			@ -58,6 +60,26 @@ class ClassTypeMeta(type):
 | 
			
		|||
        return cls
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FieldsMeta(type):
 | 
			
		||||
 | 
			
		||||
    def _build_field_map(cls, bases, local_fields):
 | 
			
		||||
        from ..utils.extract_fields import get_base_fields
 | 
			
		||||
        extended_fields = get_base_fields(cls, bases)
 | 
			
		||||
        fields = chain(extended_fields, local_fields)
 | 
			
		||||
        return OrderedDict((f.name, f) for f in fields)
 | 
			
		||||
 | 
			
		||||
    def _fields(cls, bases, attrs):
 | 
			
		||||
        from ..utils.is_graphene_type import is_graphene_type
 | 
			
		||||
        from ..utils.extract_fields import extract_fields
 | 
			
		||||
 | 
			
		||||
        inherited_types = [
 | 
			
		||||
            base._meta.graphql_type for base in bases if is_graphene_type(base) and not base._meta.abstract
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        local_fields = extract_fields(cls, attrs)
 | 
			
		||||
        return partial(cls._build_field_map, inherited_types, local_fields)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GrapheneGraphQLType(object):
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        self.graphene_type = kwargs.pop('graphene_type')
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,7 +2,7 @@ import six
 | 
			
		|||
 | 
			
		||||
from graphql import GraphQLInputObjectType
 | 
			
		||||
 | 
			
		||||
from .definitions import ClassTypeMeta, GrapheneFieldsType, FieldMap
 | 
			
		||||
from .definitions import FieldsMeta, ClassTypeMeta, GrapheneFieldsType
 | 
			
		||||
from .proxy import TypeProxy
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -10,7 +10,7 @@ class GrapheneInputObjectType(GrapheneFieldsType, GraphQLInputObjectType):
 | 
			
		|||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InputObjectTypeMeta(ClassTypeMeta):
 | 
			
		||||
class InputObjectTypeMeta(FieldsMeta, ClassTypeMeta):
 | 
			
		||||
 | 
			
		||||
    def get_options(cls, meta):
 | 
			
		||||
        return cls.options_class(
 | 
			
		||||
| 
						 | 
				
			
			@ -22,20 +22,17 @@ class InputObjectTypeMeta(ClassTypeMeta):
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
    def construct_graphql_type(cls, bases):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def construct(cls, bases, attrs):
 | 
			
		||||
        if not cls._meta.graphql_type and not cls._meta.abstract:
 | 
			
		||||
            from ..utils.get_graphql_type import get_graphql_type
 | 
			
		||||
            from ..utils.is_graphene_type import is_graphene_type
 | 
			
		||||
 | 
			
		||||
            inherited_types = [
 | 
			
		||||
                base._meta.graphql_type for base in bases if is_graphene_type(base)
 | 
			
		||||
            ]
 | 
			
		||||
 | 
			
		||||
            cls._meta.graphql_type = GrapheneInputObjectType(
 | 
			
		||||
                graphene_type=cls,
 | 
			
		||||
                name=cls._meta.name or cls.__name__,
 | 
			
		||||
                description=cls._meta.description or cls.__doc__,
 | 
			
		||||
                fields=FieldMap(cls, bases=filter(None, inherited_types)),
 | 
			
		||||
                fields=cls._fields(bases, attrs),
 | 
			
		||||
            )
 | 
			
		||||
        return super(InputObjectTypeMeta, cls).construct(bases, attrs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InputObjectType(six.with_metaclass(InputObjectTypeMeta, TypeProxy)):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,17 +1,14 @@
 | 
			
		|||
from itertools import chain
 | 
			
		||||
from functools import partial
 | 
			
		||||
from collections import OrderedDict
 | 
			
		||||
import six
 | 
			
		||||
 | 
			
		||||
from graphql import GraphQLInterfaceType
 | 
			
		||||
from .definitions import ClassTypeMeta, GrapheneFieldsType, FieldMap
 | 
			
		||||
from .definitions import FieldsMeta, ClassTypeMeta, GrapheneFieldsType
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GrapheneInterfaceType(GrapheneFieldsType, GraphQLInterfaceType):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InterfaceTypeMeta(ClassTypeMeta):
 | 
			
		||||
class InterfaceTypeMeta(FieldsMeta, ClassTypeMeta):
 | 
			
		||||
 | 
			
		||||
    def get_options(cls, meta):
 | 
			
		||||
        return cls.options_class(
 | 
			
		||||
| 
						 | 
				
			
			@ -25,29 +22,14 @@ class InterfaceTypeMeta(ClassTypeMeta):
 | 
			
		|||
    def construct_graphql_type(cls, bases):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def _build_field_map(cls, local_fields, bases):
 | 
			
		||||
        from ..utils.extract_fields import get_base_fields
 | 
			
		||||
        extended_fields = get_base_fields(bases)
 | 
			
		||||
        fields = chain(extended_fields, local_fields)
 | 
			
		||||
        return OrderedDict((f.name, f) for f in fields)
 | 
			
		||||
 | 
			
		||||
    def construct(cls, bases, attrs):
 | 
			
		||||
        if not cls._meta.graphql_type and not cls._meta.abstract:
 | 
			
		||||
            from ..utils.is_graphene_type import is_graphene_type
 | 
			
		||||
            from ..utils.extract_fields import extract_fields
 | 
			
		||||
 | 
			
		||||
            inherited_types = [
 | 
			
		||||
                base._meta.graphql_type for base in bases if is_graphene_type(base)
 | 
			
		||||
            ]
 | 
			
		||||
            inherited_types = filter(None, inherited_types)
 | 
			
		||||
 | 
			
		||||
            local_fields = list(extract_fields(attrs))
 | 
			
		||||
 | 
			
		||||
            cls._meta.graphql_type = GrapheneInterfaceType(
 | 
			
		||||
                graphene_type=cls,
 | 
			
		||||
                name=cls._meta.name or cls.__name__,
 | 
			
		||||
                description=cls._meta.description or cls.__doc__,
 | 
			
		||||
                fields=partial(cls._build_field_map, local_fields, inherited_types),
 | 
			
		||||
                fields=cls._fields(bases, attrs),
 | 
			
		||||
            )
 | 
			
		||||
        return super(InterfaceTypeMeta, cls).construct(bases, attrs)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -64,3 +64,17 @@ class TypeProxy(OrderedType):
 | 
			
		|||
            raise Exception('TypedProxy "{}" cannot be mounted in {}'.format(self.get_type(), cls))
 | 
			
		||||
 | 
			
		||||
        return inner.contribute_to_class(cls, attname)
 | 
			
		||||
 | 
			
		||||
    def as_mounted(self, cls):
 | 
			
		||||
        from .inputobjecttype import InputObjectType
 | 
			
		||||
        from .objecttype import ObjectType
 | 
			
		||||
        from .interface import Interface
 | 
			
		||||
 | 
			
		||||
        if issubclass(cls, (ObjectType, Interface)):
 | 
			
		||||
            inner = self.as_field()
 | 
			
		||||
        elif issubclass(cls, (InputObjectType)):
 | 
			
		||||
            inner = self.as_inputfield()
 | 
			
		||||
        else:
 | 
			
		||||
            raise Exception('TypedProxy "{}" cannot be mounted in {}'.format(self.get_type(), cls))
 | 
			
		||||
 | 
			
		||||
        return inner
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,20 +1,19 @@
 | 
			
		|||
import copy
 | 
			
		||||
from .get_graphql_type import get_graphql_type
 | 
			
		||||
 | 
			
		||||
from ..types.field import Field
 | 
			
		||||
from ..types.field import Field, InputField
 | 
			
		||||
from ..types.proxy import TypeProxy
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def extract_fields(attrs):
 | 
			
		||||
def extract_fields(cls, attrs):
 | 
			
		||||
    fields = set()
 | 
			
		||||
    _fields = list()
 | 
			
		||||
    for attname, value in list(attrs.items()):
 | 
			
		||||
        is_field = isinstance(value, Field)
 | 
			
		||||
        is_field = isinstance(value, (Field, InputField))
 | 
			
		||||
        is_field_proxy = isinstance(value, TypeProxy)
 | 
			
		||||
        if not (is_field or is_field_proxy):
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        field = value.as_field() if is_field_proxy else copy.copy(value)
 | 
			
		||||
        field = value.as_mounted(cls) if is_field_proxy else copy.copy(value)
 | 
			
		||||
        field.attname = attname
 | 
			
		||||
        fields.add(attname)
 | 
			
		||||
        del attrs[attname]
 | 
			
		||||
| 
						 | 
				
			
			@ -23,13 +22,12 @@ def extract_fields(attrs):
 | 
			
		|||
    return sorted(_fields)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_base_fields(bases):
 | 
			
		||||
def get_base_fields(cls, bases):
 | 
			
		||||
    fields = set()
 | 
			
		||||
    for _class in bases:
 | 
			
		||||
        for attname, field in get_graphql_type(_class).get_fields().items():
 | 
			
		||||
            if attname in fields:
 | 
			
		||||
                continue
 | 
			
		||||
            field = copy.copy(field)
 | 
			
		||||
            field.name = attname
 | 
			
		||||
            fields.add(attname)
 | 
			
		||||
            yield field
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,7 +2,7 @@ from collections import OrderedDict
 | 
			
		|||
from graphql import GraphQLField, GraphQLString, GraphQLInterfaceType, GraphQLInt, GraphQLFloat
 | 
			
		||||
from ..extract_fields import extract_fields, get_base_fields
 | 
			
		||||
 | 
			
		||||
from ...types import Field, String, Argument
 | 
			
		||||
from ...types import Field, String, Argument, ObjectType
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_extract_fields_attrs():
 | 
			
		||||
| 
						 | 
				
			
			@ -13,7 +13,7 @@ def test_extract_fields_attrs():
 | 
			
		|||
        'argument': Argument(String),
 | 
			
		||||
        'graphql_field': GraphQLField(GraphQLString)
 | 
			
		||||
    }
 | 
			
		||||
    extracted_fields = list(extract_fields(attrs))
 | 
			
		||||
    extracted_fields = list(extract_fields(ObjectType, attrs))
 | 
			
		||||
    assert [f.name for f in extracted_fields] == ['fieldString', 'string']
 | 
			
		||||
    assert sorted(attrs.keys()) == ['argument', 'graphql_field', 'other']
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -31,7 +31,7 @@ def test_extract_fields():
 | 
			
		|||
    ]))
 | 
			
		||||
 | 
			
		||||
    bases = (int_base, float_base)
 | 
			
		||||
    base_fields = list(get_base_fields(bases))
 | 
			
		||||
    base_fields = list(get_base_fields(ObjectType, bases))
 | 
			
		||||
    assert [f.name for f in base_fields] == ['int', 'num', 'extra', 'float']
 | 
			
		||||
    assert [f.type for f in base_fields] == [
 | 
			
		||||
        GraphQLInt,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user