mirror of
				https://github.com/graphql-python/graphene.git
				synced 2025-10-25 21:21:04 +03:00 
			
		
		
		
	Merge pull request #65 from graphql-python/features/plugins-autocamelcase
Create plugin structure
This commit is contained in:
		
						commit
						a161738f3d
					
				
							
								
								
									
										4
									
								
								graphene/contrib/django/debug/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								graphene/contrib/django/debug/__init__.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,4 @@ | ||||||
|  | from .plugin import DjangoDebugPlugin | ||||||
|  | from .types import DjangoDebug | ||||||
|  | 
 | ||||||
|  | __all__ = ['DjangoDebugPlugin', 'DjangoDebug'] | ||||||
							
								
								
									
										77
									
								
								graphene/contrib/django/debug/plugin.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								graphene/contrib/django/debug/plugin.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,77 @@ | ||||||
|  | from contextlib import contextmanager | ||||||
|  | 
 | ||||||
|  | from django.db import connections | ||||||
|  | 
 | ||||||
|  | from ....core.types import Field | ||||||
|  | from ....plugins import Plugin | ||||||
|  | from .sql.tracking import unwrap_cursor, wrap_cursor | ||||||
|  | from .sql.types import DjangoDebugSQL | ||||||
|  | from .types import DjangoDebug | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class WrappedRoot(object): | ||||||
|  | 
 | ||||||
|  |     def __init__(self, root): | ||||||
|  |         self._recorded = [] | ||||||
|  |         self._root = root | ||||||
|  | 
 | ||||||
|  |     def record(self, **log): | ||||||
|  |         self._recorded.append(DjangoDebugSQL(**log)) | ||||||
|  | 
 | ||||||
|  |     def debug(self): | ||||||
|  |         return DjangoDebug(sql=self._recorded) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class WrapRoot(object): | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def _root(self): | ||||||
|  |         return self._wrapped_root.root | ||||||
|  | 
 | ||||||
|  |     @_root.setter | ||||||
|  |     def _root(self, value): | ||||||
|  |         self._wrapped_root = value | ||||||
|  | 
 | ||||||
|  |     def resolve_debug(self, args, info): | ||||||
|  |         return self._wrapped_root.debug() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def debug_objecttype(objecttype): | ||||||
|  |     return type( | ||||||
|  |         'Debug{}'.format(objecttype._meta.type_name), | ||||||
|  |         (WrapRoot, objecttype), | ||||||
|  |         {'debug': Field(DjangoDebug, name='__debug')}) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class DjangoDebugPlugin(Plugin): | ||||||
|  | 
 | ||||||
|  |     def transform_type(self, _type): | ||||||
|  |         if _type == self.schema.query: | ||||||
|  |             return | ||||||
|  |         return _type | ||||||
|  | 
 | ||||||
|  |     def enable_instrumentation(self, wrapped_root): | ||||||
|  |         # This is thread-safe because database connections are thread-local. | ||||||
|  |         for connection in connections.all(): | ||||||
|  |             wrap_cursor(connection, wrapped_root) | ||||||
|  | 
 | ||||||
|  |     def disable_instrumentation(self): | ||||||
|  |         for connection in connections.all(): | ||||||
|  |             unwrap_cursor(connection) | ||||||
|  | 
 | ||||||
|  |     def wrap_schema(self, schema_type): | ||||||
|  |         query = schema_type._query | ||||||
|  |         if query: | ||||||
|  |             class_type = self.schema.objecttype(schema_type._query) | ||||||
|  |             assert class_type, 'The query in schema is not constructed with graphene' | ||||||
|  |             _type = debug_objecttype(class_type) | ||||||
|  |             schema_type._query = self.schema.T(_type) | ||||||
|  |         return schema_type | ||||||
|  | 
 | ||||||
|  |     @contextmanager | ||||||
|  |     def context_execution(self, executor): | ||||||
|  |         executor['root'] = WrappedRoot(root=executor['root']) | ||||||
|  |         executor['schema'] = self.wrap_schema(executor['schema']) | ||||||
|  |         self.enable_instrumentation(executor['root']) | ||||||
|  |         yield executor | ||||||
|  |         self.disable_instrumentation() | ||||||
							
								
								
									
										0
									
								
								graphene/contrib/django/debug/sql/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								graphene/contrib/django/debug/sql/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										165
									
								
								graphene/contrib/django/debug/sql/tracking.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										165
									
								
								graphene/contrib/django/debug/sql/tracking.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,165 @@ | ||||||
|  | # Code obtained from django-debug-toolbar sql panel tracking | ||||||
|  | from __future__ import absolute_import, unicode_literals | ||||||
|  | 
 | ||||||
|  | import json | ||||||
|  | from threading import local | ||||||
|  | from time import time | ||||||
|  | 
 | ||||||
|  | from django.utils import six | ||||||
|  | from django.utils.encoding import force_text | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class SQLQueryTriggered(Exception): | ||||||
|  |     """Thrown when template panel triggers a query""" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ThreadLocalState(local): | ||||||
|  | 
 | ||||||
|  |     def __init__(self): | ||||||
|  |         self.enabled = True | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def Wrapper(self): | ||||||
|  |         if self.enabled: | ||||||
|  |             return NormalCursorWrapper | ||||||
|  |         return ExceptionCursorWrapper | ||||||
|  | 
 | ||||||
