Merge branch 'refs/heads/features/plugins-autocamelcase' into features/django-debug

This commit is contained in:
Syrus Akbary 2015-12-06 16:39:32 -08:00
commit bd35fcee6c
14 changed files with 83 additions and 133 deletions

View File

@ -11,7 +11,7 @@ from .core import (
Interface, Interface,
Mutation, Mutation,
Scalar, Scalar,
BaseType, InstanceType,
LazyType, LazyType,
Argument, Argument,
Field, Field,
@ -51,7 +51,7 @@ __all__ = [
'NonNull', 'NonNull',
'signals', 'signals',
'Schema', 'Schema',
'BaseType', 'InstanceType',
'LazyType', 'LazyType',
'ObjectType', 'ObjectType',
'InputObjectType', 'InputObjectType',

View File

@ -11,7 +11,7 @@ from .classtypes import (
) )
from .types import ( from .types import (
BaseType, InstanceType,
LazyType, LazyType,
Argument, Argument,
Field, Field,
@ -35,7 +35,7 @@ __all__ = [
'List', 'List',
'NonNull', 'NonNull',
'Schema', 'Schema',
'BaseType', 'InstanceType',
'LazyType', 'LazyType',
'ObjectType', 'ObjectType',
'InputObjectType', 'InputObjectType',

View File

@ -1,7 +1,7 @@
import copy import copy
import inspect import inspect
from functools import partial
from collections import OrderedDict from collections import OrderedDict
from functools import partial
import six import six

View File

@ -10,9 +10,9 @@ from graphql.core.utils.schema_printer import print_schema
from graphene import signals from graphene import signals
from ..plugins import CamelCase, Plugin
from .classtypes.base import ClassType from .classtypes.base import ClassType
from .types.base import BaseType from .types.base import InstanceType
from ..plugins import Plugin, CamelCase
class GraphQLSchema(_GraphQLSchema): class GraphQLSchema(_GraphQLSchema):
@ -50,27 +50,27 @@ class Schema(object):
plugin.contribute_to_schema(self) plugin.contribute_to_schema(self)
self.plugins.append(plugin) self.plugins.append(plugin)
def get_internal_type(self, objecttype): def get_default_namedtype_name(self, value):
for plugin in self.plugins: for plugin in self.plugins:
objecttype = plugin.transform_type(objecttype) if not hasattr(plugin, 'get_default_namedtype_name'):
return objecttype.internal_type(self) continue
value = plugin.get_default_namedtype_name(value)
return value
def T(self, object_type): def T(self, _type):
if not object_type: if not _type:
return return
if inspect.isclass(object_type) and issubclass( is_classtype = inspect.isclass(_type) and issubclass(_type, ClassType)
object_type, (BaseType, ClassType)) or isinstance( is_instancetype = isinstance(_type, InstanceType)
object_type, BaseType): if is_classtype or is_instancetype:
if object_type not in self._types: if _type not in self._types:
internal_type = self.get_internal_type(object_type) internal_type = _type.internal_type(self)
self._types[object_type] = internal_type self._types[_type] = internal_type
is_objecttype = inspect.isclass( if is_classtype:
object_type) and issubclass(object_type, ClassType) self.register(_type)
if is_objecttype: return self._types[_type]
self.register(object_type)
return self._types[object_type]
else: else:
return object_type return _type
@property @property
def executor(self): def executor(self):

View File

@ -1,14 +1,14 @@
from .base import BaseType, LazyType, OrderedType from .base import InstanceType, LazyType, OrderedType
from .argument import Argument, ArgumentsGroup, to_arguments from .argument import Argument, ArgumentsGroup, to_arguments
from .definitions import List, NonNull from .definitions import List, NonNull
# Compatibility import # Compatibility import
from .objecttype import Interface, ObjectType, Mutation, InputObjectType from .objecttype import Interface, ObjectType, Mutation, InputObjectType
from .scalars import String, ID, Boolean, Int, Float, Scalar from .scalars import String, ID, Boolean, Int, Float
from .field import Field, InputField from .field import Field, InputField
__all__ = [ __all__ = [
'BaseType', 'InstanceType',
'LazyType', 'LazyType',
'OrderedType', 'OrderedType',
'Argument', 'Argument',
@ -26,5 +26,4 @@ __all__ = [
'ID', 'ID',
'Boolean', 'Boolean',
'Int', 'Int',
'Float', 'Float']
'Scalar']

View File

@ -11,9 +11,7 @@ class Argument(NamedType, OrderedType):
def __init__(self, type, description=None, default=None, def __init__(self, type, description=None, default=None,
name=None, _creation_counter=None): name=None, _creation_counter=None):
super(Argument, self).__init__(_creation_counter=_creation_counter) super(Argument, self).__init__(name=name, _creation_counter=_creation_counter)
self.name = name
self.attname = None
self.type = type self.type = type
self.description = description self.description = description
self.default = default self.default = default
@ -38,18 +36,18 @@ def to_arguments(*args, **kwargs):
arguments = {} arguments = {}
iter_arguments = chain(kwargs.items(), [(None, a) for a in args]) iter_arguments = chain(kwargs.items(), [(None, a) for a in args])
for attname, arg in iter_arguments: for default_name, arg in iter_arguments:
if isinstance(arg, Argument): if isinstance(arg, Argument):
argument = arg argument = arg
elif isinstance(arg, ArgumentType): elif isinstance(arg, ArgumentType):
argument = arg.as_argument() argument = arg.as_argument()
else: else:
raise ValueError('Unknown argument %s=%r' % (attname, arg)) raise ValueError('Unknown argument %s=%r' % (default_name, arg))
if attname: if default_name:
argument.attname = attname argument.default_name = default_name
name = argument.name or argument.attname name = argument.name or argument.default_name
assert name, 'Argument in field must have a name' assert name, 'Argument in field must have a name'
assert name not in arguments, 'Found more than one Argument with same name {}'.format(name) assert name not in arguments, 'Found more than one Argument with same name {}'.format(name)
arguments[name] = argument arguments[name] = argument

View File

@ -1,19 +1,16 @@
from collections import OrderedDict from collections import OrderedDict
from functools import total_ordering, partial from functools import partial, total_ordering
import six import six
from ...utils import to_camel_case
class InstanceType(object):
def internal_type(self, schema):
raise NotImplementedError("internal_type for type {} is not implemented".format(self.__class__.__name__))
class BaseType(object): class MountType(InstanceType):
@classmethod
def internal_type(cls, schema):
return getattr(cls, 'T', None)
class MountType(BaseType):
parent = None parent = None
def mount(self, cls): def mount(self, cls):
@ -131,20 +128,27 @@ class MountedType(FieldType, ArgumentType):
pass pass
class NamedType(BaseType): class NamedType(InstanceType):
pass def __init__(self, name=None, default_name=None, *args, **kwargs):
self.name = name
self.default_name = None
super(NamedType, self).__init__(*args, **kwargs)
class GroupNamedType(BaseType): class GroupNamedType(InstanceType):
def __init__(self, *types): def __init__(self, *types):
self.types = types self.types = types
def get_named_type(self, schema, type): def get_named_type(self, schema, type):
name = type.name or type.attname name = type.name or schema.get_default_namedtype_name(type.default_name)
return name, schema.T(type) return name, schema.T(type)
def iter_types(self, schema):
return map(partial(self.get_named_type, schema), self.types)
def internal_type(self, schema): def internal_type(self, schema):
return OrderedDict(map(partial(self.get_named_type, schema), self.types)) return OrderedDict(self.iter_types(schema))
def __len__(self): def __len__(self):
return len(self.types) return len(self.types)

View File

@ -9,7 +9,7 @@ from ..classtypes.inputobjecttype import InputObjectType
from ..classtypes.mutation import Mutation from ..classtypes.mutation import Mutation
from ..exceptions import SkipField from ..exceptions import SkipField
from .argument import ArgumentsGroup, snake_case_args from .argument import ArgumentsGroup, snake_case_args
from .base import LazyType, NamedType, MountType, OrderedType, GroupNamedType from .base import GroupNamedType, LazyType, MountType, NamedType, OrderedType
from .definitions import NonNull from .definitions import NonNull
@ -19,8 +19,7 @@ class Field(NamedType, OrderedType):
self, type, description=None, args=None, name=None, resolver=None, self, type, description=None, args=None, name=None, resolver=None,
required=False, default=None, *args_list, **kwargs): required=False, default=None, *args_list, **kwargs):
_creation_counter = kwargs.pop('_creation_counter', None) _creation_counter = kwargs.pop('_creation_counter', None)
super(Field, self).__init__(_creation_counter=_creation_counter) super(Field, self).__init__(name=name, _creation_counter=_creation_counter)
self.name = name
if isinstance(type, six.string_types): if isinstance(type, six.string_types):
type = LazyType(type) type = LazyType(type)
self.required = required self.required = required
@ -37,6 +36,7 @@ class Field(NamedType, OrderedType):
cls, (FieldsClassType)), 'Field {} cannot be mounted in {}'.format( cls, (FieldsClassType)), 'Field {} cannot be mounted in {}'.format(
self, cls) self, cls)
self.attname = attname self.attname = attname
self.default_name = attname
self.object_type = cls self.object_type = cls
self.mount(cls) self.mount(cls)
if isinstance(self.type, MountType): if isinstance(self.type, MountType):
@ -120,7 +120,6 @@ class InputField(NamedType, OrderedType):
def __init__(self, type, description=None, default=None, def __init__(self, type, description=None, default=None,
name=None, _creation_counter=None, required=False): name=None, _creation_counter=None, required=False):
super(InputField, self).__init__(_creation_counter=_creation_counter) super(InputField, self).__init__(_creation_counter=_creation_counter)
self.name = name
if required: if required:
type = NonNull(type) type = NonNull(type)
self.type = type self.type = type
@ -132,6 +131,7 @@ class InputField(NamedType, OrderedType):
cls, (InputObjectType)), 'InputField {} cannot be mounted in {}'.format( cls, (InputObjectType)), 'InputField {} cannot be mounted in {}'.format(
self, cls) self, cls)
self.attname = attname self.attname = attname
self.default_name = attname
self.object_type = cls self.object_type = cls
self.mount(cls) self.mount(cls)
if isinstance(self.type, MountType): if isinstance(self.type, MountType):
@ -145,11 +145,10 @@ class InputField(NamedType, OrderedType):
class FieldsGroupType(GroupNamedType): class FieldsGroupType(GroupNamedType):
def internal_type(self, schema):
fields = [] def iter_types(self, schema):
for field in sorted(self.types): for field in sorted(self.types):
try: try:
fields.append(self.get_named_type(schema, field)) yield self.get_named_type(schema, field)
except SkipField: except SkipField:
continue continue
return OrderedDict(fields)

