In work version graphene new types

This commit is contained in:
Syrus Akbary 2015-11-10 01:29:38 -08:00
parent 9bab0d9d6f
commit 3c65deb313
18 changed files with 85 additions and 279 deletions

View File

@ -1,258 +1,38 @@
import inspect
from functools import total_ordering, wraps
import six
from graphene.core.scalars import GraphQLSkipField
from graphene.core.types import BaseObjectType, InputObjectType
from graphene.utils import ProxySnakeDict, enum_to_graphql_enum, to_camel_case
from graphql.core.type import (GraphQLArgument, GraphQLBoolean, GraphQLField,
GraphQLFloat, GraphQLID,
GraphQLInputObjectField, GraphQLInt,
GraphQLList, GraphQLNonNull, GraphQLString)
try:
from enum import Enum
except ImportError:
class Enum(object):
pass
from .types.field import Field
from .types.scalars import String, Int, Boolean, ID, Float
from .types.definitions import List, NonNull
class Empty(object):
class DeprecatedField(object):
def __init__(self, *args, **kwargs):
print("Using {} is not longer supported".format(self.__class__.__name__))
kwargs['resolver'] = kwargs.pop('resolve', None)
return super(DeprecatedField, self).__init__(*args, **kwargs)
class StringField(DeprecatedField, String):
pass
@total_ordering
class Field(object):
SKIP = GraphQLSkipField
creation_counter = 0
required = False
def __init__(self, field_type, name=None, resolve=None, required=False, args=None, description='', default=None, **extra_args):
self.field_type = field_type
self.resolve_fn = resolve
self.required = self.required or required
self.args = args or {}
self.extra_args = extra_args
self._type = None
self.name = name
self.description = description or self.__doc__
self.object_type = None
self.default = default
self.creation_counter = Field.creation_counter
Field.creation_counter += 1
def get_default(self):
return self.default
def contribute_to_class(self, cls, name, add=True):
if not self.name:
self.name = to_camel_case(name)
self.attname = name
self.object_type = cls
if isinstance(self.field_type, Field) and not self.field_type.object_type:
self.field_type.contribute_to_class(cls, name, False)
if add:
cls._meta.add_field(self)
def resolve(self, instance, args, info):
schema = info and getattr(info.schema, 'graphene_schema', None)
resolve_fn = self.get_resolve_fn(schema)
if resolve_fn:
return resolve_fn(instance, ProxySnakeDict(args), info)
else:
return getattr(instance, self.attname, self.get_default())
def get_resolve_fn(self, schema):
object_type = self.get_object_type(schema)
if object_type and object_type._meta.is_mutation:
return object_type.mutate
elif self.resolve_fn:
return self.resolve_fn
else:
custom_resolve_fn_name = 'resolve_%s' % self.attname
if hasattr(self.object_type, custom_resolve_fn_name):
resolve_fn = getattr(self.object_type, custom_resolve_fn_name)
@wraps(resolve_fn)
def custom_resolve_fn(instance, args, info):
return resolve_fn(instance, args, info)
return custom_resolve_fn
def get_object_type(self, schema):
field_type = self.field_type
if inspect.isfunction(field_type):
field_type = field_type(self)
_is_class = inspect.isclass(field_type)
if isinstance(field_type, Field):
return field_type.get_object_type(schema)
if _is_class and issubclass(field_type, BaseObjectType):
return field_type
elif isinstance(field_type, six.string_types):
if field_type == 'self':
return self.object_type
else:
return schema.get_type(field_type)
def type_wrapper(self, field_type):
if self.required:
field_type = GraphQLNonNull(field_type)
return field_type
def internal_type(self, schema):
field_type = self.field_type
_is_class = inspect.isclass(field_type)
if isinstance(field_type, Field):
field_type = self.field_type.internal_type(schema)
elif _is_class and issubclass(field_type, Enum):
field_type = enum_to_graphql_enum(field_type)
else:
object_type = self.get_object_type(schema)
if object_type:
field_type = schema.T(object_type)
field_type = self.type_wrapper(field_type)
return field_type
def internal_field(self, schema):
if not self.object_type:
raise Exception(
'Field could not be constructed in a non graphene.ObjectType or graphene.Interface')
extra_args = self.extra_args.copy()
for arg_name, arg_value in self.extra_args.items():
if isinstance(arg_value, GraphQLArgument):
self.args[arg_name] = arg_value
del extra_args[arg_name]
if extra_args != {}:
raise TypeError("Field %s.%s initiated with invalid args: %s" % (
self.object_type,
self.attname,
','.join(extra_args.keys())
))
args = self.args
object_type = self.get_object_type(schema)
if object_type and object_type._meta.is_mutation:
assert not self.args, 'Arguments provided for mutations are defined in Input class in Mutation'
args = object_type.get_input_type().fields_as_arguments(schema)
internal_type = self.internal_type(schema)
if not internal_type:
raise Exception("Internal type for field %s is None" % self)
description = self.description
resolve_fn = self.get_resolve_fn(schema)
if resolve_fn:
description = resolve_fn.__doc__ or description
@wraps(resolve_fn)
def resolver(*args):
return self.resolve(*args)
else:
resolver = self.resolve
if issubclass(self.object_type, InputObjectType):
return GraphQLInputObjectField(
internal_type,
description=description,
)
return GraphQLField(
internal_type,
description=description,
args=args,
resolver=resolver,
)
def __str__(self):
""" Return "object_type.name". """
return '%s.%s' % (self.object_type.__name__, self.attname)
def __repr__(self):
"""
Displays the module, class and name of the field.
"""
path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__)
name = getattr(self, 'attname', None)
if name is not None:
return '<%s: %s>' % (path, name)
return '<%s>' % path
def __eq__(self, other):
# Needed for @total_ordering
if isinstance(other, Field):
return self.creation_counter == other.creation_counter and \
self.object_type == other.object_type
return NotImplemented
def __lt__(self, other):
# This is needed because bisect does not take a comparison function.
if isinstance(other, Field):
return self.creation_counter < other.creation_counter
return NotImplemented
def __hash__(self):
return hash((self.creation_counter, self.object_type))
def __copy__(self):
# We need to avoid hitting __reduce__, so define this
# slightly weird copy construct.
obj = Empty()
obj.__class__ = self.__class__
obj.__dict__ = self.__dict__.copy()
if self.field_type == 'self':
obj.field_type = self.object_type
return obj
class IntField(DeprecatedField, Int):
pass
class LazyField(Field):
def inner_field(self, schema):
return self.get_field(schema)
def internal_type(self, schema):
return self.inner_field(schema).internal_type(schema)
def internal_field(self, schema):
return self.inner_field(schema).internal_field(schema)
class BooleanField(DeprecatedField, Boolean):
pass
class TypeField(Field):
def __init__(self, *args, **kwargs):
super(TypeField, self).__init__(self.field_type, *args, **kwargs)
class IDField(DeprecatedField, ID):
pass
class StringField(TypeField):
field_type = GraphQLString
class FloatField(DeprecatedField, Float):
pass
class IntField(TypeField):
field_type = GraphQLInt
class ListField(DeprecatedField, List):
pass
class BooleanField(TypeField):
field_type = GraphQLBoolean
class IDField(TypeField):
field_type = GraphQLID
class FloatField(TypeField):
field_type = GraphQLFloat
class ListField(Field):
def type_wrapper(self, field_type):
return GraphQLList(field_type)
class NonNullField(Field):
def type_wrapper(self, field_type):
return GraphQLNonNull(field_type)
class NonNullField(DeprecatedField, NonNull):
pass

