mirror of
https://github.com/graphql-python/graphene.git
synced 2024-11-26 19:43:56 +03:00
Merge pull request #65 from graphql-python/features/plugins-autocamelcase
Create plugin structure
This commit is contained in:
commit
a161738f3d
4
graphene/contrib/django/debug/__init__.py
Normal file
4
graphene/contrib/django/debug/__init__.py
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
from .plugin import DjangoDebugPlugin
|
||||||
|
from .types import DjangoDebug
|
||||||
|
|
||||||
|
__all__ = ['DjangoDebugPlugin', 'DjangoDebug']
|
77
graphene/contrib/django/debug/plugin.py
Normal file
77
graphene/contrib/django/debug/plugin.py
Normal file
|
@ -0,0 +1,77 @@
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from django.db import connections
|
||||||
|
|
||||||
|
from ....core.types import Field
|
||||||
|
from ....plugins import Plugin
|
||||||
|
from .sql.tracking import unwrap_cursor, wrap_cursor
|
||||||
|
from .sql.types import DjangoDebugSQL
|
||||||
|
from .types import DjangoDebug
|
||||||
|
|
||||||
|
|
||||||
|
class WrappedRoot(object):
|
||||||
|
|
||||||
|
def __init__(self, root):
|
||||||
|
self._recorded = []
|
||||||
|
self._root = root
|
||||||
|
|
||||||
|
def record(self, **log):
|
||||||
|
self._recorded.append(DjangoDebugSQL(**log))
|
||||||
|
|
||||||
|
def debug(self):
|
||||||
|
return DjangoDebug(sql=self._recorded)
|
||||||
|
|
||||||
|
|
||||||
|
class WrapRoot(object):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _root(self):
|
||||||
|
return self._wrapped_root.root
|
||||||
|
|
||||||
|
@_root.setter
|
||||||
|
def _root(self, value):
|
||||||
|
self._wrapped_root = value
|
||||||
|
|
||||||
|
def resolve_debug(self, args, info):
|
||||||
|
return self._wrapped_root.debug()
|
||||||
|
|
||||||
|
|
||||||
|
def debug_objecttype(objecttype):
|
||||||
|
return type(
|
||||||
|
'Debug{}'.format(objecttype._meta.type_name),
|
||||||
|
(WrapRoot, objecttype),
|
||||||
|
{'debug': Field(DjangoDebug, name='__debug')})
|
||||||
|
|
||||||
|
|
||||||
|
class DjangoDebugPlugin(Plugin):
|
||||||
|
|
||||||
|
def transform_type(self, _type):
|
||||||
|
if _type == self.schema.query:
|
||||||
|
return
|
||||||
|
return _type
|
||||||
|
|
||||||
|
def enable_instrumentation(self, wrapped_root):
|
||||||
|
# This is thread-safe because database connections are thread-local.
|
||||||
|
for connection in connections.all():
|
||||||
|
wrap_cursor(connection, wrapped_root)
|
||||||
|
|
||||||
|
def disable_instrumentation(self):
|
||||||
|
for connection in connections.all():
|
||||||
|
unwrap_cursor(connection)
|
||||||
|
|
||||||
|
def wrap_schema(self, schema_type):
|
||||||
|
query = schema_type._query
|
||||||
|
if query:
|
||||||
|
class_type = self.schema.objecttype(schema_type._query)
|
||||||
|
assert class_type, 'The query in schema is not constructed with graphene'
|
||||||
|
_type = debug_objecttype(class_type)
|
||||||
|
schema_type._query = self.schema.T(_type)
|
||||||
|
return schema_type
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def context_execution(self, executor):
|
||||||
|
executor['root'] = WrappedRoot(root=executor['root'])
|
||||||
|
executor['schema'] = self.wrap_schema(executor['schema'])
|
||||||
|
self.enable_instrumentation(executor['root'])
|
||||||
|
yield executor
|
||||||
|
self.disable_instrumentation()
|
0
graphene/contrib/django/debug/sql/__init__.py
Normal file
0
graphene/contrib/django/debug/sql/__init__.py
Normal file
165
graphene/contrib/django/debug/sql/tracking.py
Normal file
165
graphene/contrib/django/debug/sql/tracking.py
Normal file
|
@ -0,0 +1,165 @@
|
||||||
|
# Code obtained from django-debug-toolbar sql panel tracking
|
||||||
|
from __future__ import absolute_import, unicode_literals
|
||||||
|
|
||||||
|
import json
|
||||||
|
from threading import local
|
||||||
|
from time import time
|
||||||
|
|
||||||
|
from django.utils import six
|
||||||
|
from django.utils.encoding import force_text
|
||||||
|
|
||||||
|
|
||||||
|
class SQLQueryTriggered(Exception):
|
||||||
|
"""Thrown when template panel triggers a query"""
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadLocalState(local):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.enabled = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def Wrapper(self):
|
||||||
|
if self.enabled:
|
||||||
|
return NormalCursorWrapper
|
||||||
|
return ExceptionCursorWrapper
|
||||||
|
|
||||||
|
def recording(self, v):
|
||||||
|
self.enabled = v
|
||||||
|
|
||||||
|
|
||||||
|
state = ThreadLocalState()
|
||||||
|
recording = state.recording # export function
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_cursor(connection, panel):
|
||||||
|
if not hasattr(connection, '_djdt_cursor'):
|
||||||
|
connection._djdt_cursor = connection.cursor
|
||||||
|
|
||||||
|
def cursor():
|
||||||
|
return state.Wrapper(connection._djdt_cursor(), connection, panel)
|
||||||
|
|
||||||
|
connection.cursor = cursor
|
||||||
|
return cursor
|
||||||
|
|
||||||
|
|
||||||
|
def unwrap_cursor(connection):
|
||||||
|
if hasattr(connection, '_djdt_cursor'):
|
||||||
|
del connection._djdt_cursor
|
||||||
|
del connection.cursor
|
||||||
|
|
||||||
|
|
||||||
|
class ExceptionCursorWrapper(object):
|
||||||
|
"""
|
||||||
|
Wraps a cursor and raises an exception on any operation.
|
||||||
|
Used in Templates panel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cursor, db, logger):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
raise SQLQueryTriggered()
|
||||||
|
|
||||||
|
|
||||||
|
class NormalCursorWrapper(object):
|
||||||
|
"""
|
||||||
|
Wraps a cursor and logs queries.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cursor, db, logger):
|
||||||
|
self.cursor = cursor
|
||||||
|
# Instance of a BaseDatabaseWrapper subclass
|
||||||
|
self.db = db
|
||||||
|
# logger must implement a ``record`` method
|
||||||
|
self.logger = logger
|
||||||
|
|
||||||
|
def _quote_expr(self, element):
|
||||||
|
if isinstance(element, six.string_types):
|
||||||
|
return "'%s'" % force_text(element).replace("'", "''")
|
||||||
|
else:
|
||||||
|
return repr(element)
|
||||||
|
|
||||||
|
def _quote_params(self, params):
|
||||||
|
if not params:
|
||||||
|
return params
|
||||||
|
if isinstance(params, dict):
|
||||||
|
return dict((key, self._quote_expr(value))
|
||||||
|
for key, value in params.items())
|
||||||
|
return list(map(self._quote_expr, params))
|
||||||
|
|
||||||
|
def _decode(self, param):
|
||||||
|
try:
|
||||||
|
return force_text(param, strings_only=True)
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
return '(encoded string)'
|
||||||
|
|
||||||
|
def _record(self, method, sql, params):
|
||||||
|
start_time = time()
|
||||||
|
try:
|
||||||
|
return method(sql, params)
|
||||||
|
finally:
|
||||||
|
stop_time = time()
|
||||||
|
duration = (stop_time - start_time)
|
||||||
|
_params = ''
|
||||||
|
try:
|
||||||
|
_params = json.dumps(list(map(self._decode, params)))
|
||||||
|
except Exception:
|
||||||
|
pass # object not JSON serializable
|
||||||
|
|
||||||
|
alias = getattr(self.db, 'alias', 'default')
|
||||||
|
conn = self.db.connection
|
||||||
|
vendor = getattr(conn, 'vendor', 'unknown')
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'vendor': vendor,
|
||||||
|
'alias': alias,
|
||||||
|
'sql': self.db.ops.last_executed_query(
|
||||||
|
self.cursor, sql, self._quote_params(params)),
|
||||||
|
'duration': duration,
|
||||||
|
'raw_sql': sql,
|
||||||
|
'params': _params,
|
||||||
|
'start_time': start_time,
|
||||||
|
'stop_time': stop_time,
|
||||||
|
'is_slow': duration > 10,
|
||||||
|
'is_select': sql.lower().strip().startswith('select'),
|
||||||
|
}
|
||||||
|
|
||||||
|
if vendor == 'postgresql':
|
||||||
|
# If an erroneous query was ran on the connection, it might
|
||||||
|
# be in a state where checking isolation_level raises an
|
||||||
|
# exception.
|
||||||
|
try:
|
||||||
|
iso_level = conn.isolation_level
|
||||||
|
except conn.InternalError:
|
||||||
|
iso_level = 'unknown'
|
||||||
|
params.update({
|
||||||
|
'trans_id': self.logger.get_transaction_id(alias),
|
||||||
|
'trans_status': conn.get_transaction_status(),
|
||||||
|
'iso_level': iso_level,
|
||||||
|
'encoding': conn.encoding,
|
||||||
|
})
|
||||||
|
|
||||||
|
# We keep `sql` to maintain backwards compatibility
|
||||||
|
self.logger.record(**params)
|
||||||
|
|
||||||
|
def callproc(self, procname, params=()):
|
||||||
|
return self._record(self.cursor.callproc, procname, params)
|
||||||
|
|
||||||
|
def execute(self, sql, params=()):
|
||||||
|
return self._record(self.cursor.execute, sql, params)
|
||||||
|
|
||||||
|
def executemany(self, sql, param_list):
|
||||||
|
return self._record(self.cursor.executemany, sql, param_list)
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
return getattr(self.cursor, attr)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.cursor)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, type, value, traceback):
|
||||||
|
self.close()
|
19
graphene/contrib/django/debug/sql/types.py
Normal file
19
graphene/contrib/django/debug/sql/types.py
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
from .....core import Float, ObjectType, String
|
||||||
|
|
||||||
|
|
||||||
|
class DjangoDebugSQL(ObjectType):
|
||||||
|
vendor = String()
|
||||||
|
alias = String()
|
||||||
|
sql = String()
|
||||||
|
duration = Float()
|
||||||
|
raw_sql = String()
|
||||||
|
params = String()
|
||||||
|
start_time = Float()
|
||||||
|
stop_time = Float()
|
||||||
|
is_slow = String()
|
||||||
|
is_select = String()
|
||||||
|
|
||||||
|
trans_id = String()
|
||||||
|
trans_status = String()
|
||||||
|
iso_level = String()
|
||||||
|
encoding = String()
|
0
graphene/contrib/django/debug/tests/__init__.py
Normal file
0
graphene/contrib/django/debug/tests/__init__.py
Normal file
70
graphene/contrib/django/debug/tests/test_query.py
Normal file
70
graphene/contrib/django/debug/tests/test_query.py
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import graphene
|
||||||
|
from graphene.contrib.django import DjangoObjectType
|
||||||
|
|
||||||
|
from ...tests.models import Reporter
|
||||||
|
from ..plugin import DjangoDebugPlugin
|
||||||
|
|
||||||
|
# from examples.starwars_django.models import Character
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.django_db
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_query_well():
|
||||||
|
r1 = Reporter(last_name='ABA')
|
||||||
|
r1.save()
|
||||||
|
r2 = Reporter(last_name='Griffin')
|
||||||
|
r2.save()
|
||||||
|
|
||||||
|
class ReporterType(DjangoObjectType):
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
model = Reporter
|
||||||
|
|
||||||
|
class Query(graphene.ObjectType):
|
||||||
|
reporter = graphene.Field(ReporterType)
|
||||||
|
all_reporters = ReporterType.List()
|
||||||
|
|
||||||
|
def resolve_all_reporters(self, *args, **kwargs):
|
||||||
|
return Reporter.objects.all()
|
||||||
|
|
||||||
|
def resolve_reporter(self, *args, **kwargs):
|
||||||
|
return Reporter.objects.first()
|
||||||
|
|
||||||
|
query = '''
|
||||||
|
query ReporterQuery {
|
||||||
|
reporter {
|
||||||
|
lastName
|
||||||
|
}
|
||||||
|
allReporters {
|
||||||
|
lastName
|
||||||
|
}
|
||||||
|
__debug {
|
||||||
|
sql {
|
||||||
|
rawSql
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
expected = {
|
||||||
|
'reporter': {
|
||||||
|
'lastName': 'ABA',
|
||||||
|
},
|
||||||
|
'allReporters': [{
|
||||||
|
'lastName': 'ABA',
|
||||||
|
}, {
|
||||||
|
'lastName': 'Griffin',
|
||||||
|
}],
|
||||||
|
'__debug': {
|
||||||
|
'sql': [{
|
||||||
|
'rawSql': str(Reporter.objects.order_by('pk')[:1].query)
|
||||||
|
}, {
|
||||||
|
'rawSql': str(Reporter.objects.all().query)
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()])
|
||||||
|
result = schema.execute(query)
|
||||||
|
assert not result.errors
|
||||||
|
assert result.data == expected
|
7
graphene/contrib/django/debug/types.py
Normal file
7
graphene/contrib/django/debug/types.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
from ....core.classtypes.objecttype import ObjectType
|
||||||
|
from ....core.types import Field
|
||||||
|
from .sql.types import DjangoDebugSQL
|
||||||
|
|
||||||
|
|
||||||
|
class DjangoDebug(ObjectType):
|
||||||
|
sql = Field(DjangoDebugSQL.List())
|
|
@ -5,7 +5,6 @@ from functools import partial
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from ..exceptions import SkipField
|
|
||||||
from .options import Options
|
from .options import Options
|
||||||
|
|
||||||
|
|
||||||
|
@ -82,13 +81,18 @@ class FieldsOptions(Options):
|
||||||
def fields_map(self):
|
def fields_map(self):
|
||||||
return OrderedDict([(f.attname, f) for f in self.fields])
|
return OrderedDict([(f.attname, f) for f in self.fields])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fields_group_type(self):
|
||||||
|
from ..types.field import FieldsGroupType
|
||||||
|
return FieldsGroupType(*self.local_fields)
|
||||||
|
|
||||||
|
|
||||||
class FieldsClassTypeMeta(ClassTypeMeta):
|
class FieldsClassTypeMeta(ClassTypeMeta):
|
||||||
options_class = FieldsOptions
|
options_class = FieldsOptions
|
||||||
|
|
||||||
def extend_fields(cls, bases):
|
def extend_fields(cls, bases):
|
||||||
new_fields = cls._meta.local_fields
|
new_fields = cls._meta.local_fields
|
||||||
field_names = {f.name: f for f in new_fields}
|
field_names = {f.attname: f for f in new_fields}
|
||||||
|
|
||||||
for base in bases:
|
for base in bases:
|
||||||
if not isinstance(base, FieldsClassTypeMeta):
|
if not isinstance(base, FieldsClassTypeMeta):
|
||||||
|
@ -96,17 +100,17 @@ class FieldsClassTypeMeta(ClassTypeMeta):
|
||||||
|
|
||||||
parent_fields = base._meta.local_fields
|
parent_fields = base._meta.local_fields
|
||||||
for field in parent_fields:
|
for field in parent_fields:
|
||||||
if field.name in field_names and field.type.__class__ != field_names[
|
if field.attname in field_names and field.type.__class__ != field_names[
|
||||||
field.name].type.__class__:
|
field.attname].type.__class__:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
'Local field %r in class %r (%r) clashes '
|
'Local field %r in class %r (%r) clashes '
|
||||||
'with field with similar name from '
|
'with field with similar name from '
|
||||||
'Interface %s (%r)' % (
|
'Interface %s (%r)' % (
|
||||||
field.name,
|
field.attname,
|
||||||
cls.__name__,
|
cls.__name__,
|
||||||
field.__class__,
|
field.__class__,
|
||||||
base.__name__,
|
base.__name__,
|
||||||
field_names[field.name].__class__)
|
field_names[field.attname].__class__)
|
||||||
)
|
)
|
||||||
new_field = copy.copy(field)
|
new_field = copy.copy(field)
|
||||||
cls.add_to_class(field.attname, new_field)
|
cls.add_to_class(field.attname, new_field)
|
||||||
|
@ -124,11 +128,4 @@ class FieldsClassType(six.with_metaclass(FieldsClassTypeMeta, ClassType)):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def fields_internal_types(cls, schema):
|
def fields_internal_types(cls, schema):
|
||||||
fields = []
|
return schema.T(cls._meta.fields_group_type)
|
||||||
for field in cls._meta.fields:
|
|
||||||
try:
|
|
||||||
fields.append((field.name, schema.T(field)))
|
|
||||||
except SkipField:
|
|
||||||
continue
|
|
||||||
|
|
||||||
return OrderedDict(fields)
|
|
||||||
|
|
|
@ -24,4 +24,4 @@ def test_mutation():
|
||||||
assert list(object_type.get_fields().keys()) == ['name']
|
assert list(object_type.get_fields().keys()) == ['name']
|
||||||
assert MyMutation._meta.fields_map['name'].object_type == MyMutation
|
assert MyMutation._meta.fields_map['name'].object_type == MyMutation
|
||||||
assert isinstance(MyMutation.arguments, ArgumentsGroup)
|
assert isinstance(MyMutation.arguments, ArgumentsGroup)
|
||||||
assert 'argName' in MyMutation.arguments
|
assert 'argName' in schema.T(MyMutation.arguments)
|
||||||
|
|
|
@ -10,6 +10,7 @@ from graphql.core.utils.schema_printer import print_schema
|
||||||
|
|
||||||
from graphene import signals
|
from graphene import signals
|
||||||
|
|
||||||
|
from ..plugins import CamelCase, PluginManager
|
||||||
from .classtypes.base import ClassType
|
from .classtypes.base import ClassType
|
||||||
from .types.base import InstanceType
|
from .types.base import InstanceType
|
||||||
|
|
||||||
|
@ -25,7 +26,7 @@ class Schema(object):
|
||||||
_executor = None
|
_executor = None
|
||||||
|
|
||||||
def __init__(self, query=None, mutation=None, subscription=None,
|
def __init__(self, query=None, mutation=None, subscription=None,
|
||||||
name='Schema', executor=None):
|
name='Schema', executor=None, plugins=None, auto_camelcase=True):
|
||||||
self._types_names = {}
|
self._types_names = {}
|
||||||
self._types = {}
|
self._types = {}
|
||||||
self.mutation = mutation
|
self.mutation = mutation
|
||||||
|
@ -33,11 +34,20 @@ class Schema(object):
|
||||||
self.subscription = subscription
|
self.subscription = subscription
|
||||||
self.name = name
|
self.name = name
|
||||||
self.executor = executor
|
self.executor = executor
|
||||||
|
plugins = plugins or []
|
||||||
|
if auto_camelcase:
|
||||||
|
plugins.append(CamelCase())
|
||||||
|
self.plugins = PluginManager(self, plugins)
|
||||||
signals.init_schema.send(self)
|
signals.init_schema.send(self)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '<Schema: %s (%s)>' % (str(self.name), hash(self))
|
return '<Schema: %s (%s)>' % (str(self.name), hash(self))
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
if name in self.plugins:
|
||||||
|
return getattr(self.plugins, name)
|
||||||
|
return super(Schema, self).__getattr__(name)
|
||||||
|
|
||||||
def T(self, _type):
|
def T(self, _type):
|
||||||
if not _type:
|
if not _type:
|
||||||
return
|
return
|
||||||
|
@ -108,17 +118,10 @@ class Schema(object):
|
||||||
def types(self):
|
def types(self):
|
||||||
return self._types_names
|
return self._types_names
|
||||||
|
|
||||||
def execute(self, request='', root=None, vars=None,
|
def execute(self, request='', root=None, args=None, **kwargs):
|
||||||
operation_name=None, **kwargs):
|
kwargs = dict(kwargs, request=request, root=root, args=args, schema=self.schema)
|
||||||
root = root or object()
|
with self.plugins.context_execution(**kwargs) as execute_kwargs:
|
||||||
return self.executor.execute(
|
return self.executor.execute(**execute_kwargs)
|
||||||
self.schema,
|
|
||||||
request,
|
|
||||||
root=root,
|
|
||||||
args=vars,
|
|
||||||
operation_name=operation_name,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def introspect(self):
|
def introspect(self):
|
||||||
return self.execute(introspection_query).data
|
return self.execute(introspection_query).data
|
||||||
|
|
|
@ -34,10 +34,11 @@ def test_field_type():
|
||||||
assert schema.T(f).type == GraphQLString
|
assert schema.T(f).type == GraphQLString
|
||||||
|
|
||||||
|
|
||||||
def test_field_name_automatic_camelcase():
|
def test_field_name():
|
||||||
f = Field(GraphQLString)
|
f = Field(GraphQLString)
|
||||||
f.contribute_to_class(MyOt, 'field_name')
|
f.contribute_to_class(MyOt, 'field_name')
|
||||||
assert f.name == 'fieldName'
|
assert f.name is None
|
||||||
|
assert f.attname == 'field_name'
|
||||||
|
|
||||||
|
|
||||||
def test_field_name_use_name_if_exists():
|
def test_field_name_use_name_if_exists():
|
||||||
|
|
|
@ -1,19 +1,17 @@
|
||||||
from collections import OrderedDict
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
|
||||||
from graphql.core.type import GraphQLArgument
|
from graphql.core.type import GraphQLArgument
|
||||||
|
|
||||||
from ...utils import ProxySnakeDict, to_camel_case
|
from ...utils import ProxySnakeDict
|
||||||
from .base import ArgumentType, InstanceType, OrderedType
|
from .base import ArgumentType, GroupNamedType, NamedType, OrderedType
|
||||||
|
|
||||||
|
|
||||||
class Argument(OrderedType):
|
class Argument(NamedType, OrderedType):
|
||||||
|
|
||||||
def __init__(self, type, description=None, default=None,
|
def __init__(self, type, description=None, default=None,
|
||||||
name=None, _creation_counter=None):
|
name=None, _creation_counter=None):
|
||||||
super(Argument, self).__init__(_creation_counter=_creation_counter)
|
super(Argument, self).__init__(name=name, _creation_counter=_creation_counter)
|
||||||
self.name = name
|
|
||||||
self.type = type
|
self.type = type
|
||||||
self.description = description
|
self.description = description
|
||||||
self.default = default
|
self.default = default
|
||||||
|
@ -27,47 +25,32 @@ class Argument(OrderedType):
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
class ArgumentsGroup(InstanceType):
|
class ArgumentsGroup(GroupNamedType):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
arguments = to_arguments(*args, **kwargs)
|
arguments = to_arguments(*args, **kwargs)
|
||||||
self.arguments = OrderedDict([(arg.name, arg) for arg in arguments])
|
super(ArgumentsGroup, self).__init__(*arguments)
|
||||||
|
|
||||||
def internal_type(self, schema):
|
|
||||||
return OrderedDict([(arg.name, schema.T(arg))
|
|
||||||
for arg in self.arguments.values()])
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.arguments)
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return iter(self.arguments)
|
|
||||||
|
|
||||||
def __contains__(self, *args):
|
|
||||||
return self.arguments.__contains__(*args)
|
|
||||||
|
|
||||||
def __getitem__(self, *args):
|
|
||||||
return self.arguments.__getitem__(*args)
|
|
||||||
|
|
||||||
|
|
||||||
def to_arguments(*args, **kwargs):
|
def to_arguments(*args, **kwargs):
|
||||||
arguments = {}
|
arguments = {}
|
||||||
iter_arguments = chain(kwargs.items(), [(None, a) for a in args])
|
iter_arguments = chain(kwargs.items(), [(None, a) for a in args])
|
||||||
|
|
||||||
for name, arg in iter_arguments:
|
for default_name, arg in iter_arguments:
|
||||||
if isinstance(arg, Argument):
|
if isinstance(arg, Argument):
|
||||||
argument = arg
|
argument = arg
|
||||||
elif isinstance(arg, ArgumentType):
|
elif isinstance(arg, ArgumentType):
|
||||||
argument = arg.as_argument()
|
argument = arg.as_argument()
|
||||||
else:
|
else:
|
||||||
raise ValueError('Unknown argument %s=%r' % (name, arg))
|
raise ValueError('Unknown argument %s=%r' % (default_name, arg))
|
||||||
|
|
||||||
if name:
|
if default_name:
|
||||||
argument.name = to_camel_case(name)
|
argument.default_name = default_name
|
||||||
assert argument.name, 'Argument in field must have a name'
|
|
||||||
assert argument.name not in arguments, 'Found more than one Argument with same name {}'.format(
|
name = argument.name or argument.default_name
|
||||||
argument.name)
|
assert name, 'Argument in field must have a name'
|
||||||
arguments[argument.name] = argument
|
assert name not in arguments, 'Found more than one Argument with same name {}'.format(name)
|
||||||
|
arguments[name] = argument
|
||||||
|
|
||||||
return sorted(arguments.values())
|
return sorted(arguments.values())
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from functools import total_ordering
|
from collections import OrderedDict
|
||||||
|
from functools import partial, total_ordering
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
@ -125,3 +126,39 @@ class FieldType(MirroredType):
|
||||||
|
|
||||||
class MountedType(FieldType, ArgumentType):
|
class MountedType(FieldType, ArgumentType):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NamedType(InstanceType):
|
||||||
|
|
||||||
|
def __init__(self, name=None, default_name=None, *args, **kwargs):
|
||||||
|
self.name = name
|
||||||
|
self.default_name = None
|
||||||
|
super(NamedType, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupNamedType(InstanceType):
|
||||||
|
|
||||||
|
def __init__(self, *types):
|
||||||
|
self.types = types
|
||||||
|
|
||||||
|
def get_named_type(self, schema, type):
|
||||||
|
name = type.name or schema.get_default_namedtype_name(type.default_name)
|
||||||
|
return name, schema.T(type)
|
||||||
|
|
||||||
|
def iter_types(self, schema):
|
||||||
|
return map(partial(self.get_named_type, schema), self.types)
|
||||||
|
|
||||||
|
def internal_type(self, schema):
|
||||||
|
return OrderedDict(self.iter_types(schema))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.types)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.types)
|
||||||
|
|
||||||
|
def __contains__(self, *args):
|
||||||
|
return self.types.__contains__(*args)
|
||||||
|
|
||||||
|
def __getitem__(self, *args):
|
||||||
|
return self.types.__getitem__(*args)
|
||||||
|
|
|
@ -4,23 +4,22 @@ from functools import wraps
|
||||||
import six
|
import six
|
||||||
from graphql.core.type import GraphQLField, GraphQLInputObjectField
|
from graphql.core.type import GraphQLField, GraphQLInputObjectField
|
||||||
|
|
||||||
from ...utils import to_camel_case
|
|
||||||
from ..classtypes.base import FieldsClassType
|
from ..classtypes.base import FieldsClassType
|
||||||
from ..classtypes.inputobjecttype import InputObjectType
|
from ..classtypes.inputobjecttype import InputObjectType
|
||||||
from ..classtypes.mutation import Mutation
|
from ..classtypes.mutation import Mutation
|
||||||
|
from ..exceptions import SkipField
|
||||||
from .argument import ArgumentsGroup, snake_case_args
|
from .argument import ArgumentsGroup, snake_case_args
|
||||||
from .base import LazyType, MountType, OrderedType
|
from .base import GroupNamedType, LazyType, MountType, NamedType, OrderedType
|
||||||
from .definitions import NonNull
|
from .definitions import NonNull
|
||||||
|
|
||||||
|
|
||||||
class Field(OrderedType):
|
class Field(NamedType, OrderedType):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, type, description=None, args=None, name=None, resolver=None,
|
self, type, description=None, args=None, name=None, resolver=None,
|
||||||
required=False, default=None, *args_list, **kwargs):
|
required=False, default=None, *args_list, **kwargs):
|
||||||
_creation_counter = kwargs.pop('_creation_counter', None)
|
_creation_counter = kwargs.pop('_creation_counter', None)
|
||||||
super(Field, self).__init__(_creation_counter=_creation_counter)
|
super(Field, self).__init__(name=name, _creation_counter=_creation_counter)
|
||||||
self.name = name
|
|
||||||
if isinstance(type, six.string_types):
|
if isinstance(type, six.string_types):
|
||||||
type = LazyType(type)
|
type = LazyType(type)
|
||||||
self.required = required
|
self.required = required
|
||||||
|
@ -36,9 +35,8 @@ class Field(OrderedType):
|
||||||
assert issubclass(
|
assert issubclass(
|
||||||
cls, (FieldsClassType)), 'Field {} cannot be mounted in {}'.format(
|
cls, (FieldsClassType)), 'Field {} cannot be mounted in {}'.format(
|
||||||
self, cls)
|
self, cls)
|
||||||
if not self.name:
|
|
||||||
self.name = to_camel_case(attname)
|
|
||||||
self.attname = attname
|
self.attname = attname
|
||||||
|
self.default_name = attname
|
||||||
self.object_type = cls
|
self.object_type = cls
|
||||||
self.mount(cls)
|
self.mount(cls)
|
||||||
if isinstance(self.type, MountType):
|
if isinstance(self.type, MountType):
|
||||||
|
@ -117,12 +115,11 @@ class Field(OrderedType):
|
||||||
return hash((self.creation_counter, self.object_type))
|
return hash((self.creation_counter, self.object_type))
|
||||||
|
|
||||||
|
|
||||||
class InputField(OrderedType):
|
class InputField(NamedType, OrderedType):
|
||||||
|
|
||||||
def __init__(self, type, description=None, default=None,
|
def __init__(self, type, description=None, default=None,
|
||||||
name=None, _creation_counter=None, required=False):
|
name=None, _creation_counter=None, required=False):
|
||||||
super(InputField, self).__init__(_creation_counter=_creation_counter)
|
super(InputField, self).__init__(_creation_counter=_creation_counter)
|
||||||
self.name = name
|
|
||||||
if required:
|
if required:
|
||||||
type = NonNull(type)
|
type = NonNull(type)
|
||||||
self.type = type
|
self.type = type
|
||||||
|
@ -133,9 +130,8 @@ class InputField(OrderedType):
|
||||||
assert issubclass(
|
assert issubclass(
|
||||||
cls, (InputObjectType)), 'InputField {} cannot be mounted in {}'.format(
|
cls, (InputObjectType)), 'InputField {} cannot be mounted in {}'.format(
|
||||||
self, cls)
|
self, cls)
|
||||||
if not self.name:
|
|
||||||
self.name = to_camel_case(attname)
|
|
||||||
self.attname = attname
|
self.attname = attname
|
||||||
|
self.default_name = attname
|
||||||
self.object_type = cls
|
self.object_type = cls
|
||||||
self.mount(cls)
|
self.mount(cls)
|
||||||
if isinstance(self.type, MountType):
|
if isinstance(self.type, MountType):
|
||||||
|
@ -146,3 +142,13 @@ class InputField(OrderedType):
|
||||||
return GraphQLInputObjectField(
|
return GraphQLInputObjectField(
|
||||||
schema.T(self.type),
|
schema.T(self.type),
|
||||||
default_value=self.default, description=self.description)
|
default_value=self.default, description=self.description)
|
||||||
|
|
||||||
|
|
||||||
|
class FieldsGroupType(GroupNamedType):
|
||||||
|
|
||||||
|
def iter_types(self, schema):
|
||||||
|
for field in sorted(self.types):
|
||||||
|
try:
|
||||||
|
yield self.get_named_type(schema, field)
|
||||||
|
except SkipField:
|
||||||
|
continue
|
||||||
|
|
|
@ -5,6 +5,7 @@ from .base import MountedType
|
||||||
|
|
||||||
|
|
||||||
class ScalarType(MountedType):
|
class ScalarType(MountedType):
|
||||||
|
|
||||||
def internal_type(self, schema):
|
def internal_type(self, schema):
|
||||||
return self._internal_type
|
return self._internal_type
|
||||||
|
|
||||||
|
|
|
@ -27,8 +27,8 @@ def test_to_arguments():
|
||||||
other_kwarg=String(),
|
other_kwarg=String(),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert [a.name for a in arguments] == [
|
assert [a.name or a.default_name for a in arguments] == [
|
||||||
'myArg', 'otherArg', 'myKwarg', 'otherKwarg']
|
'myArg', 'otherArg', 'my_kwarg', 'other_kwarg']
|
||||||
|
|
||||||
|
|
||||||
def test_to_arguments_no_name():
|
def test_to_arguments_no_name():
|
||||||
|
|
|
@ -20,7 +20,7 @@ def test_field_internal_type():
|
||||||
schema = Schema(query=Query)
|
schema = Schema(query=Query)
|
||||||
|
|
||||||
type = schema.T(field)
|
type = schema.T(field)
|
||||||
assert field.name == 'myField'
|
assert field.name is None
|
||||||
assert field.attname == 'my_field'
|
assert field.attname == 'my_field'
|
||||||
assert isinstance(type, GraphQLField)
|
assert isinstance(type, GraphQLField)
|
||||||
assert type.description == 'My argument'
|
assert type.description == 'My argument'
|
||||||
|
@ -98,9 +98,10 @@ def test_field_string_reference():
|
||||||
|
|
||||||
def test_field_custom_arguments():
|
def test_field_custom_arguments():
|
||||||
field = Field(None, name='my_customName', p=String())
|
field = Field(None, name='my_customName', p=String())
|
||||||
|
schema = Schema()
|
||||||
|
|
||||||
args = field.arguments
|
args = field.arguments
|
||||||
assert 'p' in args
|
assert 'p' in schema.T(args)
|
||||||
|
|
||||||
|
|
||||||
def test_inputfield_internal_type():
|
def test_inputfield_internal_type():
|
||||||
|
@ -115,7 +116,7 @@ def test_inputfield_internal_type():
|
||||||
schema = Schema(query=MyObjectType)
|
schema = Schema(query=MyObjectType)
|
||||||
|
|
||||||
type = schema.T(field)
|
type = schema.T(field)
|
||||||
assert field.name == 'myField'
|
assert field.name is None
|
||||||
assert field.attname == 'my_field'
|
assert field.attname == 'my_field'
|
||||||
assert isinstance(type, GraphQLInputObjectField)
|
assert isinstance(type, GraphQLInputObjectField)
|
||||||
assert type.description == 'My input field'
|
assert type.description == 'My input field'
|
||||||
|
|
6
graphene/plugins/__init__.py
Normal file
6
graphene/plugins/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
from .base import Plugin, PluginManager
|
||||||
|
from .camel_case import CamelCase
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'Plugin', 'PluginManager', 'CamelCase'
|
||||||
|
]
|
53
graphene/plugins/base.py
Normal file
53
graphene/plugins/base.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from functools import partial, reduce
|
||||||
|
|
||||||
|
|
||||||
|
class Plugin(object):
|
||||||
|
|
||||||
|
def contribute_to_schema(self, schema):
|
||||||
|
self.schema = schema
|
||||||
|
|
||||||
|
|
||||||
|
def apply_function(a, b):
|
||||||
|
return b(a)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginManager(object):
|
||||||
|
|
||||||
|
PLUGIN_FUNCTIONS = ('get_default_namedtype_name', )
|
||||||
|
|
||||||
|
def __init__(self, schema, plugins=[]):
|
||||||
|
self.schema = schema
|
||||||
|
self.plugins = []
|
||||||
|
for plugin in plugins:
|
||||||
|
self.add_plugin(plugin)
|
||||||
|
|
||||||
|
def add_plugin(self, plugin):
|
||||||
|
if hasattr(plugin, 'contribute_to_schema'):
|
||||||
|
plugin.contribute_to_schema(self.schema)
|
||||||
|
self.plugins.append(plugin)
|
||||||
|
|
||||||
|
def get_plugin_functions(self, function):
|
||||||
|
for plugin in self.plugins:
|
||||||
|
if not hasattr(plugin, function):
|
||||||
|
continue
|
||||||
|
yield getattr(plugin, function)
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
functions = self.get_plugin_functions(name)
|
||||||
|
return partial(reduce, apply_function, functions)
|
||||||
|
|
||||||
|
def __contains__(self, name):
|
||||||
|
return name in self.PLUGIN_FUNCTIONS
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def context_execution(self, **executor):
|
||||||
|
contexts = []
|
||||||
|
functions = self.get_plugin_functions('context_execution')
|
||||||
|
for f in functions:
|
||||||
|
context = f(executor)
|
||||||
|
executor = context.__enter__()
|
||||||
|
contexts.append((context, executor))
|
||||||
|
yield executor
|
||||||
|
for context, value in contexts[::-1]:
|
||||||
|
context.__exit__(None, None, None)
|
7
graphene/plugins/camel_case.py
Normal file
7
graphene/plugins/camel_case.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
from ..utils import to_camel_case
|
||||||
|
|
||||||
|
|
||||||
|
class CamelCase(object):
|
||||||
|
|
||||||
|
def get_default_namedtype_name(self, value):
|
||||||
|
return to_camel_case(value)
|
|
@ -34,8 +34,7 @@ schema = Schema(query=Query, mutation=MyResultMutation)
|
||||||
|
|
||||||
def test_mutation_arguments():
|
def test_mutation_arguments():
|
||||||
assert ChangeNumber.arguments
|
assert ChangeNumber.arguments
|
||||||
assert list(ChangeNumber.arguments) == ['input']
|
assert 'input' in schema.T(ChangeNumber.arguments)
|
||||||
assert 'input' in ChangeNumber.arguments
|
|
||||||
inner_type = ChangeNumber.input_type
|
inner_type = ChangeNumber.input_type
|
||||||
client_mutation_id_field = inner_type._meta.fields_map[
|
client_mutation_id_field = inner_type._meta.fields_map[
|
||||||
'client_mutation_id']
|
'client_mutation_id']
|
||||||
|
|
Loading…
Reference in New Issue
Block a user