First phase of middlewares

This commit is contained in:
Syrus Akbary 2016-05-20 23:25:54 -07:00
parent 8421b59d3a
commit 3428725314
9 changed files with 96 additions and 102 deletions

View File

@ -1,4 +1,4 @@
from .plugin import DjangoDebugPlugin from .middleware import DjangoDebugMiddleware
from .types import DjangoDebug from .types import DjangoDebug
__all__ = ['DjangoDebugPlugin', 'DjangoDebug'] __all__ = ['DjangoDebugMiddleware', 'DjangoDebug']

View File

@ -0,0 +1,56 @@
from promise import Promise
from django.db import connections
from ....core.schema import GraphQLSchema
from ....core.types import Field
from .sql.tracking import unwrap_cursor, wrap_cursor
from .types import DjangoDebug
class DjangoDebugContext(object):
def __init__(self):
self.debug_promise = None
self.promises = []
self.enable_instrumentation()
self.object = DjangoDebug(sql=[])
def get_debug_promise(self):
if not self.debug_promise:
self.debug_promise = Promise.all(self.promises)
return self.debug_promise.then(self.on_resolve_all_promises)
def on_resolve_all_promises(self, values):
self.disable_instrumentation()
return self.object
def add_promise(self, promise):
if self.debug_promise and not self.debug_promise.is_fulfilled:
self.promises.append(promise)
def enable_instrumentation(self):
# This is thread-safe because database connections are thread-local.
for connection in connections.all():
wrap_cursor(connection, self)
def disable_instrumentation(self):
for connection in connections.all():
unwrap_cursor(connection)
class DjangoDebugMiddleware(object):
def resolve(self, next, root, args, context, info):
django_debug = getattr(context, 'django_debug', None)
if not django_debug:
if context is None:
raise Exception('DjangoDebug cannot be executed in None contexts')
try:
context.django_debug = DjangoDebugContext()
except Exception, e:
raise Exception('DjangoDebug need the context to be writable, context received: {}.'.format(context.__class__.__name__))
if info.schema.graphene_schema.T(DjangoDebug) == info.return_type:
return context.django_debug.get_debug_promise()
promise = next(root, args, context, info)
context.django_debug.add_promise(promise)
return promise

View File

@ -1,77 +0,0 @@
from contextlib import contextmanager
from django.db import connections
from graphene import with_context
from ....core.schema import GraphQLSchema
from ....core.types import Field
from ....plugins import Plugin
from .sql.tracking import unwrap_cursor, wrap_cursor
from .sql.types import DjangoDebugSQL
from .types import DjangoDebug
class EmptyContext(object):
pass
class DjangoDebugContext(object):
def __init__(self):
self._recorded = []
def record(self, **log):
self._recorded.append(DjangoDebugSQL(**log))
def debug(self):
return DjangoDebug(sql=self._recorded)
class WrapRoot(object):
@with_context
def resolve_debug(self, args, context, info):
return context.django_debug.debug()
def debug_objecttype(objecttype):
return type(
'Debug{}'.format(objecttype._meta.type_name),
(WrapRoot, objecttype),
{'debug': Field(DjangoDebug, name='__debug')})
class DjangoDebugPlugin(Plugin):
def enable_instrumentation(self, wrapped_root):
# This is thread-safe because database connections are thread-local.
for connection in connections.all():
wrap_cursor(connection, wrapped_root)
def disable_instrumentation(self):
for connection in connections.all():
unwrap_cursor(connection)
def wrap_schema(self, schema_type):
query = schema_type._query
if query:
class_type = self.schema.objecttype(schema_type.get_query_type())
assert class_type, 'The query in schema is not constructed with graphene'
_type = debug_objecttype(class_type)
self.schema.register(_type, force=True)
return GraphQLSchema(
self.schema,
self.schema.T(_type),
schema_type.get_mutation_type(),
schema_type.get_subscription_type()
)
return schema_type
@contextmanager
def context_execution(self, executor):
context_value = executor.get('context_value') or EmptyContext()
context_value.django_debug = DjangoDebugContext()
executor['context_value'] = context_value
executor['schema'] = self.wrap_schema(executor['schema'])
self.enable_instrumentation(context_value.django_debug)
yield executor
self.disable_instrumentation()