View File

@ -1,3 +1,4 @@
import inspect
from collections import OrderedDict
from graphene import signals
@ -6,6 +7,7 @@ from graphql.core.execution.middlewares.sync import \
SynchronousExecutionMiddleware
from graphql.core.type import GraphQLSchema as _GraphQLSchema
from graphql.core.utils.introspection_query import introspection_query
from graphene.core.types.base import BaseType
class GraphQLSchema(_GraphQLSchema):
@ -34,13 +36,15 @@ class Schema(object):
def T(self, object_type):
if not object_type:
return
if object_type not in self._types:
internal_type = object_type.internal_type(self)
self._types[object_type] = internal_type
name = getattr(internal_type, 'name', None)
if name:
self._types_names[name] = object_type
return self._types[object_type]
# if inspect.isclass(object_type) and issubclass(object_type, BaseType):
if True:
if object_type not in self._types:
internal_type = object_type.internal_type(self)
self._types[object_type] = internal_type
name = getattr(internal_type, 'name', None)
if name:
self._types_names[name] = object_type
return self._types[object_type]
@property
def query(self):

View File

@ -0,0 +1 @@
from .objecttype import ObjectTypeMeta, BaseObjectType, ObjectType, Interface, Mutation, InputObjectType

View File

@ -31,7 +31,7 @@ def to_arguments(*args, **kwargs):
elif isinstance(arg, ArgumentType):
argument = arg.as_argument()
else:
raise ValueError('Unknown argument value type %r' % arg)
raise ValueError('Unknown argument %s=%r' % (name, arg))
if name:
argument.name = to_camel_case(name)