|  |     def recording(self, v): | ||||||
|  |         self.enabled = v | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | state = ThreadLocalState() | ||||||
|  | recording = state.recording  # export function | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def wrap_cursor(connection, panel): | ||||||
|  |     if not hasattr(connection, '_djdt_cursor'): | ||||||
|  |         connection._djdt_cursor = connection.cursor | ||||||
|  | 
 | ||||||
|  |         def cursor(): | ||||||
|  |             return state.Wrapper(connection._djdt_cursor(), connection, panel) | ||||||
|  | 
 | ||||||
|  |         connection.cursor = cursor | ||||||
|  |         return cursor | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def unwrap_cursor(connection): | ||||||
|  |     if hasattr(connection, '_djdt_cursor'): | ||||||
|  |         del connection._djdt_cursor | ||||||
|  |         del connection.cursor | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ExceptionCursorWrapper(object): | ||||||
|  |     """ | ||||||
|  |     Wraps a cursor and raises an exception on any operation. | ||||||
|  |     Used in Templates panel. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__(self, cursor, db, logger): | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|  |     def __getattr__(self, attr): | ||||||
|  |         raise SQLQueryTriggered() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class NormalCursorWrapper(object): | ||||||
|  |     """ | ||||||
|  |     Wraps a cursor and logs queries. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__(self, cursor, db, logger): | ||||||
|  |         self.cursor = cursor | ||||||
|  |         # Instance of a BaseDatabaseWrapper subclass | ||||||
|  |         self.db = db | ||||||
|  |         # logger must implement a ``record`` method | ||||||
|  |         self.logger = logger | ||||||
|  | 
 | ||||||
|  |     def _quote_expr(self, element): | ||||||
|  |         if isinstance(element, six.string_types): | ||||||
|  |             return "'%s'" % force_text(element).replace("'", "''") | ||||||
|  |         else: | ||||||
|  |             return repr(element) | ||||||
|  | 
 | ||||||
|  |     def _quote_params(self, params): | ||||||
|  |         if not params: | ||||||
|  |             return params | ||||||
|  |         if isinstance(params, dict): | ||||||
|  |             return dict((key, self._quote_expr(value)) | ||||||
|  |                         for key, value in params.items()) | ||||||
|  |         return list(map(self._quote_expr, params)) | ||||||
|  | 
 | ||||||
|  |     def _decode(self, param): | ||||||
|  |         try: | ||||||
|  |             return force_text(param, strings_only=True) | ||||||
|  |         except UnicodeDecodeError: | ||||||
|  |             return '(encoded string)' | ||||||
|  | 
 | ||||||
|  |     def _record(self, method, sql, params): | ||||||
|  |         start_time = time() | ||||||
|  |         try: | ||||||
|  |             return method(sql, params) | ||||||
|  |         finally: | ||||||
|  |             stop_time = time() | ||||||
|  |             duration = (stop_time - start_time) | ||||||
|  |             _params = '' | ||||||
|  |             try: | ||||||
|  |                 _params = json.dumps(list(map(self._decode, params))) | ||||||
|  |             except Exception: | ||||||
|  |                 pass  # object not JSON serializable | ||||||
|  | 
 | ||||||
|  |             alias = getattr(self.db, 'alias', 'default') | ||||||
|  |             conn = self.db.connection | ||||||
|  |             vendor = getattr(conn, 'vendor', 'unknown') | ||||||
|  | 
 | ||||||
