diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 00000000..5ebeb47b --- /dev/null +++ b/.editorconfig @@ -0,0 +1,14 @@ +# http://editorconfig.org + +root = true + +[*] +charset = utf-8 +end_of_line = lf +insert_final_newline = true +trim_trailing_whitespace = true + +[*.{py,rst,ini}] +indent_style = space +indent_size = 4 + diff --git a/.travis.yml b/.travis.yml index fba19076..3dbb00e0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,7 +2,6 @@ language: python sudo: false python: - 2.7 -- 3.3 - 3.4 - 3.5 - pypy @@ -14,8 +13,9 @@ cache: - $HOME/docs/node_modules before_install: - | - if [ "$TEST_TYPE" != build_website ] && \ - ! git diff --name-only $TRAVIS_COMMIT_RANGE | grep -qvE '(\.md$)|(^(docs))/' + git_diff=$(git diff --name-only $TRAVIS_COMMIT_RANGE) + if [ "$?" == 0 ] && [ "$TEST_TYPE" != build_website ] && \ + ! echo "$git_diff" | grep -qvE '(\.md$)|(^(docs))/' then echo "Only docs were updated, stopping build process." exit @@ -25,6 +25,7 @@ install: if [ "$TEST_TYPE" = build ]; then pip install --download-cache $HOME/.cache/pip/ pytest pytest-cov coveralls six pytest-django django-filter pip install --download-cache $HOME/.cache/pip/ -e .[django] + pip install django==$DJANGO_VERSION python setup.py develop elif [ "$TEST_TYPE" = build_website ]; then pip install --download-cache $HOME/.cache/pip/ -e . @@ -78,6 +79,14 @@ env: matrix: fast_finish: true include: + - python: '2.7' + env: DJANGO_VERSION=1.6 + - python: '2.7' + env: DJANGO_VERSION=1.7 + - python: '2.7' + env: DJANGO_VERSION=1.8 + - python: '2.7' + env: DJANGO_VERSION=1.9 - python: '2.7' env: TEST_TYPE=build_website - python: '2.7' diff --git a/docs/pages/docs/basic-types.md b/docs/pages/docs/basic-types.md index b1acb2f6..defd78b5 100644 --- a/docs/pages/docs/basic-types.md +++ b/docs/pages/docs/basic-types.md @@ -82,3 +82,21 @@ graphene.Field(graphene.String(), to=graphene.String()) # Is equivalent to: graphene.Field(graphene.String(), to=graphene.Argument(graphene.String())) ``` + + +## Using custom object types as argument + +To use a custom object type as an argument, you need to inherit `graphene.InputObjectType`, not `graphene.ObjectType`. + +```python +class CustomArgumentObjectType(graphene.InputObjectType): + field1 = graphene.String() + field2 = graphene.String() + +``` + +Then, when defining this in an argument, you need to wrap it in an `Argument` object. + +```python +graphene.Field(graphene.String(), to=graphene.Argument(CustomArgumentObjectType)) +``` diff --git a/examples/starwars_django/schema.py b/examples/starwars_django/schema.py index 501ccce6..88407103 100644 --- a/examples/starwars_django/schema.py +++ b/examples/starwars_django/schema.py @@ -63,15 +63,15 @@ class Query(graphene.ObjectType): @resolve_only_args def resolve_ships(self): - return [Ship(s) for s in get_ships()] + return get_ships() @resolve_only_args def resolve_rebels(self): - return Faction(get_rebels()) + return get_rebels() @resolve_only_args def resolve_empire(self): - return Faction(get_empire()) + return get_empire() class Mutation(graphene.ObjectType): diff --git a/graphene/__init__.py b/graphene/__init__.py index 88d3a365..71066499 100644 --- a/graphene/__init__.py +++ b/graphene/__init__.py @@ -4,20 +4,14 @@ from graphql.core.type import ( from graphene import signals -from graphene.core.schema import ( - Schema -) - -from graphene.core.classtypes import ( +from .core import ( + Schema, ObjectType, InputObjectType, Interface, Mutation, - Scalar -) - -from graphene.core.types import ( - BaseType, + Scalar, + InstanceType, LazyType, Argument, Field, @@ -57,7 +51,7 @@ __all__ = [ 'NonNull', 'signals', 'Schema', - 'BaseType', + 'InstanceType', 'LazyType', 'ObjectType', 'InputObjectType', diff --git a/graphene/contrib/django/compat.py b/graphene/contrib/django/compat.py new file mode 100644 index 00000000..a5b444c7 --- /dev/null +++ b/graphene/contrib/django/compat.py @@ -0,0 +1,15 @@ +from django.db import models + +try: + UUIDField = models.UUIDField +except AttributeError: + # Improved compatibility for Django 1.6 + class UUIDField(object): + pass + +try: + from django.db.models.related import RelatedObject +except: + # Improved compatibility for Django 1.6 + class RelatedObject(object): + pass diff --git a/graphene/contrib/django/converter.py b/graphene/contrib/django/converter.py index 255ec8a5..ffdcfc5a 100644 --- a/graphene/contrib/django/converter.py +++ b/graphene/contrib/django/converter.py @@ -1,14 +1,10 @@ from django.db import models -from singledispatch import singledispatch from ...core.types.scalars import ID, Boolean, Float, Int, String +from .compat import RelatedObject, UUIDField +from .utils import get_related_model, import_single_dispatch -try: - UUIDField = models.UUIDField -except AttributeError: - # Improved compatibility for Django 1.6 - class UUIDField(object): - pass +singledispatch = import_single_dispatch() @singledispatch @@ -24,6 +20,7 @@ def convert_django_field(field): @convert_django_field.register(models.EmailField) @convert_django_field.register(models.SlugField) @convert_django_field.register(models.URLField) +@convert_django_field.register(models.GenericIPAddressField) @convert_django_field.register(UUIDField) def convert_field_to_string(field): return String(description=field.help_text) @@ -63,7 +60,15 @@ def convert_field_to_float(field): @convert_django_field.register(models.ManyToOneRel) def convert_field_to_list_or_connection(field): from .fields import DjangoModelField, ConnectionOrListField - model_field = DjangoModelField(field.related_model) + model_field = DjangoModelField(get_related_model(field)) + return ConnectionOrListField(model_field) + + +# For Django 1.6 +@convert_django_field.register(RelatedObject) +def convert_relatedfield_to_djangomodel(field): + from .fields import DjangoModelField, ConnectionOrListField + model_field = DjangoModelField(field.model) return ConnectionOrListField(model_field) @@ -71,4 +76,4 @@ def convert_field_to_list_or_connection(field): @convert_django_field.register(models.ForeignKey) def convert_field_to_djangomodel(field): from .fields import DjangoModelField - return DjangoModelField(field.related_model, description=field.help_text) + return DjangoModelField(get_related_model(field), description=field.help_text) diff --git a/graphene/contrib/django/debug/__init__.py b/graphene/contrib/django/debug/__init__.py new file mode 100644 index 00000000..4c76aeca --- /dev/null +++ b/graphene/contrib/django/debug/__init__.py @@ -0,0 +1,4 @@ +from .plugin import DjangoDebugPlugin +from .types import DjangoDebug + +__all__ = ['DjangoDebugPlugin', 'DjangoDebug'] diff --git a/graphene/contrib/django/debug/plugin.py b/graphene/contrib/django/debug/plugin.py new file mode 100644 index 00000000..70cd6741 --- /dev/null +++ b/graphene/contrib/django/debug/plugin.py @@ -0,0 +1,79 @@ +from contextlib import contextmanager + +from django.db import connections + +from ....core.schema import GraphQLSchema +from ....core.types import Field +from ....plugins import Plugin +from .sql.tracking import unwrap_cursor, wrap_cursor +from .sql.types import DjangoDebugSQL +from .types import DjangoDebug + + +class WrappedRoot(object): + + def __init__(self, root): + self._recorded = [] + self._root = root + + def record(self, **log): + self._recorded.append(DjangoDebugSQL(**log)) + + def debug(self): + return DjangoDebug(sql=self._recorded) + + +class WrapRoot(object): + + @property + def _root(self): + return self._wrapped_root.root + + @_root.setter + def _root(self, value): + self._wrapped_root = value + + def resolve_debug(self, args, info): + return self._wrapped_root.debug() + + +def debug_objecttype(objecttype): + return type( + 'Debug{}'.format(objecttype._meta.type_name), + (WrapRoot, objecttype), + {'debug': Field(DjangoDebug, name='__debug')}) + + +class DjangoDebugPlugin(Plugin): + + def enable_instrumentation(self, wrapped_root): + # This is thread-safe because database connections are thread-local. + for connection in connections.all(): + wrap_cursor(connection, wrapped_root) + + def disable_instrumentation(self): + for connection in connections.all(): + unwrap_cursor(connection) + + def wrap_schema(self, schema_type): + query = schema_type._query + if query: + class_type = self.schema.objecttype(schema_type.get_query_type()) + assert class_type, 'The query in schema is not constructed with graphene' + _type = debug_objecttype(class_type) + self.schema.register(_type, force=True) + return GraphQLSchema( + self.schema, + self.schema.T(_type), + schema_type.get_mutation_type(), + schema_type.get_subscription_type() + ) + return schema_type + + @contextmanager + def context_execution(self, executor): + executor['root'] = WrappedRoot(root=executor['root']) + executor['schema'] = self.wrap_schema(executor['schema']) + self.enable_instrumentation(executor['root']) + yield executor + self.disable_instrumentation() diff --git a/graphene/contrib/django/tests/filter/__init__.py b/graphene/contrib/django/debug/sql/__init__.py similarity index 100% rename from graphene/contrib/django/tests/filter/__init__.py rename to graphene/contrib/django/debug/sql/__init__.py diff --git a/graphene/contrib/django/debug/sql/tracking.py b/graphene/contrib/django/debug/sql/tracking.py new file mode 100644 index 00000000..8ed40492 --- /dev/null +++ b/graphene/contrib/django/debug/sql/tracking.py @@ -0,0 +1,165 @@ +# Code obtained from django-debug-toolbar sql panel tracking +from __future__ import absolute_import, unicode_literals + +import json +from threading import local +from time import time + +from django.utils import six +from django.utils.encoding import force_text + + +class SQLQueryTriggered(Exception): + """Thrown when template panel triggers a query""" + + +class ThreadLocalState(local): + + def __init__(self): + self.enabled = True + + @property + def Wrapper(self): + if self.enabled: + return NormalCursorWrapper + return ExceptionCursorWrapper + + def recording(self, v): + self.enabled = v + + +state = ThreadLocalState() +recording = state.recording # export function + + +def wrap_cursor(connection, panel): + if not hasattr(connection, '_djdt_cursor'): + connection._djdt_cursor = connection.cursor + + def cursor(): + return state.Wrapper(connection._djdt_cursor(), connection, panel) + + connection.cursor = cursor + return cursor + + +def unwrap_cursor(connection): + if hasattr(connection, '_djdt_cursor'): + del connection._djdt_cursor + del connection.cursor + + +class ExceptionCursorWrapper(object): + """ + Wraps a cursor and raises an exception on any operation. + Used in Templates panel. + """ + + def __init__(self, cursor, db, logger): + pass + + def __getattr__(self, attr): + raise SQLQueryTriggered() + + +class NormalCursorWrapper(object): + """ + Wraps a cursor and logs queries. + """ + + def __init__(self, cursor, db, logger): + self.cursor = cursor + # Instance of a BaseDatabaseWrapper subclass + self.db = db + # logger must implement a ``record`` method + self.logger = logger + + def _quote_expr(self, element): + if isinstance(element, six.string_types): + return "'%s'" % force_text(element).replace("'", "''") + else: + return repr(element) + + def _quote_params(self, params): + if not params: + return params + if isinstance(params, dict): + return dict((key, self._quote_expr(value)) + for key, value in params.items()) + return list(map(self._quote_expr, params)) + + def _decode(self, param): + try: + return force_text(param, strings_only=True) + except UnicodeDecodeError: + return '(encoded string)' + + def _record(self, method, sql, params): + start_time = time() + try: + return method(sql, params) + finally: + stop_time = time() + duration = (stop_time - start_time) + _params = '' + try: + _params = json.dumps(list(map(self._decode, params))) + except Exception: + pass # object not JSON serializable + + alias = getattr(self.db, 'alias', 'default') + conn = self.db.connection + vendor = getattr(conn, 'vendor', 'unknown') + + params = { + 'vendor': vendor, + 'alias': alias, + 'sql': self.db.ops.last_executed_query( + self.cursor, sql, self._quote_params(params)), + 'duration': duration, + 'raw_sql': sql, + 'params': _params, + 'start_time': start_time, + 'stop_time': stop_time, + 'is_slow': duration > 10, + 'is_select': sql.lower().strip().startswith('select'), + } + + if vendor == 'postgresql': + # If an erroneous query was ran on the connection, it might + # be in a state where checking isolation_level raises an + # exception. + try: + iso_level = conn.isolation_level + except conn.InternalError: + iso_level = 'unknown' + params.update({ + 'trans_id': self.logger.get_transaction_id(alias), + 'trans_status': conn.get_transaction_status(), + 'iso_level': iso_level, + 'encoding': conn.encoding, + }) + + # We keep `sql` to maintain backwards compatibility + self.logger.record(**params) + + def callproc(self, procname, params=()): + return self._record(self.cursor.callproc, procname, params) + + def execute(self, sql, params=()): + return self._record(self.cursor.execute, sql, params) + + def executemany(self, sql, param_list): + return self._record(self.cursor.executemany, sql, param_list) + + def __getattr__(self, attr): + return getattr(self.cursor, attr) + + def __iter__(self): + return iter(self.cursor) + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.close() diff --git a/graphene/contrib/django/debug/sql/types.py b/graphene/contrib/django/debug/sql/types.py new file mode 100644 index 00000000..995aeaa2 --- /dev/null +++ b/graphene/contrib/django/debug/sql/types.py @@ -0,0 +1,19 @@ +from .....core import Boolean, Float, ObjectType, String + + +class DjangoDebugSQL(ObjectType): + vendor = String() + alias = String() + sql = String() + duration = Float() + raw_sql = String() + params = String() + start_time = Float() + stop_time = Float() + is_slow = Boolean() + is_select = Boolean() + + trans_id = String() + trans_status = String() + iso_level = String() + encoding = String() diff --git a/graphene/contrib/django/debug/tests/__init__.py b/graphene/contrib/django/debug/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphene/contrib/django/debug/tests/test_query.py b/graphene/contrib/django/debug/tests/test_query.py new file mode 100644 index 00000000..4df26e4f --- /dev/null +++ b/graphene/contrib/django/debug/tests/test_query.py @@ -0,0 +1,70 @@ +import pytest + +import graphene +from graphene.contrib.django import DjangoObjectType + +from ...tests.models import Reporter +from ..plugin import DjangoDebugPlugin + +# from examples.starwars_django.models import Character + +pytestmark = pytest.mark.django_db + + +def test_should_query_well(): + r1 = Reporter(last_name='ABA') + r1.save() + r2 = Reporter(last_name='Griffin') + r2.save() + + class ReporterType(DjangoObjectType): + + class Meta: + model = Reporter + + class Query(graphene.ObjectType): + reporter = graphene.Field(ReporterType) + all_reporters = ReporterType.List() + + def resolve_all_reporters(self, *args, **kwargs): + return Reporter.objects.all() + + def resolve_reporter(self, *args, **kwargs): + return Reporter.objects.first() + + query = ''' + query ReporterQuery { + reporter { + lastName + } + allReporters { + lastName + } + __debug { + sql { + rawSql + } + } + } + ''' + expected = { + 'reporter': { + 'lastName': 'ABA', + }, + 'allReporters': [{ + 'lastName': 'ABA', + }, { + 'lastName': 'Griffin', + }], + '__debug': { + 'sql': [{ + 'rawSql': str(Reporter.objects.order_by('pk')[:1].query) + }, { + 'rawSql': str(Reporter.objects.all().query) + }] + } + } + schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()]) + result = schema.execute(query) + assert not result.errors + assert result.data == expected diff --git a/graphene/contrib/django/debug/types.py b/graphene/contrib/django/debug/types.py new file mode 100644 index 00000000..bceb54b0 --- /dev/null +++ b/graphene/contrib/django/debug/types.py @@ -0,0 +1,7 @@ +from ....core.classtypes.objecttype import ObjectType +from ....core.types import Field +from .sql.types import DjangoDebugSQL + + +class DjangoDebug(ObjectType): + sql = Field(DjangoDebugSQL.List()) diff --git a/graphene/contrib/django/fields.py b/graphene/contrib/django/fields.py index 16568883..d9d6f3da 100644 --- a/graphene/contrib/django/fields.py +++ b/graphene/contrib/django/fields.py @@ -1,12 +1,13 @@ import warnings -from .utils import get_type_for_model from ...core.exceptions import SkipField from ...core.fields import Field from ...core.types.base import FieldType from ...core.types.definitions import List from ...relay import ConnectionField from ...relay.utils import is_node +from .filter.fields import DjangoFilterConnectionField +from .utils import get_type_for_model class DjangoConnectionField(ConnectionField): @@ -27,7 +28,10 @@ class ConnectionOrListField(Field): if not field_object_type: raise SkipField() if is_node(field_object_type): - field = DjangoConnectionField(field_object_type) + if field_object_type._meta.filter_fields: + field = DjangoFilterConnectionField(field_object_type) + else: + field = ConnectionField(field_object_type) else: field = Field(List(field_object_type)) field.contribute_to_class(self.object_type, self.attname) diff --git a/graphene/contrib/django/filter/__init__.py b/graphene/contrib/django/filter/__init__.py index 21e65b56..95b28aff 100644 --- a/graphene/contrib/django/filter/__init__.py +++ b/graphene/contrib/django/filter/__init__.py @@ -1,3 +1,11 @@ +from graphene.contrib.django.utils import DJANGO_FILTER_INSTALLED + +if not DJANGO_FILTER_INSTALLED: + raise Exception( + "Use of django filtering requires the django-filter package " + "be installed. You can do so using `pip install django-filter`" + ) + from .fields import DjangoFilterConnectionField from .filterset import GrapheneFilterSet, GlobalIDFilter, GlobalIDMultipleChoiceFilter from .resolvers import FilterConnectionResolver diff --git a/graphene/contrib/django/filter/fields.py b/graphene/contrib/django/filter/fields.py index 8c402fbe..43196f6e 100644 --- a/graphene/contrib/django/filter/fields.py +++ b/graphene/contrib/django/filter/fields.py @@ -1,9 +1,9 @@ -from graphene.contrib.django import DjangoConnectionField from graphene.contrib.django.filter.resolvers import FilterConnectionResolver from graphene.contrib.django.utils import get_filtering_args_from_filterset +from graphene.relay import ConnectionField -class DjangoFilterConnectionField(DjangoConnectionField): +class DjangoFilterConnectionField(ConnectionField): def __init__(self, type, on=None, fields=None, order_by=None, extra_filter_meta=None, filterset_class=None, resolver=None, diff --git a/graphene/contrib/django/filter/filterset.py b/graphene/contrib/django/filter/filterset.py index c71478ee..3ecd9680 100644 --- a/graphene/contrib/django/filter/filterset.py +++ b/graphene/contrib/django/filter/filterset.py @@ -2,11 +2,12 @@ import six from django.conf import settings from django.db import models from django.utils.text import capfirst -from django_filters import Filter, MultipleChoiceFilter -from django_filters.filterset import FilterSetMetaclass, FilterSet -from graphql_relay.node.node import from_global_id -from graphene.contrib.django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField +from django_filters import Filter, MultipleChoiceFilter +from django_filters.filterset import FilterSet, FilterSetMetaclass +from graphene.contrib.django.forms import (GlobalIDFormField, + GlobalIDMultipleChoiceField) +from graphql_relay.node.node import from_global_id class GlobalIDFilter(Filter): @@ -25,7 +26,7 @@ class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter): return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids) -ORDER_BY_FIELD = getattr(settings, 'GRAPHENE_ORDER_BY_FIELD', 'order') +ORDER_BY_FIELD = getattr(settings, 'GRAPHENE_ORDER_BY_FIELD', 'order_by') GRAPHENE_FILTER_SET_OVERRIDES = { @@ -45,6 +46,7 @@ GRAPHENE_FILTER_SET_OVERRIDES = { class GrapheneFilterSetMetaclass(FilterSetMetaclass): + def __new__(cls, name, bases, attrs): new_class = super(GrapheneFilterSetMetaclass, cls).__new__(cls, name, bases, attrs) # Customise the filter_overrides for Graphene @@ -84,7 +86,6 @@ class GrapheneFilterSet(six.with_metaclass(GrapheneFilterSetMetaclass, GrapheneF DjangoFilterConnectionField will wrap FilterSets with this class as necessary """ - pass def setup_filterset(filterset_class): diff --git a/graphene/contrib/django/filter/resolvers.py b/graphene/contrib/django/filter/resolvers.py index 5f696a08..76b3e7ad 100644 --- a/graphene/contrib/django/filter/resolvers.py +++ b/graphene/contrib/django/filter/resolvers.py @@ -1,6 +1,7 @@ from django.core.exceptions import ImproperlyConfigured -from graphene.contrib.django.filter.filterset import setup_filterset, custom_filterset_factory +from graphene.contrib.django.filter.filterset import (custom_filterset_factory, + setup_filterset) from graphene.contrib.django.resolvers import BaseQuerySetConnectionResolver @@ -10,8 +11,8 @@ class FilterConnectionResolver(BaseQuerySetConnectionResolver): def __init__(self, node, on=None, filterset_class=None, fields=None, order_by=None, extra_filter_meta=None): self.filterset_class = filterset_class - self.fields = fields - self.order_by = order_by + self.fields = fields or node._meta.filter_fields + self.order_by = order_by or node._meta.filter_order_by self.extra_filter_meta = extra_filter_meta or {} self._filterset_class = None super(FilterConnectionResolver, self).__init__(node, on) diff --git a/graphene/contrib/django/filter/tests/__init__.py b/graphene/contrib/django/filter/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphene/contrib/django/tests/filter/filters.py b/graphene/contrib/django/filter/tests/filters.py similarity index 83% rename from graphene/contrib/django/tests/filter/filters.py rename to graphene/contrib/django/filter/tests/filters.py index 4549a83e..94c0dffe 100644 --- a/graphene/contrib/django/tests/filter/filters.py +++ b/graphene/contrib/django/filter/tests/filters.py @@ -1,7 +1,5 @@ import django_filters - -from graphene.contrib.django.tests.models import Reporter -from graphene.contrib.django.tests.models import Article, Pet +from graphene.contrib.django.tests.models import Article, Pet, Reporter class ArticleFilter(django_filters.FilterSet): diff --git a/graphene/contrib/django/tests/filter/test_fields.py b/graphene/contrib/django/filter/tests/test_fields.py similarity index 72% rename from graphene/contrib/django/tests/filter/test_fields.py rename to graphene/contrib/django/filter/tests/test_fields.py index fc99a273..45c1f0d0 100644 --- a/graphene/contrib/django/tests/filter/test_fields.py +++ b/graphene/contrib/django/filter/tests/test_fields.py @@ -1,55 +1,66 @@ import pytest -try: +from graphene import ObjectType, Schema +from graphene.contrib.django import DjangoNode +from graphene.contrib.django.forms import (GlobalIDFormField, + GlobalIDMultipleChoiceField) +from graphene.contrib.django.tests.models import Article, Pet, Reporter +from graphene.contrib.django.utils import DJANGO_FILTER_INSTALLED +from graphene.relay import NodeField + +pytestmark = [] +if DJANGO_FILTER_INSTALLED: import django_filters -except ImportError: - pytestmark = pytest.mark.skipif(True, reason='django_filters not installed') -else: from graphene.contrib.django.filter import (GlobalIDFilter, DjangoFilterConnectionField, GlobalIDMultipleChoiceFilter) - from graphene.contrib.django.tests.filter.filters import ArticleFilter, PetFilter + from graphene.contrib.django.filter.tests.filters import ArticleFilter, PetFilter +else: + pytestmark.append(pytest.mark.skipif(True, reason='django_filters not installed')) -from graphene.contrib.django import DjangoNode -from graphene.contrib.django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField -from graphene.contrib.django.tests.models import Article, Pet, Reporter +pytestmark.append(pytest.mark.django_db) class ArticleNode(DjangoNode): + class Meta: model = Article class ReporterNode(DjangoNode): + class Meta: model = Reporter class PetNode(DjangoNode): + class Meta: model = Pet +schema = Schema() + def assert_arguments(field, *arguments): - ignore = ('after', 'before', 'first', 'last', 'order') + ignore = ('after', 'before', 'first', 'last', 'orderBy') actual = [ name - for name in field.arguments.arguments.keys() + for name in schema.T(field.arguments) if name not in ignore and not name.startswith('_') ] assert set(arguments) == set(actual), \ 'Expected arguments ({}) did not match actual ({})'.format( arguments, actual - ) + ) def assert_orderable(field): - assert 'order' in field.arguments.arguments.keys(), \ + assert 'orderBy' in schema.T(field.arguments), \ 'Field cannot be ordered' def assert_not_orderable(field): - assert 'order' in field.arguments.arguments.keys(), \ + assert 'orderBy' not in schema.T(field.arguments), \ 'Field can be ordered' @@ -103,11 +114,52 @@ def test_filter_explicit_filterset_not_orderable(): def test_filter_shortcut_filterset_extra_meta(): field = DjangoFilterConnectionField(ArticleNode, extra_filter_meta={ - 'ordering': True + 'order_by': True }) assert_orderable(field) +def test_filter_filterset_information_on_meta(): + class ReporterFilterNode(DjangoNode): + + class Meta: + model = Reporter + filter_fields = ['first_name', 'articles'] + filter_order_by = True + + field = DjangoFilterConnectionField(ReporterFilterNode) + assert_arguments(field, 'firstName', 'articles') + assert_orderable(field) + + +def test_filter_filterset_information_on_meta_related(): + class ReporterFilterNode(DjangoNode): + + class Meta: + model = Reporter + filter_fields = ['first_name', 'articles'] + filter_order_by = True + + class ArticleFilterNode(DjangoNode): + + class Meta: + model = Article + filter_fields = ['headline', 'reporter'] + filter_order_by = True + + class Query(ObjectType): + all_reporters = DjangoFilterConnectionField(ReporterFilterNode) + all_articles = DjangoFilterConnectionField(ArticleFilterNode) + reporter = NodeField(ReporterFilterNode) + article = NodeField(ArticleFilterNode) + + schema = Schema(query=Query) + schema.schema # Trigger the schema loading + articles_field = schema.get_type('ReporterFilterNode')._meta.fields_map['articles'] + assert_arguments(articles_field, 'headline', 'reporter') + assert_orderable(articles_field) + + def test_global_id_field_implicit(): field = DjangoFilterConnectionField(ArticleNode, fields=['id']) filterset_class = field.resolver_fn.get_filterset_class() @@ -118,6 +170,7 @@ def test_global_id_field_implicit(): def test_global_id_field_explicit(): class ArticleIdFilter(django_filters.FilterSet): + class Meta: model = Article fields = ['id'] @@ -147,6 +200,7 @@ def test_global_id_multiple_field_implicit(): def test_global_id_multiple_field_explicit(): class ReporterPetsFilter(django_filters.FilterSet): + class Meta: model = Reporter fields = ['pets'] @@ -158,9 +212,6 @@ def test_global_id_multiple_field_explicit(): assert multiple_filter.field_class == GlobalIDMultipleChoiceField -@pytest.mark.skipif(True, reason="Trying to test GrapheneFilterSetMixin.filter_for_reverse_field" - "but django has not loaded the models, so the test fails as " - "reverse relations are not ready yet") def test_global_id_multiple_field_implicit_reverse(): field = DjangoFilterConnectionField(ReporterNode, fields=['articles']) filterset_class = field.resolver_fn.get_filterset_class() @@ -169,11 +220,9 @@ def test_global_id_multiple_field_implicit_reverse(): assert multiple_filter.field_class == GlobalIDMultipleChoiceField -@pytest.mark.skipif(True, reason="Trying to test GrapheneFilterSetMixin.filter_for_reverse_field" - "but django has not loaded the models, so the test fails as " - "reverse relations are not ready yet") def test_global_id_multiple_field_explicit_reverse(): class ReporterPetsFilter(django_filters.FilterSet): + class Meta: model = Reporter fields = ['articles'] diff --git a/graphene/contrib/django/tests/filter/test_resolvers.py b/graphene/contrib/django/filter/tests/test_resolvers.py similarity index 86% rename from graphene/contrib/django/tests/filter/test_resolvers.py rename to graphene/contrib/django/filter/tests/test_resolvers.py index af8bfc47..670e87c8 100644 --- a/graphene/contrib/django/tests/filter/test_resolvers.py +++ b/graphene/contrib/django/filter/tests/test_resolvers.py @@ -1,16 +1,16 @@ import pytest from django.core.exceptions import ImproperlyConfigured -try: - import django_filters # noqa -except ImportError: - pytestmark = pytest.mark.skipif(True, reason='django_filters not installed') -else: - from graphene.contrib.django.filter.resolvers import FilterConnectionResolver - from graphene.contrib.django.tests.filter.filters import ReporterFilter, ArticleFilter +from graphene.contrib.django.tests.models import Article, Reporter +from graphene.contrib.django.tests.test_resolvers import (ArticleNode, + ReporterNode) +from graphene.contrib.django.utils import DJANGO_FILTER_INSTALLED -from graphene.contrib.django.tests.models import Reporter, Article -from graphene.contrib.django.tests.test_resolvers import ReporterNode, ArticleNode +if DJANGO_FILTER_INSTALLED: + from graphene.contrib.django.filter.resolvers import FilterConnectionResolver + from graphene.contrib.django.filter.tests.filters import ArticleFilter, ReporterFilter +else: + pytestmark = pytest.mark.skipif(True, reason='django_filters not installed') def test_filter_get_filterset_class_explicit(): @@ -64,7 +64,7 @@ def test_filter_order(): resolver = FilterConnectionResolver(ArticleNode, filterset_class=ArticleFilter) resolved = resolver(inst=article, args={ - 'order': 'headline' + 'order_by': 'headline' }, info=None) assert 'WHERE' not in str(resolved.query) assert 'ORDER BY' in str(resolved.query) @@ -76,7 +76,7 @@ def test_filter_order_not_available(): resolver = FilterConnectionResolver(ReporterNode, filterset_class=ReporterFilter) resolved = resolver(inst=reporter, args={ - 'order': 'last_name' + 'order_by': 'last_name' }, info=None) assert 'WHERE' not in str(resolved.query) assert 'ORDER BY' not in str(resolved.query) diff --git a/graphene/contrib/django/form_converter.py b/graphene/contrib/django/form_converter.py index f5acf202..de2a40d8 100644 --- a/graphene/contrib/django/form_converter.py +++ b/graphene/contrib/django/form_converter.py @@ -1,11 +1,14 @@ from django import forms from django.forms.fields import BaseTemporalField -from singledispatch import singledispatch -from graphene import String, Int, Boolean, Float, ID -from graphene.contrib.django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField +from graphene import ID, Boolean, Float, Int, String +from graphene.contrib.django.forms import (GlobalIDFormField, + GlobalIDMultipleChoiceField) +from graphene.contrib.django.utils import import_single_dispatch from graphene.core.types.definitions import List +singledispatch = import_single_dispatch() + try: UUIDField = forms.UUIDField except AttributeError: @@ -60,11 +63,11 @@ def convert_form_field_to_float(field): @convert_form_field.register(forms.ModelMultipleChoiceField) @convert_form_field.register(GlobalIDMultipleChoiceField) -def convert_form_field_to_list_or_connection(field): +def convert_form_field_to_list(field): return List(ID()) @convert_form_field.register(forms.ModelChoiceField) @convert_form_field.register(GlobalIDFormField) -def convert_form_field_to_djangomodel(field): +def convert_form_field_to_id(field): return ID() diff --git a/graphene/contrib/django/forms.py b/graphene/contrib/django/forms.py index f971897b..88f1665e 100644 --- a/graphene/contrib/django/forms.py +++ b/graphene/contrib/django/forms.py @@ -1,7 +1,7 @@ import binascii from django.core.exceptions import ValidationError -from django.forms import Field, IntegerField, CharField, MultipleChoiceField +from django.forms import CharField, Field, IntegerField, MultipleChoiceField from django.utils.translation import ugettext_lazy as _ from graphql_relay import from_global_id diff --git a/graphene/contrib/django/management/__init__.py b/graphene/contrib/django/management/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphene/contrib/django/management/commands/__init__.py b/graphene/contrib/django/management/commands/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphene/contrib/django/management/commands/graphql_schema.py b/graphene/contrib/django/management/commands/graphql_schema.py new file mode 100644 index 00000000..35eb2772 --- /dev/null +++ b/graphene/contrib/django/management/commands/graphql_schema.py @@ -0,0 +1,38 @@ +from django.core.management.base import BaseCommand, CommandError + +import importlib +import json + + +class Command(BaseCommand): + help = 'Dump Graphene schema JSON to file' + can_import_settings = True + + def add_arguments(self, parser): + from django.conf import settings + parser.add_argument( + '--schema', + type=str, + dest='schema', + default=getattr(settings, 'GRAPHENE_SCHEMA', ''), + help='Django app containing schema to dump, e.g. myproject.core.schema') + + parser.add_argument( + '--out', + type=str, + dest='out', + default=getattr(settings, 'GRAPHENE_SCHEMA_OUTPUT', 'schema.json'), + help='Output file (default: schema.json)') + + def handle(self, *args, **options): + schema_module = options['schema'] + if schema_module == '': + raise CommandError('Specify schema on GRAPHENE_SCHEMA setting or by using --schema') + i = importlib.import_module(schema_module) + + schema_dict = {'data': i.schema.introspect()} + + with open(options['out'], 'w') as outfile: + json.dump(schema_dict, outfile) + + self.stdout.write(self.style.SUCCESS('Successfully dumped GraphQL schema to %s' % options['out'])) diff --git a/graphene/contrib/django/options.py b/graphene/contrib/django/options.py index 61dd37a3..dbd88aca 100644 --- a/graphene/contrib/django/options.py +++ b/graphene/contrib/django/options.py @@ -1,9 +1,13 @@ from ...core.classtypes.objecttype import ObjectTypeOptions from ...relay.types import Node from ...relay.utils import is_node +from .utils import DJANGO_FILTER_INSTALLED VALID_ATTRS = ('model', 'only_fields', 'exclude_fields') +if DJANGO_FILTER_INSTALLED: + VALID_ATTRS += ('filter_fields', 'filter_order_by') + class DjangoOptions(ObjectTypeOptions): @@ -13,6 +17,8 @@ class DjangoOptions(ObjectTypeOptions): self.valid_attrs += VALID_ATTRS self.only_fields = None self.exclude_fields = [] + self.filter_fields = None + self.filter_order_by = None def contribute_to_class(self, cls, name): super(DjangoOptions, self).contribute_to_class(cls, name) diff --git a/graphene/contrib/django/resolvers.py b/graphene/contrib/django/resolvers.py index 0499acc5..a5494bfb 100644 --- a/graphene/contrib/django/resolvers.py +++ b/graphene/contrib/django/resolvers.py @@ -36,8 +36,8 @@ class SimpleQuerySetConnectionResolver(BaseQuerySetConnectionResolver): return query def get_filter_kwargs(self): - ignore = ['first', 'last', 'before', 'after', 'order'] + ignore = ['first', 'last', 'before', 'after', 'order_by'] return {k: v for k, v in self.args.items() if k not in ignore} def get_order(self): - return self.args.get('order', None) + return self.args.get('order_by', None) diff --git a/graphene/contrib/django/tests/models.py b/graphene/contrib/django/tests/models.py index 172352e4..a4ff3386 100644 --- a/graphene/contrib/django/tests/models.py +++ b/graphene/contrib/django/tests/models.py @@ -16,9 +16,6 @@ class Reporter(models.Model): def __str__(self): # __unicode__ on Python 2 return "%s %s" % (self.first_name, self.last_name) - class Meta: - app_label = 'contrib_django' - class Article(models.Model): headline = models.CharField(max_length=100) @@ -30,4 +27,3 @@ class Article(models.Model): class Meta: ordering = ('headline',) - app_label = 'contrib_django' diff --git a/graphene/contrib/django/tests/test_converter.py b/graphene/contrib/django/tests/test_converter.py index d868ec97..59f3aa29 100644 --- a/graphene/contrib/django/tests/test_converter.py +++ b/graphene/contrib/django/tests/test_converter.py @@ -9,8 +9,8 @@ from graphene.contrib.django.fields import (ConnectionOrListField, from .models import Article, Reporter -def assert_conversion(django_field, graphene_field, *args): - field = django_field(*args, help_text='Custom Help Text') +def assert_conversion(django_field, graphene_field, *args, **kwargs): + field = django_field(help_text='Custom Help Text', *args, **kwargs) graphene_type = convert_django_field(field) assert isinstance(graphene_type, graphene_field) field = graphene_type.as_field() @@ -48,8 +48,12 @@ def test_should_url_convert_string(): assert_conversion(models.URLField, graphene.String) +def test_should_ipaddress_convert_string(): + assert_conversion(models.GenericIPAddressField, graphene.String) + + def test_should_auto_convert_id(): - assert_conversion(models.AutoField, graphene.ID) + assert_conversion(models.AutoField, graphene.ID, primary_key=True) def test_should_positive_integer_convert_int(): @@ -94,7 +98,10 @@ def test_should_manytomany_convert_connectionorlist(): def test_should_manytoone_convert_connectionorlist(): - graphene_type = convert_django_field(Reporter.articles.related) + # Django 1.9 uses 'rel', <1.9 uses 'related + related = getattr(Reporter.articles, 'rel', None) or \ + getattr(Reporter.articles, 'related') + graphene_type = convert_django_field(related) assert isinstance(graphene_type, ConnectionOrListField) assert isinstance(graphene_type.type, DjangoModelField) assert graphene_type.type.model == Article diff --git a/graphene/contrib/django/tests/test_form_converter.py b/graphene/contrib/django/tests/test_form_converter.py index 7492fc51..44d9bec3 100644 --- a/graphene/contrib/django/tests/test_form_converter.py +++ b/graphene/contrib/django/tests/test_form_converter.py @@ -1,10 +1,9 @@ from django import forms -from graphene.core.types import List, ID from py.test import raises import graphene from graphene.contrib.django.form_converter import convert_form_field - +from graphene.core.types import ID, List from .models import Reporter diff --git a/graphene/contrib/django/tests/test_query.py b/graphene/contrib/django/tests/test_query.py index 090c8695..460c8e22 100644 --- a/graphene/contrib/django/tests/test_query.py +++ b/graphene/contrib/django/tests/test_query.py @@ -1,3 +1,4 @@ +import pytest from py.test import raises import graphene @@ -6,6 +7,8 @@ from graphene.contrib.django import DjangoNode, DjangoObjectType from .models import Article, Reporter +pytestmark = pytest.mark.django_db + def test_should_query_only_fields(): with raises(Exception): diff --git a/graphene/contrib/django/tests/test_resolvers.py b/graphene/contrib/django/tests/test_resolvers.py index 38d98aba..db1610c9 100644 --- a/graphene/contrib/django/tests/test_resolvers.py +++ b/graphene/contrib/django/tests/test_resolvers.py @@ -3,15 +3,17 @@ from django.db.models.query import QuerySet from graphene.contrib.django import DjangoNode from graphene.contrib.django.resolvers import SimpleQuerySetConnectionResolver -from graphene.contrib.django.tests.models import Reporter, Article +from graphene.contrib.django.tests.models import Article, Reporter class ReporterNode(DjangoNode): + class Meta: model = Reporter class ArticleNode(DjangoNode): + class Meta: model = Article @@ -34,7 +36,7 @@ def test_simple_get_manager_all(): reporter = Reporter(id=1, first_name='Cookie Monster') resolver = SimpleQuerySetConnectionResolver(ReporterNode) resolver(inst=reporter, args={}, info=None) - assert type(resolver.get_manager()) == Manager, 'Resolver did not return a Manager' + assert isinstance(resolver.get_manager(), Manager), 'Resolver did not return a Manager' def test_simple_filter(): @@ -51,7 +53,7 @@ def test_simple_order(): reporter = Reporter(id=1, first_name='Cookie Monster') resolver = SimpleQuerySetConnectionResolver(ReporterNode) resolved = resolver(inst=reporter, args={ - 'order': 'last_name' + 'order_by': 'last_name' }, info=None) assert 'WHERE' not in str(resolved.query) assert 'ORDER BY' in str(resolved.query) diff --git a/graphene/contrib/django/tests/test_urls.py b/graphene/contrib/django/tests/test_urls.py index 409471d4..9d38980e 100644 --- a/graphene/contrib/django/tests/test_urls.py +++ b/graphene/contrib/django/tests/test_urls.py @@ -29,7 +29,15 @@ class Human(DjangoNode): def get_node(self, id): pass -schema = Schema(query=Human) + +class Query(graphene.ObjectType): + human = graphene.Field(Human) + + def resolve_human(self, args, info): + return Human() + + +schema = Schema(query=Query) urlpatterns = [ diff --git a/graphene/contrib/django/tests/test_views.py b/graphene/contrib/django/tests/test_views.py index f82be99f..b4e1b367 100644 --- a/graphene/contrib/django/tests/test_views.py +++ b/graphene/contrib/django/tests/test_views.py @@ -7,11 +7,13 @@ def format_response(response): def test_client_get_good_query(settings, client): settings.ROOT_URLCONF = 'graphene.contrib.django.tests.test_urls' - response = client.get('/graphql', {'query': '{ headline }'}) + response = client.get('/graphql', {'query': '{ human { headline } }'}) json_response = format_response(response) expected_json = { 'data': { - 'headline': None + 'human': { + 'headline': None + } } } assert json_response == expected_json @@ -19,20 +21,22 @@ def test_client_get_good_query(settings, client): def test_client_get_good_query_with_raise(settings, client): settings.ROOT_URLCONF = 'graphene.contrib.django.tests.test_urls' - response = client.get('/graphql', {'query': '{ raises }'}) + response = client.get('/graphql', {'query': '{ human { raises } }'}) json_response = format_response(response) assert json_response['errors'][0]['message'] == 'This field should raise exception' - assert json_response['data']['raises'] is None + assert json_response['data']['human']['raises'] is None def test_client_post_good_query_json(settings, client): settings.ROOT_URLCONF = 'graphene.contrib.django.tests.test_urls' response = client.post( - '/graphql', json.dumps({'query': '{ headline }'}), 'application/json') + '/graphql', json.dumps({'query': '{ human { headline } }'}), 'application/json') json_response = format_response(response) expected_json = { 'data': { - 'headline': None + 'human': { + 'headline': None + } } } assert json_response == expected_json @@ -41,11 +45,13 @@ def test_client_post_good_query_json(settings, client): def test_client_post_good_query_graphql(settings, client): settings.ROOT_URLCONF = 'graphene.contrib.django.tests.test_urls' response = client.post( - '/graphql', '{ headline }', 'application/graphql') + '/graphql', '{ human { headline } }', 'application/graphql') json_response = format_response(response) expected_json = { 'data': { - 'headline': None + 'human': { + 'headline': None + } } } assert json_response == expected_json diff --git a/graphene/contrib/django/utils.py b/graphene/contrib/django/utils.py index 2b4519fc..76f4477c 100644 --- a/graphene/contrib/django/utils.py +++ b/graphene/contrib/django/utils.py @@ -1,9 +1,18 @@ import six from django.db import models from django.db.models.manager import Manager +from django.db.models.query import QuerySet from graphene import Argument, String -from graphene.contrib.django.form_converter import convert_form_field +from graphene.utils import LazyList + +from .compat import RelatedObject + +try: + import django_filters # noqa + DJANGO_FILTER_INSTALLED = True +except ImportError: + DJANGO_FILTER_INSTALLED = False def get_type_for_model(schema, model): @@ -18,14 +27,32 @@ def get_type_for_model(schema, model): def get_reverse_fields(model): for name, attr in model.__dict__.items(): - related = getattr(attr, 'related', None) - if isinstance(related, models.ManyToOneRel): + # Django =>1.9 uses 'rel', django <1.9 uses 'related' + related = getattr(attr, 'rel', None) or \ + getattr(attr, 'related', None) + if isinstance(related, RelatedObject): + # Hack for making it compatible with Django 1.6 + new_related = RelatedObject(related.parent_model, related.model, related.field) + new_related.name = name + yield new_related + elif isinstance(related, models.ManyToOneRel): yield related +class WrappedQueryset(LazyList): + + def __len__(self): + # Dont calculate the length using len(queryset), as this will + # evaluate the whole queryset and return it's length. + # Use .count() instead + return self._origin.count() + + def maybe_queryset(value): if isinstance(value, Manager): value = value.get_queryset() + if isinstance(value, QuerySet): + return WrappedQueryset(value) return value @@ -34,6 +61,8 @@ def get_filtering_args_from_filterset(filterset_class, type): a Graphene Field. These arguments will be available to filter against in the GraphQL """ + from graphene.contrib.django.form_converter import convert_form_field + args = {} for name, filter_field in six.iteritems(filterset_class.base_filters): field_type = Argument(convert_form_field(filter_field.field)) @@ -42,5 +71,36 @@ def get_filtering_args_from_filterset(filterset_class, type): args[name] = field_type # Also add the 'order_by' field - args[filterset_class.order_by_field] = Argument(String) + if filterset_class._meta.order_by: + args[filterset_class.order_by_field] = Argument(String()) return args + + +def get_related_model(field): + if hasattr(field, 'rel'): + # Django 1.6, 1.7 + return field.rel.to + return field.related_model + + +def import_single_dispatch(): + try: + from functools import singledispatch + except ImportError: + singledispatch = None + + if not singledispatch: + try: + from singledispatch import singledispatch + except ImportError: + pass + + if not singledispatch: + raise Exception( + "It seems your python version does not include " + "functools.singledispatch. Please install the 'singledispatch' " + "package. More information here: " + "https://pypi.python.org/pypi/singledispatch" + ) + + return singledispatch diff --git a/graphene/contrib/django/views.py b/graphene/contrib/django/views.py index ad245d72..9a4bd96e 100644 --- a/graphene/contrib/django/views.py +++ b/graphene/contrib/django/views.py @@ -12,5 +12,5 @@ class GraphQLView(BaseGraphQLView): **kwargs ) - def get_root_value(self, request): - return self.graphene_schema.query(super(GraphQLView, self).get_root_value(request)) + def execute(self, *args, **kwargs): + return self.graphene_schema.execute(*args, **kwargs) diff --git a/graphene/core/__init__.py b/graphene/core/__init__.py index e69de29b..9e8c7108 100644 --- a/graphene/core/__init__.py +++ b/graphene/core/__init__.py @@ -0,0 +1,46 @@ +from .schema import ( + Schema +) + +from .classtypes import ( + ObjectType, + InputObjectType, + Interface, + Mutation, + Scalar +) + +from .types import ( + InstanceType, + LazyType, + Argument, + Field, + InputField, + String, + Int, + Boolean, + ID, + Float, + List, + NonNull +) + +__all__ = [ + 'Argument', + 'String', + 'Int', + 'Boolean', + 'Float', + 'ID', + 'List', + 'NonNull', + 'Schema', + 'InstanceType', + 'LazyType', + 'ObjectType', + 'InputObjectType', + 'Interface', + 'Mutation', + 'Scalar', + 'Field', + 'InputField'] diff --git a/graphene/core/classtypes/base.py b/graphene/core/classtypes/base.py index 31a2c8d2..cabb909a 100644 --- a/graphene/core/classtypes/base.py +++ b/graphene/core/classtypes/base.py @@ -1,10 +1,10 @@ import copy import inspect from collections import OrderedDict +from functools import partial import six -from ..exceptions import SkipField from .options import Options @@ -48,8 +48,8 @@ class ClassTypeMeta(type): if not cls._meta.abstract: from ..types import List, NonNull - setattr(cls, 'NonNull', NonNull(cls)) - setattr(cls, 'List', List(cls)) + setattr(cls, 'NonNull', partial(NonNull, cls)) + setattr(cls, 'List', partial(List, cls)) return cls @@ -81,13 +81,18 @@ class FieldsOptions(Options): def fields_map(self): return OrderedDict([(f.attname, f) for f in self.fields]) + @property + def fields_group_type(self): + from ..types.field import FieldsGroupType + return FieldsGroupType(*self.local_fields) + class FieldsClassTypeMeta(ClassTypeMeta): options_class = FieldsOptions def extend_fields(cls, bases): new_fields = cls._meta.local_fields - field_names = {f.name: f for f in new_fields} + field_names = {f.attname: f for f in new_fields} for base in bases: if not isinstance(base, FieldsClassTypeMeta): @@ -95,17 +100,17 @@ class FieldsClassTypeMeta(ClassTypeMeta): parent_fields = base._meta.local_fields for field in parent_fields: - if field.name in field_names and field.type.__class__ != field_names[ - field.name].type.__class__: + if field.attname in field_names and field.type.__class__ != field_names[ + field.attname].type.__class__: raise Exception( 'Local field %r in class %r (%r) clashes ' 'with field with similar name from ' 'Interface %s (%r)' % ( - field.name, + field.attname, cls.__name__, field.__class__, base.__name__, - field_names[field.name].__class__) + field_names[field.attname].__class__) ) new_field = copy.copy(field) cls.add_to_class(field.attname, new_field) @@ -123,11 +128,4 @@ class FieldsClassType(six.with_metaclass(FieldsClassTypeMeta, ClassType)): @classmethod def fields_internal_types(cls, schema): - fields = [] - for field in cls._meta.fields: - try: - fields.append((field.name, schema.T(field))) - except SkipField: - continue - - return OrderedDict(fields) + return schema.T(cls._meta.fields_group_type) diff --git a/graphene/core/classtypes/tests/test_base.py b/graphene/core/classtypes/tests/test_base.py index 5017f94d..4666fdc2 100644 --- a/graphene/core/classtypes/tests/test_base.py +++ b/graphene/core/classtypes/tests/test_base.py @@ -23,15 +23,26 @@ def test_classtype_advanced(): def test_classtype_definition_list(): class Character(ClassType): '''Character description''' - assert isinstance(Character.List, List) - assert Character.List.of_type == Character + assert isinstance(Character.List(), List) + assert Character.List().of_type == Character def test_classtype_definition_nonnull(): class Character(ClassType): '''Character description''' - assert isinstance(Character.NonNull, NonNull) - assert Character.NonNull.of_type == Character + assert isinstance(Character.NonNull(), NonNull) + assert Character.NonNull().of_type == Character + + +def test_fieldsclasstype_definition_order(): + class Character(ClassType): + '''Character description''' + + class Query(FieldsClassType): + name = String() + char = Character.NonNull() + + assert list(Query._meta.fields_map.keys()) == ['name', 'char'] def test_fieldsclasstype(): diff --git a/graphene/core/classtypes/tests/test_mutation.py b/graphene/core/classtypes/tests/test_mutation.py index ac32585e..85dd2368 100644 --- a/graphene/core/classtypes/tests/test_mutation.py +++ b/graphene/core/classtypes/tests/test_mutation.py @@ -24,4 +24,4 @@ def test_mutation(): assert list(object_type.get_fields().keys()) == ['name'] assert MyMutation._meta.fields_map['name'].object_type == MyMutation assert isinstance(MyMutation.arguments, ArgumentsGroup) - assert 'argName' in MyMutation.arguments + assert 'argName' in schema.T(MyMutation.arguments) diff --git a/graphene/core/schema.py b/graphene/core/schema.py index 1b0ce8f9..c8695317 100644 --- a/graphene/core/schema.py +++ b/graphene/core/schema.py @@ -10,8 +10,9 @@ from graphql.core.utils.schema_printer import print_schema from graphene import signals +from ..plugins import CamelCase, PluginManager from .classtypes.base import ClassType -from .types.base import BaseType +from .types.base import InstanceType class GraphQLSchema(_GraphQLSchema): @@ -25,7 +26,7 @@ class Schema(object): _executor = None def __init__(self, query=None, mutation=None, subscription=None, - name='Schema', executor=None): + name='Schema', executor=None, plugins=None, auto_camelcase=True): self._types_names = {} self._types = {} self.mutation = mutation @@ -33,27 +34,34 @@ class Schema(object): self.subscription = subscription self.name = name self.executor = executor + plugins = plugins or [] + if auto_camelcase: + plugins.append(CamelCase()) + self.plugins = PluginManager(self, plugins) signals.init_schema.send(self) def __repr__(self): return '' % (str(self.name), hash(self)) - def T(self, object_type): - if not object_type: + def __getattr__(self, name): + if name in self.plugins: + return getattr(self.plugins, name) + return super(Schema, self).__getattr__(name) + + def T(self, _type): + if not _type: return - if inspect.isclass(object_type) and issubclass( - object_type, (BaseType, ClassType)) or isinstance( - object_type, BaseType): - if object_type not in self._types: - internal_type = object_type.internal_type(self) - self._types[object_type] = internal_type - is_objecttype = inspect.isclass( - object_type) and issubclass(object_type, ClassType) - if is_objecttype: - self.register(object_type) - return self._types[object_type] + is_classtype = inspect.isclass(_type) and issubclass(_type, ClassType) + is_instancetype = isinstance(_type, InstanceType) + if is_classtype or is_instancetype: + if _type not in self._types: + internal_type = _type.internal_type(self) + self._types[_type] = internal_type + if is_classtype: + self.register(_type) + return self._types[_type] else: - return object_type + return _type @property def executor(self): @@ -76,9 +84,9 @@ class Schema(object): mutation=self.T(self.mutation), subscription=self.T(self.subscription)) - def register(self, object_type): + def register(self, object_type, force=False): type_name = object_type._meta.type_name - registered_object_type = self._types_names.get(type_name, None) + registered_object_type = not force and self._types_names.get(type_name, None) if registered_object_type: assert registered_object_type == object_type, 'Type {} already registered with other object type'.format( type_name) @@ -110,17 +118,10 @@ class Schema(object): def types(self): return self._types_names - def execute(self, request='', root=None, vars=None, - operation_name=None, **kwargs): - root = root or object() - return self.executor.execute( - self.schema, - request, - root=root, - args=vars, - operation_name=operation_name, - **kwargs - ) + def execute(self, request='', root=None, args=None, **kwargs): + kwargs = dict(kwargs, request=request, root=root, args=args, schema=self.schema) + with self.plugins.context_execution(**kwargs) as execute_kwargs: + return self.executor.execute(**execute_kwargs) def introspect(self): return self.execute(introspection_query).data diff --git a/graphene/core/tests/test_old_fields.py b/graphene/core/tests/test_old_fields.py index 95bf9aab..3f24aedf 100644 --- a/graphene/core/tests/test_old_fields.py +++ b/graphene/core/tests/test_old_fields.py @@ -34,10 +34,11 @@ def test_field_type(): assert schema.T(f).type == GraphQLString -def test_field_name_automatic_camelcase(): +def test_field_name(): f = Field(GraphQLString) f.contribute_to_class(MyOt, 'field_name') - assert f.name == 'fieldName' + assert f.name is None + assert f.attname == 'field_name' def test_field_name_use_name_if_exists(): diff --git a/graphene/core/types/__init__.py b/graphene/core/types/__init__.py index 9260476c..51512ec4 100644 --- a/graphene/core/types/__init__.py +++ b/graphene/core/types/__init__.py @@ -1,14 +1,14 @@ -from .base import BaseType, LazyType, OrderedType +from .base import InstanceType, LazyType, OrderedType from .argument import Argument, ArgumentsGroup, to_arguments from .definitions import List, NonNull # Compatibility import from .objecttype import Interface, ObjectType, Mutation, InputObjectType -from .scalars import String, ID, Boolean, Int, Float, Scalar +from .scalars import String, ID, Boolean, Int, Float from .field import Field, InputField __all__ = [ - 'BaseType', + 'InstanceType', 'LazyType', 'OrderedType', 'Argument', @@ -26,5 +26,4 @@ __all__ = [ 'ID', 'Boolean', 'Int', - 'Float', - 'Scalar'] + 'Float'] diff --git a/graphene/core/types/argument.py b/graphene/core/types/argument.py index 0892c446..b10aff21 100644 --- a/graphene/core/types/argument.py +++ b/graphene/core/types/argument.py @@ -1,19 +1,17 @@ -from collections import OrderedDict from functools import wraps from itertools import chain from graphql.core.type import GraphQLArgument -from ...utils import ProxySnakeDict, to_camel_case -from .base import ArgumentType, BaseType, OrderedType +from ...utils import ProxySnakeDict +from .base import ArgumentType, GroupNamedType, NamedType, OrderedType -class Argument(OrderedType): +class Argument(NamedType, OrderedType): def __init__(self, type, description=None, default=None, name=None, _creation_counter=None): - super(Argument, self).__init__(_creation_counter=_creation_counter) - self.name = name + super(Argument, self).__init__(name=name, _creation_counter=_creation_counter) self.type = type self.description = description self.default = default @@ -27,47 +25,32 @@ class Argument(OrderedType): return self.name -class ArgumentsGroup(BaseType): +class ArgumentsGroup(GroupNamedType): def __init__(self, *args, **kwargs): arguments = to_arguments(*args, **kwargs) - self.arguments = OrderedDict([(arg.name, arg) for arg in arguments]) - - def internal_type(self, schema): - return OrderedDict([(arg.name, schema.T(arg)) - for arg in self.arguments.values()]) - - def __len__(self): - return len(self.arguments) - - def __iter__(self): - return iter(self.arguments) - - def __contains__(self, *args): - return self.arguments.__contains__(*args) - - def __getitem__(self, *args): - return self.arguments.__getitem__(*args) + super(ArgumentsGroup, self).__init__(*arguments) def to_arguments(*args, **kwargs): arguments = {} iter_arguments = chain(kwargs.items(), [(None, a) for a in args]) - for name, arg in iter_arguments: + for default_name, arg in iter_arguments: if isinstance(arg, Argument): argument = arg elif isinstance(arg, ArgumentType): argument = arg.as_argument() else: - raise ValueError('Unknown argument %s=%r' % (name, arg)) + raise ValueError('Unknown argument %s=%r' % (default_name, arg)) - if name: - argument.name = to_camel_case(name) - assert argument.name, 'Argument in field must have a name' - assert argument.name not in arguments, 'Found more than one Argument with same name {}'.format( - argument.name) - arguments[argument.name] = argument + if default_name: + argument.default_name = default_name + + name = argument.name or argument.default_name + assert name, 'Argument in field must have a name' + assert name not in arguments, 'Found more than one Argument with same name {}'.format(name) + arguments[name] = argument return sorted(arguments.values()) diff --git a/graphene/core/types/base.py b/graphene/core/types/base.py index 2b4078e4..ec8c7b3b 100644 --- a/graphene/core/types/base.py +++ b/graphene/core/types/base.py @@ -1,16 +1,16 @@ -from functools import total_ordering +from collections import OrderedDict +from functools import partial, total_ordering import six -class BaseType(object): +class InstanceType(object): - @classmethod - def internal_type(cls, schema): - return getattr(cls, 'T', None) + def internal_type(self, schema): + raise NotImplementedError("internal_type for type {} is not implemented".format(self.__class__.__name__)) -class MountType(BaseType): +class MountType(InstanceType): parent = None def mount(self, cls): @@ -126,3 +126,39 @@ class FieldType(MirroredType): class MountedType(FieldType, ArgumentType): pass + + +class NamedType(InstanceType): + + def __init__(self, name=None, default_name=None, *args, **kwargs): + self.name = name + self.default_name = None + super(NamedType, self).__init__(*args, **kwargs) + + +class GroupNamedType(InstanceType): + + def __init__(self, *types): + self.types = types + + def get_named_type(self, schema, type): + name = type.name or schema.get_default_namedtype_name(type.default_name) + return name, schema.T(type) + + def iter_types(self, schema): + return map(partial(self.get_named_type, schema), self.types) + + def internal_type(self, schema): + return OrderedDict(self.iter_types(schema)) + + def __len__(self): + return len(self.types) + + def __iter__(self): + return iter(self.types) + + def __contains__(self, *args): + return self.types.__contains__(*args) + + def __getitem__(self, *args): + return self.types.__getitem__(*args) diff --git a/graphene/core/types/field.py b/graphene/core/types/field.py index d7643b5e..6cbfff96 100644 --- a/graphene/core/types/field.py +++ b/graphene/core/types/field.py @@ -4,23 +4,26 @@ from functools import wraps import six from graphql.core.type import GraphQLField, GraphQLInputObjectField -from ...utils import to_camel_case from ..classtypes.base import FieldsClassType from ..classtypes.inputobjecttype import InputObjectType from ..classtypes.mutation import Mutation -from .argument import ArgumentsGroup, snake_case_args -from .base import LazyType, MountType, OrderedType +from ..exceptions import SkipField +from .argument import Argument, ArgumentsGroup, snake_case_args +from .base import (ArgumentType, GroupNamedType, LazyType, MountType, + NamedType, OrderedType) from .definitions import NonNull -class Field(OrderedType): +class Field(NamedType, OrderedType): def __init__( self, type, description=None, args=None, name=None, resolver=None, required=False, default=None, *args_list, **kwargs): _creation_counter = kwargs.pop('_creation_counter', None) - super(Field, self).__init__(_creation_counter=_creation_counter) - self.name = name + if isinstance(name, (Argument, ArgumentType)): + kwargs['name'] = name + name = None + super(Field, self).__init__(name=name, _creation_counter=_creation_counter) if isinstance(type, six.string_types): type = LazyType(type) self.required = required @@ -36,9 +39,8 @@ class Field(OrderedType): assert issubclass( cls, (FieldsClassType)), 'Field {} cannot be mounted in {}'.format( self, cls) - if not self.name: - self.name = to_camel_case(attname) self.attname = attname + self.default_name = attname self.object_type = cls self.mount(cls) if isinstance(self.type, MountType): @@ -63,6 +65,9 @@ class Field(OrderedType): return NonNull(self.type) return self.type + def decorate_resolver(self, resolver): + return snake_case_args(resolver) + def internal_type(self, schema): resolver = self.resolver description = self.description @@ -85,9 +90,9 @@ class Field(OrderedType): return my_resolver(instance, args, info) resolver = wrapped_func - resolver = snake_case_args(resolver) assert type, 'Internal type for field %s is None' % str(self) - return GraphQLField(type, args=schema.T(arguments), resolver=resolver, + return GraphQLField(type, args=schema.T(arguments), + resolver=self.decorate_resolver(resolver), description=description,) def __repr__(self): @@ -114,12 +119,11 @@ class Field(OrderedType): return hash((self.creation_counter, self.object_type)) -class InputField(OrderedType): +class InputField(NamedType, OrderedType): def __init__(self, type, description=None, default=None, name=None, _creation_counter=None, required=False): super(InputField, self).__init__(_creation_counter=_creation_counter) - self.name = name if required: type = NonNull(type) self.type = type @@ -130,9 +134,8 @@ class InputField(OrderedType): assert issubclass( cls, (InputObjectType)), 'InputField {} cannot be mounted in {}'.format( self, cls) - if not self.name: - self.name = to_camel_case(attname) self.attname = attname + self.default_name = attname self.object_type = cls self.mount(cls) if isinstance(self.type, MountType): @@ -143,3 +146,13 @@ class InputField(OrderedType): return GraphQLInputObjectField( schema.T(self.type), default_value=self.default, description=self.description) + + +class FieldsGroupType(GroupNamedType): + + def iter_types(self, schema): + for field in sorted(self.types): + try: + yield self.get_named_type(schema, field) + except SkipField: + continue diff --git a/graphene/core/types/scalars.py b/graphene/core/types/scalars.py index 75cd70a3..9d7f5aeb 100644 --- a/graphene/core/types/scalars.py +++ b/graphene/core/types/scalars.py @@ -1,41 +1,30 @@ from graphql.core.type import (GraphQLBoolean, GraphQLFloat, GraphQLID, - GraphQLInt, GraphQLScalarType, GraphQLString) + GraphQLInt, GraphQLString) from .base import MountedType -class String(MountedType): - T = GraphQLString +class ScalarType(MountedType): + + def internal_type(self, schema): + return self._internal_type -class Int(MountedType): - T = GraphQLInt +class String(ScalarType): + _internal_type = GraphQLString -class Boolean(MountedType): - T = GraphQLBoolean +class Int(ScalarType): + _internal_type = GraphQLInt -class ID(MountedType): - T = GraphQLID +class Boolean(ScalarType): + _internal_type = GraphQLBoolean -class Float(MountedType): - T = GraphQLFloat +class ID(ScalarType): + _internal_type = GraphQLID -class Scalar(MountedType): - - @classmethod - def internal_type(cls, schema): - serialize = getattr(cls, 'serialize') - parse_literal = getattr(cls, 'parse_literal') - parse_value = getattr(cls, 'parse_value') - - return GraphQLScalarType( - name=cls.__name__, - description=cls.__doc__, - serialize=serialize, - parse_value=parse_value, - parse_literal=parse_literal - ) +class Float(ScalarType): + _internal_type = GraphQLFloat diff --git a/graphene/core/types/tests/test_argument.py b/graphene/core/types/tests/test_argument.py index 26bbb310..b2f5e239 100644 --- a/graphene/core/types/tests/test_argument.py +++ b/graphene/core/types/tests/test_argument.py @@ -27,8 +27,8 @@ def test_to_arguments(): other_kwarg=String(), ) - assert [a.name for a in arguments] == [ - 'myArg', 'otherArg', 'myKwarg', 'otherKwarg'] + assert [a.name or a.default_name for a in arguments] == [ + 'myArg', 'otherArg', 'my_kwarg', 'other_kwarg'] def test_to_arguments_no_name(): diff --git a/graphene/core/types/tests/test_field.py b/graphene/core/types/tests/test_field.py index 8253ed20..706cbc59 100644 --- a/graphene/core/types/tests/test_field.py +++ b/graphene/core/types/tests/test_field.py @@ -13,14 +13,14 @@ from ..scalars import String def test_field_internal_type(): resolver = lambda *args: 'RESOLVED' - field = Field(String, description='My argument', resolver=resolver) + field = Field(String(), description='My argument', resolver=resolver) class Query(ObjectType): my_field = field schema = Schema(query=Query) type = schema.T(field) - assert field.name == 'myField' + assert field.name is None assert field.attname == 'my_field' assert isinstance(type, GraphQLField) assert type.description == 'My argument' @@ -98,9 +98,18 @@ def test_field_string_reference(): def test_field_custom_arguments(): field = Field(None, name='my_customName', p=String()) + schema = Schema() args = field.arguments - assert 'p' in args + assert 'p' in schema.T(args) + + +def test_field_name_as_argument(): + field = Field(None, name=String()) + schema = Schema() + + args = field.arguments + assert 'name' in schema.T(args) def test_inputfield_internal_type(): @@ -115,8 +124,43 @@ def test_inputfield_internal_type(): schema = Schema(query=MyObjectType) type = schema.T(field) - assert field.name == 'myField' + assert field.name is None assert field.attname == 'my_field' assert isinstance(type, GraphQLInputObjectField) assert type.description == 'My input field' assert type.default_value == '3' + + +def test_field_resolve_argument(): + resolver = lambda instance, args, info: args.get('first_name') + + field = Field(String(), first_name=String(), description='My argument', resolver=resolver) + + class Query(ObjectType): + my_field = field + schema = Schema(query=Query) + + type = schema.T(field) + assert type.resolver(None, {'firstName': 'Peter'}, None) == 'Peter' + + +def test_field_resolve_vars(): + class Query(ObjectType): + hello = String(first_name=String()) + + def resolve_hello(self, args, info): + return 'Hello ' + args.get('first_name') + + schema = Schema(query=Query) + + result = schema.execute(""" + query foo($firstName:String) + { + hello(firstName:$firstName) + } + """, args={"firstName": "Serkan"}) + + expected = { + 'hello': 'Hello Serkan' + } + assert result.data == expected diff --git a/graphene/core/types/tests/test_scalars.py b/graphene/core/types/tests/test_scalars.py index 8b8d930f..39fd6063 100644 --- a/graphene/core/types/tests/test_scalars.py +++ b/graphene/core/types/tests/test_scalars.py @@ -1,9 +1,9 @@ from graphql.core.type import (GraphQLBoolean, GraphQLFloat, GraphQLID, - GraphQLInt, GraphQLScalarType, GraphQLString) + GraphQLInt, GraphQLString) from graphene.core.schema import Schema -from ..scalars import ID, Boolean, Float, Int, Scalar, String +from ..scalars import ID, Boolean, Float, Int, String schema = Schema() @@ -26,29 +26,3 @@ def test_id_scalar(): def test_float_scalar(): assert schema.T(Float()) == GraphQLFloat - - -def test_custom_scalar(): - import datetime - from graphql.core.language import ast - - class DateTimeScalar(Scalar): - '''DateTimeScalar Documentation''' - @staticmethod - def serialize(dt): - return dt.isoformat() - - @staticmethod - def parse_literal(node): - if isinstance(node, ast.StringValue): - return datetime.datetime.strptime( - node.value, "%Y-%m-%dT%H:%M:%S.%f") - - @staticmethod - def parse_value(value): - return datetime.datetime.strptime(value, "%Y-%m-%dT%H:%M:%S.%f") - - scalar_type = schema.T(DateTimeScalar) - assert isinstance(scalar_type, GraphQLScalarType) - assert scalar_type.name == 'DateTimeScalar' - assert scalar_type.description == 'DateTimeScalar Documentation' diff --git a/graphene/plugins/__init__.py b/graphene/plugins/__init__.py new file mode 100644 index 00000000..160bffba --- /dev/null +++ b/graphene/plugins/__init__.py @@ -0,0 +1,6 @@ +from .base import Plugin, PluginManager +from .camel_case import CamelCase + +__all__ = [ + 'Plugin', 'PluginManager', 'CamelCase' +] diff --git a/graphene/plugins/base.py b/graphene/plugins/base.py new file mode 100644 index 00000000..2347beba --- /dev/null +++ b/graphene/plugins/base.py @@ -0,0 +1,53 @@ +from contextlib import contextmanager +from functools import partial, reduce + + +class Plugin(object): + + def contribute_to_schema(self, schema): + self.schema = schema + + +def apply_function(a, b): + return b(a) + + +class PluginManager(object): + + PLUGIN_FUNCTIONS = ('get_default_namedtype_name', ) + + def __init__(self, schema, plugins=[]): + self.schema = schema + self.plugins = [] + for plugin in plugins: + self.add_plugin(plugin) + + def add_plugin(self, plugin): + if hasattr(plugin, 'contribute_to_schema'): + plugin.contribute_to_schema(self.schema) + self.plugins.append(plugin) + + def get_plugin_functions(self, function): + for plugin in self.plugins: + if not hasattr(plugin, function): + continue + yield getattr(plugin, function) + + def __getattr__(self, name): + functions = self.get_plugin_functions(name) + return partial(reduce, apply_function, functions) + + def __contains__(self, name): + return name in self.PLUGIN_FUNCTIONS + + @contextmanager + def context_execution(self, **executor): + contexts = [] + functions = self.get_plugin_functions('context_execution') + for f in functions: + context = f(executor) + executor = context.__enter__() + contexts.append((context, executor)) + yield executor + for context, value in contexts[::-1]: + context.__exit__(None, None, None) diff --git a/graphene/plugins/camel_case.py b/graphene/plugins/camel_case.py new file mode 100644 index 00000000..d9a9084f --- /dev/null +++ b/graphene/plugins/camel_case.py @@ -0,0 +1,7 @@ +from ..utils import to_camel_case + + +class CamelCase(object): + + def get_default_namedtype_name(self, value): + return to_camel_case(value) diff --git a/graphene/relay/tests/test_mutations.py b/graphene/relay/tests/test_mutations.py index 4356a1ec..02287725 100644 --- a/graphene/relay/tests/test_mutations.py +++ b/graphene/relay/tests/test_mutations.py @@ -34,8 +34,7 @@ schema = Schema(query=Query, mutation=MyResultMutation) def test_mutation_arguments(): assert ChangeNumber.arguments - assert list(ChangeNumber.arguments) == ['input'] - assert 'input' in ChangeNumber.arguments + assert 'input' in schema.T(ChangeNumber.arguments) inner_type = ChangeNumber.input_type client_mutation_id_field = inner_type._meta.fields_map[ 'client_mutation_id'] diff --git a/graphene/relay/types.py b/graphene/relay/types.py index 672042e7..425e3038 100644 --- a/graphene/relay/types.py +++ b/graphene/relay/types.py @@ -4,6 +4,7 @@ from collections import Iterable from functools import wraps import six + from graphql_relay.connection.arrayconnection import connection_from_list from graphql_relay.node.node import to_global_id diff --git a/graphene/utils/__init__.py b/graphene/utils/__init__.py index 52fb6417..6adb1f73 100644 --- a/graphene/utils/__init__.py +++ b/graphene/utils/__init__.py @@ -3,8 +3,9 @@ from .proxy_snake_dict import ProxySnakeDict from .caching import cached_property, memoize from .misc import enum_to_graphql_enum from .resolve_only_args import resolve_only_args +from .lazylist import LazyList __all__ = ['to_camel_case', 'to_snake_case', 'ProxySnakeDict', 'cached_property', 'memoize', 'enum_to_graphql_enum', - 'resolve_only_args'] + 'resolve_only_args', 'LazyList'] diff --git a/graphene/utils/lazylist.py b/graphene/utils/lazylist.py new file mode 100644 index 00000000..434dcfd4 --- /dev/null +++ b/graphene/utils/lazylist.py @@ -0,0 +1,43 @@ +class LazyList(object): + + def __init__(self, origin, state=None): + self._origin = origin + self._state = state or [] + self._origin_iter = None + self._finished = False + + def __iter__(self): + return self if not self._finished else iter(self._state) + + def iter(self): + return self.__iter__() + + def __len__(self): + return self._origin.__len__() + + def __next__(self): + try: + if not self._origin_iter: + self._origin_iter = self._origin.__iter__() + n = next(self._origin_iter) + except StopIteration as e: + self._finished = True + raise e + else: + self._state.append(n) + return n + + def next(self): + return self.__next__() + + def __getitem__(self, key): + item = self._origin[key] + if isinstance(key, slice): + return self.__class__(item) + return item + + def __getattr__(self, name): + return getattr(self._origin, name) + + def __repr__(self): + return "<{} {}>".format(self.__class__.__name__, repr(self._origin)) diff --git a/graphene/utils/tests/test_lazylist.py b/graphene/utils/tests/test_lazylist.py new file mode 100644 index 00000000..972e2942 --- /dev/null +++ b/graphene/utils/tests/test_lazylist.py @@ -0,0 +1,23 @@ +from py.test import raises + +from ..lazylist import LazyList + + +def test_lazymap(): + data = list(range(10)) + lm = LazyList(data) + assert len(lm) == 10 + assert lm[1] == 1 + assert isinstance(lm[1:4], LazyList) + assert lm.append == data.append + assert repr(lm) == '' + + +def test_lazymap_iter(): + data = list(range(2)) + lm = LazyList(data) + iter_lm = iter(lm) + assert iter_lm.next() == 0 + assert iter_lm.next() == 1 + with raises(StopIteration): + iter_lm.next() diff --git a/setup.py b/setup.py index 8b2c324c..ac0e3d0f 100644 --- a/setup.py +++ b/setup.py @@ -24,9 +24,9 @@ class PyTest(TestCommand): setup( name='graphene', - version='0.4.2', + version='0.5.0', - description='Graphene: Python DSL for GraphQL', + description='GraphQL Framework for Python', long_description=open('README.rst').read(), url='https://github.com/graphql-python/graphene', @@ -66,9 +66,9 @@ setup( ], extras_require={ 'django': [ - 'Django>=1.6.0,<1.9', + 'Django>=1.6.0', 'singledispatch>=3.4.0.3', - 'graphql-django-view>=1.0.0', + 'graphql-django-view>=1.1.0', ], }, diff --git a/tests/django_settings.py b/tests/django_settings.py index 2af62bb3..998f68ab 100644 --- a/tests/django_settings.py +++ b/tests/django_settings.py @@ -1,6 +1,7 @@ SECRET_KEY = 1 INSTALLED_APPS = [ + 'graphene.contrib.django.tests', 'examples.starwars_django', ]