View File

@ -1,41 +1,30 @@
from graphql.core.type import (GraphQLBoolean, GraphQLFloat, GraphQLID, from graphql.core.type import (GraphQLBoolean, GraphQLFloat, GraphQLID,
GraphQLInt, GraphQLScalarType, GraphQLString) GraphQLInt, GraphQLString)
from .base import MountedType from .base import MountedType
class String(MountedType): class ScalarType(MountedType):
T = GraphQLString
def internal_type(self, schema):
return self._internal_type
class Int(MountedType): class String(ScalarType):
T = GraphQLInt _internal_type = GraphQLString
class Boolean(MountedType): class Int(ScalarType):
T = GraphQLBoolean _internal_type = GraphQLInt
class ID(MountedType): class Boolean(ScalarType):
T = GraphQLID _internal_type = GraphQLBoolean
class Float(MountedType): class ID(ScalarType):
T = GraphQLFloat _internal_type = GraphQLID
class Scalar(MountedType): class Float(ScalarType):
_internal_type = GraphQLFloat
@classmethod
def internal_type(cls, schema):
serialize = getattr(cls, 'serialize')
parse_literal = getattr(cls, 'parse_literal')
parse_value = getattr(cls, 'parse_value')
return GraphQLScalarType(
name=cls.__name__,
description=cls.__doc__,
serialize=serialize,
parse_value=parse_value,
parse_literal=parse_literal
)