|  |             params = { | ||||||
|  |                 'vendor': vendor, | ||||||
|  |                 'alias': alias, | ||||||
|  |                 'sql': self.db.ops.last_executed_query( | ||||||
|  |                     self.cursor, sql, self._quote_params(params)), | ||||||
|  |                 'duration': duration, | ||||||
|  |                 'raw_sql': sql, | ||||||
|  |                 'params': _params, | ||||||
|  |                 'start_time': start_time, | ||||||
|  |                 'stop_time': stop_time, | ||||||
|  |                 'is_slow': duration > 10, | ||||||
|  |                 'is_select': sql.lower().strip().startswith('select'), | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             if vendor == 'postgresql': | ||||||
|  |                 # If an erroneous query was ran on the connection, it might | ||||||
|  |                 # be in a state where checking isolation_level raises an | ||||||
|  |                 # exception. | ||||||
|  |                 try: | ||||||
|  |                     iso_level = conn.isolation_level | ||||||
|  |                 except conn.InternalError: | ||||||
|  |                     iso_level = 'unknown' | ||||||
|  |                 params.update({ | ||||||
|  |                     'trans_id': self.logger.get_transaction_id(alias), | ||||||
|  |                     'trans_status': conn.get_transaction_status(), | ||||||
|  |                     'iso_level': iso_level, | ||||||
|  |                     'encoding': conn.encoding, | ||||||
|  |                 }) | ||||||
|  | 
 | ||||||
|  |             # We keep `sql` to maintain backwards compatibility | ||||||
|  |             self.logger.record(**params) | ||||||
|  | 
 | ||||||
|  |     def callproc(self, procname, params=()): | ||||||
|  |         return self._record(self.cursor.callproc, procname, params) | ||||||
|  | 
 | ||||||
|  |     def execute(self, sql, params=()): | ||||||
|  |         return self._record(self.cursor.execute, sql, params) | ||||||
|  | 
 | ||||||
|  |     def executemany(self, sql, param_list): | ||||||
|  |         return self._record(self.cursor.executemany, sql, param_list) | ||||||
|  | 
 | ||||||
|  |     def __getattr__(self, attr): | ||||||
|  |         return getattr(self.cursor, attr) | ||||||
|  | 
 | ||||||
|  |     def __iter__(self): | ||||||
|  |         return iter(self.cursor) | ||||||
|  | 
 | ||||||
|  |     def __enter__(self): | ||||||
|  |         return self | ||||||
|  | 
 | ||||||
|  |     def __exit__(self, type, value, traceback): | ||||||
|  |         self.close() | ||||||
							
								
								
									
										19
									
								
								graphene/contrib/django/debug/sql/types.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								graphene/contrib/django/debug/sql/types.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,19 @@ | ||||||
|  | from .....core import Float, ObjectType, String | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class DjangoDebugSQL(ObjectType): | ||||||
|  |     vendor = String() | ||||||
|  |     alias = String() | ||||||
|  |     sql = String() | ||||||
|  |     duration = Float() | ||||||
|  |     raw_sql = String() | ||||||
|  |     params = String() | ||||||
|  |     start_time = Float() | ||||||
|  |     stop_time = Float() | ||||||
|  |     is_slow = String() | ||||||
|  |     is_select = String() | ||||||
|  | 
 | ||||||
|  |     trans_id = String() | ||||||
|  |     trans_status = String() | ||||||
|  |     iso_level = String() | ||||||
|  |     encoding = String() | ||||||
							
								
								
									
										0
									
								
								graphene/contrib/django/debug/tests/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								graphene/contrib/django/debug/tests/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										70
									
								
								graphene/contrib/django/debug/tests/test_query.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										70
									
								
								graphene/contrib/django/debug/tests/test_query.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,70 @@ | ||||||
|  | import pytest | ||||||
|  | 
 | ||||||
|  | import graphene | ||||||
|  | from graphene.contrib.django import DjangoObjectType | ||||||
|  | 
 | ||||||
|  | from ...tests.models import Reporter | ||||||
|  | from ..plugin import DjangoDebugPlugin | ||||||
|  | 
 | ||||||
|  | # from examples.starwars_django.models import Character | ||||||
|  | 
 | ||||||
|  | pytestmark = pytest.mark.django_db | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_should_query_well(): | ||||||
|  |     r1 = Reporter(last_name='ABA') | ||||||
|  |     r1.save() | ||||||
|  |     r2 = Reporter(last_name='Griffin') | ||||||
|  |     r2.save() | ||||||
|  | 
 | ||||||
|  |     class ReporterType(DjangoObjectType): | ||||||
|  | 
 | ||||||
|  |         class Meta: | ||||||
|  |             model = Reporter | ||||||
|  | 
 | ||||||
|  |     class Query(graphene.ObjectType): | ||||||
|  |         reporter = graphene.Field(ReporterType) | ||||||
|  |         all_reporters = ReporterType.List() | ||||||
|  | 
 | ||||||
|  |         def resolve_all_reporters(self, *args, **kwargs): | ||||||
|  |             return Reporter.objects.all() | ||||||
|  | 
 | ||||||
|  |         def resolve_reporter(self, *args, **kwargs): | ||||||
|  |             return Reporter.objects.first() | ||||||
|  | 
 | ||||||
|  |     query = ''' | ||||||
|  |         query ReporterQuery { | ||||||
|  |           reporter { | ||||||
|  |             lastName | ||||||
|  |           } | ||||||
|  |           allReporters { | ||||||
|  |             lastName | ||||||
|  |           } | ||||||
|  |           __debug { | ||||||
|  |             sql { | ||||||
|  |               rawSql | ||||||
|  |             } | ||||||
|  |           } | ||||||
|  |         } | ||||||
|  |     ''' | ||||||
|  |     expected = { | ||||||
|  |         'reporter': { | ||||||
|  |             'lastName': 'ABA', | ||||||
|  |         }, | ||||||
|  |         'allReporters': [{ | ||||||
|  |             'lastName': 'ABA', | ||||||
|  |         }, { | ||||||
|  |             'lastName': 'Griffin', | ||||||
|  |         }], | ||||||
|  |         '__debug': { | ||||||
|  |             'sql': [{ | ||||||
|  |                 'rawSql': str(Reporter.objects.order_by('pk')[:1].query) | ||||||
|  |             }, { | ||||||
|  |                 'rawSql': str(Reporter.objects.all().query) | ||||||
|  |             }] | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()]) | ||||||
|  |     result = schema.execute(query) | ||||||
|  |     assert not result.errors | ||||||
|  |     assert result.data == expected | ||||||
							
								
								
									
										7
									
								
								graphene/contrib/django/debug/types.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								graphene/contrib/django/debug/types.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,7 @@ | ||||||
|  | from ....core.classtypes.objecttype import ObjectType | ||||||
|  | from ....core.types import Field | ||||||
|  | from .sql.types import DjangoDebugSQL | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class DjangoDebug(ObjectType): | ||||||
|  |     sql = Field(DjangoDebugSQL.List()) | ||||||
|  | @ -5,7 +5,6 @@ from functools import partial | ||||||
| 
 | 
 | ||||||
| import six | import six | ||||||
| 
 | 
 | ||||||
| from ..exceptions import SkipField |  | ||||||
| from .options import Options | from .options import Options | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -82,13 +81,18 @@ class FieldsOptions(Options): | ||||||
|     def fields_map(self): |     def fields_map(self): | ||||||
|         return OrderedDict([(f.attname, f) for f in self.fields]) |         return OrderedDict([(f.attname, f) for f in self.fields]) | ||||||
| 
 | 
 | ||||||
|  |     @property | ||||||
|  |     def fields_group_type(self): | ||||||
|  |         from ..types.field import FieldsGroupType | ||||||
|  |         return FieldsGroupType(*self.local_fields) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| class FieldsClassTypeMeta(ClassTypeMeta): | class FieldsClassTypeMeta(ClassTypeMeta): | ||||||
|     options_class = FieldsOptions |     options_class = FieldsOptions | ||||||
| 
 | 
 | ||||||
|     def extend_fields(cls, bases): |     def extend_fields(cls, bases): | ||||||
|         new_fields = cls._meta.local_fields |         new_fields = cls._meta.local_fields | ||||||
|         field_names = {f.name: f for f in new_fields} |         field_names = {f.attname: f for f in new_fields} | ||||||
| 
 | 
 | ||||||
|         for base in bases: |         for base in bases: | ||||||
|             if not isinstance(base, FieldsClassTypeMeta): |             if not isinstance(base, FieldsClassTypeMeta): | ||||||
|  | @ -96,17 +100,17 @@ class FieldsClassTypeMeta(ClassTypeMeta): | ||||||
| 
 | 
 | ||||||
|             parent_fields = base._meta.local_fields |             parent_fields = base._meta.local_fields | ||||||
|             for field in parent_fields: |             for field in parent_fields: | ||||||
|                 if field.name in field_names and field.type.__class__ != field_names[ |                 if field.attname in field_names and field.type.__class__ != field_names[ | ||||||
|                         field.name].type.__class__: |                         field.attname].type.__class__: | ||||||
|                     raise Exception( |                     raise Exception( | ||||||
|                         'Local field %r in class %r (%r) clashes ' |                         'Local field %r in class %r (%r) clashes ' | ||||||
|                         'with field with similar name from ' |                         'with field with similar name from ' | ||||||
|                         'Interface %s (%r)' % ( |                         'Interface %s (%r)' % ( | ||||||
|                             field.name, |                             field.attname, | ||||||
|                             cls.__name__, |                             cls.__name__, | ||||||
|                             field.__class__, |                             field.__class__, | ||||||
|                             base.__name__, |                             base.__name__, | ||||||
|                             field_names[field.name].__class__) |                             field_names[field.attname].__class__) | ||||||
|                     ) |                     ) | ||||||
|                 new_field = copy.copy(field) |                 new_field = copy.copy(field) | ||||||
|                 cls.add_to_class(field.attname, new_field) |                 cls.add_to_class(field.attname, new_field) | ||||||
|  | @ -124,11 +128,4 @@ class FieldsClassType(six.with_metaclass(FieldsClassTypeMeta, ClassType)): | ||||||
| 
 | 
 | ||||||
|     @classmethod |     @classmethod | ||||||
|     def fields_internal_types(cls, schema): |     def fields_internal_types(cls, schema): | ||||||
|         fields = [] |         return schema.T(cls._meta.fields_group_type) | ||||||
|         for field in cls._meta.fields: |  | ||||||
|             try: |  | ||||||
|                 fields.append((field.name, schema.T(field))) |  | ||||||
|             except SkipField: |  | ||||||
|                 continue |  | ||||||
| 
 |  | ||||||
|         return OrderedDict(fields) |  | ||||||
|  |  | ||||||
|  | @ -24,4 +24,4 @@ def test_mutation(): | ||||||
|     assert list(object_type.get_fields().keys()) == ['name'] |     assert list(object_type.get_fields().keys()) == ['name'] | ||||||
|     assert MyMutation._meta.fields_map['name'].object_type == MyMutation |     assert MyMutation._meta.fields_map['name'].object_type == MyMutation | ||||||
|     assert isinstance(MyMutation.arguments, ArgumentsGroup) |     assert isinstance(MyMutation.arguments, ArgumentsGroup) | ||||||
|     assert 'argName' in MyMutation.arguments |     assert 'argName' in schema.T(MyMutation.arguments) | ||||||
|  |  | ||||||
|  | @ -10,6 +10,7 @@ from graphql.core.utils.schema_printer import print_schema | ||||||
| 
 | 
 | ||||||
| from graphene import signals | from graphene import signals | ||||||
| 
 | 
 | ||||||
|  | from ..plugins import CamelCase, PluginManager | ||||||
| from .classtypes.base import ClassType | from .classtypes.base import ClassType | ||||||
| from .types.base import InstanceType | from .types.base import InstanceType | ||||||
| 
 | 
 | ||||||
|  | @ -25,7 +26,7 @@ class Schema(object): | ||||||
|     _executor = None |     _executor = None | ||||||
| 
 | 
 | ||||||
|     def __init__(self, query=None, mutation=None, subscription=None, |     def __init__(self, query=None, mutation=None, subscription=None, | ||||||
|                  name='Schema', executor=None): |                  name='Schema', executor=None, plugins=None, auto_camelcase=True): | ||||||
|         self._types_names = {} |         self._types_names = {} | ||||||
|         self._types = {} |         self._types = {} | ||||||
|         self.mutation = mutation |         self.mutation = mutation | ||||||
|  | @ -33,11 +34,20 @@ class Schema(object): | ||||||
|         self.subscription = subscription |         self.subscription = subscription | ||||||
|         self.name = name |         self.name = name | ||||||
|         self.executor = executor |         self.executor = executor | ||||||
|  |         plugins = plugins or [] | ||||||
|  |         if auto_camelcase: | ||||||
|  |             plugins.append(CamelCase()) | ||||||
|  |         self.plugins = PluginManager(self, plugins) | ||||||
|         signals.init_schema.send(self) |         signals.init_schema.send(self) | ||||||
| 
 | 
 | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return '<Schema: %s (%s)>' % (str(self.name), hash(self)) |         return '<Schema: %s (%s)>' % (str(self.name), hash(self)) | ||||||
| 
 | 
 | ||||||
|  |     def __getattr__(self, name): | ||||||
|  |         if name in self.plugins: | ||||||
|  |             return getattr(self.plugins, name) | ||||||
|  |         return super(Schema, self).__getattr__(name) | ||||||
|  | 
 | ||||||
|     def T(self, _type): |     def T(self, _type): | ||||||
|         if not _type: |         if not _type: | ||||||
|             return |             return | ||||||
|  | @ -108,17 +118,10 @@ class Schema(object): | ||||||
|     def types(self): |     def types(self): | ||||||
|         return self._types_names |         return self._types_names | ||||||
| 
 | 
 | ||||||
|     def execute(self, request='', root=None, vars=None, |     def execute(self, request='', root=None, args=None, **kwargs): | ||||||
|                 operation_name=None, **kwargs): |         kwargs = dict(kwargs, request=request, root=root, args=args, schema=self.schema) | ||||||
|         root = root or object() |         with self.plugins.context_execution(**kwargs) as execute_kwargs: | ||||||
|         return self.executor.execute( |             return self.executor.execute(**execute_kwargs) | ||||||
|             self.schema, |  | ||||||
|             request, |  | ||||||
|             root=root, |  | ||||||
|             args=vars, |  | ||||||
|             operation_name=operation_name, |  | ||||||
|             **kwargs |  | ||||||
|         ) |  | ||||||
| 
 | 
 | ||||||
|     def introspect(self): |     def introspect(self): | ||||||
|         return self.execute(introspection_query).data |         return self.execute(introspection_query).data | ||||||
|  |  | ||||||
|  | @ -34,10 +34,11 @@ def test_field_type(): | ||||||
|     assert schema.T(f).type == GraphQLString |     assert schema.T(f).type == GraphQLString | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_field_name_automatic_camelcase(): | def test_field_name(): | ||||||
|     f = Field(GraphQLString) |     f = Field(GraphQLString) | ||||||
|     f.contribute_to_class(MyOt, 'field_name') |     f.contribute_to_class(MyOt, 'field_name') | ||||||
|     assert f.name == 'fieldName' |     assert f.name is None | ||||||
|  |     assert f.attname == 'field_name' | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_field_name_use_name_if_exists(): | def test_field_name_use_name_if_exists(): | ||||||
|  |  | ||||||
|  | @ -1,19 +1,17 @@ | ||||||
| from collections import OrderedDict |  | ||||||
| from functools import wraps | from functools import wraps | ||||||
| from itertools import chain | from itertools import chain | ||||||
| 
 | 
 | ||||||
| from graphql.core.type import GraphQLArgument | from graphql.core.type import GraphQLArgument | ||||||
| 
 | 
 | ||||||
| from ...utils import ProxySnakeDict, to_camel_case | from ...utils import ProxySnakeDict | ||||||
| from .base import ArgumentType, InstanceType, OrderedType | from .base import ArgumentType, GroupNamedType, NamedType, OrderedType | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Argument(OrderedType): | class Argument(NamedType, OrderedType): | ||||||
| 
 | 
 | ||||||
|     def __init__(self, type, description=None, default=None, |     def __init__(self, type, description=None, default=None, | ||||||
|                  name=None, _creation_counter=None): |                  name=None, _creation_counter=None): | ||||||
|         super(Argument, self).__init__(_creation_counter=_creation_counter) |         super(Argument, self).__init__(name=name, _creation_counter=_creation_counter) | ||||||
|         self.name = name |  | ||||||
|         self.type = type |         self.type = type | ||||||
|         self.description = description |         self.description = description | ||||||
|         self.default = default |         self.default = default | ||||||
|  | @ -27,47 +25,32 @@ class Argument(OrderedType): | ||||||
|         return self.name |         return self.name | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class ArgumentsGroup(InstanceType): | class ArgumentsGroup(GroupNamedType): | ||||||
| 
 | 
 | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, **kwargs): | ||||||
