Merge pull request #65 from graphql-python/features/plugins-autocamelcase

Create plugin structure
This commit is contained in:
Syrus Akbary 2015-12-09 19:52:12 -08:00
commit a161738f3d
22 changed files with 516 additions and 80 deletions

View File

@ -0,0 +1,4 @@
from .plugin import DjangoDebugPlugin
from .types import DjangoDebug
__all__ = ['DjangoDebugPlugin', 'DjangoDebug']

View File

@ -0,0 +1,77 @@
from contextlib import contextmanager
from django.db import connections
from ....core.types import Field
from ....plugins import Plugin
from .sql.tracking import unwrap_cursor, wrap_cursor
from .sql.types import DjangoDebugSQL
from .types import DjangoDebug
class WrappedRoot(object):
def __init__(self, root):
self._recorded = []
self._root = root
def record(self, **log):
self._recorded.append(DjangoDebugSQL(**log))
def debug(self):
return DjangoDebug(sql=self._recorded)
class WrapRoot(object):
@property
def _root(self):
return self._wrapped_root.root
@_root.setter
def _root(self, value):
self._wrapped_root = value
def resolve_debug(self, args, info):
return self._wrapped_root.debug()
def debug_objecttype(objecttype):
return type(
'Debug{}'.format(objecttype._meta.type_name),
(WrapRoot, objecttype),
{'debug': Field(DjangoDebug, name='__debug')})
class DjangoDebugPlugin(Plugin):
def transform_type(self, _type):
if _type == self.schema.query:
return
return _type
def enable_instrumentation(self, wrapped_root):
# This is thread-safe because database connections are thread-local.
for connection in connections.all():
wrap_cursor(connection, wrapped_root)
def disable_instrumentation(self):
for connection in connections.all():
unwrap_cursor(connection)
def wrap_schema(self, schema_type):
query = schema_type._query
if query:
class_type = self.schema.objecttype(schema_type._query)
assert class_type, 'The query in schema is not constructed with graphene'
_type = debug_objecttype(class_type)
schema_type._query = self.schema.T(_type)
return schema_type
@contextmanager
def context_execution(self, executor):
executor['root'] = WrappedRoot(root=executor['root'])
executor['schema'] = self.wrap_schema(executor['schema'])
self.enable_instrumentation(executor['root'])
yield executor
self.disable_instrumentation()

View File

@ -0,0 +1,165 @@
# Code obtained from django-debug-toolbar sql panel tracking
from __future__ import absolute_import, unicode_literals
import json
from threading import local
from time import time
from django.utils import six
from django.utils.encoding import force_text
class SQLQueryTriggered(Exception):
"""Thrown when template panel triggers a query"""
class ThreadLocalState(local):
def __init__(self):
self.enabled = True
@property
def Wrapper(self):
if self.enabled:
return NormalCursorWrapper
return ExceptionCursorWrapper
def recording(self, v):
self.enabled = v
state = ThreadLocalState()
recording = state.recording # export function
def wrap_cursor(connection, panel):
if not hasattr(connection, '_djdt_cursor'):
connection._djdt_cursor = connection.cursor
def cursor():
return state.Wrapper(connection._djdt_cursor(), connection, panel)
connection.cursor = cursor
return cursor
def unwrap_cursor(connection):
if hasattr(connection, '_djdt_cursor'):
del connection._djdt_cursor
del connection.cursor
class ExceptionCursorWrapper(object):
"""
Wraps a cursor and raises an exception on any operation.
Used in Templates panel.
"""
def __init__(self, cursor, db, logger):
pass
def __getattr__(self, attr):
raise SQLQueryTriggered()
class NormalCursorWrapper(object):
"""
Wraps a cursor and logs queries.
"""
def __init__(self, cursor, db, logger):
self.cursor = cursor
# Instance of a BaseDatabaseWrapper subclass
self.db = db
# logger must implement a ``record`` method
self.logger = logger
def _quote_expr(self, element):
if isinstance(element, six.string_types):
return "'%s'" % force_text(element).replace("'", "''")
else:
return repr(element)
def _quote_params(self, params):
if not params:
return params
if isinstance(params, dict):
return dict((key, self._quote_expr(value))
for key, value in params.items())
return list(map(self._quote_expr, params))
def _decode(self, param):
try:
return force_text(param, strings_only=True)
except UnicodeDecodeError:
return '(encoded string)'
def _record(self, method, sql, params):
start_time = time()
try:
return method(sql, params)
finally:
stop_time = time()
duration = (stop_time - start_time)
_params = ''
try:
_params = json.dumps(list(map(self._decode, params)))
except Exception:
pass # object not JSON serializable
alias = getattr(self.db, 'alias', 'default')
conn = self.db.connection
vendor = getattr(conn, 'vendor', 'unknown')
params = {
'vendor': vendor,
'alias': alias,
'sql': self.db.ops.last_executed_query(
self.cursor, sql, self._quote_params(params)),
'duration': duration,
'raw_sql': sql,
'params': _params,
'start_time': start_time,
'stop_time': stop_time,
'is_slow': duration > 10,
'is_select': sql.lower().strip().startswith('select'),
}
if vendor == 'postgresql':
# If an erroneous query was ran on the connection, it might
# be in a state where checking isolation_level raises an
# exception.
try:
iso_level = conn.isolation_level
except conn.InternalError:
iso_level = 'unknown'
params.update({
'trans_id': self.logger.get_transaction_id(alias),
'trans_status': conn.get_transaction_status(),
'iso_level': iso_level,
'encoding': conn.encoding,
})
# We keep `sql` to maintain backwards compatibility
self.logger.record(**params)
def callproc(self, procname, params=()):
return self._record(self.cursor.callproc, procname, params)
def execute(self, sql, params=()):
return self._record(self.cursor.execute, sql, params)
def executemany(self, sql, param_list):
return self._record(self.cursor.executemany, sql, param_list)
def __getattr__(self, attr):
return getattr(self.cursor, attr)
def __iter__(self):
return iter(self.cursor)
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.close()