View File

@ -1,9 +1,14 @@
from functools import total_ordering
from ..types import BaseObjectType, InputObjectType
class BaseType(object):
@classmethod
def internal_type(cls, schema):
return getattr(cls, 'T', None)
@total_ordering
class OrderedType(object):
class OrderedType(BaseType):
creation_counter = 0
def __init__(self, _creation_counter=None):
@ -38,10 +43,6 @@ class MirroredType(OrderedType):
self.args = args
self.kwargs = kwargs
@classmethod
def internal_type(cls, schema):
return getattr(cls, 'T', None)
class ArgumentType(MirroredType):
def as_argument(self):
@ -51,6 +52,7 @@ class ArgumentType(MirroredType):
class FieldType(MirroredType):
def contribute_to_class(self, cls, name):
from ..types import BaseObjectType, InputObjectType
if issubclass(cls, InputObjectType):
inputfield = self.as_inputfield()
return inputfield.contribute_to_class(cls, name)
@ -60,11 +62,11 @@ class FieldType(MirroredType):
def as_field(self):
from .field import Field
return Field(self.__class__, _creation_counter=self.creation_counter, *self.args, **self.kwargs)
return Field(self, _creation_counter=self.creation_counter, *self.args, **self.kwargs)
def as_inputfield(self):
from .field import InputField
return InputField(self.__class__, _creation_counter=self.creation_counter, *self.args, **self.kwargs)
return InputField(self, _creation_counter=self.creation_counter, *self.args, **self.kwargs)
class MountedType(FieldType, ArgumentType):

View File

@ -9,6 +9,10 @@ from ...utils import to_camel_case
from ..types import BaseObjectType, InputObjectType
class Empty(object):
pass
class Field(OrderedType):
def __init__(self, type, description=None, args=None, name=None, resolver=None, *args_list, **kwargs):
_creation_counter = kwargs.pop('_creation_counter', None)
@ -18,6 +22,7 @@ class Field(OrderedType):
self.description = description
args = OrderedDict(args or {}, **kwargs)
self.arguments = to_arguments(*args_list, **args)
self.object_type = None
self.resolver = resolver
def contribute_to_class(self, cls, attname):
@ -32,7 +37,7 @@ class Field(OrderedType):
@property
def resolver(self):
return self._resolver
return self._resolver or self.get_resolver_fn()
@resolver.setter
def resolver(self, value):
@ -51,8 +56,6 @@ class Field(OrderedType):
def internal_type(self, schema):
resolver = self.resolver
description = self.description
if not resolver:
resolver = self.get_resolver_fn()
if not description and resolver:
description = resolver.__doc__
@ -65,6 +68,26 @@ class Field(OrderedType):
return OrderedDict([(arg.name, schema.T(arg)) for arg in self.arguments])
def __copy__(self):
obj = Empty()
obj.__class__ = self.__class__
obj.__dict__ = self.__dict__.copy()
obj.object_type = None
return obj
def __repr__(self):
"""
Displays the module, class and name of the field.
"""
path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__)
name = getattr(self, 'attname', None)
if name is not None:
return '<%s: %s>' % (path, name)
return '<%s>' % path
def __hash__(self):
return hash((self.creation_counter, self.object_type))
class InputField(OrderedType):
def __init__(self, type, description=None, default=None, name=None, _creation_counter=None):