|         arguments = to_arguments(*args, **kwargs) |         arguments = to_arguments(*args, **kwargs) | ||||||
|         self.arguments = OrderedDict([(arg.name, arg) for arg in arguments]) |         super(ArgumentsGroup, self).__init__(*arguments) | ||||||
| 
 |  | ||||||
|     def internal_type(self, schema): |  | ||||||
|         return OrderedDict([(arg.name, schema.T(arg)) |  | ||||||
|                             for arg in self.arguments.values()]) |  | ||||||
| 
 |  | ||||||
|     def __len__(self): |  | ||||||
|         return len(self.arguments) |  | ||||||
| 
 |  | ||||||
|     def __iter__(self): |  | ||||||
|         return iter(self.arguments) |  | ||||||
| 
 |  | ||||||
|     def __contains__(self, *args): |  | ||||||
|         return self.arguments.__contains__(*args) |  | ||||||
| 
 |  | ||||||
|     def __getitem__(self, *args): |  | ||||||
|         return self.arguments.__getitem__(*args) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def to_arguments(*args, **kwargs): | def to_arguments(*args, **kwargs): | ||||||
|     arguments = {} |     arguments = {} | ||||||
|     iter_arguments = chain(kwargs.items(), [(None, a) for a in args]) |     iter_arguments = chain(kwargs.items(), [(None, a) for a in args]) | ||||||
| 
 | 
 | ||||||