View File

@ -0,0 +1,19 @@
from .....core import Float, ObjectType, String
class DjangoDebugSQL(ObjectType):
vendor = String()
alias = String()
sql = String()
duration = Float()
raw_sql = String()
params = String()
start_time = Float()
stop_time = Float()
is_slow = String()
is_select = String()
trans_id = String()
trans_status = String()
iso_level = String()
encoding = String()

View File

@ -0,0 +1,70 @@
import pytest
import graphene
from graphene.contrib.django import DjangoObjectType
from ...tests.models import Reporter
from ..plugin import DjangoDebugPlugin
# from examples.starwars_django.models import Character
pytestmark = pytest.mark.django_db
def test_should_query_well():
r1 = Reporter(last_name='ABA')
r1.save()
r2 = Reporter(last_name='Griffin')
r2.save()
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType)
all_reporters = ReporterType.List()
def resolve_all_reporters(self, *args, **kwargs):
return Reporter.objects.all()
def resolve_reporter(self, *args, **kwargs):
return Reporter.objects.first()
query = '''
query ReporterQuery {
reporter {
lastName
}
allReporters {
lastName
}
__debug {
sql {
rawSql
}
}
}
'''
expected = {
'reporter': {
'lastName': 'ABA',
},
'allReporters': [{
'lastName': 'ABA',
}, {
'lastName': 'Griffin',
}],
'__debug': {
'sql': [{
'rawSql': str(Reporter.objects.order_by('pk')[:1].query)
}, {
'rawSql': str(Reporter.objects.all().query)
}]
}
}
schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()])
result = schema.execute(query)
assert not result.errors
assert result.data == expected

View File

@ -0,0 +1,7 @@
from ....core.classtypes.objecttype import ObjectType
from ....core.types import Field
from .sql.types import DjangoDebugSQL
class DjangoDebug(ObjectType):
sql = Field(DjangoDebugSQL.List())

View File

