diff --git a/.travis.yml b/.travis.yml index 93f4550f..3dbb00e0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -73,13 +73,20 @@ after_success: fi env: matrix: - - TEST_TYPE=build DJANGO_VERSION=1.8 - - TEST_TYPE=build DJANGO_VERSION=1.9 + - TEST_TYPE=build global: secure: SQC0eCWCWw8bZxbLE8vQn+UjJOp3Z1m779s9SMK3lCLwJxro/VCLBZ7hj4xsrq1MtcFO2U2Kqf068symw4Hr/0amYI3HFTCFiwXAC3PAKXeURca03eNO2heku+FtnQcOjBanExTsIBQRLDXMOaUkf3MIztpLJ4LHqMfUupKmw9YSB0v40jDbSN8khBnndFykmOnVVHznFp8USoN5F0CiPpnfEvHnJkaX76lNf7Kc9XNShBTTtJsnsHMhuYQeInt0vg9HSjoIYC38Tv2hmMj1myNdzyrHF+LgRjI6ceGi50ApAnGepXC/DNRhXROfECKez+LON/ZSqBGdJhUILqC8A4WmWmIjNcwitVFp3JGBqO7LULS0BI96EtSLe8rD1rkkdTbjivajkbykM1Q0Tnmg1adzGwLxRUbTq9tJQlTTkHBCuXIkpKb1mAtb/TY7A6BqfnPi2xTc/++qEawUG7ePhscdTj0IBrUfZsUNUYZqD8E8XbSWKIuS3SHE+cZ+s/kdAsm4q+FFAlpZKOYGxIkwvgyfu4/Plfol4b7X6iAP9J3r1Kv0DgBVFst5CXEwzZs19/g0CgokQbCXf1N+xeNnUELl6/fImaR3RKP22EaABoil4z8vzl4EqxqVoH1nfhE+WlpryXsuSaF/1R+WklR7aQ1FwoCk8V8HxM2zrj4tI8k= matrix: fast_finish: true include: + - python: '2.7' + env: DJANGO_VERSION=1.6 + - python: '2.7' + env: DJANGO_VERSION=1.7 + - python: '2.7' + env: DJANGO_VERSION=1.8 + - python: '2.7' + env: DJANGO_VERSION=1.9 - python: '2.7' env: TEST_TYPE=build_website - python: '2.7' diff --git a/graphene/__init__.py b/graphene/__init__.py index 88404d62..71066499 100644 --- a/graphene/__init__.py +++ b/graphene/__init__.py @@ -11,7 +11,7 @@ from .core import ( Interface, Mutation, Scalar, - BaseType, + InstanceType, LazyType, Argument, Field, @@ -51,7 +51,7 @@ __all__ = [ 'NonNull', 'signals', 'Schema', - 'BaseType', + 'InstanceType', 'LazyType', 'ObjectType', 'InputObjectType', diff --git a/graphene/contrib/django/compat.py b/graphene/contrib/django/compat.py new file mode 100644 index 00000000..a5b444c7 --- /dev/null +++ b/graphene/contrib/django/compat.py @@ -0,0 +1,15 @@ +from django.db import models + +try: + UUIDField = models.UUIDField +except AttributeError: + # Improved compatibility for Django 1.6 + class UUIDField(object): + pass + +try: + from django.db.models.related import RelatedObject +except: + # Improved compatibility for Django 1.6 + class RelatedObject(object): + pass diff --git a/graphene/contrib/django/converter.py b/graphene/contrib/django/converter.py index 0722643b..ef1265ac 100644 --- a/graphene/contrib/django/converter.py +++ b/graphene/contrib/django/converter.py @@ -1,17 +1,11 @@ from django.db import models -from .utils import import_single_dispatch from ...core.types.scalars import ID, Boolean, Float, Int, String +from .compat import RelatedObject, UUIDField +from .utils import get_related_model, import_single_dispatch singledispatch = import_single_dispatch() -try: - UUIDField = models.UUIDField -except AttributeError: - # Improved compatibility for Django 1.6 - class UUIDField(object): - pass - @singledispatch def convert_django_field(field): @@ -65,7 +59,15 @@ def convert_field_to_float(field): @convert_django_field.register(models.ManyToOneRel) def convert_field_to_list_or_connection(field): from .fields import DjangoModelField, ConnectionOrListField - model_field = DjangoModelField(field.related_model) + model_field = DjangoModelField(get_related_model(field)) + return ConnectionOrListField(model_field) + + +# For Django 1.6 +@convert_django_field.register(RelatedObject) +def convert_relatedfield_to_djangomodel(field): + from .fields import DjangoModelField, ConnectionOrListField + model_field = DjangoModelField(field.model) return ConnectionOrListField(model_field) @@ -73,4 +75,4 @@ def convert_field_to_list_or_connection(field): @convert_django_field.register(models.ForeignKey) def convert_field_to_djangomodel(field): from .fields import DjangoModelField - return DjangoModelField(field.related_model, description=field.help_text) + return DjangoModelField(get_related_model(field), description=field.help_text) diff --git a/graphene/contrib/django/debug/__init__.py b/graphene/contrib/django/debug/__init__.py new file mode 100644 index 00000000..4c76aeca --- /dev/null +++ b/graphene/contrib/django/debug/__init__.py @@ -0,0 +1,4 @@ +from .plugin import DjangoDebugPlugin +from .types import DjangoDebug + +__all__ = ['DjangoDebugPlugin', 'DjangoDebug'] diff --git a/graphene/contrib/django/debug/plugin.py b/graphene/contrib/django/debug/plugin.py new file mode 100644 index 00000000..70cd6741 --- /dev/null +++ b/graphene/contrib/django/debug/plugin.py @@ -0,0 +1,79 @@ +from contextlib import contextmanager + +from django.db import connections + +from ....core.schema import GraphQLSchema +from ....core.types import Field +from ....plugins import Plugin +from .sql.tracking import unwrap_cursor, wrap_cursor +from .sql.types import DjangoDebugSQL +from .types import DjangoDebug + + +class WrappedRoot(object): + + def __init__(self, root): + self._recorded = [] + self._root = root + + def record(self, **log): + self._recorded.append(DjangoDebugSQL(**log)) + + def debug(self): + return DjangoDebug(sql=self._recorded) + + +class WrapRoot(object): + + @property + def _root(self): + return self._wrapped_root.root + + @_root.setter + def _root(self, value): + self._wrapped_root = value + + def resolve_debug(self, args, info): + return self._wrapped_root.debug() + + +def debug_objecttype(objecttype): + return type( + 'Debug{}'.format(objecttype._meta.type_name), + (WrapRoot, objecttype), + {'debug': Field(DjangoDebug, name='__debug')}) + + +class DjangoDebugPlugin(Plugin): + + def enable_instrumentation(self, wrapped_root): + # This is thread-safe because database connections are thread-local. + for connection in connections.all(): + wrap_cursor(connection, wrapped_root) + + def disable_instrumentation(self): + for connection in connections.all(): + unwrap_cursor(connection) + + def wrap_schema(self, schema_type): + query = schema_type._query + if query: + class_type = self.schema.objecttype(schema_type.get_query_type()) + assert class_type, 'The query in schema is not constructed with graphene' + _type = debug_objecttype(class_type) + self.schema.register(_type, force=True) + return GraphQLSchema( + self.schema, + self.schema.T(_type), + schema_type.get_mutation_type(), + schema_type.get_subscription_type() + ) + return schema_type + + @contextmanager + def context_execution(self, executor): + executor['root'] = WrappedRoot(root=executor['root']) + executor['schema'] = self.wrap_schema(executor['schema']) + self.enable_instrumentation(executor['root']) + yield executor + self.disable_instrumentation() diff --git a/graphene/contrib/django/tests/filter/__init__.py b/graphene/contrib/django/debug/sql/__init__.py similarity index 100% rename from graphene/contrib/django/tests/filter/__init__.py rename to graphene/contrib/django/debug/sql/__init__.py diff --git a/graphene/contrib/django/debug/sql/tracking.py b/graphene/contrib/django/debug/sql/tracking.py new file mode 100644 index 00000000..8ed40492 --- /dev/null +++ b/graphene/contrib/django/debug/sql/tracking.py @@ -0,0 +1,165 @@ +# Code obtained from django-debug-toolbar sql panel tracking +from __future__ import absolute_import, unicode_literals + +import json +from threading import local +from time import time + +from django.utils import six +from django.utils.encoding import force_text + + +class SQLQueryTriggered(Exception): + """Thrown when template panel triggers a query""" + + +class ThreadLocalState(local): + + def __init__(self): + self.enabled = True + + @property + def Wrapper(self): + if self.enabled: + return NormalCursorWrapper + return ExceptionCursorWrapper + + def recording(self, v): + self.enabled = v + + +state = ThreadLocalState() +recording = state.recording # export function + + +def wrap_cursor(connection, panel): + if not hasattr(connection, '_djdt_cursor'): + connection._djdt_cursor = connection.cursor + + def cursor(): + return state.Wrapper(connection._djdt_cursor(), connection, panel) + + connection.cursor = cursor + return cursor + + +def unwrap_cursor(connection): + if hasattr(connection, '_djdt_cursor'): + del connection._djdt_cursor + del connection.cursor + + +class ExceptionCursorWrapper(object): + """ + Wraps a cursor and raises an exception on any operation. + Used in Templates panel. + """ + + def __init__(self, cursor, db, logger): + pass + + def __getattr__(self, attr): + raise SQLQueryTriggered() + + +class NormalCursorWrapper(object): + """ + Wraps a cursor and logs queries. + """ + + def __init__(self, cursor, db, logger): + self.cursor = cursor + # Instance of a BaseDatabaseWrapper subclass + self.db = db + # logger must implement a ``record`` method + self.logger = logger + + def _quote_expr(self, element): + if isinstance(element, six.string_types): + return "'%s'" % force_text(element).replace("'", "''") + else: + return repr(element) + + def _quote_params(self, params): + if not params: + return params + if isinstance(params, dict): + return dict((key, self._quote_expr(value)) + for key, value in params.items()) + return list(map(self._quote_expr, params)) + + def _decode(self, param): + try: + return force_text(param, strings_only=True) + except UnicodeDecodeError: + return '(encoded string)' + + def _record(self, method, sql, params): + start_time = time() + try: + return method(sql, params) + finally: + stop_time = time() + duration = (stop_time - start_time) + _params = '' + try: + _params = json.dumps(list(map(self._decode, params))) + except Exception: + pass # object not JSON serializable + + alias = getattr(self.db, 'alias', 'default') + conn = self.db.connection + vendor = getattr(conn, 'vendor', 'unknown') + + params = { + 'vendor': vendor, + 'alias': alias, + 'sql': self.db.ops.last_executed_query( + self.cursor, sql, self._quote_params(params)), + 'duration': duration, + 'raw_sql': sql, + 'params': _params, + 'start_time': start_time, + 'stop_time': stop_time, + 'is_slow': duration > 10, + 'is_select': sql.lower().strip().startswith('select'), + } + + if vendor == 'postgresql': + # If an erroneous query was ran on the connection, it might + # be in a state where checking isolation_level raises an + # exception. + try: + iso_level = conn.isolation_level + except conn.InternalError: + iso_level = 'unknown' + params.update({ + 'trans_id': self.logger.get_transaction_id(alias), + 'trans_status': conn.get_transaction_status(), + 'iso_level': iso_level, + 'encoding': conn.encoding, + }) + + # We keep `sql` to maintain backwards compatibility + self.logger.record(**params) + + def callproc(self, procname, params=()): + return self._record(self.cursor.callproc, procname, params) + + def execute(self, sql, params=()): + return self._record(self.cursor.execute, sql, params) + + def executemany(self, sql, param_list): + return self._record(self.cursor.executemany, sql, param_list) + + def __getattr__(self, attr): + return getattr(self.cursor, attr) + + def __iter__(self): + return iter(self.cursor) + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.close() diff --git a/graphene/contrib/django/debug/sql/types.py b/graphene/contrib/django/debug/sql/types.py new file mode 100644 index 00000000..995aeaa2 --- /dev/null +++ b/graphene/contrib/django/debug/sql/types.py @@ -0,0 +1,19 @@ +from .....core import Boolean, Float, ObjectType, String + + +class DjangoDebugSQL(ObjectType): + vendor = String() + alias = String() + sql = String() + duration = Float() + raw_sql = String() + params = String() + start_time = Float() + stop_time = Float() + is_slow = Boolean() + is_select = Boolean() + + trans_id = String() + trans_status = String() + iso_level = String() + encoding = String() diff --git a/graphene/contrib/django/debug/tests/__init__.py b/graphene/contrib/django/debug/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphene/contrib/django/debug/tests/test_query.py b/graphene/contrib/django/debug/tests/test_query.py new file mode 100644 index 00000000..4df26e4f --- /dev/null +++ b/graphene/contrib/django/debug/tests/test_query.py @@ -0,0 +1,70 @@ +import pytest + +import graphene +from graphene.contrib.django import DjangoObjectType + +from ...tests.models import Reporter +from ..plugin import DjangoDebugPlugin + +# from examples.starwars_django.models import Character + +pytestmark = pytest.mark.django_db + + +def test_should_query_well(): + r1 = Reporter(last_name='ABA') + r1.save() + r2 = Reporter(last_name='Griffin') + r2.save() + + class ReporterType(DjangoObjectType): + + class Meta: + model = Reporter + + class Query(graphene.ObjectType): + reporter = graphene.Field(ReporterType) + all_reporters = ReporterType.List() + + def resolve_all_reporters(self, *args, **kwargs): + return Reporter.objects.all() + + def resolve_reporter(self, *args, **kwargs): + return Reporter.objects.first() + + query = ''' + query ReporterQuery { + reporter { + lastName + } + allReporters { + lastName + } + __debug { + sql { + rawSql + } + } + } + ''' + expected = { + 'reporter': { + 'lastName': 'ABA', + }, + 'allReporters': [{ + 'lastName': 'ABA', + }, { + 'lastName': 'Griffin', + }], + '__debug': { + 'sql': [{ + 'rawSql': str(Reporter.objects.order_by('pk')[:1].query) + }, { + 'rawSql': str(Reporter.objects.all().query) + }] + } + } + schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()]) + result = schema.execute(query) + assert not result.errors + assert result.data == expected diff --git a/graphene/contrib/django/debug/types.py b/graphene/contrib/django/debug/types.py new file mode 100644 index 00000000..bceb54b0 --- /dev/null +++ b/graphene/contrib/django/debug/types.py @@ -0,0 +1,7 @@ +from ....core.classtypes.objecttype import ObjectType +from ....core.types import Field +from .sql.types import DjangoDebugSQL + + +class DjangoDebug(ObjectType): + sql = Field(DjangoDebugSQL.List()) diff --git a/graphene/contrib/django/fields.py b/graphene/contrib/django/fields.py index 76a85580..d9d6f3da 100644 --- a/graphene/contrib/django/fields.py +++ b/graphene/contrib/django/fields.py @@ -1,12 +1,13 @@ import warnings -from .utils import get_type_for_model from ...core.exceptions import SkipField from ...core.fields import Field from ...core.types.base import FieldType from ...core.types.definitions import List from ...relay import ConnectionField from ...relay.utils import is_node +from .filter.fields import DjangoFilterConnectionField +from .utils import get_type_for_model class DjangoConnectionField(ConnectionField): @@ -20,7 +21,6 @@ class DjangoConnectionField(ConnectionField): class ConnectionOrListField(Field): - connection_field_class = ConnectionField def internal_type(self, schema): model_field = self.type @@ -28,7 +28,10 @@ class ConnectionOrListField(Field): if not field_object_type: raise SkipField() if is_node(field_object_type): - field = self.connection_field_class(field_object_type) + if field_object_type._meta.filter_fields: + field = DjangoFilterConnectionField(field_object_type) + else: + field = ConnectionField(field_object_type) else: field = Field(List(field_object_type)) field.contribute_to_class(self.object_type, self.attname) diff --git a/graphene/contrib/django/filter/fields.py b/graphene/contrib/django/filter/fields.py index 012ae00a..43196f6e 100644 --- a/graphene/contrib/django/filter/fields.py +++ b/graphene/contrib/django/filter/fields.py @@ -1,6 +1,6 @@ -from graphene.relay import ConnectionField from graphene.contrib.django.filter.resolvers import FilterConnectionResolver from graphene.contrib.django.utils import get_filtering_args_from_filterset +from graphene.relay import ConnectionField class DjangoFilterConnectionField(ConnectionField): diff --git a/graphene/contrib/django/filter/filterset.py b/graphene/contrib/django/filter/filterset.py index 4755eac2..3ecd9680 100644 --- a/graphene/contrib/django/filter/filterset.py +++ b/graphene/contrib/django/filter/filterset.py @@ -2,11 +2,12 @@ import six from django.conf import settings from django.db import models from django.utils.text import capfirst -from django_filters import Filter, MultipleChoiceFilter -from django_filters.filterset import FilterSetMetaclass, FilterSet -from graphql_relay.node.node import from_global_id -from graphene.contrib.django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField +from django_filters import Filter, MultipleChoiceFilter +from django_filters.filterset import FilterSet, FilterSetMetaclass +from graphene.contrib.django.forms import (GlobalIDFormField, + GlobalIDMultipleChoiceField) +from graphql_relay.node.node import from_global_id class GlobalIDFilter(Filter): @@ -45,6 +46,7 @@ GRAPHENE_FILTER_SET_OVERRIDES = { class GrapheneFilterSetMetaclass(FilterSetMetaclass): + def __new__(cls, name, bases, attrs): new_class = super(GrapheneFilterSetMetaclass, cls).__new__(cls, name, bases, attrs) # Customise the filter_overrides for Graphene @@ -84,7 +86,6 @@ class GrapheneFilterSet(six.with_metaclass(GrapheneFilterSetMetaclass, GrapheneF DjangoFilterConnectionField will wrap FilterSets with this class as necessary """ - pass def setup_filterset(filterset_class): diff --git a/graphene/contrib/django/filter/resolvers.py b/graphene/contrib/django/filter/resolvers.py index c2204d6c..76b3e7ad 100644 --- a/graphene/contrib/django/filter/resolvers.py +++ b/graphene/contrib/django/filter/resolvers.py @@ -1,6 +1,7 @@ from django.core.exceptions import ImproperlyConfigured -from graphene.contrib.django.filter.filterset import setup_filterset, custom_filterset_factory +from graphene.contrib.django.filter.filterset import (custom_filterset_factory, + setup_filterset) from graphene.contrib.django.resolvers import BaseQuerySetConnectionResolver diff --git a/graphene/contrib/django/filter/tests/__init__.py b/graphene/contrib/django/filter/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphene/contrib/django/tests/filter/filters.py b/graphene/contrib/django/filter/tests/filters.py similarity index 83% rename from graphene/contrib/django/tests/filter/filters.py rename to graphene/contrib/django/filter/tests/filters.py index 4549a83e..94c0dffe 100644 --- a/graphene/contrib/django/tests/filter/filters.py +++ b/graphene/contrib/django/filter/tests/filters.py @@ -1,7 +1,5 @@ import django_filters - -from graphene.contrib.django.tests.models import Reporter -from graphene.contrib.django.tests.models import Article, Pet +from graphene.contrib.django.tests.models import Article, Pet, Reporter class ArticleFilter(django_filters.FilterSet): diff --git a/graphene/contrib/django/tests/filter/test_fields.py b/graphene/contrib/django/filter/tests/test_fields.py similarity index 95% rename from graphene/contrib/django/tests/filter/test_fields.py rename to graphene/contrib/django/filter/tests/test_fields.py index 33cc2421..45c1f0d0 100644 --- a/graphene/contrib/django/tests/filter/test_fields.py +++ b/graphene/contrib/django/filter/tests/test_fields.py @@ -1,20 +1,19 @@ import pytest from graphene import ObjectType, Schema +from graphene.contrib.django import DjangoNode +from graphene.contrib.django.forms import (GlobalIDFormField, + GlobalIDMultipleChoiceField) +from graphene.contrib.django.tests.models import Article, Pet, Reporter from graphene.contrib.django.utils import DJANGO_FILTER_INSTALLED from graphene.relay import NodeField - -from graphene.contrib.django import DjangoNode -from graphene.contrib.django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField -from graphene.contrib.django.tests.models import Article, Pet, Reporter - pytestmark = [] if DJANGO_FILTER_INSTALLED: import django_filters from graphene.contrib.django.filter import (GlobalIDFilter, DjangoFilterConnectionField, GlobalIDMultipleChoiceFilter) - from graphene.contrib.django.tests.filter.filters import ArticleFilter, PetFilter + from graphene.contrib.django.filter.tests.filters import ArticleFilter, PetFilter else: pytestmark.append(pytest.mark.skipif(True, reason='django_filters not installed')) @@ -22,41 +21,46 @@ pytestmark.append(pytest.mark.django_db) class ArticleNode(DjangoNode): + class Meta: model = Article class ReporterNode(DjangoNode): + class Meta: model = Reporter class PetNode(DjangoNode): + class Meta: model = Pet +schema = Schema() + def assert_arguments(field, *arguments): ignore = ('after', 'before', 'first', 'last', 'orderBy') actual = [ name - for name in field.arguments.arguments.keys() + for name in schema.T(field.arguments) if name not in ignore and not name.startswith('_') ] assert set(arguments) == set(actual), \ 'Expected arguments ({}) did not match actual ({})'.format( arguments, actual - ) + ) def assert_orderable(field): - assert 'orderBy' in field.arguments.arguments.keys(), \ + assert 'orderBy' in schema.T(field.arguments), \ 'Field cannot be ordered' def assert_not_orderable(field): - assert 'orderBy' not in field.arguments.arguments.keys(), \ + assert 'orderBy' not in schema.T(field.arguments), \ 'Field can be ordered' @@ -117,6 +121,7 @@ def test_filter_shortcut_filterset_extra_meta(): def test_filter_filterset_information_on_meta(): class ReporterFilterNode(DjangoNode): + class Meta: model = Reporter filter_fields = ['first_name', 'articles'] @@ -129,12 +134,14 @@ def test_filter_filterset_information_on_meta(): def test_filter_filterset_information_on_meta_related(): class ReporterFilterNode(DjangoNode): + class Meta: model = Reporter filter_fields = ['first_name', 'articles'] filter_order_by = True class ArticleFilterNode(DjangoNode): + class Meta: model = Article filter_fields = ['headline', 'reporter'] @@ -163,6 +170,7 @@ def test_global_id_field_implicit(): def test_global_id_field_explicit(): class ArticleIdFilter(django_filters.FilterSet): + class Meta: model = Article fields = ['id'] @@ -192,6 +200,7 @@ def test_global_id_multiple_field_implicit(): def test_global_id_multiple_field_explicit(): class ReporterPetsFilter(django_filters.FilterSet): + class Meta: model = Reporter fields = ['pets'] @@ -213,6 +222,7 @@ def test_global_id_multiple_field_implicit_reverse(): def test_global_id_multiple_field_explicit_reverse(): class ReporterPetsFilter(django_filters.FilterSet): + class Meta: model = Reporter fields = ['articles'] diff --git a/graphene/contrib/django/tests/filter/test_resolvers.py b/graphene/contrib/django/filter/tests/test_resolvers.py similarity index 91% rename from graphene/contrib/django/tests/filter/test_resolvers.py rename to graphene/contrib/django/filter/tests/test_resolvers.py index a336cddf..670e87c8 100644 --- a/graphene/contrib/django/tests/filter/test_resolvers.py +++ b/graphene/contrib/django/filter/tests/test_resolvers.py @@ -1,17 +1,17 @@ import pytest from django.core.exceptions import ImproperlyConfigured +from graphene.contrib.django.tests.models import Article, Reporter +from graphene.contrib.django.tests.test_resolvers import (ArticleNode, + ReporterNode) from graphene.contrib.django.utils import DJANGO_FILTER_INSTALLED if DJANGO_FILTER_INSTALLED: from graphene.contrib.django.filter.resolvers import FilterConnectionResolver - from graphene.contrib.django.tests.filter.filters import ReporterFilter, ArticleFilter + from graphene.contrib.django.filter.tests.filters import ArticleFilter, ReporterFilter else: pytestmark = pytest.mark.skipif(True, reason='django_filters not installed') -from graphene.contrib.django.tests.models import Reporter, Article -from graphene.contrib.django.tests.test_resolvers import ReporterNode, ArticleNode - def test_filter_get_filterset_class_explicit(): reporter = Reporter(id=1, first_name='Cookie Monster') diff --git a/graphene/contrib/django/form_converter.py b/graphene/contrib/django/form_converter.py index 826c8c69..de2a40d8 100644 --- a/graphene/contrib/django/form_converter.py +++ b/graphene/contrib/django/form_converter.py @@ -1,9 +1,12 @@ from django import forms from django.forms.fields import BaseTemporalField -from graphene import String, Int, Boolean, Float, ID -from graphene.contrib.django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField + +from graphene import ID, Boolean, Float, Int, String +from graphene.contrib.django.forms import (GlobalIDFormField, + GlobalIDMultipleChoiceField) from graphene.contrib.django.utils import import_single_dispatch from graphene.core.types.definitions import List + singledispatch = import_single_dispatch() try: diff --git a/graphene/contrib/django/forms.py b/graphene/contrib/django/forms.py index f971897b..88f1665e 100644 --- a/graphene/contrib/django/forms.py +++ b/graphene/contrib/django/forms.py @@ -1,7 +1,7 @@ import binascii from django.core.exceptions import ValidationError -from django.forms import Field, IntegerField, CharField, MultipleChoiceField +from django.forms import CharField, Field, IntegerField, MultipleChoiceField from django.utils.translation import ugettext_lazy as _ from graphql_relay import from_global_id diff --git a/graphene/contrib/django/options.py b/graphene/contrib/django/options.py index 55868dd7..dbd88aca 100644 --- a/graphene/contrib/django/options.py +++ b/graphene/contrib/django/options.py @@ -1,7 +1,7 @@ -from .utils import DJANGO_FILTER_INSTALLED from ...core.classtypes.objecttype import ObjectTypeOptions from ...relay.types import Node from ...relay.utils import is_node +from .utils import DJANGO_FILTER_INSTALLED VALID_ATTRS = ('model', 'only_fields', 'exclude_fields') diff --git a/graphene/contrib/django/tests/test_converter.py b/graphene/contrib/django/tests/test_converter.py index dcbb3e30..3a02b03a 100644 --- a/graphene/contrib/django/tests/test_converter.py +++ b/graphene/contrib/django/tests/test_converter.py @@ -9,8 +9,8 @@ from graphene.contrib.django.fields import (ConnectionOrListField, from .models import Article, Reporter -def assert_conversion(django_field, graphene_field, *args): - field = django_field(*args, help_text='Custom Help Text') +def assert_conversion(django_field, graphene_field, *args, **kwargs): + field = django_field(help_text='Custom Help Text', *args, **kwargs) graphene_type = convert_django_field(field) assert isinstance(graphene_type, graphene_field) field = graphene_type.as_field() @@ -49,7 +49,7 @@ def test_should_url_convert_string(): def test_should_auto_convert_id(): - assert_conversion(models.AutoField, graphene.ID) + assert_conversion(models.AutoField, graphene.ID, primary_key=True) def test_should_positive_integer_convert_int(): diff --git a/graphene/contrib/django/tests/test_form_converter.py b/graphene/contrib/django/tests/test_form_converter.py index 7492fc51..44d9bec3 100644 --- a/graphene/contrib/django/tests/test_form_converter.py +++ b/graphene/contrib/django/tests/test_form_converter.py @@ -1,10 +1,9 @@ from django import forms -from graphene.core.types import List, ID from py.test import raises import graphene from graphene.contrib.django.form_converter import convert_form_field - +from graphene.core.types import ID, List from .models import Reporter diff --git a/graphene/contrib/django/tests/test_query.py b/graphene/contrib/django/tests/test_query.py index 090c8695..460c8e22 100644 --- a/graphene/contrib/django/tests/test_query.py +++ b/graphene/contrib/django/tests/test_query.py @@ -1,3 +1,4 @@ +import pytest from py.test import raises import graphene @@ -6,6 +7,8 @@ from graphene.contrib.django import DjangoNode, DjangoObjectType from .models import Article, Reporter +pytestmark = pytest.mark.django_db + def test_should_query_only_fields(): with raises(Exception): diff --git a/graphene/contrib/django/tests/test_resolvers.py b/graphene/contrib/django/tests/test_resolvers.py index fe617666..db1610c9 100644 --- a/graphene/contrib/django/tests/test_resolvers.py +++ b/graphene/contrib/django/tests/test_resolvers.py @@ -3,15 +3,17 @@ from django.db.models.query import QuerySet from graphene.contrib.django import DjangoNode from graphene.contrib.django.resolvers import SimpleQuerySetConnectionResolver -from graphene.contrib.django.tests.models import Reporter, Article +from graphene.contrib.django.tests.models import Article, Reporter class ReporterNode(DjangoNode): + class Meta: model = Reporter class ArticleNode(DjangoNode): + class Meta: model = Article @@ -34,7 +36,7 @@ def test_simple_get_manager_all(): reporter = Reporter(id=1, first_name='Cookie Monster') resolver = SimpleQuerySetConnectionResolver(ReporterNode) resolver(inst=reporter, args={}, info=None) - assert type(resolver.get_manager()) == Manager, 'Resolver did not return a Manager' + assert isinstance(resolver.get_manager(), Manager), 'Resolver did not return a Manager' def test_simple_filter(): diff --git a/graphene/contrib/django/tests/test_urls.py b/graphene/contrib/django/tests/test_urls.py index 409471d4..9d38980e 100644 --- a/graphene/contrib/django/tests/test_urls.py +++ b/graphene/contrib/django/tests/test_urls.py @@ -29,7 +29,15 @@ class Human(DjangoNode): def get_node(self, id): pass -schema = Schema(query=Human) + +class Query(graphene.ObjectType): + human = graphene.Field(Human) + + def resolve_human(self, args, info): + return Human() + + +schema = Schema(query=Query) urlpatterns = [ diff --git a/graphene/contrib/django/tests/test_views.py b/graphene/contrib/django/tests/test_views.py index f82be99f..b4e1b367 100644 --- a/graphene/contrib/django/tests/test_views.py +++ b/graphene/contrib/django/tests/test_views.py @@ -7,11 +7,13 @@ def format_response(response): def test_client_get_good_query(settings, client): settings.ROOT_URLCONF = 'graphene.contrib.django.tests.test_urls' - response = client.get('/graphql', {'query': '{ headline }'}) + response = client.get('/graphql', {'query': '{ human { headline } }'}) json_response = format_response(response) expected_json = { 'data': { - 'headline': None + 'human': { + 'headline': None + } } } assert json_response == expected_json @@ -19,20 +21,22 @@ def test_client_get_good_query(settings, client): def test_client_get_good_query_with_raise(settings, client): settings.ROOT_URLCONF = 'graphene.contrib.django.tests.test_urls' - response = client.get('/graphql', {'query': '{ raises }'}) + response = client.get('/graphql', {'query': '{ human { raises } }'}) json_response = format_response(response) assert json_response['errors'][0]['message'] == 'This field should raise exception' - assert json_response['data']['raises'] is None + assert json_response['data']['human']['raises'] is None def test_client_post_good_query_json(settings, client): settings.ROOT_URLCONF = 'graphene.contrib.django.tests.test_urls' response = client.post( - '/graphql', json.dumps({'query': '{ headline }'}), 'application/json') + '/graphql', json.dumps({'query': '{ human { headline } }'}), 'application/json') json_response = format_response(response) expected_json = { 'data': { - 'headline': None + 'human': { + 'headline': None + } } } assert json_response == expected_json @@ -41,11 +45,13 @@ def test_client_post_good_query_json(settings, client): def test_client_post_good_query_graphql(settings, client): settings.ROOT_URLCONF = 'graphene.contrib.django.tests.test_urls' response = client.post( - '/graphql', '{ headline }', 'application/graphql') + '/graphql', '{ human { headline } }', 'application/graphql') json_response = format_response(response) expected_json = { 'data': { - 'headline': None + 'human': { + 'headline': None + } } } assert json_response == expected_json diff --git a/graphene/contrib/django/types.py b/graphene/contrib/django/types.py index e25130df..5b68ebbb 100644 --- a/graphene/contrib/django/types.py +++ b/graphene/contrib/django/types.py @@ -5,7 +5,6 @@ from django.db import models from ...core.classtypes.objecttype import ObjectType, ObjectTypeMeta from ...relay.types import Connection, Node, NodeMeta -from .utils import DJANGO_FILTER_INSTALLED from .converter import convert_django_field from .options import DjangoOptions from .utils import get_reverse_fields, maybe_queryset @@ -30,12 +29,9 @@ class DjangoObjectTypeMeta(ObjectTypeMeta): # We skip this field if we specify only_fields and is not # in there. Or when we exclude this field in exclude_fields continue - converted_field = cls.convert_django_field(field) + converted_field = convert_django_field(field) cls.add_to_class(field.name, converted_field) - def convert_django_field(cls, field): - return convert_django_field(field) - def construct(cls, *args, **kwargs): cls = super(DjangoObjectTypeMeta, cls).construct(*args, **kwargs) if not cls._meta.abstract: @@ -50,15 +46,6 @@ class DjangoObjectTypeMeta(ObjectTypeMeta): return cls -class DjangoFilterObjectTypeMeta(ObjectTypeMeta): - - def convert_django_field(cls, field): - from graphene.contrib.django.filter import DjangoFilterConnectionField - field = super(DjangoFilterObjectTypeMeta, cls).convert_django_field(field) - field.connection_field_class = DjangoFilterConnectionField - return field - - class InstanceObjectType(ObjectType): class Meta: @@ -102,13 +89,7 @@ class DjangoConnection(Connection): return super(DjangoConnection, cls).from_list(iterable, *args, **kwargs) -django_node_meta_bases = (DjangoObjectTypeMeta, NodeMeta) -# Only include filter functionality if available -if DJANGO_FILTER_INSTALLED: - django_node_meta_bases = (DjangoFilterObjectTypeMeta,) + django_node_meta_bases - - -class DjangoNodeMeta(*django_node_meta_bases): +class DjangoNodeMeta(DjangoObjectTypeMeta, NodeMeta): pass diff --git a/graphene/contrib/django/utils.py b/graphene/contrib/django/utils.py index 38bf7546..76f4477c 100644 --- a/graphene/contrib/django/utils.py +++ b/graphene/contrib/django/utils.py @@ -3,9 +3,10 @@ from django.db import models from django.db.models.manager import Manager from django.db.models.query import QuerySet +from graphene import Argument, String from graphene.utils import LazyList -from graphene import Argument, String +from .compat import RelatedObject try: import django_filters # noqa @@ -29,7 +30,12 @@ def get_reverse_fields(model): # Django =>1.9 uses 'rel', django <1.9 uses 'related' related = getattr(attr, 'rel', None) or \ getattr(attr, 'related', None) - if isinstance(related, models.ManyToOneRel): + if isinstance(related, RelatedObject): + # Hack for making it compatible with Django 1.6 + new_related = RelatedObject(related.parent_model, related.model, related.field) + new_related.name = name + yield new_related + elif isinstance(related, models.ManyToOneRel): yield related @@ -66,10 +72,17 @@ def get_filtering_args_from_filterset(filterset_class, type): # Also add the 'order_by' field if filterset_class._meta.order_by: - args[filterset_class.order_by_field] = Argument(String) + args[filterset_class.order_by_field] = Argument(String()) return args +def get_related_model(field): + if hasattr(field, 'rel'): + # Django 1.6, 1.7 + return field.rel.to + return field.related_model + + def import_single_dispatch(): try: from functools import singledispatch diff --git a/graphene/contrib/django/views.py b/graphene/contrib/django/views.py index ad245d72..9a4bd96e 100644 --- a/graphene/contrib/django/views.py +++ b/graphene/contrib/django/views.py @@ -12,5 +12,5 @@ class GraphQLView(BaseGraphQLView): **kwargs ) - def get_root_value(self, request): - return self.graphene_schema.query(super(GraphQLView, self).get_root_value(request)) + def execute(self, *args, **kwargs): + return self.graphene_schema.execute(*args, **kwargs) diff --git a/graphene/core/__init__.py b/graphene/core/__init__.py index d27a72bb..9e8c7108 100644 --- a/graphene/core/__init__.py +++ b/graphene/core/__init__.py @@ -11,7 +11,7 @@ from .classtypes import ( ) from .types import ( - BaseType, + InstanceType, LazyType, Argument, Field, @@ -35,7 +35,7 @@ __all__ = [ 'List', 'NonNull', 'Schema', - 'BaseType', + 'InstanceType', 'LazyType', 'ObjectType', 'InputObjectType', diff --git a/graphene/core/classtypes/base.py b/graphene/core/classtypes/base.py index 4f2de009..cabb909a 100644 --- a/graphene/core/classtypes/base.py +++ b/graphene/core/classtypes/base.py @@ -1,11 +1,10 @@ import copy import inspect -from functools import partial from collections import OrderedDict +from functools import partial import six -from ..exceptions import SkipField from .options import Options @@ -82,13 +81,18 @@ class FieldsOptions(Options): def fields_map(self): return OrderedDict([(f.attname, f) for f in self.fields]) + @property + def fields_group_type(self): + from ..types.field import FieldsGroupType + return FieldsGroupType(*self.local_fields) + class FieldsClassTypeMeta(ClassTypeMeta): options_class = FieldsOptions def extend_fields(cls, bases): new_fields = cls._meta.local_fields - field_names = {f.name: f for f in new_fields} + field_names = {f.attname: f for f in new_fields} for base in bases: if not isinstance(base, FieldsClassTypeMeta): @@ -96,17 +100,17 @@ class FieldsClassTypeMeta(ClassTypeMeta): parent_fields = base._meta.local_fields for field in parent_fields: - if field.name in field_names and field.type.__class__ != field_names[ - field.name].type.__class__: + if field.attname in field_names and field.type.__class__ != field_names[ + field.attname].type.__class__: raise Exception( 'Local field %r in class %r (%r) clashes ' 'with field with similar name from ' 'Interface %s (%r)' % ( - field.name, + field.attname, cls.__name__, field.__class__, base.__name__, - field_names[field.name].__class__) + field_names[field.attname].__class__) ) new_field = copy.copy(field) cls.add_to_class(field.attname, new_field) @@ -124,11 +128,4 @@ class FieldsClassType(six.with_metaclass(FieldsClassTypeMeta, ClassType)): @classmethod def fields_internal_types(cls, schema): - fields = [] - for field in cls._meta.fields: - try: - fields.append((field.name, schema.T(field))) - except SkipField: - continue - - return OrderedDict(fields) + return schema.T(cls._meta.fields_group_type) diff --git a/graphene/core/classtypes/tests/test_mutation.py b/graphene/core/classtypes/tests/test_mutation.py index ac32585e..85dd2368 100644 --- a/graphene/core/classtypes/tests/test_mutation.py +++ b/graphene/core/classtypes/tests/test_mutation.py @@ -24,4 +24,4 @@ def test_mutation(): assert list(object_type.get_fields().keys()) == ['name'] assert MyMutation._meta.fields_map['name'].object_type == MyMutation assert isinstance(MyMutation.arguments, ArgumentsGroup) - assert 'argName' in MyMutation.arguments + assert 'argName' in schema.T(MyMutation.arguments) diff --git a/graphene/core/schema.py b/graphene/core/schema.py index 1b0ce8f9..c8695317 100644 --- a/graphene/core/schema.py +++ b/graphene/core/schema.py @@ -10,8 +10,9 @@ from graphql.core.utils.schema_printer import print_schema from graphene import signals +from ..plugins import CamelCase, PluginManager from .classtypes.base import ClassType -from .types.base import BaseType +from .types.base import InstanceType class GraphQLSchema(_GraphQLSchema): @@ -25,7 +26,7 @@ class Schema(object): _executor = None def __init__(self, query=None, mutation=None, subscription=None, - name='Schema', executor=None): + name='Schema', executor=None, plugins=None, auto_camelcase=True): self._types_names = {} self._types = {} self.mutation = mutation @@ -33,27 +34,34 @@ class Schema(object): self.subscription = subscription self.name = name self.executor = executor + plugins = plugins or [] + if auto_camelcase: + plugins.append(CamelCase()) + self.plugins = PluginManager(self, plugins) signals.init_schema.send(self) def __repr__(self): return '' % (str(self.name), hash(self)) - def T(self, object_type): - if not object_type: + def __getattr__(self, name): + if name in self.plugins: + return getattr(self.plugins, name) + return super(Schema, self).__getattr__(name) + + def T(self, _type): + if not _type: return - if inspect.isclass(object_type) and issubclass( - object_type, (BaseType, ClassType)) or isinstance( - object_type, BaseType): - if object_type not in self._types: - internal_type = object_type.internal_type(self) - self._types[object_type] = internal_type - is_objecttype = inspect.isclass( - object_type) and issubclass(object_type, ClassType) - if is_objecttype: - self.register(object_type) - return self._types[object_type] + is_classtype = inspect.isclass(_type) and issubclass(_type, ClassType) + is_instancetype = isinstance(_type, InstanceType) + if is_classtype or is_instancetype: + if _type not in self._types: + internal_type = _type.internal_type(self) + self._types[_type] = internal_type + if is_classtype: + self.register(_type) + return self._types[_type] else: - return object_type + return _type @property def executor(self): @@ -76,9 +84,9 @@ class Schema(object): mutation=self.T(self.mutation), subscription=self.T(self.subscription)) - def register(self, object_type): + def register(self, object_type, force=False): type_name = object_type._meta.type_name - registered_object_type = self._types_names.get(type_name, None) + registered_object_type = not force and self._types_names.get(type_name, None) if registered_object_type: assert registered_object_type == object_type, 'Type {} already registered with other object type'.format( type_name) @@ -110,17 +118,10 @@ class Schema(object): def types(self): return self._types_names - def execute(self, request='', root=None, vars=None, - operation_name=None, **kwargs): - root = root or object() - return self.executor.execute( - self.schema, - request, - root=root, - args=vars, - operation_name=operation_name, - **kwargs - ) + def execute(self, request='', root=None, args=None, **kwargs): + kwargs = dict(kwargs, request=request, root=root, args=args, schema=self.schema) + with self.plugins.context_execution(**kwargs) as execute_kwargs: + return self.executor.execute(**execute_kwargs) def introspect(self): return self.execute(introspection_query).data diff --git a/graphene/core/tests/test_old_fields.py b/graphene/core/tests/test_old_fields.py index 95bf9aab..3f24aedf 100644 --- a/graphene/core/tests/test_old_fields.py +++ b/graphene/core/tests/test_old_fields.py @@ -34,10 +34,11 @@ def test_field_type(): assert schema.T(f).type == GraphQLString -def test_field_name_automatic_camelcase(): +def test_field_name(): f = Field(GraphQLString) f.contribute_to_class(MyOt, 'field_name') - assert f.name == 'fieldName' + assert f.name is None + assert f.attname == 'field_name' def test_field_name_use_name_if_exists(): diff --git a/graphene/core/types/__init__.py b/graphene/core/types/__init__.py index 9260476c..51512ec4 100644 --- a/graphene/core/types/__init__.py +++ b/graphene/core/types/__init__.py @@ -1,14 +1,14 @@ -from .base import BaseType, LazyType, OrderedType +from .base import InstanceType, LazyType, OrderedType from .argument import Argument, ArgumentsGroup, to_arguments from .definitions import List, NonNull # Compatibility import from .objecttype import Interface, ObjectType, Mutation, InputObjectType -from .scalars import String, ID, Boolean, Int, Float, Scalar +from .scalars import String, ID, Boolean, Int, Float from .field import Field, InputField __all__ = [ - 'BaseType', + 'InstanceType', 'LazyType', 'OrderedType', 'Argument', @@ -26,5 +26,4 @@ __all__ = [ 'ID', 'Boolean', 'Int', - 'Float', - 'Scalar'] + 'Float'] diff --git a/graphene/core/types/argument.py b/graphene/core/types/argument.py index 0892c446..b10aff21 100644 --- a/graphene/core/types/argument.py +++ b/graphene/core/types/argument.py @@ -1,19 +1,17 @@ -from collections import OrderedDict from functools import wraps from itertools import chain from graphql.core.type import GraphQLArgument -from ...utils import ProxySnakeDict, to_camel_case -from .base import ArgumentType, BaseType, OrderedType +from ...utils import ProxySnakeDict +from .base import ArgumentType, GroupNamedType, NamedType, OrderedType -class Argument(OrderedType): +class Argument(NamedType, OrderedType): def __init__(self, type, description=None, default=None, name=None, _creation_counter=None): - super(Argument, self).__init__(_creation_counter=_creation_counter) - self.name = name + super(Argument, self).__init__(name=name, _creation_counter=_creation_counter) self.type = type self.description = description self.default = default @@ -27,47 +25,32 @@ class Argument(OrderedType): return self.name -class ArgumentsGroup(BaseType): +class ArgumentsGroup(GroupNamedType): def __init__(self, *args, **kwargs): arguments = to_arguments(*args, **kwargs) - self.arguments = OrderedDict([(arg.name, arg) for arg in arguments]) - - def internal_type(self, schema): - return OrderedDict([(arg.name, schema.T(arg)) - for arg in self.arguments.values()]) - - def __len__(self): - return len(self.arguments) - - def __iter__(self): - return iter(self.arguments) - - def __contains__(self, *args): - return self.arguments.__contains__(*args) - - def __getitem__(self, *args): - return self.arguments.__getitem__(*args) + super(ArgumentsGroup, self).__init__(*arguments) def to_arguments(*args, **kwargs): arguments = {} iter_arguments = chain(kwargs.items(), [(None, a) for a in args]) - for name, arg in iter_arguments: + for default_name, arg in iter_arguments: if isinstance(arg, Argument): argument = arg elif isinstance(arg, ArgumentType): argument = arg.as_argument() else: - raise ValueError('Unknown argument %s=%r' % (name, arg)) + raise ValueError('Unknown argument %s=%r' % (default_name, arg)) - if name: - argument.name = to_camel_case(name) - assert argument.name, 'Argument in field must have a name' - assert argument.name not in arguments, 'Found more than one Argument with same name {}'.format( - argument.name) - arguments[argument.name] = argument + if default_name: + argument.default_name = default_name + + name = argument.name or argument.default_name + assert name, 'Argument in field must have a name' + assert name not in arguments, 'Found more than one Argument with same name {}'.format(name) + arguments[name] = argument return sorted(arguments.values()) diff --git a/graphene/core/types/base.py b/graphene/core/types/base.py index 2b4078e4..ec8c7b3b 100644 --- a/graphene/core/types/base.py +++ b/graphene/core/types/base.py @@ -1,16 +1,16 @@ -from functools import total_ordering +from collections import OrderedDict +from functools import partial, total_ordering import six -class BaseType(object): +class InstanceType(object): - @classmethod - def internal_type(cls, schema): - return getattr(cls, 'T', None) + def internal_type(self, schema): + raise NotImplementedError("internal_type for type {} is not implemented".format(self.__class__.__name__)) -class MountType(BaseType): +class MountType(InstanceType): parent = None def mount(self, cls): @@ -126,3 +126,39 @@ class FieldType(MirroredType): class MountedType(FieldType, ArgumentType): pass + + +class NamedType(InstanceType): + + def __init__(self, name=None, default_name=None, *args, **kwargs): + self.name = name + self.default_name = None + super(NamedType, self).__init__(*args, **kwargs) + + +class GroupNamedType(InstanceType): + + def __init__(self, *types): + self.types = types + + def get_named_type(self, schema, type): + name = type.name or schema.get_default_namedtype_name(type.default_name) + return name, schema.T(type) + + def iter_types(self, schema): + return map(partial(self.get_named_type, schema), self.types) + + def internal_type(self, schema): + return OrderedDict(self.iter_types(schema)) + + def __len__(self): + return len(self.types) + + def __iter__(self): + return iter(self.types) + + def __contains__(self, *args): + return self.types.__contains__(*args) + + def __getitem__(self, *args): + return self.types.__getitem__(*args) diff --git a/graphene/core/types/field.py b/graphene/core/types/field.py index c3fa712f..6cbfff96 100644 --- a/graphene/core/types/field.py +++ b/graphene/core/types/field.py @@ -4,23 +4,26 @@ from functools import wraps import six from graphql.core.type import GraphQLField, GraphQLInputObjectField -from ...utils import to_camel_case from ..classtypes.base import FieldsClassType from ..classtypes.inputobjecttype import InputObjectType from ..classtypes.mutation import Mutation -from .argument import ArgumentsGroup, snake_case_args -from .base import LazyType, MountType, OrderedType +from ..exceptions import SkipField +from .argument import Argument, ArgumentsGroup, snake_case_args +from .base import (ArgumentType, GroupNamedType, LazyType, MountType, + NamedType, OrderedType) from .definitions import NonNull -class Field(OrderedType): +class Field(NamedType, OrderedType): def __init__( self, type, description=None, args=None, name=None, resolver=None, required=False, default=None, *args_list, **kwargs): _creation_counter = kwargs.pop('_creation_counter', None) - super(Field, self).__init__(_creation_counter=_creation_counter) - self.name = name + if isinstance(name, (Argument, ArgumentType)): + kwargs['name'] = name + name = None + super(Field, self).__init__(name=name, _creation_counter=_creation_counter) if isinstance(type, six.string_types): type = LazyType(type) self.required = required @@ -36,9 +39,8 @@ class Field(OrderedType): assert issubclass( cls, (FieldsClassType)), 'Field {} cannot be mounted in {}'.format( self, cls) - if not self.name: - self.name = to_camel_case(attname) self.attname = attname + self.default_name = attname self.object_type = cls self.mount(cls) if isinstance(self.type, MountType): @@ -117,12 +119,11 @@ class Field(OrderedType): return hash((self.creation_counter, self.object_type)) -class InputField(OrderedType): +class InputField(NamedType, OrderedType): def __init__(self, type, description=None, default=None, name=None, _creation_counter=None, required=False): super(InputField, self).__init__(_creation_counter=_creation_counter) - self.name = name if required: type = NonNull(type) self.type = type @@ -133,9 +134,8 @@ class InputField(OrderedType): assert issubclass( cls, (InputObjectType)), 'InputField {} cannot be mounted in {}'.format( self, cls) - if not self.name: - self.name = to_camel_case(attname) self.attname = attname + self.default_name = attname self.object_type = cls self.mount(cls) if isinstance(self.type, MountType): @@ -146,3 +146,13 @@ class InputField(OrderedType): return GraphQLInputObjectField( schema.T(self.type), default_value=self.default, description=self.description) + + +class FieldsGroupType(GroupNamedType): + + def iter_types(self, schema): + for field in sorted(self.types): + try: + yield self.get_named_type(schema, field) + except SkipField: + continue diff --git a/graphene/core/types/scalars.py b/graphene/core/types/scalars.py index 75cd70a3..9d7f5aeb 100644 --- a/graphene/core/types/scalars.py +++ b/graphene/core/types/scalars.py @@ -1,41 +1,30 @@ from graphql.core.type import (GraphQLBoolean, GraphQLFloat, GraphQLID, - GraphQLInt, GraphQLScalarType, GraphQLString) + GraphQLInt, GraphQLString) from .base import MountedType -class String(MountedType): - T = GraphQLString +class ScalarType(MountedType): + + def internal_type(self, schema): + return self._internal_type -class Int(MountedType): - T = GraphQLInt +class String(ScalarType): + _internal_type = GraphQLString -class Boolean(MountedType): - T = GraphQLBoolean +class Int(ScalarType): + _internal_type = GraphQLInt -class ID(MountedType): - T = GraphQLID +class Boolean(ScalarType): + _internal_type = GraphQLBoolean -class Float(MountedType): - T = GraphQLFloat +class ID(ScalarType): + _internal_type = GraphQLID -class Scalar(MountedType): - - @classmethod - def internal_type(cls, schema): - serialize = getattr(cls, 'serialize') - parse_literal = getattr(cls, 'parse_literal') - parse_value = getattr(cls, 'parse_value') - - return GraphQLScalarType( - name=cls.__name__, - description=cls.__doc__, - serialize=serialize, - parse_value=parse_value, - parse_literal=parse_literal - ) +class Float(ScalarType): + _internal_type = GraphQLFloat diff --git a/graphene/core/types/tests/test_argument.py b/graphene/core/types/tests/test_argument.py index 26bbb310..b2f5e239 100644 --- a/graphene/core/types/tests/test_argument.py +++ b/graphene/core/types/tests/test_argument.py @@ -27,8 +27,8 @@ def test_to_arguments(): other_kwarg=String(), ) - assert [a.name for a in arguments] == [ - 'myArg', 'otherArg', 'myKwarg', 'otherKwarg'] + assert [a.name or a.default_name for a in arguments] == [ + 'myArg', 'otherArg', 'my_kwarg', 'other_kwarg'] def test_to_arguments_no_name(): diff --git a/graphene/core/types/tests/test_field.py b/graphene/core/types/tests/test_field.py index 8253ed20..706cbc59 100644 --- a/graphene/core/types/tests/test_field.py +++ b/graphene/core/types/tests/test_field.py @@ -13,14 +13,14 @@ from ..scalars import String def test_field_internal_type(): resolver = lambda *args: 'RESOLVED' - field = Field(String, description='My argument', resolver=resolver) + field = Field(String(), description='My argument', resolver=resolver) class Query(ObjectType): my_field = field schema = Schema(query=Query) type = schema.T(field) - assert field.name == 'myField' + assert field.name is None assert field.attname == 'my_field' assert isinstance(type, GraphQLField) assert type.description == 'My argument' @@ -98,9 +98,18 @@ def test_field_string_reference(): def test_field_custom_arguments(): field = Field(None, name='my_customName', p=String()) + schema = Schema() args = field.arguments - assert 'p' in args + assert 'p' in schema.T(args) + + +def test_field_name_as_argument(): + field = Field(None, name=String()) + schema = Schema() + + args = field.arguments + assert 'name' in schema.T(args) def test_inputfield_internal_type(): @@ -115,8 +124,43 @@ def test_inputfield_internal_type(): schema = Schema(query=MyObjectType) type = schema.T(field) - assert field.name == 'myField' + assert field.name is None assert field.attname == 'my_field' assert isinstance(type, GraphQLInputObjectField) assert type.description == 'My input field' assert type.default_value == '3' + + +def test_field_resolve_argument(): + resolver = lambda instance, args, info: args.get('first_name') + + field = Field(String(), first_name=String(), description='My argument', resolver=resolver) + + class Query(ObjectType): + my_field = field + schema = Schema(query=Query) + + type = schema.T(field) + assert type.resolver(None, {'firstName': 'Peter'}, None) == 'Peter' + + +def test_field_resolve_vars(): + class Query(ObjectType): + hello = String(first_name=String()) + + def resolve_hello(self, args, info): + return 'Hello ' + args.get('first_name') + + schema = Schema(query=Query) + + result = schema.execute(""" + query foo($firstName:String) + { + hello(firstName:$firstName) + } + """, args={"firstName": "Serkan"}) + + expected = { + 'hello': 'Hello Serkan' + } + assert result.data == expected diff --git a/graphene/core/types/tests/test_scalars.py b/graphene/core/types/tests/test_scalars.py index 8b8d930f..39fd6063 100644 --- a/graphene/core/types/tests/test_scalars.py +++ b/graphene/core/types/tests/test_scalars.py @@ -1,9 +1,9 @@ from graphql.core.type import (GraphQLBoolean, GraphQLFloat, GraphQLID, - GraphQLInt, GraphQLScalarType, GraphQLString) + GraphQLInt, GraphQLString) from graphene.core.schema import Schema -from ..scalars import ID, Boolean, Float, Int, Scalar, String +from ..scalars import ID, Boolean, Float, Int, String schema = Schema() @@ -26,29 +26,3 @@ def test_id_scalar(): def test_float_scalar(): assert schema.T(Float()) == GraphQLFloat - - -def test_custom_scalar(): - import datetime - from graphql.core.language import ast - - class DateTimeScalar(Scalar): - '''DateTimeScalar Documentation''' - @staticmethod - def serialize(dt): - return dt.isoformat() - - @staticmethod - def parse_literal(node): - if isinstance(node, ast.StringValue): - return datetime.datetime.strptime( - node.value, "%Y-%m-%dT%H:%M:%S.%f") - - @staticmethod - def parse_value(value): - return datetime.datetime.strptime(value, "%Y-%m-%dT%H:%M:%S.%f") - - scalar_type = schema.T(DateTimeScalar) - assert isinstance(scalar_type, GraphQLScalarType) - assert scalar_type.name == 'DateTimeScalar' - assert scalar_type.description == 'DateTimeScalar Documentation' diff --git a/graphene/plugins/__init__.py b/graphene/plugins/__init__.py new file mode 100644 index 00000000..160bffba --- /dev/null +++ b/graphene/plugins/__init__.py @@ -0,0 +1,6 @@ +from .base import Plugin, PluginManager +from .camel_case import CamelCase + +__all__ = [ + 'Plugin', 'PluginManager', 'CamelCase' +] diff --git a/graphene/plugins/base.py b/graphene/plugins/base.py new file mode 100644 index 00000000..2347beba --- /dev/null +++ b/graphene/plugins/base.py @@ -0,0 +1,53 @@ +from contextlib import contextmanager +from functools import partial, reduce + + +class Plugin(object): + + def contribute_to_schema(self, schema): + self.schema = schema + + +def apply_function(a, b): + return b(a) + + +class PluginManager(object): + + PLUGIN_FUNCTIONS = ('get_default_namedtype_name', ) + + def __init__(self, schema, plugins=[]): + self.schema = schema + self.plugins = [] + for plugin in plugins: + self.add_plugin(plugin) + + def add_plugin(self, plugin): + if hasattr(plugin, 'contribute_to_schema'): + plugin.contribute_to_schema(self.schema) + self.plugins.append(plugin) + + def get_plugin_functions(self, function): + for plugin in self.plugins: + if not hasattr(plugin, function): + continue + yield getattr(plugin, function) + + def __getattr__(self, name): + functions = self.get_plugin_functions(name) + return partial(reduce, apply_function, functions) + + def __contains__(self, name): + return name in self.PLUGIN_FUNCTIONS + + @contextmanager + def context_execution(self, **executor): + contexts = [] + functions = self.get_plugin_functions('context_execution') + for f in functions: + context = f(executor) + executor = context.__enter__() + contexts.append((context, executor)) + yield executor + for context, value in contexts[::-1]: + context.__exit__(None, None, None) diff --git a/graphene/plugins/camel_case.py b/graphene/plugins/camel_case.py new file mode 100644 index 00000000..d9a9084f --- /dev/null +++ b/graphene/plugins/camel_case.py @@ -0,0 +1,7 @@ +from ..utils import to_camel_case + + +class CamelCase(object): + + def get_default_namedtype_name(self, value): + return to_camel_case(value) diff --git a/graphene/relay/tests/test_mutations.py b/graphene/relay/tests/test_mutations.py index 4356a1ec..02287725 100644 --- a/graphene/relay/tests/test_mutations.py +++ b/graphene/relay/tests/test_mutations.py @@ -34,8 +34,7 @@ schema = Schema(query=Query, mutation=MyResultMutation) def test_mutation_arguments(): assert ChangeNumber.arguments - assert list(ChangeNumber.arguments) == ['input'] - assert 'input' in ChangeNumber.arguments + assert 'input' in schema.T(ChangeNumber.arguments) inner_type = ChangeNumber.input_type client_mutation_id_field = inner_type._meta.fields_map[ 'client_mutation_id'] diff --git a/setup.py b/setup.py index 2c65336b..ac0e3d0f 100644 --- a/setup.py +++ b/setup.py @@ -24,9 +24,9 @@ class PyTest(TestCommand): setup( name='graphene', - version='0.4.3', + version='0.5.0', - description='Graphene: Python DSL for GraphQL', + description='GraphQL Framework for Python', long_description=open('README.rst').read(), url='https://github.com/graphql-python/graphene', @@ -66,9 +66,9 @@ setup( ], extras_require={ 'django': [ - 'Django>=1.8.0', + 'Django>=1.6.0', 'singledispatch>=3.4.0.3', - 'graphql-django-view>=1.0.0', + 'graphql-django-view>=1.1.0', ], },