diff --git a/examples/starwars/tests/test_schema.py b/examples/starwars/tests/test_schema.py index b7ae49e4..e69de29b 100644 --- a/examples/starwars/tests/test_schema.py +++ b/examples/starwars/tests/test_schema.py @@ -1,9 +0,0 @@ - -from ..schema import Droid - - -def test_query_types(): - graphql_type = Droid._meta.graphql_type - fields = graphql_type.get_fields() - assert fields['friends'].parent == Droid - assert fields diff --git a/graphene/relay/connection.py b/graphene/relay/connection.py index 02b3dbe5..e88d1b82 100644 --- a/graphene/relay/connection.py +++ b/graphene/relay/connection.py @@ -54,6 +54,10 @@ class ConnectionMeta(ObjectTypeMeta): ) cls.Edge = type(edge.name, (ObjectType, ), {'Meta': type('Meta', (object,), {'graphql_type': edge})}) cls._meta.graphql_type = connection + fields = copy_fields(Field, options.graphql_type.get_fields(), parent=cls) + + cls._meta.get_fields = lambda: fields + return cls diff --git a/graphene/relay/mutation.py b/graphene/relay/mutation.py index f5fbec83..24a828a0 100644 --- a/graphene/relay/mutation.py +++ b/graphene/relay/mutation.py @@ -52,6 +52,8 @@ class ClientIDMutationMeta(MutationMeta): mutate_and_get_payload=cls.mutate_and_get_payload, ) options.graphql_type = field.type + options.get_fields = lambda: output_fields + cls.Field = partial(Field.copy_and_extend, field, type=field.type, _creation_counter=None) return cls diff --git a/graphene/relay/node.py b/graphene/relay/node.py index 4467746d..e0d6f3d1 100644 --- a/graphene/relay/node.py +++ b/graphene/relay/node.py @@ -10,6 +10,8 @@ from ..types.objecttype import ObjectType, ObjectTypeMeta, is_objecttype from ..types.options import Options from .connection import Connection +from ..utils.copy_fields import copy_fields + # We inherit from ObjectTypeMeta as we want to allow # inheriting from Node, and also ObjectType. @@ -23,16 +25,17 @@ class NodeMeta(ObjectTypeMeta): meta, ) - def __new__(cls, name, bases, attrs): - - if is_objecttype(bases): - cls = super(NodeMeta, cls).__new__(cls, name, bases, attrs) - # The interface provided by node_definitions is not an instance - # of GrapheneInterfaceType, so it will have no graphql_type, - # so will not trigger Node.implements - cls.implements(cls) - return cls + @staticmethod + def _create_objecttype(cls, name, bases, attrs): + # The interface provided by node_definitions is not an instance + # of GrapheneInterfaceType, so it will have no graphql_type, + # so will not trigger Node.implements + cls = super(NodeMeta, cls)._create_objecttype(cls, name, bases, attrs) + cls.implements(cls) + return cls + @staticmethod + def _create_interface(cls, name, bases, attrs): options = cls._get_interface_options(attrs.pop('Meta', None)) cls = type.__new__(cls, name, bases, dict(attrs, _meta=options)) @@ -45,6 +48,10 @@ class NodeMeta(ObjectTypeMeta): type_resolver=cls.resolve_type, ) options.graphql_type = node_interface + + fields = copy_fields(Field, options.graphql_type.get_fields(), parent=cls) + options.get_fields = lambda: fields + cls.Field = partial( Field.copy_and_extend, node_field, diff --git a/graphene/types/argument.py b/graphene/types/argument.py index a3febbb2..072b83a9 100644 --- a/graphene/types/argument.py +++ b/graphene/types/argument.py @@ -2,7 +2,7 @@ import inspect from collections import OrderedDict from itertools import chain -from graphql import GraphQLArgument +from graphql.type.definition import GraphQLArgument, GraphQLArgumentDefinition from graphql.utils.assert_valid_name import assert_valid_name from ..utils.orderedtype import OrderedType @@ -40,11 +40,15 @@ class Argument(GraphQLArgument, OrderedType): @classmethod def copy_from(cls, argument): + if isinstance (argument, (GraphQLArgumentDefinition, Argument)): + name = argument.name + else: + name = None return cls( type=argument.type, default_value=argument.default_value, description=argument.description, - name=argument.name, + name=name, _creation_counter=argument.creation_counter if isinstance(argument, Argument) else None, ) diff --git a/graphene/types/enum.py b/graphene/types/enum.py index 2fda8412..fe733577 100644 --- a/graphene/types/enum.py +++ b/graphene/types/enum.py @@ -22,7 +22,12 @@ class GrapheneEnumType(GrapheneGraphQLType, GraphQLEnumType): def values_from_enum(enum): _values = OrderedDict() for name, value in enum.__members__.items(): - _values[name] = GraphQLEnumValue(name=name, value=value.value) + _values[name] = GraphQLEnumValue( + name=name, + value=value.value, + description=getattr(value, 'description', None), + deprecation_reason=getattr(value, 'deprecation_reason', None) + ) return _values diff --git a/graphene/types/field.py b/graphene/types/field.py index b9aa285a..f5316896 100644 --- a/graphene/types/field.py +++ b/graphene/types/field.py @@ -1,6 +1,7 @@ +from collections import OrderedDict import inspect -from graphql.type import GraphQLField, GraphQLInputObjectField +from graphql.type import GraphQLField, GraphQLInputObjectField, GraphQLFieldDefinition from graphql.utils.assert_valid_name import assert_valid_name from ..utils.orderedtype import OrderedType @@ -126,18 +127,23 @@ class Field(AbstractField, GraphQLField, OrderedType): _creation_counter = field.creation_counter if _creation_counter is False else None attname = attname or field.attname parent = parent or field.parent + args = to_arguments(args, field.args) else: # If is a GraphQLField type = type or field.type resolver = resolver or field.resolver - name = field.name + field_args = field.args + if isinstance(field, GraphQLFieldDefinition): + name = name or field.name + field_args = OrderedDict((a.name, a) for a in field_args) + args = to_arguments(args, field_args) _creation_counter = None attname = attname or name parent = parent new_field = cls( type=type, - args=to_arguments(args, field.args), + args=args, resolver=resolver, source=source, deprecation_reason=field.deprecation_reason, diff --git a/graphene/types/mutation.py b/graphene/types/mutation.py index ee015d76..6874ae3d 100644 --- a/graphene/types/mutation.py +++ b/graphene/types/mutation.py @@ -11,16 +11,14 @@ from .objecttype import ObjectType, ObjectTypeMeta class MutationMeta(ObjectTypeMeta): def __new__(cls, name, bases, attrs): - super_new = super(MutationMeta, cls).__new__ - - # Also ensure initialization is only performed for subclasses of Model - # (excluding Model class itself). + # Also ensure initialization is only performed for subclasses of + # Mutation if not is_base_type(bases, MutationMeta): return type.__new__(cls, name, bases, attrs) Input = attrs.pop('Input', None) - cls = super_new(cls, name, bases, attrs) + cls = cls._create_objecttype(cls, name, bases, attrs) field_args = props(Input) if Input else {} resolver = getattr(cls, 'mutate', None) assert resolver, 'All mutations must define a mutate method in it' diff --git a/graphene/types/objecttype.py b/graphene/types/objecttype.py index a5fdf169..19f5e2c8 100644 --- a/graphene/types/objecttype.py +++ b/graphene/types/objecttype.py @@ -55,48 +55,15 @@ class ObjectTypeMeta(type): def __new__(cls, name, bases, attrs): super_new = type.__new__ - # Also ensure initialization is only performed for subclasses of Model - # (excluding Model class itself). - + # Also ensure initialization is only performed for subclasses of + # ObjectType,or Interfaces if not is_base_type(bases, ObjectTypeMeta): - return super_new(cls, name, bases, attrs) + return type.__new__(cls, name, bases, attrs) if not is_objecttype(bases): return cls._create_interface(cls, name, bases, attrs) - options = Options( - attrs.pop('Meta', None), - name=None, - description=None, - graphql_type=None, - interfaces=(), - abstract=False - ) - - interfaces = tuple(options.interfaces) - fields = get_fields(ObjectType, attrs, bases, interfaces) - attrs = attrs_without_fields(attrs, fields) - cls = super_new(cls, name, bases, dict(attrs, _meta=options)) - - if not options.graphql_type: - fields = copy_fields(Field, fields, parent=cls) - base_interfaces = tuple(b for b in bases if issubclass(b, Interface)) - options.graphql_type = GrapheneObjectType( - graphene_type=cls, - name=options.name or cls.__name__, - description=options.description or cls.__doc__, - fields=fields, - is_type_of=cls.is_type_of, - interfaces=tuple(get_interfaces(interfaces + base_interfaces)) - ) - else: - assert not fields, "Can't mount Fields in an ObjectType with a defined graphql_type" - fields = copy_fields(Field, options.graphql_type.get_fields(), parent=cls) - - for name, field in fields.items(): - setattr(cls, field.attname or name, field) - - return cls + return cls._create_objecttype(cls, name, bases, attrs) def get_interfaces(cls, bases): return (b for b in bases if issubclass(b, Interface)) @@ -133,7 +100,47 @@ class ObjectTypeMeta(type): ) else: assert not fields, "Can't mount Fields in an Interface with a defined graphql_type" - fields = copy_fields(options.graphql_type.get_fields(), parent=cls) + fields = copy_fields(Field, options.graphql_type.get_fields(), parent=cls) + + options.get_fields = lambda: fields + + for name, field in fields.items(): + setattr(cls, field.attname or name, field) + + return cls + + @staticmethod + def _create_objecttype(cls, name, bases, attrs): + options = Options( + attrs.pop('Meta', None), + name=None, + description=None, + graphql_type=None, + interfaces=(), + abstract=False + ) + + interfaces = tuple(options.interfaces) + fields = get_fields(ObjectType, attrs, bases, interfaces) + attrs = attrs_without_fields(attrs, fields) + cls = type.__new__(cls, name, bases, dict(attrs, _meta=options)) + + if not options.graphql_type: + fields = copy_fields(Field, fields, parent=cls) + base_interfaces = tuple(b for b in bases if issubclass(b, Interface)) + options.graphql_type = GrapheneObjectType( + graphene_type=cls, + name=options.name or cls.__name__, + description=options.description or cls.__doc__, + fields=fields, + is_type_of=cls.is_type_of, + interfaces=tuple(get_interfaces(interfaces + base_interfaces)) + ) + else: + assert not fields, "Can't mount Fields in an ObjectType with a defined graphql_type" + fields = copy_fields(Field, options.graphql_type.get_fields(), parent=cls) + + options.get_fields = lambda: fields for name, field in fields.items(): setattr(cls, field.attname or name, field) @@ -146,9 +153,9 @@ class ObjectType(six.with_metaclass(ObjectTypeMeta)): def __init__(self, *args, **kwargs): # GraphQL ObjectType acting as container args_len = len(args) - fields = self._meta.graphql_type.get_fields().values() - for f in fields: - setattr(self, getattr(f, 'attname', f.name), None) + fields = self._meta.get_fields().items() + for name, f in fields: + setattr(self, getattr(f, 'attname', name), None) if args_len > len(fields): # Daft, but matches old exception sans the err msg. @@ -156,18 +163,18 @@ class ObjectType(six.with_metaclass(ObjectTypeMeta)): fields_iter = iter(fields) if not kwargs: - for val, field in zip(args, fields_iter): - attname = getattr(field, 'attname', field.name) + for val, (name, field) in zip(args, fields_iter): + attname = getattr(field, 'attname', name) setattr(self, attname, val) else: - for val, field in zip(args, fields_iter): - attname = getattr(field, 'attname', field.name) + for val, (name, field) in zip(args, fields_iter): + attname = getattr(field, 'attname', name) setattr(self, attname, val) kwargs.pop(attname, None) - for field in fields_iter: + for name, field in fields_iter: try: - attname = getattr(field, 'attname', field.name) + attname = getattr(field, 'attname', name) val = kwargs.pop(attname) setattr(self, attname, val) except KeyError: diff --git a/graphene/types/tests/test_enum.py b/graphene/types/tests/test_enum.py index 4813d902..4abd185e 100644 --- a/graphene/types/tests/test_enum.py +++ b/graphene/types/tests/test_enum.py @@ -11,6 +11,10 @@ def test_enum_construction(): GREEN = 2 BLUE = 3 + @property + def description(self): + return "Description {}".format(self.name) + assert isinstance(RGB._meta.graphql_type, GraphQLEnumType) values = RGB._meta.graphql_type.get_values() assert sorted([v.name for v in values]) == [ @@ -18,6 +22,11 @@ def test_enum_construction(): 'GREEN', 'RED' ] + assert sorted([v.description for v in values]) == [ + 'Description BLUE', + 'Description GREEN', + 'Description RED' + ] assert isinstance(RGB(name='field_name').as_field(), Field) assert isinstance(RGB(name='field_name').as_argument(), Argument) diff --git a/graphene/types/tests/test_inputobjecttype.py b/graphene/types/tests/test_inputobjecttype.py index 6455b1c0..8036b96a 100644 --- a/graphene/types/tests/test_inputobjecttype.py +++ b/graphene/types/tests/test_inputobjecttype.py @@ -1,5 +1,6 @@ from graphql import GraphQLInputObjectType, GraphQLString +from graphql.type.definition import GraphQLInputFieldDefinition from ..field import InputField from ..inputobjecttype import InputObjectType @@ -43,7 +44,7 @@ def test_generate_objecttype_with_fields(): graphql_type = MyObjectType._meta.graphql_type fields = graphql_type.get_fields() assert 'field' in fields - assert isinstance(fields['field'], InputField) + assert isinstance(fields['field'], GraphQLInputFieldDefinition) def test_generate_objecttype_with_graphene_fields(): @@ -53,4 +54,4 @@ def test_generate_objecttype_with_graphene_fields(): graphql_type = MyObjectType._meta.graphql_type fields = graphql_type.get_fields() assert 'field' in fields - assert isinstance(fields['field'], InputField) + assert isinstance(fields['field'], GraphQLInputFieldDefinition) diff --git a/graphene/types/tests/test_interface.py b/graphene/types/tests/test_interface.py index 9370de98..09136e04 100644 --- a/graphene/types/tests/test_interface.py +++ b/graphene/types/tests/test_interface.py @@ -56,7 +56,7 @@ def test_interface_inheritance(): fields = graphql_type.get_fields() assert 'field' in fields assert 'inherited' in fields - assert fields['field'] > fields['inherited'] + assert MyInterface.field > MyInheritedInterface.inherited def test_interface_instance(): diff --git a/graphene/types/tests/test_objecttype.py b/graphene/types/tests/test_objecttype.py index c3f59470..12901456 100644 --- a/graphene/types/tests/test_objecttype.py +++ b/graphene/types/tests/test_objecttype.py @@ -74,11 +74,7 @@ def test_objecttype_inheritance(): graphql_type = MyObjectType._meta.graphql_type fields = graphql_type.get_fields() - assert 'field1' in fields - assert 'field2' in fields - assert 'inherited' in fields - assert fields['field1'] > fields['inherited'] - assert fields['field2'] > fields['field1'] + assert fields.keys() == ['inherited', 'field1', 'field2'] def test_objecttype_as_container_get_fields(): @@ -195,11 +191,7 @@ def test_objecttype_graphene_interface(): graphql_type = GrapheneObjectType._meta.graphql_type assert graphql_type.get_interfaces() == (GrapheneInterface._meta.graphql_type, ) assert graphql_type.is_type_of(GrapheneObjectType(), None, None) - fields = graphql_type.get_fields() - assert 'field' in fields - assert 'extended' in fields - assert 'name' in fields - assert fields['field'] > fields['extended'] > fields['name'] + fields = graphql_type.get_fields().keys() == ['name', 'extended', 'field'] def test_objecttype_graphene_inherit_interface(): @@ -214,11 +206,8 @@ def test_objecttype_graphene_inherit_interface(): assert graphql_type.get_interfaces() == (GrapheneInterface._meta.graphql_type, ) assert graphql_type.is_type_of(GrapheneObjectType(), None, None) fields = graphql_type.get_fields() - assert 'field' in fields - assert 'extended' in fields - assert 'name' in fields + fields = graphql_type.get_fields().keys() == ['name', 'extended', 'field'] assert issubclass(GrapheneObjectType, GrapheneInterface) - assert fields['field'] > fields['extended'] > fields['name'] # def test_objecttype_graphene_interface_extended(): diff --git a/graphene/types/tests/test_scalars.py b/graphene/types/tests/test_scalars.py index 7ff232a2..c79b6910 100644 --- a/graphene/types/tests/test_scalars.py +++ b/graphene/types/tests/test_scalars.py @@ -3,8 +3,9 @@ import datetime import pytest from graphene.utils.get_graphql_type import get_graphql_type -from graphql import (GraphQLBoolean, GraphQLFloat, GraphQLInt, - GraphQLScalarType, GraphQLString, graphql) +from graphql import graphql +from graphql.type import (GraphQLBoolean, GraphQLFloat, GraphQLInt, + GraphQLScalarType, GraphQLString, GraphQLFieldDefinition) from graphql.language import ast from ..field import Field @@ -94,7 +95,7 @@ def test_scalar_in_objecttype(scalar_class, graphql_type): graphql_type = get_graphql_type(MyObjectType) fields = graphql_type.get_fields() assert list(fields.keys()) == ['before', 'field', 'after'] - assert isinstance(fields['field'], Field) + assert isinstance(fields['field'], GraphQLFieldDefinition) def test_custom_scalar_empty(): diff --git a/graphene/types/tests/test_schema.py b/graphene/types/tests/test_schema.py index cb5ed568..5f3a7edc 100644 --- a/graphene/types/tests/test_schema.py +++ b/graphene/types/tests/test_schema.py @@ -11,6 +11,12 @@ class Character(Interface): friends = List(lambda: Character) best_friend = Field(lambda: Character) + def resolve_friends(self, *args): + return [Human(name='Peter')] + + def resolve_best_friend(self, *args): + return Human(name='Best') + class Pet(ObjectType): type = String() @@ -26,12 +32,6 @@ class Human(ObjectType): def resolve_pet(self, *args): return Pet(type='Dog') - def resolve_friends(self, *args): - return [Human(name='Peter')] - - def resolve_best_friend(self, *args): - return Human(name='Best') - class RootQuery(ObjectType): character = Field(Character) diff --git a/graphene/utils/get_fields.py b/graphene/utils/get_fields.py index 29eb0397..4e18de67 100644 --- a/graphene/utils/get_fields.py +++ b/graphene/utils/get_fields.py @@ -16,10 +16,24 @@ def get_fields_from_attrs(in_type, attrs): yield attname, field -def get_fields_from_types(bases): +def get_fields_from_bases_and_types(bases, types): fields = set() for _class in bases: - for attname, field in get_graphql_type(_class).get_fields().items(): + if not is_graphene_type(_class): + continue + _fields = _class._meta.get_fields() + if callable(_fields): + _fields = _fields() + + for default_attname, field in _fields.items(): + attname = getattr(field, 'attname', default_attname) + if attname in fields: + continue + fields.add(attname) + yield attname, field + + for grapqhl_type in types: + for attname, field in get_graphql_type(grapqhl_type).get_fields().items(): if attname in fields: continue fields.add(attname) @@ -29,11 +43,7 @@ def get_fields_from_types(bases): def get_fields(in_type, attrs, bases, graphql_types=()): fields = [] - graphene_bases = tuple( - base._meta.graphql_type for base in bases if is_graphene_type(base) - ) + graphql_types - - extended_fields = list(get_fields_from_types(graphene_bases)) + extended_fields = list(get_fields_from_bases_and_types(bases, graphql_types)) local_fields = list(get_fields_from_attrs(in_type, attrs)) # We asume the extended fields are already sorted, so we only # have to sort the local fields, that are get from attrs diff --git a/graphene/utils/tests/test_get_fields.py b/graphene/utils/tests/test_get_fields.py index 0c188e75..65aae160 100644 --- a/graphene/utils/tests/test_get_fields.py +++ b/graphene/utils/tests/test_get_fields.py @@ -4,7 +4,7 @@ from graphql import (GraphQLField, GraphQLFloat, GraphQLInt, GraphQLInterfaceType, GraphQLString) from ...types import Argument, Field, ObjectType, String -from ..get_fields import get_fields_from_attrs, get_fields_from_types +from ..get_fields import get_fields_from_attrs, get_fields_from_bases_and_types def test_get_fields_from_attrs(): @@ -31,8 +31,8 @@ def test_get_fields_from_types(): ('extra', GraphQLField(GraphQLFloat)) ])) - bases = (int_base, float_base) - base_fields = OrderedDict(get_fields_from_types(bases)) + _types = (int_base, float_base) + base_fields = OrderedDict(get_fields_from_bases_and_types((), _types)) assert [f for f in base_fields.keys()] == ['int', 'num', 'extra', 'float'] assert [f.type for f in base_fields.values()] == [ GraphQLInt,