@ -5,7 +5,6 @@ from functools import partial
import six import six
from ..exceptions import SkipField
from .options import Options from .options import Options
@ -82,13 +81,18 @@ class FieldsOptions(Options):
def fields_map(self): def fields_map(self):
return OrderedDict([(f.attname, f) for f in self.fields]) return OrderedDict([(f.attname, f) for f in self.fields])
@property
def fields_group_type(self):
from ..types.field import FieldsGroupType
return FieldsGroupType(*self.local_fields)
class FieldsClassTypeMeta(ClassTypeMeta): class FieldsClassTypeMeta(ClassTypeMeta):
options_class = FieldsOptions options_class = FieldsOptions
def extend_fields(cls, bases): def extend_fields(cls, bases):
new_fields = cls._meta.local_fields new_fields = cls._meta.local_fields
field_names = {f.name: f for f in new_fields} field_names = {f.attname: f for f in new_fields}
for base in bases: for base in bases:
if not isinstance(base, FieldsClassTypeMeta): if not isinstance(base, FieldsClassTypeMeta):
@ -96,17 +100,17 @@ class FieldsClassTypeMeta(ClassTypeMeta):
parent_fields = base._meta.local_fields parent_fields = base._meta.local_fields
for field in parent_fields: for field in parent_fields:
if field.name in field_names and field.type.__class__ != field_names[ if field.attname in field_names and field.type.__class__ != field_names[
field.name].type.__class__: field.attname].type.__class__:
raise Exception( raise Exception(
'Local field %r in class %r (%r) clashes ' 'Local field %r in class %r (%r) clashes '
'with field with similar name from ' 'with field with similar name from '
'Interface %s (%r)' % ( 'Interface %s (%r)' % (
field.name, field.attname,
cls.__name__, cls.__name__,
field.__class__, field.__class__,
base.__name__, base.__name__,
field_names[field.name].__class__) field_names[field.attname].__class__)
) )
new_field = copy.copy(field) new_field = copy.copy(field)
cls.add_to_class(field.attname, new_field) cls.add_to_class(field.attname, new_field)
@ -124,11 +128,4 @@ class FieldsClassType(six.with_metaclass(FieldsClassTypeMeta, ClassType)):
@classmethod @classmethod
def fields_internal_types(cls, schema): def fields_internal_types(cls, schema):
fields = [] return schema.T(cls._meta.fields_group_type)
for field in cls._meta.fields:
try:
fields.append((field.name, schema.T(field)))
except SkipField:
continue
return OrderedDict(fields)

View File

@ -24,4 +24,4 @@ def test_mutation():
assert list(object_type.get_fields().keys()) == ['name'] assert list(object_type.get_fields().keys()) == ['name']
assert MyMutation._meta.fields_map['name'].object_type == MyMutation assert MyMutation._meta.fields_map['name'].object_type == MyMutation
assert isinstance(MyMutation.arguments, ArgumentsGroup) assert isinstance(MyMutation.arguments, ArgumentsGroup)
assert 'argName' in MyMutation.arguments assert 'argName' in schema.T(MyMutation.arguments)

View File

@ -10,6 +10,7 @@ from graphql.core.utils.schema_printer import print_schema
from graphene import signals from graphene import signals
from ..plugins import CamelCase, PluginManager
from .classtypes.base import ClassType from .classtypes.base import ClassType
from .types.base import InstanceType from .types.base import InstanceType
@ -25,7 +26,7 @@ class Schema(object):
_executor = None _executor = None
def __init__(self, query=None, mutation=None, subscription=None, def __init__(self, query=None, mutation=None, subscription=None,
name='Schema', executor=None): name='Schema', executor=None, plugins=None, auto_camelcase=True):
self._types_names = {} self._types_names = {}
self._types = {} self._types = {}
self.mutation = mutation self.mutation = mutation
@ -33,11 +34,20 @@ class Schema(object):
self.subscription = subscription self.subscription = subscription
self.name = name self.name = name
self.executor = executor self.executor = executor
plugins = plugins or []
if auto_camelcase:
plugins.append(CamelCase())
self.plugins = PluginManager(self, plugins)
signals.init_schema.send(self) signals.init_schema.send(self)
def __repr__(self): def __repr__(self):
return '<Schema: %s (%s)>' % (str(self.name), hash(self)) return '<Schema: %s (%s)>' % (str(self.name), hash(self))
def __getattr__(self, name):
if name in self.plugins:
return getattr(self.plugins, name)
return super(Schema, self).__getattr__(name)
def T(self, _type): def T(self, _type):
if not _type: if not _type:
return return
@ -108,17 +118,10 @@ class Schema(object):
def types(self): def types(self):
return self._types_names return self._types_names
def execute(self, request='', root=None, vars=None, def execute(self, request='', root=None, args=None, **kwargs):
operation_name=None, **kwargs): kwargs = dict(kwargs, request=request, root=root, args=args, schema=self.schema)
root = root or object() with self.plugins.context_execution(**kwargs) as execute_kwargs:
return self.executor.execute( return self.executor.execute(**execute_kwargs)
self.schema,
request,
root=root,
args=vars,
operation_name=operation_name,
**kwargs
)
def introspect(self): def introspect(self):
return self.execute(introspection_query).data return self.execute(introspection_query).data