View File

@ -8,6 +8,7 @@ from time import time
from django.utils import six from django.utils import six
from django.utils.encoding import force_text from django.utils.encoding import force_text
from .types import DjangoDebugSQL, DjangoDebugPostgreSQL
class SQLQueryTriggered(Exception): class SQLQueryTriggered(Exception):
"""Thrown when template panel triggers a query""" """Thrown when template panel triggers a query"""
@ -139,9 +140,11 @@ class NormalCursorWrapper(object):
'iso_level': iso_level, 'iso_level': iso_level,
'encoding': conn.encoding, 'encoding': conn.encoding,
}) })
_sql = DjangoDebugPostgreSQL(**params)
else:
_sql = DjangoDebugSQL(**params)
# We keep `sql` to maintain backwards compatibility # We keep `sql` to maintain backwards compatibility
self.logger.record(**params) self.logger.object.sql.append(_sql)
def callproc(self, procname, params=()): def callproc(self, procname, params=()):
return self._record(self.cursor.callproc, procname, params) return self._record(self.cursor.callproc, procname, params)

View File

@ -1,7 +1,7 @@
from .....core import Boolean, Float, ObjectType, String from .....core import Boolean, Float, ObjectType, String
class DjangoDebugSQL(ObjectType): class DjangoDebugBaseSQL(ObjectType):
vendor = String() vendor = String()
alias = String() alias = String()
sql = String() sql = String()
@ -13,6 +13,12 @@ class DjangoDebugSQL(ObjectType):
is_slow = Boolean() is_slow = Boolean()
is_select = Boolean() is_select = Boolean()
class DjangoDebugSQL(DjangoDebugBaseSQL):
pass
class DjangoDebugPostgreSQL(DjangoDebugBaseSQL):
trans_id = String() trans_id = String()
trans_status = String() trans_status = String()
iso_level = String() iso_level = String()

View File

@ -5,7 +5,11 @@ from graphene.contrib.django import DjangoConnectionField, DjangoNode
from graphene.contrib.django.utils import DJANGO_FILTER_INSTALLED from graphene.contrib.django.utils import DJANGO_FILTER_INSTALLED
from ...tests.models import Reporter from ...tests.models import Reporter
from ..plugin import DjangoDebugPlugin from ..middleware import DjangoDebugMiddleware
from ..types import DjangoDebug
class context(object):
pass
# from examples.starwars_django.models import Character # from examples.starwars_django.models import Character
@ -25,6 +29,7 @@ def test_should_query_field():
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType) reporter = graphene.Field(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug')
def resolve_reporter(self, *args, **kwargs): def resolve_reporter(self, *args, **kwargs):
return Reporter.objects.first() return Reporter.objects.first()
@ -51,8 +56,8 @@ def test_should_query_field():
}] }]
} }
} }
schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()]) schema = graphene.Schema(query=Query, plugins=[DjangoDebugMiddleware()])
result = schema.execute(query) result = schema.execute(query, context_value=context())
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
@ -70,6 +75,7 @@ def test_should_query_list():
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = ReporterType.List() all_reporters = ReporterType.List()
debug = graphene.Field(DjangoDebug, name='__debug')
def resolve_all_reporters(self, *args, **kwargs): def resolve_all_reporters(self, *args, **kwargs):
return Reporter.objects.all() return Reporter.objects.all()
@ -98,8 +104,8 @@ def test_should_query_list():
}] }]
} }
} }
schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()]) schema = graphene.Schema(query=Query, plugins=[DjangoDebugMiddleware()])
result = schema.execute(query) result = schema.execute(query, context_value=context())
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
@ -117,6 +123,7 @@ def test_should_query_connection():
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug')
def resolve_all_reporters(self, *args, **kwargs): def resolve_all_reporters(self, *args, **kwargs):
return Reporter.objects.all() return Reporter.objects.all()
@ -146,8 +153,8 @@ def test_should_query_connection():
}] }]
}, },
} }
schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()]) schema = graphene.Schema(query=Query, plugins=[DjangoDebugMiddleware()])
result = schema.execute(query) result = schema.execute(query, context_value=context())
assert not result.errors assert not result.errors
assert result.data['allReporters'] == expected['allReporters'] assert result.data['allReporters'] == expected['allReporters']
assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql'] assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql']
@ -172,6 +179,7 @@ def test_should_query_connectionfilter():
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterType) all_reporters = DjangoFilterConnectionField(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug')
def resolve_all_reporters(self, *args, **kwargs): def resolve_all_reporters(self, *args, **kwargs):
return Reporter.objects.all() return Reporter.objects.all()
@ -201,8 +209,8 @@ def test_should_query_connectionfilter():
}] }]
}, },
} }
schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()]) schema = graphene.Schema(query=Query, plugins=[DjangoDebugMiddleware()])
result = schema.execute(query) result = schema.execute(query, context_value=context())
assert not result.errors assert not result.errors
assert result.data['allReporters'] == expected['allReporters'] assert result.data['allReporters'] == expected['allReporters']
assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql'] assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql']