|     for name, arg in iter_arguments: |     for default_name, arg in iter_arguments: | ||||||
|         if isinstance(arg, Argument): |         if isinstance(arg, Argument): | ||||||
|             argument = arg |             argument = arg | ||||||
|         elif isinstance(arg, ArgumentType): |         elif isinstance(arg, ArgumentType): | ||||||
|             argument = arg.as_argument() |             argument = arg.as_argument() | ||||||
|         else: |         else: | ||||||
|             raise ValueError('Unknown argument %s=%r' % (name, arg)) |             raise ValueError('Unknown argument %s=%r' % (default_name, arg)) | ||||||
| 
 | 
 | ||||||
|         if name: |         if default_name: | ||||||
|             argument.name = to_camel_case(name) |             argument.default_name = default_name | ||||||
|         assert argument.name, 'Argument in field must have a name' | 
 | ||||||
|         assert argument.name not in arguments, 'Found more than one Argument with same name {}'.format( |         name = argument.name or argument.default_name | ||||||
|             argument.name) |         assert name, 'Argument in field must have a name' | ||||||
|         arguments[argument.name] = argument |         assert name not in arguments, 'Found more than one Argument with same name {}'.format(name) | ||||||
|  |         arguments[name] = argument | ||||||
| 
 | 
 | ||||||