View File

@ -34,10 +34,11 @@ def test_field_type():
assert schema.T(f).type == GraphQLString assert schema.T(f).type == GraphQLString
def test_field_name_automatic_camelcase(): def test_field_name():
f = Field(GraphQLString) f = Field(GraphQLString)
f.contribute_to_class(MyOt, 'field_name') f.contribute_to_class(MyOt, 'field_name')
assert f.name == 'fieldName' assert f.name is None
assert f.attname == 'field_name'
def test_field_name_use_name_if_exists(): def test_field_name_use_name_if_exists():

View File

@ -1,19 +1,17 @@
from collections import OrderedDict
from functools import wraps from functools import wraps
from itertools import chain from itertools import chain
from graphql.core.type import GraphQLArgument from graphql.core.type import GraphQLArgument
from ...utils import ProxySnakeDict, to_camel_case from ...utils import ProxySnakeDict
from .base import ArgumentType, InstanceType, OrderedType from .base import ArgumentType, GroupNamedType, NamedType, OrderedType
class Argument(OrderedType): class Argument(NamedType, OrderedType):
def __init__(self, type, description=None, default=None, def __init__(self, type, description=None, default=None,
name=None, _creation_counter=None): name=None, _creation_counter=None):
super(Argument, self).__init__(_creation_counter=_creation_counter) super(Argument, self).__init__(name=name, _creation_counter=_creation_counter)
self.name = name
self.type = type self.type = type
self.description = description self.description = description
self.default = default self.default = default
@ -27,47 +25,32 @@ class Argument(OrderedType):
return self.name return self.name
class ArgumentsGroup(InstanceType): class ArgumentsGroup(GroupNamedType):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
arguments = to_arguments(*args, **kwargs) arguments = to_arguments(*args, **kwargs)
self.arguments = OrderedDict([(arg.name, arg) for arg in arguments]) super(ArgumentsGroup, self).__init__(*arguments)
def internal_type(self, schema):
return OrderedDict([(arg.name, schema.T(arg))
for arg in self.arguments.values()])
def __len__(self):
return len(self.arguments)
def __iter__(self):
return iter(self.arguments)
def __contains__(self, *args):
return self.arguments.__contains__(*args)
def __getitem__(self, *args):
return self.arguments.__getitem__(*args)
def to_arguments(*args, **kwargs): def to_arguments(*args, **kwargs):
arguments = {} arguments = {}
iter_arguments = chain(kwargs.items(), [(None, a) for a in args]) iter_arguments = chain(kwargs.items(), [(None, a) for a in args])
for name, arg in iter_arguments: for default_name, arg in iter_arguments:
if isinstance(arg, Argument): if isinstance(arg, Argument):
argument = arg argument = arg
elif isinstance(arg, ArgumentType): elif isinstance(arg, ArgumentType):
argument = arg.as_argument() argument = arg.as_argument()
else: else:
raise ValueError('Unknown argument %s=%r' % (name, arg)) raise ValueError('Unknown argument %s=%r' % (default_name, arg))
if name: if default_name:
argument.name = to_camel_case(name) argument.default_name = default_name
assert argument.name, 'Argument in field must have a name'
assert argument.name not in arguments, 'Found more than one Argument with same name {}'.format( name = argument.name or argument.default_name
argument.name) assert name, 'Argument in field must have a name'
arguments[argument.name] = argument assert name not in arguments, 'Found more than one Argument with same name {}'.format(name)
arguments[name] = argument
return sorted(arguments.values()) return sorted(arguments.values())

View File

@ -1,4 +1,5 @@
from functools import total_ordering from collections import OrderedDict
from functools import partial, total_ordering
import six import six
@ -125,3 +126,39 @@ class FieldType(MirroredType):
class MountedType(FieldType, ArgumentType): class MountedType(FieldType, ArgumentType):
pass pass
class NamedType(InstanceType):
def __init__(self, name=None, default_name=None, *args, **kwargs):
self.name = name
self.default_name = None
super(NamedType, self).__init__(*args, **kwargs)
class GroupNamedType(InstanceType):
def __init__(self, *types):
self.types = types
def get_named_type(self, schema, type):
name = type.name or schema.get_default_namedtype_name(type.default_name)
return name, schema.T(type)
def iter_types(self, schema):
return map(partial(self.get_named_type, schema), self.types)
def internal_type(self, schema):
return OrderedDict(self.iter_types(schema))
def __len__(self):
return len(self.types)
def __iter__(self):
return iter(self.types)
def __contains__(self, *args):
return self.types.__contains__(*args)
def __getitem__(self, *args):
return self.types.__getitem__(*args)

