From ab72393e668f12255b41b16defb7a62c7f8d1ca2 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Sat, 4 Jun 2016 15:22:10 -0700 Subject: [PATCH] Added InputField, InputObjectType. Improved Field implementation --- examples/complex_example.py | 18 ++- graphene/__init__.py | 4 +- graphene/types/__init__.py | 5 +- graphene/types/field.py | 113 +++++++++++++------ graphene/types/inputobjecttype.py | 43 +++++++ graphene/types/objecttype.py | 1 - graphene/types/proxy.py | 25 +++- graphene/types/tests/test_field.py | 10 +- graphene/types/tests/test_inputobjecttype.py | 58 ++++++++++ graphene/types/tests/test_interface.py | 2 +- graphene/types/tests/test_objecttype.py | 2 +- graphene/utils/is_graphene_type.py | 2 + 12 files changed, 238 insertions(+), 45 deletions(-) create mode 100644 graphene/types/inputobjecttype.py create mode 100644 graphene/types/tests/test_inputobjecttype.py diff --git a/examples/complex_example.py b/examples/complex_example.py index dd801fff..79f52f0b 100644 --- a/examples/complex_example.py +++ b/examples/complex_example.py @@ -13,7 +13,7 @@ class Address(graphene.ObjectType): class Query(graphene.ObjectType): address = graphene.Field(Address, geo=graphene.Argument(GeoInput)) - def resolve_address(self, args, info): + def resolve_address(self, args, context, info): geo = args.get('geo') return Address(latlng="({},{})".format(geo.get('lat'), geo.get('lng'))) @@ -27,5 +27,17 @@ query = ''' } ''' -result = schema.execute(query) -print(result.data['address']['latlng']) + +def test_query(): + result = schema.execute(query) + assert not result.errors + assert result.data == { + 'address': { + 'latlng': "(32.2,12.0)", + } + } + + +if __name__ == '__main__': + result = schema.execute(query) + print(result.data['address']['latlng']) diff --git a/graphene/__init__.py b/graphene/__init__.py index e27c3979..f86859f4 100644 --- a/graphene/__init__.py +++ b/graphene/__init__.py @@ -1,8 +1,10 @@ from .types import ( ObjectType, + InputObjectType, Interface, implements, Field, + InputField, Schema, Scalar, String, ID, Int, Float, Boolean, @@ -12,4 +14,4 @@ from .types import ( ) from .utils.resolve_only_args import resolve_only_args -__all__ = ['ObjectType', 'Interface', 'implements', 'Field', 'Schema', 'Scalar', 'String', 'ID', 'Int', 'Float', 'Enum', 'Boolean', 'List','NonNull', 'Argument','resolve_only_args'] +__all__ = ['ObjectType', 'InputObjectType', 'Interface', 'implements', 'Field', 'InputField', 'Schema', 'Scalar', 'String', 'ID', 'Int', 'Float', 'Enum', 'Boolean', 'List','NonNull', 'Argument','resolve_only_args'] diff --git a/graphene/types/__init__.py b/graphene/types/__init__.py index 9bf5b326..d1ea4105 100644 --- a/graphene/types/__init__.py +++ b/graphene/types/__init__.py @@ -1,10 +1,11 @@ from .objecttype import ObjectType, implements +from .inputobjecttype import InputObjectType from .interface import Interface from .scalars import Scalar, String, ID, Int, Float, Boolean from .schema import Schema from .structures import List, NonNull from .enum import Enum -from .field import Field +from .field import Field, InputField from .argument import Argument -__all__ = ['ObjectType', 'Interface', 'implements', 'Enum', 'Field', 'Schema', 'Scalar', 'String', 'ID', 'Int', 'Float', 'Boolean', 'List', 'NonNull', 'Argument'] +__all__ = ['ObjectType', 'InputObjectType', 'Interface', 'implements', 'Enum', 'Field', 'InputField', 'Schema', 'Scalar', 'String', 'ID', 'Int', 'Float', 'Boolean', 'List', 'NonNull', 'Argument'] diff --git a/graphene/types/field.py b/graphene/types/field.py index 37e564f2..b761ab2d 100644 --- a/graphene/types/field.py +++ b/graphene/types/field.py @@ -1,44 +1,17 @@ import inspect -from graphql import GraphQLField +from graphql.type import GraphQLField, GraphQLInputObjectField from graphql.utils.assert_valid_name import assert_valid_name from .objecttype import ObjectType +from .inputobjecttype import InputObjectType from .interface import Interface from ..utils.orderedtype import OrderedType from ..utils.str_converters import to_camel_case from .argument import to_arguments -class Field(GraphQLField, OrderedType): - __slots__ = ('_name', '_type', '_args', '_resolver', 'deprecation_reason', 'description', 'source', 'attname', 'parent', 'creation_counter') - - def __init__(self, type, args=None, resolver=None, source=None, deprecation_reason=None, name=None, description=None, _creation_counter=None, **extra_args): - self.name = name - self.attname = None - self.parent = None - self.type = type - self.args = to_arguments(args, extra_args) - assert not (source and resolver), ('You cannot have a source ' - 'and a resolver at the same time') - - self.resolver = resolver - self.source = source - self.deprecation_reason = deprecation_reason - self.description = description - OrderedType.__init__(self, _creation_counter=_creation_counter) - - def contribute_to_class(self, cls, attname): - assert issubclass(cls, (ObjectType, Interface)), 'Field {} cannot be mounted in {}'.format( - self, - cls - ) - self.attname = attname - self.parent = cls - add_field = getattr(cls._meta.graphql_type, "add_field", None) - assert add_field, "Field {} cannot be mounted in {}".format(self, cls) - add_field(self) - +class AbstractField(object): @property def name(self): return self._name or to_camel_case(self.attname) @@ -52,14 +25,54 @@ class Field(GraphQLField, OrderedType): @property def type(self): from ..utils.get_graphql_type import get_graphql_type + from .structures import NonNull if inspect.isfunction(self._type): - return get_graphql_type(self._type()) - return get_graphql_type(self._type) + _type = self._type() + else: + _type = self._type + + if self.required: + return NonNull(_type) + return get_graphql_type(_type) @type.setter def type(self, type): self._type = type + +class Field(AbstractField, GraphQLField, OrderedType): + __slots__ = ('_name', '_type', '_args', '_resolver', 'deprecation_reason', 'description', 'source', 'attname', 'parent', 'creation_counter', 'required') + + def __init__(self, type, args=None, resolver=None, source=None, deprecation_reason=None, name=None, description=None, required=False, _creation_counter=None, **extra_args): + self.name = name + self.attname = None + self.parent = None + self.type = type + self.args = to_arguments(args, extra_args) + assert not (source and resolver), ('You cannot have a source ' + 'and a resolver at the same time') + + self.resolver = resolver + self.source = source + self.required = required + self.deprecation_reason = deprecation_reason + self.description = description + OrderedType.__init__(self, _creation_counter=_creation_counter) + + def mount_error_message(self, where): + return 'Field "{}" can only be mounted in ObjectType or Interface, received {}.'.format( + self, + where.__name__ + ) + + def contribute_to_class(self, cls, attname): + assert issubclass(cls, (ObjectType, Interface)), self.mount_error_message(cls) + self.attname = attname + self.parent = cls + add_field = getattr(cls._meta.graphql_type, "add_field", None) + assert add_field, self.mount_error_message(cls) + add_field(self) + @property def resolver(self): def default_resolver(root, args, context, info): @@ -86,6 +99,7 @@ class Field(GraphQLField, OrderedType): source=self.source, deprecation_reason=self.deprecation_reason, name=self._name, + required=self.required, description=self.description, _creation_counter=self.creation_counter, ) @@ -97,3 +111,38 @@ class Field(GraphQLField, OrderedType): if not self.parent: return 'Not bounded field' return "{}.{}".format(self.parent._meta.graphql_type, self.attname) + + +class InputField(AbstractField, GraphQLInputObjectField, OrderedType): + __slots__ = ('_name', '_type', 'default_value', 'description', 'required') + + def __init__(self, type, default_value=None, description=None, name=None, required=False, _creation_counter=None): + self.name = name + self.type = type + self.default_value = default_value + self.description = description + self.required = required + OrderedType.__init__(self, _creation_counter=_creation_counter) + + def mount_error_message(self, where): + return 'InputField {} can only be mounted in InputObjectType classes, received {}.'.format( + self, + where.__name__ + ) + + def contribute_to_class(self, cls, attname): + assert issubclass(cls, (InputObjectType)), self.mount_error_message(cls) + self.attname = attname + self.parent = cls + add_field = getattr(cls._meta.graphql_type, "add_field", None) + assert add_field, self.mount_error_message(cls) + add_field(self) + + def __copy__(self): + return InputField( + type=self._type, + name=self._name, + required=self.required, + default_value=self.default_value, + description=self.description, + ) diff --git a/graphene/types/inputobjecttype.py b/graphene/types/inputobjecttype.py new file mode 100644 index 00000000..93de386d --- /dev/null +++ b/graphene/types/inputobjecttype.py @@ -0,0 +1,43 @@ +import six + +from graphql import GraphQLInputObjectType + +from .definitions import ClassTypeMeta, GrapheneFieldsType, FieldMap + + +class GrapheneInputObjectType(GrapheneFieldsType, GraphQLInputObjectType): + __slots__ = ('graphene_type', '_name', '_description', '_fields', '_field_map') + + +class InputObjectTypeMeta(ClassTypeMeta): + + def get_options(cls, meta): + options = cls.options_class( + meta, + name=None, + description=None, + graphql_type=None, + ) + options.valid_attrs = ['graphql_type', 'name', 'description', 'abstract'] + return options + + def construct_graphql_type(cls, bases): + if not cls._meta.graphql_type and not cls._meta.abstract: + from ..utils.get_graphql_type import get_graphql_type + from ..utils.is_graphene_type import is_graphene_type + + inherited_types = [ + base._meta.graphql_type for base in bases if is_graphene_type(base) + ] + + cls._meta.graphql_type = GrapheneInputObjectType( + graphene_type=cls, + name=cls._meta.name or cls.__name__, + description=cls._meta.description, + fields=FieldMap(cls, bases=filter(None, inherited_types)), + ) + + +class InputObjectType(six.with_metaclass(InputObjectTypeMeta)): + class Meta: + abstract = True diff --git a/graphene/types/objecttype.py b/graphene/types/objecttype.py index 8d938429..18b43b1f 100644 --- a/graphene/types/objecttype.py +++ b/graphene/types/objecttype.py @@ -39,7 +39,6 @@ class GrapheneObjectType(GrapheneFieldsType, GraphQLObjectType): self._provided_interfaces.append(graphql_type) - class ObjectTypeMeta(ClassTypeMeta): def get_options(cls, meta): diff --git a/graphene/types/proxy.py b/graphene/types/proxy.py index 754b42bc..f6da7d1a 100644 --- a/graphene/types/proxy.py +++ b/graphene/types/proxy.py @@ -1,6 +1,11 @@ -from .field import Field +from .field import Field, InputField from .argument import Argument + +from .objecttype import ObjectType +from .interface import Interface +from .inputobjecttype import InputObjectType + from ..utils.orderedtype import OrderedType @@ -21,6 +26,14 @@ class TypeProxy(OrderedType): **self.kwargs ) + def as_inputfield(self): + return InputField( + self.get_type(), + *self.args, + _creation_counter=self.creation_counter, + **self.kwargs + ) + def as_argument(self): return Argument( self.get_type(), @@ -30,5 +43,11 @@ class TypeProxy(OrderedType): ) def contribute_to_class(self, cls, attname): - field = self.as_field() - return field.contribute_to_class(cls, attname) + if issubclass(cls, (ObjectType, Interface)): + inner = self.as_field() + elif issubclass(cls, (InputObjectType)): + inner = self.as_inputfield() + else: + raise Exception('TypedProxy "{}" cannot be mounted in {}'.format(self.get_type(), cls)) + + return inner.contribute_to_class(cls, attname) diff --git a/graphene/types/tests/test_field.py b/graphene/types/tests/test_field.py index 3d890ec3..ec9168f7 100644 --- a/graphene/types/tests/test_field.py +++ b/graphene/types/tests/test_field.py @@ -1,7 +1,7 @@ import pytest import copy -from graphql import GraphQLString, GraphQLField, GraphQLInt +from graphql import GraphQLString, GraphQLField, GraphQLInt, GraphQLNonNull from ..field import Field from ..argument import Argument @@ -14,6 +14,14 @@ def test_field(): assert isinstance(field, GraphQLField) assert field.name == "name" assert field.description == "description" + assert field.type == GraphQLString + + +def test_field_required(): + field = Field(GraphQLString, required=True) + assert isinstance(field, GraphQLField) + assert isinstance(field.type, GraphQLNonNull) + assert field.type.of_type == GraphQLString def test_field_wrong_name(): diff --git a/graphene/types/tests/test_inputobjecttype.py b/graphene/types/tests/test_inputobjecttype.py new file mode 100644 index 00000000..f6546b2b --- /dev/null +++ b/graphene/types/tests/test_inputobjecttype.py @@ -0,0 +1,58 @@ +import pytest + +from graphql import GraphQLObjectType, GraphQLField, GraphQLString, GraphQLInputObjectType + +from ..objecttype import ObjectType +from ..inputobjecttype import InputObjectType +from ..field import Field, InputField +from ..scalars import String + + +def test_generate_inputobjecttype(): + class MyObjectType(InputObjectType): + '''Documentation''' + pass + + graphql_type = MyObjectType._meta.graphql_type + assert isinstance(graphql_type, GraphQLInputObjectType) + assert graphql_type.name == "MyObjectType" + assert graphql_type.description == "Documentation" + + +def test_generate_inputobjecttype_with_meta(): + class MyObjectType(InputObjectType): + class Meta: + name = 'MyOtherObjectType' + description = 'Documentation' + + graphql_type = MyObjectType._meta.graphql_type + assert isinstance(graphql_type, GraphQLInputObjectType) + assert graphql_type.name == "MyOtherObjectType" + assert graphql_type.description == "Documentation" + + +def test_empty_inputobjecttype_has_meta(): + class MyObjectType(InputObjectType): + pass + + assert MyObjectType._meta + + +def test_generate_objecttype_with_fields(): + class MyObjectType(InputObjectType): + field = InputField(GraphQLString) + + graphql_type = MyObjectType._meta.graphql_type + fields = graphql_type.get_fields() + assert 'field' in fields + assert isinstance(fields['field'], InputField) + + +def test_generate_objecttype_with_graphene_fields(): + class MyObjectType(InputObjectType): + field = String() + + graphql_type = MyObjectType._meta.graphql_type + fields = graphql_type.get_fields() + assert 'field' in fields + assert isinstance(fields['field'], InputField) diff --git a/graphene/types/tests/test_interface.py b/graphene/types/tests/test_interface.py index a1f42660..620d3998 100644 --- a/graphene/types/tests/test_interface.py +++ b/graphene/types/tests/test_interface.py @@ -81,4 +81,4 @@ def test_interface_add_fields_in_reused_graphql_type(): class Meta: graphql_type = MyGraphQLType - assert "cannot be mounted in" in str(excinfo.value) + assert """Field "MyGraphQLType.field" can only be mounted in ObjectType or Interface, received GrapheneInterface.""" == str(excinfo.value) diff --git a/graphene/types/tests/test_objecttype.py b/graphene/types/tests/test_objecttype.py index 8eb218ba..374546b5 100644 --- a/graphene/types/tests/test_objecttype.py +++ b/graphene/types/tests/test_objecttype.py @@ -129,7 +129,7 @@ def test_objecttype_add_fields_in_reused_graphql_type(): class Meta: graphql_type = MyGraphQLType - assert "cannot be mounted in" in str(excinfo.value) + assert """Field "MyGraphQLType.field" can only be mounted in ObjectType or Interface, received GrapheneObjectType.""" == str(excinfo.value) def test_objecttype_graphql_interface(): diff --git a/graphene/utils/is_graphene_type.py b/graphene/utils/is_graphene_type.py index 87c79835..145f08c5 100644 --- a/graphene/utils/is_graphene_type.py +++ b/graphene/utils/is_graphene_type.py @@ -1,5 +1,6 @@ import inspect from ..types.objecttype import ObjectType +from ..types.inputobjecttype import InputObjectType from ..types.interface import Interface from ..types.scalars import Scalar from ..types.enum import Enum @@ -10,6 +11,7 @@ def is_graphene_type(_type): return issubclass(_type, ( Interface, ObjectType, + InputObjectType, Scalar, Enum ))