Merge pull request #64 from graphql-python/features/django-debug

Add plugin to debug Django SQL queries in the Query (similar to DjangoDebugToolbar)
This commit is contained in:
Syrus Akbary 2015-12-09 19:48:57 -08:00
commit 64460839be
11 changed files with 360 additions and 11 deletions

View File

@ -0,0 +1,4 @@
from .plugin import DjangoDebugPlugin
from .types import DjangoDebug
__all__ = ['DjangoDebugPlugin', 'DjangoDebug']

View File

@ -0,0 +1,77 @@
from contextlib import contextmanager
from django.db import connections
from ....core.types import Field
from ....plugins import Plugin
from .sql.tracking import unwrap_cursor, wrap_cursor
from .sql.types import DjangoDebugSQL
from .types import DjangoDebug
class WrappedRoot(object):
def __init__(self, root):
self._recorded = []
self._root = root
def record(self, **log):
self._recorded.append(DjangoDebugSQL(**log))
def debug(self):
return DjangoDebug(sql=self._recorded)
class WrapRoot(object):
@property
def _root(self):
return self._wrapped_root.root
@_root.setter
def _root(self, value):
self._wrapped_root = value
def resolve_debug(self, args, info):
return self._wrapped_root.debug()
def debug_objecttype(objecttype):
return type(
'Debug{}'.format(objecttype._meta.type_name),
(WrapRoot, objecttype),
{'debug': Field(DjangoDebug, name='__debug')})
class DjangoDebugPlugin(Plugin):
def transform_type(self, _type):
if _type == self.schema.query:
return
return _type
def enable_instrumentation(self, wrapped_root):
# This is thread-safe because database connections are thread-local.
for connection in connections.all():
wrap_cursor(connection, wrapped_root)
def disable_instrumentation(self):
for connection in connections.all():
unwrap_cursor(connection)
def wrap_schema(self, schema_type):
query = schema_type._query
if query:
class_type = self.schema.objecttype(schema_type._query)
assert class_type, 'The query in schema is not constructed with graphene'
_type = debug_objecttype(class_type)
schema_type._query = self.schema.T(_type)
return schema_type
@contextmanager
def context_execution(self, executor):
executor['root'] = WrappedRoot(root=executor['root'])
executor['schema'] = self.wrap_schema(executor['schema'])
self.enable_instrumentation(executor['root'])
yield executor
self.disable_instrumentation()

View File

@ -0,0 +1,165 @@
# Code obtained from django-debug-toolbar sql panel tracking
from __future__ import absolute_import, unicode_literals
import json
from threading import local
from time import time
from django.utils import six
from django.utils.encoding import force_text
class SQLQueryTriggered(Exception):
"""Thrown when template panel triggers a query"""
class ThreadLocalState(local):
def __init__(self):
self.enabled = True
@property
def Wrapper(self):
if self.enabled:
return NormalCursorWrapper
return ExceptionCursorWrapper
def recording(self, v):
self.enabled = v
state = ThreadLocalState()
recording = state.recording # export function
def wrap_cursor(connection, panel):
if not hasattr(connection, '_djdt_cursor'):
connection._djdt_cursor = connection.cursor
def cursor():
return state.Wrapper(connection._djdt_cursor(), connection, panel)
connection.cursor = cursor
return cursor
def unwrap_cursor(connection):
if hasattr(connection, '_djdt_cursor'):
del connection._djdt_cursor
del connection.cursor
class ExceptionCursorWrapper(object):
"""
Wraps a cursor and raises an exception on any operation.
Used in Templates panel.
"""
def __init__(self, cursor, db, logger):
pass
def __getattr__(self, attr):
raise SQLQueryTriggered()
class NormalCursorWrapper(object):
"""
Wraps a cursor and logs queries.
"""
def __init__(self, cursor, db, logger):
self.cursor = cursor
# Instance of a BaseDatabaseWrapper subclass
self.db = db
# logger must implement a ``record`` method
self.logger = logger
def _quote_expr(self, element):
if isinstance(element, six.string_types):
return "'%s'" % force_text(element).replace("'", "''")
else:
return repr(element)
def _quote_params(self, params):
if not params:
return params
if isinstance(params, dict):
return dict((key, self._quote_expr(value))
for key, value in params.items())
return list(map(self._quote_expr, params))
def _decode(self, param):
try:
return force_text(param, strings_only=True)
except UnicodeDecodeError:
return '(encoded string)'
def _record(self, method, sql, params):
start_time = time()
try:
return method(sql, params)
finally:
stop_time = time()
duration = (stop_time - start_time)
_params = ''
try:
_params = json.dumps(list(map(self._decode, params)))
except Exception:
pass # object not JSON serializable
alias = getattr(self.db, 'alias', 'default')
conn = self.db.connection
vendor = getattr(conn, 'vendor', 'unknown')
params = {
'vendor': vendor,
'alias': alias,
'sql': self.db.ops.last_executed_query(
self.cursor, sql, self._quote_params(params)),
'duration': duration,
'raw_sql': sql,
'params': _params,
'start_time': start_time,
'stop_time': stop_time,
'is_slow': duration > 10,
'is_select': sql.lower().strip().startswith('select'),
}
if vendor == 'postgresql':
# If an erroneous query was ran on the connection, it might
# be in a state where checking isolation_level raises an
# exception.
try:
iso_level = conn.isolation_level
except conn.InternalError:
iso_level = 'unknown'
params.update({
'trans_id': self.logger.get_transaction_id(alias),
'trans_status': conn.get_transaction_status(),
'iso_level': iso_level,
'encoding': conn.encoding,
})
# We keep `sql` to maintain backwards compatibility
self.logger.record(**params)
def callproc(self, procname, params=()):
return self._record(self.cursor.callproc, procname, params)
def execute(self, sql, params=()):
return self._record(self.cursor.execute, sql, params)
def executemany(self, sql, param_list):
return self._record(self.cursor.executemany, sql, param_list)
def __getattr__(self, attr):
return getattr(self.cursor, attr)
def __iter__(self):
return iter(self.cursor)
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.close()