View File

@ -1,7 +1,7 @@
from ....core.classtypes.objecttype import ObjectType from ....core.classtypes.objecttype import ObjectType
from ....core.types import Field from ....core.types import Field
from .sql.types import DjangoDebugSQL from .sql.types import DjangoDebugBaseSQL
class DjangoDebug(ObjectType): class DjangoDebug(ObjectType):
sql = Field(DjangoDebugSQL.List()) sql = Field(DjangoDebugBaseSQL.List())

View File

@ -35,6 +35,7 @@ class Schema(object):
plugins = plugins or [] plugins = plugins or []
if auto_camelcase: if auto_camelcase:
plugins.append(CamelCase()) plugins.append(CamelCase())
self.auto_camelcase = auto_camelcase
self.plugins = PluginManager(self, plugins) self.plugins = PluginManager(self, plugins)
self.options = options self.options = options
signals.init_schema.send(self) signals.init_schema.send(self)
@ -42,11 +43,6 @@ class Schema(object):
def __repr__(self): def __repr__(self):
return '<Schema: %s (%s)>' % (str(self.name), hash(self)) return '<Schema: %s (%s)>' % (str(self.name), hash(self))
def __getattr__(self, name):
if name in self.plugins:
return getattr(self.plugins, name)
return super(Schema, self).__getattr__(name)
def T(self, _type): def T(self, _type):
if not _type: if not _type:
return return
@ -122,7 +118,7 @@ class Schema(object):
def execute(self, request_string='', root_value=None, variable_values=None, def execute(self, request_string='', root_value=None, variable_values=None,
context_value=None, operation_name=None, executor=None): context_value=None, operation_name=None, executor=None):
kwargs = dict( return graphql(
schema=self.schema, schema=self.schema,
request_string=request_string, request_string=request_string,
root_value=root_value, root_value=root_value,
@ -131,8 +127,6 @@ class Schema(object):
operation_name=operation_name, operation_name=operation_name,
executor=executor or self._executor executor=executor or self._executor
) )
with self.plugins.context_execution(**kwargs) as execute_kwargs:
return graphql(**execute_kwargs)
def introspect(self): def introspect(self):
return graphql(self.schema, introspection_query).data return graphql(self.schema, introspection_query).data

View File

@ -3,6 +3,8 @@ from functools import partial, total_ordering
import six import six
from ...utils import to_camel_case
class InstanceType(object): class InstanceType(object):
@ -142,7 +144,9 @@ class GroupNamedType(InstanceType):
self.types = types self.types = types
def get_named_type(self, schema, type): def get_named_type(self, schema, type):
name = type.name or schema.get_default_namedtype_name(type.default_name) name = type.name
if not name and schema.auto_camelcase:
name = to_camel_case(type.default_name)
return name, schema.T(type) return name, schema.T(type)
def iter_types(self, schema): def iter_types(self, schema):