View File

@ -4,23 +4,22 @@ from functools import wraps
import six import six
from graphql.core.type import GraphQLField, GraphQLInputObjectField from graphql.core.type import GraphQLField, GraphQLInputObjectField
from ...utils import to_camel_case
from ..classtypes.base import FieldsClassType from ..classtypes.base import FieldsClassType
from ..classtypes.inputobjecttype import InputObjectType from ..classtypes.inputobjecttype import InputObjectType
from ..classtypes.mutation import Mutation from ..classtypes.mutation import Mutation
from ..exceptions import SkipField
from .argument import ArgumentsGroup, snake_case_args from .argument import ArgumentsGroup, snake_case_args
from .base import LazyType, MountType, OrderedType from .base import GroupNamedType, LazyType, MountType, NamedType, OrderedType
from .definitions import NonNull from .definitions import NonNull
class Field(OrderedType): class Field(NamedType, OrderedType):
def __init__( def __init__(
self, type, description=None, args=None, name=None, resolver=None, self, type, description=None, args=None, name=None, resolver=None,
required=False, default=None, *args_list, **kwargs): required=False, default=None, *args_list, **kwargs):
_creation_counter = kwargs.pop('_creation_counter', None) _creation_counter = kwargs.pop('_creation_counter', None)
super(Field, self).__init__(_creation_counter=_creation_counter) super(Field, self).__init__(name=name, _creation_counter=_creation_counter)
self.name = name
if isinstance(type, six.string_types): if isinstance(type, six.string_types):
type = LazyType(type) type = LazyType(type)
self.required = required self.required = required
@ -36,9 +35,8 @@ class Field(OrderedType):
assert issubclass( assert issubclass(
cls, (FieldsClassType)), 'Field {} cannot be mounted in {}'.format( cls, (FieldsClassType)), 'Field {} cannot be mounted in {}'.format(
self, cls) self, cls)
if not self.name:
self.name = to_camel_case(attname)
self.attname = attname self.attname = attname
self.default_name = attname
self.object_type = cls self.object_type = cls
self.mount(cls) self.mount(cls)
if isinstance(self.type, MountType): if isinstance(self.type, MountType):
@ -117,12 +115,11 @@ class Field(OrderedType):
return hash((self.creation_counter, self.object_type)) return hash((self.creation_counter, self.object_type))
class InputField(OrderedType): class InputField(NamedType, OrderedType):
def __init__(self, type, description=None, default=None, def __init__(self, type, description=None, default=None,
name=None, _creation_counter=None, required=False): name=None, _creation_counter=None, required=False):
super(InputField, self).__init__(_creation_counter=_creation_counter) super(InputField, self).__init__(_creation_counter=_creation_counter)
self.name = name
if required: if required:
type = NonNull(type) type = NonNull(type)
self.type = type self.type = type
@ -133,9 +130,8 @@ class InputField(OrderedType):
assert issubclass( assert issubclass(
cls, (InputObjectType)), 'InputField {} cannot be mounted in {}'.format( cls, (InputObjectType)), 'InputField {} cannot be mounted in {}'.format(
self, cls) self, cls)
if not self.name:
self.name = to_camel_case(attname)
self.attname = attname self.attname = attname
self.default_name = attname
self.object_type = cls self.object_type = cls
self.mount(cls) self.mount(cls)
if isinstance(self.type, MountType): if isinstance(self.type, MountType):
@ -146,3 +142,13 @@ class InputField(OrderedType):
return GraphQLInputObjectField( return GraphQLInputObjectField(
schema.T(self.type), schema.T(self.type),
default_value=self.default, description=self.description) default_value=self.default, description=self.description)
class FieldsGroupType(GroupNamedType):
def iter_types(self, schema):
for field in sorted(self.types):
try:
yield self.get_named_type(schema, field)
except SkipField:
continue

View File