View File

@ -0,0 +1,19 @@
from .....core import Float, ObjectType, String
class DjangoDebugSQL(ObjectType):
vendor = String()
alias = String()
sql = String()
duration = Float()
raw_sql = String()
params = String()
start_time = Float()
stop_time = Float()
is_slow = String()
is_select = String()
trans_id = String()
trans_status = String()
iso_level = String()
encoding = String()

View File

@ -0,0 +1,70 @@
import pytest
import graphene
from graphene.contrib.django import DjangoObjectType
from ...tests.models import Reporter
from ..plugin import DjangoDebugPlugin
# from examples.starwars_django.models import Character
pytestmark = pytest.mark.django_db
def test_should_query_well():
r1 = Reporter(last_name='ABA')
r1.save()
r2 = Reporter(last_name='Griffin')
r2.save()
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType)
all_reporters = ReporterType.List()
def resolve_all_reporters(self, *args, **kwargs):
return Reporter.objects.all()
def resolve_reporter(self, *args, **kwargs):
return Reporter.objects.first()
query = '''
query ReporterQuery {
reporter {
lastName
}
allReporters {
lastName
}
__debug {
sql {
rawSql
}
}
}
'''
expected = {
'reporter': {
'lastName': 'ABA',
},
'allReporters': [{
'lastName': 'ABA',
}, {
'lastName': 'Griffin',
}],
'__debug': {
'sql': [{
'rawSql': str(Reporter.objects.order_by('pk')[:1].query)
}, {
'rawSql': str(Reporter.objects.all().query)
}]
}
}
schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()])
result = schema.execute(query)
assert not result.errors
assert result.data == expected

View File

@ -0,0 +1,7 @@
from ....core.classtypes.objecttype import ObjectType
from ....core.types import Field
from .sql.types import DjangoDebugSQL
class DjangoDebug(ObjectType):
sql = Field(DjangoDebugSQL.List())

View File

@ -118,17 +118,10 @@ class Schema(object):
def types(self):
return self._types_names
def execute(self, request='', root=None, vars=None,
operation_name=None, **kwargs):
root = root or object()
return self.executor.execute(
self.schema,
request,
root=root,
args=vars,
operation_name=operation_name,
**kwargs
)
def execute(self, request='', root=None, args=None, **kwargs):
kwargs = dict(kwargs, request=request, root=root, args=args, schema=self.schema)
with self.plugins.context_execution(**kwargs) as execute_kwargs:
return self.executor.execute(**execute_kwargs)
def introspect(self):
return self.execute(introspection_query).data

View File

@ -129,6 +129,7 @@ class MountedType(FieldType, ArgumentType):
class NamedType(InstanceType):
def __init__(self, name=None, default_name=None, *args, **kwargs):
self.name = name
self.default_name = None

View File

@ -1,3 +1,4 @@
from contextlib import contextmanager
from functools import partial, reduce
@ -38,3 +39,15 @@ class PluginManager(object):
def __contains__(self, name):
return name in self.PLUGIN_FUNCTIONS
@contextmanager
def context_execution(self, **executor):
contexts = []
functions = self.get_plugin_functions('context_execution')
for f in functions:
context = f(executor)
executor = context.__enter__()
contexts.append((context, executor))
yield executor
for context, value in contexts[::-1]:
context.__exit__(None, None, None)