View File

@ -27,7 +27,7 @@ def test_to_arguments():
other_kwarg=String(), other_kwarg=String(),
) )
assert [a.name or a.attname for a in arguments] == [ assert [a.name or a.default_name for a in arguments] == [
'myArg', 'otherArg', 'my_kwarg', 'other_kwarg'] 'myArg', 'otherArg', 'my_kwarg', 'other_kwarg']

View File

@ -13,7 +13,7 @@ from ..scalars import String
def test_field_internal_type(): def test_field_internal_type():
resolver = lambda *args: 'RESOLVED' resolver = lambda *args: 'RESOLVED'
field = Field(String, description='My argument', resolver=resolver) field = Field(String(), description='My argument', resolver=resolver)
class Query(ObjectType): class Query(ObjectType):
my_field = field my_field = field

View File

@ -1,9 +1,9 @@
from graphql.core.type import (GraphQLBoolean, GraphQLFloat, GraphQLID, from graphql.core.type import (GraphQLBoolean, GraphQLFloat, GraphQLID,
GraphQLInt, GraphQLScalarType, GraphQLString) GraphQLInt, GraphQLString)
from graphene.core.schema import Schema from graphene.core.schema import Schema
from ..scalars import ID, Boolean, Float, Int, Scalar, String from ..scalars import ID, Boolean, Float, Int, String
schema = Schema() schema = Schema()
@ -26,29 +26,3 @@ def test_id_scalar():
def test_float_scalar(): def test_float_scalar():
assert schema.T(Float()) == GraphQLFloat assert schema.T(Float()) == GraphQLFloat
def test_custom_scalar():
import datetime
from graphql.core.language import ast
class DateTimeScalar(Scalar):
'''DateTimeScalar Documentation'''
@staticmethod
def serialize(dt):
return dt.isoformat()
@staticmethod
def parse_literal(node):
if isinstance(node, ast.StringValue):
return datetime.datetime.strptime(
node.value, "%Y-%m-%dT%H:%M:%S.%f")
@staticmethod
def parse_value(value):
return datetime.datetime.strptime(value, "%Y-%m-%dT%H:%M:%S.%f")
scalar_type = schema.T(DateTimeScalar)
assert isinstance(scalar_type, GraphQLScalarType)
assert scalar_type.name == 'DateTimeScalar'
assert scalar_type.description == 'DateTimeScalar Documentation'

View File

@ -1,4 +1,5 @@
class Plugin(object): class Plugin(object):
def contribute_to_schema(self, schema): def contribute_to_schema(self, schema):
self.schema = schema self.schema = schema

View File

@ -1,22 +1,8 @@
from ..utils import to_camel_case
from .base import Plugin from .base import Plugin
from ..core.types.base import GroupNamedType
from ..utils import memoize, to_camel_case
def camelcase_named_type(schema, type):
name = type.name or to_camel_case(type.attname)
return name, schema.T(type)
class CamelCase(Plugin): class CamelCase(Plugin):
@memoize
def transform_group(self, _type):
new_type = _type.__class__(*_type.types)
setattr(new_type, 'get_named_type', camelcase_named_type)
return new_type
def transform_type(self, _type): def get_default_namedtype_name(self, value):
if isinstance(_type, GroupNamedType): return to_camel_case(value)
return self.transform_group(_type)
return _type