|     return sorted(arguments.values()) |     return sorted(arguments.values()) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1,4 +1,5 @@ | ||||||
| from functools import total_ordering | from collections import OrderedDict | ||||||
|  | from functools import partial, total_ordering | ||||||
| 
 | 
 | ||||||
| import six | import six | ||||||
| 
 | 
 | ||||||
|  | @ -125,3 +126,39 @@ class FieldType(MirroredType): | ||||||
| 
 | 
 | ||||||
| class MountedType(FieldType, ArgumentType): | class MountedType(FieldType, ArgumentType): | ||||||
|     pass |     pass | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class NamedType(InstanceType): | ||||||
|  | 
 | ||||||
|  |     def __init__(self, name=None, default_name=None, *args, **kwargs): | ||||||
|  |         self.name = name | ||||||
|  |         self.default_name = None | ||||||
|  |         super(NamedType, self).__init__(*args, **kwargs) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class GroupNamedType(InstanceType): | ||||||
|  | 
 | ||||||
|  |     def __init__(self, *types): | ||||||
|  |         self.types = types | ||||||
|  | 
 | ||||||
|  |     def get_named_type(self, schema, type): | ||||||
|  |         name = type.name or schema.get_default_namedtype_name(type.default_name) | ||||||
|  |         return name, schema.T(type) | ||||||
|  | 
 | ||||||
|  |     def iter_types(self, schema): | ||||||
|  |         return map(partial(self.get_named_type, schema), self.types) | ||||||
|  | 
 | ||||||
|  |     def internal_type(self, schema): | ||||||
|  |         return OrderedDict(self.iter_types(schema)) | ||||||
|  | 
 | ||||||
|  |     def __len__(self): | ||||||
|  |         return len(self.types) | ||||||
|  | 
 | ||||||
|  |     def __iter__(self): | ||||||
|  |         return iter(self.types) | ||||||
|  | 
 | ||||||
|  |     def __contains__(self, *args): | ||||||
|  |         return self.types.__contains__(*args) | ||||||
|  | 
 | ||||||
|  |     def __getitem__(self, *args): | ||||||
|  |         return self.types.__getitem__(*args) | ||||||
|  |  | ||||||
|  | @ -4,23 +4,22 @@ from functools import wraps | ||||||
| import six | import six | ||||||
| from graphql.core.type import GraphQLField, GraphQLInputObjectField | from graphql.core.type import GraphQLField, GraphQLInputObjectField | ||||||
| 
 | 
 | ||||||
| from ...utils import to_camel_case |  | ||||||
| from ..classtypes.base import FieldsClassType | from ..classtypes.base import FieldsClassType | ||||||
| from ..classtypes.inputobjecttype import InputObjectType | from ..classtypes.inputobjecttype import InputObjectType | ||||||
| from ..classtypes.mutation import Mutation | from ..classtypes.mutation import Mutation | ||||||
|  | from ..exceptions import SkipField | ||||||
| from .argument import ArgumentsGroup, snake_case_args | from .argument import ArgumentsGroup, snake_case_args | ||||||
| from .base import LazyType, MountType, OrderedType | from .base import GroupNamedType, LazyType, MountType, NamedType, OrderedType | ||||||
| from .definitions import NonNull | from .definitions import NonNull | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Field(OrderedType): | class Field(NamedType, OrderedType): | ||||||
| 
 | 
 | ||||||