@ -5,6 +5,7 @@ from .base import MountedType
class ScalarType(MountedType): class ScalarType(MountedType):
def internal_type(self, schema): def internal_type(self, schema):
return self._internal_type return self._internal_type

View File

@ -27,8 +27,8 @@ def test_to_arguments():
other_kwarg=String(), other_kwarg=String(),
) )
assert [a.name for a in arguments] == [ assert [a.name or a.default_name for a in arguments] == [
'myArg', 'otherArg', 'myKwarg', 'otherKwarg'] 'myArg', 'otherArg', 'my_kwarg', 'other_kwarg']
def test_to_arguments_no_name(): def test_to_arguments_no_name():

View File

@ -20,7 +20,7 @@ def test_field_internal_type():
schema = Schema(query=Query) schema = Schema(query=Query)
type = schema.T(field) type = schema.T(field)
assert field.name == 'myField' assert field.name is None
assert field.attname == 'my_field' assert field.attname == 'my_field'
assert isinstance(type, GraphQLField) assert isinstance(type, GraphQLField)
assert type.description == 'My argument' assert type.description == 'My argument'
@ -98,9 +98,10 @@ def test_field_string_reference():
def test_field_custom_arguments(): def test_field_custom_arguments():
field = Field(None, name='my_customName', p=String()) field = Field(None, name='my_customName', p=String())
schema = Schema()
args = field.arguments args = field.arguments
assert 'p' in args assert 'p' in schema.T(args)
def test_inputfield_internal_type(): def test_inputfield_internal_type():
@ -115,7 +116,7 @@ def test_inputfield_internal_type():
schema = Schema(query=MyObjectType) schema = Schema(query=MyObjectType)
type = schema.T(field) type = schema.T(field)
assert field.name == 'myField' assert field.name is None
assert field.attname == 'my_field' assert field.attname == 'my_field'
assert isinstance(type, GraphQLInputObjectField) assert isinstance(type, GraphQLInputObjectField)
assert type.description == 'My input field' assert type.description == 'My input field'

View File

@ -0,0 +1,6 @@
from .base import Plugin, PluginManager
from .camel_case import CamelCase
__all__ = [
'Plugin', 'PluginManager', 'CamelCase'
]

53
graphene/plugins/base.py Normal file
View File

@ -0,0 +1,53 @@
from contextlib import contextmanager
from functools import partial, reduce
class Plugin(object):
def contribute_to_schema(self, schema):
self.schema = schema
def apply_function(a, b):
return b(a)
class PluginManager(object):
PLUGIN_FUNCTIONS = ('get_default_namedtype_name', )
def __init__(self, schema, plugins=[]):
self.schema = schema
self.plugins = []
for plugin in plugins:
self.add_plugin(plugin)
def add_plugin(self, plugin):
if hasattr(plugin, 'contribute_to_schema'):
plugin.contribute_to_schema(self.schema)
self.plugins.append(plugin)
def get_plugin_functions(self, function):
for plugin in self.plugins:
if not hasattr(plugin, function):
continue
yield getattr(plugin, function)
def __getattr__(self, name):
functions = self.get_plugin_functions(name)
return partial(reduce, apply_function, functions)
def __contains__(self, name):
return name in self.PLUGIN_FUNCTIONS
@contextmanager
def context_execution(self, **executor):
contexts = []
functions = self.get_plugin_functions('context_execution')
for f in functions:
context = f(executor)
executor = context.__enter__()
contexts.append((context, executor))
yield executor
for context, value in contexts[::-1]:
context.__exit__(None, None, None)

View File

@ -0,0 +1,7 @@
from ..utils import to_camel_case
class CamelCase(object):
def get_default_namedtype_name(self, value):
return to_camel_case(value)

View File

@ -34,8 +34,7 @@ schema = Schema(query=Query, mutation=MyResultMutation)
def test_mutation_arguments(): def test_mutation_arguments():
assert ChangeNumber.arguments assert ChangeNumber.arguments
assert list(ChangeNumber.arguments) == ['input'] assert 'input' in schema.T(ChangeNumber.arguments)
assert 'input' in ChangeNumber.arguments
inner_type = ChangeNumber.input_type inner_type = ChangeNumber.input_type
client_mutation_id_field = inner_type._meta.fields_map[ client_mutation_id_field = inner_type._meta.fields_map[
'client_mutation_id'] 'client_mutation_id']