mirror of
				https://github.com/graphql-python/graphene.git
				synced 2025-11-04 01:47:45 +03:00 
			
		
		
		
	Improved fields mounting
This commit is contained in:
		
							parent
							
								
									9f655d9416
								
							
						
					
					
						commit
						25e967200b
					
				| 
						 | 
					@ -1,6 +1,8 @@
 | 
				
			||||||
from collections import OrderedDict
 | 
					from collections import OrderedDict
 | 
				
			||||||
import inspect
 | 
					import inspect
 | 
				
			||||||
import copy
 | 
					import copy
 | 
				
			||||||
 | 
					from itertools import chain
 | 
				
			||||||
 | 
					from functools import partial
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from graphql.utils.assert_valid_name import assert_valid_name
 | 
					from graphql.utils.assert_valid_name import assert_valid_name
 | 
				
			||||||
from graphql.type.definition import GraphQLObjectType
 | 
					from graphql.type.definition import GraphQLObjectType
 | 
				
			||||||
| 
						 | 
					@ -58,6 +60,26 @@ class ClassTypeMeta(type):
 | 
				
			||||||
        return cls
 | 
					        return cls
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FieldsMeta(type):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _build_field_map(cls, bases, local_fields):
 | 
				
			||||||
 | 
					        from ..utils.extract_fields import get_base_fields
 | 
				
			||||||
 | 
					        extended_fields = get_base_fields(cls, bases)
 | 
				
			||||||
 | 
					        fields = chain(extended_fields, local_fields)
 | 
				
			||||||
 | 
					        return OrderedDict((f.name, f) for f in fields)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _fields(cls, bases, attrs):
 | 
				
			||||||
 | 
					        from ..utils.is_graphene_type import is_graphene_type
 | 
				
			||||||
 | 
					        from ..utils.extract_fields import extract_fields
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        inherited_types = [
 | 
				
			||||||
 | 
					            base._meta.graphql_type for base in bases if is_graphene_type(base) and not base._meta.abstract
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        local_fields = extract_fields(cls, attrs)
 | 
				
			||||||
 | 
					        return partial(cls._build_field_map, inherited_types, local_fields)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class GrapheneGraphQLType(object):
 | 
					class GrapheneGraphQLType(object):
 | 
				
			||||||
    def __init__(self, *args, **kwargs):
 | 
					    def __init__(self, *args, **kwargs):
 | 
				
			||||||
        self.graphene_type = kwargs.pop('graphene_type')
 | 
					        self.graphene_type = kwargs.pop('graphene_type')
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2,7 +2,7 @@ import six
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from graphql import GraphQLInputObjectType
 | 
					from graphql import GraphQLInputObjectType
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .definitions import ClassTypeMeta, GrapheneFieldsType, FieldMap
 | 
					from .definitions import FieldsMeta, ClassTypeMeta, GrapheneFieldsType
 | 
				
			||||||
from .proxy import TypeProxy
 | 
					from .proxy import TypeProxy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -10,7 +10,7 @@ class GrapheneInputObjectType(GrapheneFieldsType, GraphQLInputObjectType):
 | 
				
			||||||
    pass
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class InputObjectTypeMeta(ClassTypeMeta):
 | 
					class InputObjectTypeMeta(FieldsMeta, ClassTypeMeta):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_options(cls, meta):
 | 
					    def get_options(cls, meta):
 | 
				
			||||||
        return cls.options_class(
 | 
					        return cls.options_class(
 | 
				
			||||||
| 
						 | 
					@ -22,20 +22,17 @@ class InputObjectTypeMeta(ClassTypeMeta):
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def construct_graphql_type(cls, bases):
 | 
					    def construct_graphql_type(cls, bases):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def construct(cls, bases, attrs):
 | 
				
			||||||
        if not cls._meta.graphql_type and not cls._meta.abstract:
 | 
					        if not cls._meta.graphql_type and not cls._meta.abstract:
 | 
				
			||||||
            from ..utils.get_graphql_type import get_graphql_type
 | 
					 | 
				
			||||||
            from ..utils.is_graphene_type import is_graphene_type
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            inherited_types = [
 | 
					 | 
				
			||||||
                base._meta.graphql_type for base in bases if is_graphene_type(base)
 | 
					 | 
				
			||||||
            ]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            cls._meta.graphql_type = GrapheneInputObjectType(
 | 
					            cls._meta.graphql_type = GrapheneInputObjectType(
 | 
				
			||||||
                graphene_type=cls,
 | 
					                graphene_type=cls,
 | 
				
			||||||
                name=cls._meta.name or cls.__name__,
 | 
					                name=cls._meta.name or cls.__name__,
 | 
				
			||||||
                description=cls._meta.description or cls.__doc__,
 | 
					                description=cls._meta.description or cls.__doc__,
 | 
				
			||||||
                fields=FieldMap(cls, bases=filter(None, inherited_types)),
 | 
					                fields=cls._fields(bases, attrs),
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					        return super(InputObjectTypeMeta, cls).construct(bases, attrs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class InputObjectType(six.with_metaclass(InputObjectTypeMeta, TypeProxy)):
 | 
					class InputObjectType(six.with_metaclass(InputObjectTypeMeta, TypeProxy)):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,17 +1,14 @@
 | 
				
			||||||
from itertools import chain
 | 
					 | 
				
			||||||
from functools import partial
 | 
					 | 
				
			||||||
from collections import OrderedDict
 | 
					 | 
				
			||||||
import six
 | 
					import six
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from graphql import GraphQLInterfaceType
 | 
					from graphql import GraphQLInterfaceType
 | 
				
			||||||
from .definitions import ClassTypeMeta, GrapheneFieldsType, FieldMap
 | 
					from .definitions import FieldsMeta, ClassTypeMeta, GrapheneFieldsType
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class GrapheneInterfaceType(GrapheneFieldsType, GraphQLInterfaceType):
 | 
					class GrapheneInterfaceType(GrapheneFieldsType, GraphQLInterfaceType):
 | 
				
			||||||
    pass
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class InterfaceTypeMeta(ClassTypeMeta):
 | 
					class InterfaceTypeMeta(FieldsMeta, ClassTypeMeta):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_options(cls, meta):
 | 
					    def get_options(cls, meta):
 | 
				
			||||||
        return cls.options_class(
 | 
					        return cls.options_class(
 | 
				
			||||||
| 
						 | 
					@ -25,29 +22,14 @@ class InterfaceTypeMeta(ClassTypeMeta):
 | 
				
			||||||
    def construct_graphql_type(cls, bases):
 | 
					    def construct_graphql_type(cls, bases):
 | 
				
			||||||
        pass
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _build_field_map(cls, local_fields, bases):
 | 
					 | 
				
			||||||
        from ..utils.extract_fields import get_base_fields
 | 
					 | 
				
			||||||
        extended_fields = get_base_fields(bases)
 | 
					 | 
				
			||||||
        fields = chain(extended_fields, local_fields)
 | 
					 | 
				
			||||||
        return OrderedDict((f.name, f) for f in fields)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def construct(cls, bases, attrs):
 | 
					    def construct(cls, bases, attrs):
 | 
				
			||||||
        if not cls._meta.graphql_type and not cls._meta.abstract:
 | 
					        if not cls._meta.graphql_type and not cls._meta.abstract:
 | 
				
			||||||
            from ..utils.is_graphene_type import is_graphene_type
 | 
					 | 
				
			||||||
            from ..utils.extract_fields import extract_fields
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            inherited_types = [
 | 
					 | 
				
			||||||
                base._meta.graphql_type for base in bases if is_graphene_type(base)
 | 
					 | 
				
			||||||
            ]
 | 
					 | 
				
			||||||
            inherited_types = filter(None, inherited_types)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            local_fields = list(extract_fields(attrs))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            cls._meta.graphql_type = GrapheneInterfaceType(
 | 
					            cls._meta.graphql_type = GrapheneInterfaceType(
 | 
				
			||||||
                graphene_type=cls,
 | 
					                graphene_type=cls,
 | 
				
			||||||
                name=cls._meta.name or cls.__name__,
 | 
					                name=cls._meta.name or cls.__name__,
 | 
				
			||||||
                description=cls._meta.description or cls.__doc__,
 | 
					                description=cls._meta.description or cls.__doc__,
 | 
				
			||||||
                fields=partial(cls._build_field_map, local_fields, inherited_types),
 | 
					                fields=cls._fields(bases, attrs),
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        return super(InterfaceTypeMeta, cls).construct(bases, attrs)
 | 
					        return super(InterfaceTypeMeta, cls).construct(bases, attrs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -64,3 +64,17 @@ class TypeProxy(OrderedType):
 | 
				
			||||||
            raise Exception('TypedProxy "{}" cannot be mounted in {}'.format(self.get_type(), cls))
 | 
					            raise Exception('TypedProxy "{}" cannot be mounted in {}'.format(self.get_type(), cls))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return inner.contribute_to_class(cls, attname)
 | 
					        return inner.contribute_to_class(cls, attname)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def as_mounted(self, cls):
 | 
				
			||||||
 | 
					        from .inputobjecttype import InputObjectType
 | 
				
			||||||
 | 
					        from .objecttype import ObjectType
 | 
				
			||||||
 | 
					        from .interface import Interface
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if issubclass(cls, (ObjectType, Interface)):
 | 
				
			||||||
 | 
					            inner = self.as_field()
 | 
				
			||||||
 | 
					        elif issubclass(cls, (InputObjectType)):
 | 
				
			||||||
 | 
					            inner = self.as_inputfield()
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise Exception('TypedProxy "{}" cannot be mounted in {}'.format(self.get_type(), cls))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return inner
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,20 +1,19 @@
 | 
				
			||||||
import copy
 | 
					import copy
 | 
				
			||||||
from .get_graphql_type import get_graphql_type
 | 
					from .get_graphql_type import get_graphql_type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..types.field import Field
 | 
					from ..types.field import Field, InputField
 | 
				
			||||||
from ..types.proxy import TypeProxy
 | 
					from ..types.proxy import TypeProxy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def extract_fields(attrs):
 | 
					def extract_fields(cls, attrs):
 | 
				
			||||||
    fields = set()
 | 
					    fields = set()
 | 
				
			||||||
    _fields = list()
 | 
					    _fields = list()
 | 
				
			||||||
    for attname, value in list(attrs.items()):
 | 
					    for attname, value in list(attrs.items()):
 | 
				
			||||||
        is_field = isinstance(value, Field)
 | 
					        is_field = isinstance(value, (Field, InputField))
 | 
				
			||||||
        is_field_proxy = isinstance(value, TypeProxy)
 | 
					        is_field_proxy = isinstance(value, TypeProxy)
 | 
				
			||||||
        if not (is_field or is_field_proxy):
 | 
					        if not (is_field or is_field_proxy):
 | 
				
			||||||
            continue
 | 
					            continue
 | 
				
			||||||
 | 
					        field = value.as_mounted(cls) if is_field_proxy else copy.copy(value)
 | 
				
			||||||
        field = value.as_field() if is_field_proxy else copy.copy(value)
 | 
					 | 
				
			||||||
        field.attname = attname
 | 
					        field.attname = attname
 | 
				
			||||||
        fields.add(attname)
 | 
					        fields.add(attname)
 | 
				
			||||||
        del attrs[attname]
 | 
					        del attrs[attname]
 | 
				
			||||||
| 
						 | 
					@ -23,13 +22,12 @@ def extract_fields(attrs):
 | 
				
			||||||
    return sorted(_fields)
 | 
					    return sorted(_fields)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_base_fields(bases):
 | 
					def get_base_fields(cls, bases):
 | 
				
			||||||
    fields = set()
 | 
					    fields = set()
 | 
				
			||||||
    for _class in bases:
 | 
					    for _class in bases:
 | 
				
			||||||
        for attname, field in get_graphql_type(_class).get_fields().items():
 | 
					        for attname, field in get_graphql_type(_class).get_fields().items():
 | 
				
			||||||
            if attname in fields:
 | 
					            if attname in fields:
 | 
				
			||||||
                continue
 | 
					                continue
 | 
				
			||||||
            field = copy.copy(field)
 | 
					            field = copy.copy(field)
 | 
				
			||||||
            field.name = attname
 | 
					 | 
				
			||||||
            fields.add(attname)
 | 
					            fields.add(attname)
 | 
				
			||||||
            yield field
 | 
					            yield field
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2,7 +2,7 @@ from collections import OrderedDict
 | 
				
			||||||
from graphql import GraphQLField, GraphQLString, GraphQLInterfaceType, GraphQLInt, GraphQLFloat
 | 
					from graphql import GraphQLField, GraphQLString, GraphQLInterfaceType, GraphQLInt, GraphQLFloat
 | 
				
			||||||
from ..extract_fields import extract_fields, get_base_fields
 | 
					from ..extract_fields import extract_fields, get_base_fields
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ...types import Field, String, Argument
 | 
					from ...types import Field, String, Argument, ObjectType
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_extract_fields_attrs():
 | 
					def test_extract_fields_attrs():
 | 
				
			||||||
| 
						 | 
					@ -13,7 +13,7 @@ def test_extract_fields_attrs():
 | 
				
			||||||
        'argument': Argument(String),
 | 
					        'argument': Argument(String),
 | 
				
			||||||
        'graphql_field': GraphQLField(GraphQLString)
 | 
					        'graphql_field': GraphQLField(GraphQLString)
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    extracted_fields = list(extract_fields(attrs))
 | 
					    extracted_fields = list(extract_fields(ObjectType, attrs))
 | 
				
			||||||
    assert [f.name for f in extracted_fields] == ['fieldString', 'string']
 | 
					    assert [f.name for f in extracted_fields] == ['fieldString', 'string']
 | 
				
			||||||
    assert sorted(attrs.keys()) == ['argument', 'graphql_field', 'other']
 | 
					    assert sorted(attrs.keys()) == ['argument', 'graphql_field', 'other']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -31,7 +31,7 @@ def test_extract_fields():
 | 
				
			||||||
    ]))
 | 
					    ]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    bases = (int_base, float_base)
 | 
					    bases = (int_base, float_base)
 | 
				
			||||||
    base_fields = list(get_base_fields(bases))
 | 
					    base_fields = list(get_base_fields(ObjectType, bases))
 | 
				
			||||||
    assert [f.name for f in base_fields] == ['int', 'num', 'extra', 'float']
 | 
					    assert [f.name for f in base_fields] == ['int', 'num', 'extra', 'float']
 | 
				
			||||||
    assert [f.type for f in base_fields] == [
 | 
					    assert [f.type for f in base_fields] == [
 | 
				
			||||||
        GraphQLInt,
 | 
					        GraphQLInt,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user