|     def __init__( |     def __init__( | ||||||
|             self, type, description=None, args=None, name=None, resolver=None, |             self, type, description=None, args=None, name=None, resolver=None, | ||||||
|             required=False, default=None, *args_list, **kwargs): |             required=False, default=None, *args_list, **kwargs): | ||||||
|         _creation_counter = kwargs.pop('_creation_counter', None) |         _creation_counter = kwargs.pop('_creation_counter', None) | ||||||
|         super(Field, self).__init__(_creation_counter=_creation_counter) |         super(Field, self).__init__(name=name, _creation_counter=_creation_counter) | ||||||
|         self.name = name |  | ||||||
|         if isinstance(type, six.string_types): |         if isinstance(type, six.string_types): | ||||||
|             type = LazyType(type) |             type = LazyType(type) | ||||||
|         self.required = required |         self.required = required | ||||||
|  | @ -36,9 +35,8 @@ class Field(OrderedType): | ||||||
|         assert issubclass( |         assert issubclass( | ||||||
|             cls, (FieldsClassType)), 'Field {} cannot be mounted in {}'.format( |             cls, (FieldsClassType)), 'Field {} cannot be mounted in {}'.format( | ||||||
|             self, cls) |             self, cls) | ||||||
|         if not self.name: |  | ||||||
|             self.name = to_camel_case(attname) |  | ||||||
|         self.attname = attname |         self.attname = attname | ||||||
|  |         self.default_name = attname | ||||||
|         self.object_type = cls |         self.object_type = cls | ||||||
|         self.mount(cls) |         self.mount(cls) | ||||||
|         if isinstance(self.type, MountType): |         if isinstance(self.type, MountType): | ||||||
|  | @ -117,12 +115,11 @@ class Field(OrderedType): | ||||||
|         return hash((self.creation_counter, self.object_type)) |         return hash((self.creation_counter, self.object_type)) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class InputField(OrderedType): | class InputField(NamedType, OrderedType): | ||||||
| 
 | 
 | ||||||
|     def __init__(self, type, description=None, default=None, |     def __init__(self, type, description=None, default=None, | ||||||
|                  name=None, _creation_counter=None, required=False): |                  name=None, _creation_counter=None, required=False): | ||||||
|         super(InputField, self).__init__(_creation_counter=_creation_counter) |         super(InputField, self).__init__(_creation_counter=_creation_counter) | ||||||
|         self.name = name |  | ||||||
|         if required: |         if required: | ||||||
|             type = NonNull(type) |             type = NonNull(type) | ||||||
|         self.type = type |         self.type = type | ||||||
|  | @ -133,9 +130,8 @@ class InputField(OrderedType): | ||||||
|         assert issubclass( |         assert issubclass( | ||||||
|             cls, (InputObjectType)), 'InputField {} cannot be mounted in {}'.format( |             cls, (InputObjectType)), 'InputField {} cannot be mounted in {}'.format( | ||||||
|             self, cls) |             self, cls) | ||||||
|         if not self.name: |  | ||||||
|             self.name = to_camel_case(attname) |  | ||||||
|         self.attname = attname |         self.attname = attname | ||||||
|  |         self.default_name = attname | ||||||
|         self.object_type = cls |         self.object_type = cls | ||||||
|         self.mount(cls) |         self.mount(cls) | ||||||
|         if isinstance(self.type, MountType): |         if isinstance(self.type, MountType): | ||||||
|  | @ -146,3 +142,13 @@ class InputField(OrderedType): | ||||||
|         return GraphQLInputObjectField( |         return GraphQLInputObjectField( | ||||||
|             schema.T(self.type), |             schema.T(self.type), | ||||||
|             default_value=self.default, description=self.description) |             default_value=self.default, description=self.description) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class FieldsGroupType(GroupNamedType): | ||||||
|  | 
 | ||||||
|  |     def iter_types(self, schema): | ||||||
|  |         for field in sorted(self.types): | ||||||
|  |             try: | ||||||
|  |                 yield self.get_named_type(schema, field) | ||||||
|  |             except SkipField: | ||||||
|  |                 continue | ||||||
|  |  | ||||||
|  | @ -5,6 +5,7 @@ from .base import MountedType | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class ScalarType(MountedType): | class ScalarType(MountedType): | ||||||
|  | 
 | ||||||
|     def internal_type(self, schema): |     def internal_type(self, schema): | ||||||
|         return self._internal_type |         return self._internal_type | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -27,8 +27,8 @@ def test_to_arguments(): | ||||||
|         other_kwarg=String(), |         other_kwarg=String(), | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|     assert [a.name for a in arguments] == [ |     assert [a.name or a.default_name for a in arguments] == [ | ||||||
|         'myArg', 'otherArg', 'myKwarg', 'otherKwarg'] |         'myArg', 'otherArg', 'my_kwarg', 'other_kwarg'] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_to_arguments_no_name(): | def test_to_arguments_no_name(): | ||||||
|  |  | ||||||
|  | @ -20,7 +20,7 @@ def test_field_internal_type(): | ||||||
|     schema = Schema(query=Query) |     schema = Schema(query=Query) | ||||||
| 
 | 
 | ||||||
|     type = schema.T(field) |     type = schema.T(field) | ||||||
|     assert field.name == 'myField' |     assert field.name is None | ||||||
|     assert field.attname == 'my_field' |     assert field.attname == 'my_field' | ||||||
|     assert isinstance(type, GraphQLField) |     assert isinstance(type, GraphQLField) | ||||||
|     assert type.description == 'My argument' |     assert type.description == 'My argument' | ||||||
|  | @ -98,9 +98,10 @@ def test_field_string_reference(): | ||||||
| 
 | 
 | ||||||
