From fae376cbb08978a80fb231998bb744b687bad26b Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Sun, 6 Dec 2015 01:01:23 -0800 Subject: [PATCH] Moved arguments to a named group --- .../core/classtypes/tests/test_mutation.py | 2 +- graphene/core/schema.py | 5 +++- graphene/core/types/argument.py | 24 +++------------ graphene/core/types/base.py | 30 ++++++++++++++++++- graphene/core/types/field.py | 6 ++-- graphene/core/types/tests/test_field.py | 3 +- graphene/relay/tests/test_mutations.py | 3 +- 7 files changed, 44 insertions(+), 29 deletions(-) diff --git a/graphene/core/classtypes/tests/test_mutation.py b/graphene/core/classtypes/tests/test_mutation.py index ac32585e..85dd2368 100644 --- a/graphene/core/classtypes/tests/test_mutation.py +++ b/graphene/core/classtypes/tests/test_mutation.py @@ -24,4 +24,4 @@ def test_mutation(): assert list(object_type.get_fields().keys()) == ['name'] assert MyMutation._meta.fields_map['name'].object_type == MyMutation assert isinstance(MyMutation.arguments, ArgumentsGroup) - assert 'argName' in MyMutation.arguments + assert 'argName' in schema.T(MyMutation.arguments) diff --git a/graphene/core/schema.py b/graphene/core/schema.py index 1b0ce8f9..2eea83ea 100644 --- a/graphene/core/schema.py +++ b/graphene/core/schema.py @@ -38,6 +38,9 @@ class Schema(object): def __repr__(self): return '' % (str(self.name), hash(self)) + def get_internal_type(self, objecttype): + return objecttype.internal_type(self) + def T(self, object_type): if not object_type: return @@ -45,7 +48,7 @@ class Schema(object): object_type, (BaseType, ClassType)) or isinstance( object_type, BaseType): if object_type not in self._types: - internal_type = object_type.internal_type(self) + internal_type = self.get_internal_type(object_type) self._types[object_type] = internal_type is_objecttype = inspect.isclass( object_type) and issubclass(object_type, ClassType) diff --git a/graphene/core/types/argument.py b/graphene/core/types/argument.py index 0892c446..0ef7686a 100644 --- a/graphene/core/types/argument.py +++ b/graphene/core/types/argument.py @@ -5,10 +5,10 @@ from itertools import chain from graphql.core.type import GraphQLArgument from ...utils import ProxySnakeDict, to_camel_case -from .base import ArgumentType, BaseType, OrderedType +from .base import ArgumentType, GroupNamedType, NamedType, OrderedType -class Argument(OrderedType): +class Argument(NamedType, OrderedType): def __init__(self, type, description=None, default=None, name=None, _creation_counter=None): @@ -27,27 +27,11 @@ class Argument(OrderedType): return self.name -class ArgumentsGroup(BaseType): +class ArgumentsGroup(GroupNamedType): def __init__(self, *args, **kwargs): arguments = to_arguments(*args, **kwargs) - self.arguments = OrderedDict([(arg.name, arg) for arg in arguments]) - - def internal_type(self, schema): - return OrderedDict([(arg.name, schema.T(arg)) - for arg in self.arguments.values()]) - - def __len__(self): - return len(self.arguments) - - def __iter__(self): - return iter(self.arguments) - - def __contains__(self, *args): - return self.arguments.__contains__(*args) - - def __getitem__(self, *args): - return self.arguments.__getitem__(*args) + super(ArgumentsGroup, self).__init__(*arguments) def to_arguments(*args, **kwargs): diff --git a/graphene/core/types/base.py b/graphene/core/types/base.py index 2b4078e4..920963b6 100644 --- a/graphene/core/types/base.py +++ b/graphene/core/types/base.py @@ -1,4 +1,5 @@ -from functools import total_ordering +from collections import OrderedDict +from functools import total_ordering, partial import six @@ -126,3 +127,30 @@ class FieldType(MirroredType): class MountedType(FieldType, ArgumentType): pass + + +class NamedType(BaseType): + pass + + +class GroupNamedType(BaseType): + def __init__(self, *types): + self.types = types + + def get_named_type(self, schema, type): + return type.name or type.attname, schema.T(type) + + def internal_type(self, schema): + return OrderedDict(map(partial(self.get_named_type, schema), self.types)) + + def __len__(self): + return len(self.types) + + def __iter__(self): + return iter(self.types) + + def __contains__(self, *args): + return self.types.__contains__(*args) + + def __getitem__(self, *args): + return self.types.__getitem__(*args) diff --git a/graphene/core/types/field.py b/graphene/core/types/field.py index c3fa712f..3be90c74 100644 --- a/graphene/core/types/field.py +++ b/graphene/core/types/field.py @@ -9,11 +9,11 @@ from ..classtypes.base import FieldsClassType from ..classtypes.inputobjecttype import InputObjectType from ..classtypes.mutation import Mutation from .argument import ArgumentsGroup, snake_case_args -from .base import LazyType, MountType, OrderedType +from .base import LazyType, NamedType, MountType, OrderedType from .definitions import NonNull -class Field(OrderedType): +class Field(NamedType, OrderedType): def __init__( self, type, description=None, args=None, name=None, resolver=None, @@ -117,7 +117,7 @@ class Field(OrderedType): return hash((self.creation_counter, self.object_type)) -class InputField(OrderedType): +class InputField(NamedType, OrderedType): def __init__(self, type, description=None, default=None, name=None, _creation_counter=None, required=False): diff --git a/graphene/core/types/tests/test_field.py b/graphene/core/types/tests/test_field.py index 8253ed20..bb0bcf2c 100644 --- a/graphene/core/types/tests/test_field.py +++ b/graphene/core/types/tests/test_field.py @@ -98,9 +98,10 @@ def test_field_string_reference(): def test_field_custom_arguments(): field = Field(None, name='my_customName', p=String()) + schema = Schema() args = field.arguments - assert 'p' in args + assert 'p' in schema.T(args) def test_inputfield_internal_type(): diff --git a/graphene/relay/tests/test_mutations.py b/graphene/relay/tests/test_mutations.py index 4356a1ec..02287725 100644 --- a/graphene/relay/tests/test_mutations.py +++ b/graphene/relay/tests/test_mutations.py @@ -34,8 +34,7 @@ schema = Schema(query=Query, mutation=MyResultMutation) def test_mutation_arguments(): assert ChangeNumber.arguments - assert list(ChangeNumber.arguments) == ['input'] - assert 'input' in ChangeNumber.arguments + assert 'input' in schema.T(ChangeNumber.arguments) inner_type = ChangeNumber.input_type client_mutation_id_field = inner_type._meta.fields_map[ 'client_mutation_id']