View File

@ -7,6 +7,7 @@ import six
from graphene import signals
from graphene.core.options import Options
from graphene.core.types.base import BaseType
from graphql.core.type import (GraphQLArgument, GraphQLInputObjectType,
GraphQLInterfaceType, GraphQLObjectType)
@ -125,7 +126,7 @@ class ObjectTypeMeta(type):
setattr(cls, name, value)
class BaseObjectType(object):
class BaseObjectType(BaseType):
def __new__(cls, *args, **kwargs):
if cls._meta.is_interface:
@ -185,7 +186,7 @@ class BaseObjectType(object):
@classmethod
def internal_type(cls, schema):
fields = lambda: OrderedDict([(f.name, f.internal_field(schema))
fields = lambda: OrderedDict([(f.name, schema.T(f))
for f in cls._meta.fields])
if cls._meta.is_interface:
return GraphQLInterfaceType(
@ -222,7 +223,7 @@ class InputObjectType(ObjectType):
@classmethod
def internal_type(cls, schema):
fields = lambda: OrderedDict([(f.name, f.internal_field(schema))
fields = lambda: OrderedDict([(f.name, schema.T(f))
for f in cls._meta.fields])
return GraphQLInputObjectType(
cls._meta.type_name,

View File

@ -21,7 +21,7 @@ def test_orderedtype_different():
assert b > a
@patch('graphene.core.ntypes.field.Field')
@patch('graphene.core.types.field.Field')
def test_type_as_field_called(Field):
resolver = lambda x: x
a = MountedType(2, description='A', resolver=resolver)
@ -29,7 +29,7 @@ def test_type_as_field_called(Field):
Field.assert_called_with(MountedType, 2, _creation_counter=a.creation_counter, description='A', resolver=resolver)
@patch('graphene.core.ntypes.argument.Argument')
@patch('graphene.core.types.argument.Argument')
def test_type_as_argument_called(Argument):
a = MountedType(2, description='A')
a.as_argument()

View File

@ -10,9 +10,7 @@ from graphql.core.type import (GraphQLBoolean, GraphQLField, GraphQLID,
GraphQLInt, GraphQLNonNull, GraphQLString)
class ObjectType(object):
_meta = Options()
class ot(ObjectType):
def resolve_customdoc(self, *args, **kwargs):
'''Resolver documentation'''
return None
@ -20,22 +18,20 @@ class ObjectType(object):
def __str__(self):
return "ObjectType"
ot = ObjectType
schema = Schema()
def test_field_no_contributed_raises_error():
f = Field(GraphQLString)
with raises(Exception) as excinfo:
f.internal_field(schema)
schema.T(f)
def test_field_type():
f = Field(GraphQLString)
f.contribute_to_class(ot, 'field_name')
assert isinstance(f.internal_field(schema), GraphQLField)
assert f.internal_type(schema) == GraphQLString
assert isinstance(schema.T(f), GraphQLField)
assert schema.T(f).type == GraphQLString
def test_field_name_automatic_camelcase():

View File

@ -4,8 +4,7 @@ from graphene.core.fields import Field, ListField, StringField
from graphene.core.schema import Schema
from graphene.core.types import Interface, ObjectType
from graphql.core import graphql
from graphql.core.type import (GraphQLInterfaceType, GraphQLObjectType,
GraphQLSchema)
from graphql.core.type import GraphQLSchema
class Character(Interface):
@ -38,8 +37,8 @@ Human_type = schema.T(Human)
def test_type():
assert Human._meta.fields_map['name'].resolve(
Human(object()), None, None) == 'Peter'
assert Human._meta.fields_map['name'].resolver(
Human(object()), {}, None) == 'Peter'
def test_query():