| def test_field_custom_arguments(): | def test_field_custom_arguments(): | ||||||
|     field = Field(None, name='my_customName', p=String()) |     field = Field(None, name='my_customName', p=String()) | ||||||
|  |     schema = Schema() | ||||||
| 
 | 
 | ||||||
|     args = field.arguments |     args = field.arguments | ||||||
|     assert 'p' in args |     assert 'p' in schema.T(args) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_inputfield_internal_type(): | def test_inputfield_internal_type(): | ||||||
|  | @ -115,7 +116,7 @@ def test_inputfield_internal_type(): | ||||||
|     schema = Schema(query=MyObjectType) |     schema = Schema(query=MyObjectType) | ||||||
| 
 | 
 | ||||||
|     type = schema.T(field) |     type = schema.T(field) | ||||||
|     assert field.name == 'myField' |     assert field.name is None | ||||||
|     assert field.attname == 'my_field' |     assert field.attname == 'my_field' | ||||||
|     assert isinstance(type, GraphQLInputObjectField) |     assert isinstance(type, GraphQLInputObjectField) | ||||||
|     assert type.description == 'My input field' |     assert type.description == 'My input field' | ||||||
|  |  | ||||||
							
								
								
									
										6
									
								
								graphene/plugins/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								graphene/plugins/__init__.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,6 @@ | ||||||
|  | from .base import Plugin, PluginManager | ||||||
|  | from .camel_case import CamelCase | ||||||
|  | 
 | ||||||
|  | __all__ = [ | ||||||
|  |     'Plugin', 'PluginManager', 'CamelCase' | ||||||
|  | ] | ||||||
							
								
								
									
										53
									
								
								graphene/plugins/base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								graphene/plugins/base.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,53 @@ | ||||||
|  | from contextlib import contextmanager | ||||||
|  | from functools import partial, reduce | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Plugin(object): | ||||||
|  | 
 | ||||||
|  |     def contribute_to_schema(self, schema): | ||||||
|  |         self.schema = schema | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def apply_function(a, b): | ||||||
|  |     return b(a) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class PluginManager(object): | ||||||
|  | 
 | ||||||
|  |     PLUGIN_FUNCTIONS = ('get_default_namedtype_name', ) | ||||||
|  | 
 | ||||||
|  |     def __init__(self, schema, plugins=[]): | ||||||
|  |         self.schema = schema | ||||||
|  |         self.plugins = [] | ||||||
|  |         for plugin in plugins: | ||||||
|  |             self.add_plugin(plugin) | ||||||
|  | 
 | ||||||
|  |     def add_plugin(self, plugin): | ||||||
|  |         if hasattr(plugin, 'contribute_to_schema'): | ||||||
|  |             plugin.contribute_to_schema(self.schema) | ||||||
|  |         self.plugins.append(plugin) | ||||||
|  | 
 | ||||||
|  |     def get_plugin_functions(self, function): | ||||||
|  |         for plugin in self.plugins: | ||||||
|  |             if not hasattr(plugin, function): | ||||||
|  |                 continue | ||||||
|  |             yield getattr(plugin, function) | ||||||
|  | 
 | ||||||
|  |     def __getattr__(self, name): | ||||||
|  |         functions = self.get_plugin_functions(name) | ||||||
|  |         return partial(reduce, apply_function, functions) | ||||||
|  | 
 | ||||||
|  |     def __contains__(self, name): | ||||||
|  |         return name in self.PLUGIN_FUNCTIONS | ||||||
|  | 
 | ||||||
|  |     @contextmanager | ||||||
|  |     def context_execution(self, **executor): | ||||||
|  |         contexts = [] | ||||||
|  |         functions = self.get_plugin_functions('context_execution') | ||||||
|  |         for f in functions: | ||||||
|  |             context = f(executor) | ||||||
|  |             executor = context.__enter__() | ||||||
|  |             contexts.append((context, executor)) | ||||||
|  |         yield executor | ||||||
|  |         for context, value in contexts[::-1]: | ||||||
|  |             context.__exit__(None, None, None) | ||||||
							
								
								
									
										7
									
								
								graphene/plugins/camel_case.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								graphene/plugins/camel_case.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,7 @@ | ||||||
|  | from ..utils import to_camel_case | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class CamelCase(object): | ||||||
|  | 
 | ||||||
|  |     def get_default_namedtype_name(self, value): | ||||||
|  |         return to_camel_case(value) | ||||||
|  | @ -34,8 +34,7 @@ schema = Schema(query=Query, mutation=MyResultMutation) | ||||||
| 
 | 
 | ||||||
| def test_mutation_arguments(): | def test_mutation_arguments(): | ||||||
|     assert ChangeNumber.arguments |     assert ChangeNumber.arguments | ||||||
|     assert list(ChangeNumber.arguments) == ['input'] |     assert 'input' in schema.T(ChangeNumber.arguments) | ||||||
|     assert 'input' in ChangeNumber.arguments |  | ||||||
|     inner_type = ChangeNumber.input_type |     inner_type = ChangeNumber.input_type | ||||||
|     client_mutation_id_field = inner_type._meta.fields_map[ |     client_mutation_id_field = inner_type._meta.fields_map[ | ||||||
|         'client_mutation_id'] |         'client_mutation_id'] | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user