Merge branch 'feature/django' into feature/django-docs

This commit is contained in:
Adam Charnock 2015-12-29 13:31:31 +00:00
commit 90f58e786a
64 changed files with 1102 additions and 285 deletions

14
.editorconfig Normal file
View File

@ -0,0 +1,14 @@
# http://editorconfig.org
root = true
[*]
charset = utf-8
end_of_line = lf
insert_final_newline = true
trim_trailing_whitespace = true
[*.{py,rst,ini}]
indent_style = space
indent_size = 4

View File

@ -2,7 +2,6 @@ language: python
sudo: false sudo: false
python: python:
- 2.7 - 2.7
- 3.3
- 3.4 - 3.4
- 3.5 - 3.5
- pypy - pypy
@ -14,8 +13,9 @@ cache:
- $HOME/docs/node_modules - $HOME/docs/node_modules
before_install: before_install:
- | - |
if [ "$TEST_TYPE" != build_website ] && \ git_diff=$(git diff --name-only $TRAVIS_COMMIT_RANGE)
! git diff --name-only $TRAVIS_COMMIT_RANGE | grep -qvE '(\.md$)|(^(docs))/' if [ "$?" == 0 ] && [ "$TEST_TYPE" != build_website ] && \
! echo "$git_diff" | grep -qvE '(\.md$)|(^(docs))/'
then then
echo "Only docs were updated, stopping build process." echo "Only docs were updated, stopping build process."
exit exit
@ -25,6 +25,7 @@ install:
if [ "$TEST_TYPE" = build ]; then if [ "$TEST_TYPE" = build ]; then
pip install --download-cache $HOME/.cache/pip/ pytest pytest-cov coveralls six pytest-django django-filter pip install --download-cache $HOME/.cache/pip/ pytest pytest-cov coveralls six pytest-django django-filter
pip install --download-cache $HOME/.cache/pip/ -e .[django] pip install --download-cache $HOME/.cache/pip/ -e .[django]
pip install django==$DJANGO_VERSION
python setup.py develop python setup.py develop
elif [ "$TEST_TYPE" = build_website ]; then elif [ "$TEST_TYPE" = build_website ]; then
pip install --download-cache $HOME/.cache/pip/ -e . pip install --download-cache $HOME/.cache/pip/ -e .
@ -78,6 +79,14 @@ env:
matrix: matrix:
fast_finish: true fast_finish: true
include: include:
- python: '2.7'
env: DJANGO_VERSION=1.6
- python: '2.7'
env: DJANGO_VERSION=1.7
- python: '2.7'
env: DJANGO_VERSION=1.8
- python: '2.7'
env: DJANGO_VERSION=1.9
- python: '2.7' - python: '2.7'
env: TEST_TYPE=build_website env: TEST_TYPE=build_website
- python: '2.7' - python: '2.7'

View File

@ -82,3 +82,21 @@ graphene.Field(graphene.String(), to=graphene.String())
# Is equivalent to: # Is equivalent to:
graphene.Field(graphene.String(), to=graphene.Argument(graphene.String())) graphene.Field(graphene.String(), to=graphene.Argument(graphene.String()))
``` ```
## Using custom object types as argument
To use a custom object type as an argument, you need to inherit `graphene.InputObjectType`, not `graphene.ObjectType`.
```python
class CustomArgumentObjectType(graphene.InputObjectType):
field1 = graphene.String()
field2 = graphene.String()
```
Then, when defining this in an argument, you need to wrap it in an `Argument` object.
```python
graphene.Field(graphene.String(), to=graphene.Argument(CustomArgumentObjectType))
```

View File

@ -63,15 +63,15 @@ class Query(graphene.ObjectType):
@resolve_only_args @resolve_only_args
def resolve_ships(self): def resolve_ships(self):
return [Ship(s) for s in get_ships()] return get_ships()
@resolve_only_args @resolve_only_args
def resolve_rebels(self): def resolve_rebels(self):
return Faction(get_rebels()) return get_rebels()
@resolve_only_args @resolve_only_args
def resolve_empire(self): def resolve_empire(self):
return Faction(get_empire()) return get_empire()
class Mutation(graphene.ObjectType): class Mutation(graphene.ObjectType):

View File

@ -4,20 +4,14 @@ from graphql.core.type import (
from graphene import signals from graphene import signals
from graphene.core.schema import ( from .core import (
Schema Schema,
)
from graphene.core.classtypes import (
ObjectType, ObjectType,
InputObjectType, InputObjectType,
Interface, Interface,
Mutation, Mutation,
Scalar Scalar,
) InstanceType,
from graphene.core.types import (
BaseType,
LazyType, LazyType,
Argument, Argument,
Field, Field,
@ -57,7 +51,7 @@ __all__ = [
'NonNull', 'NonNull',
'signals', 'signals',
'Schema', 'Schema',
'BaseType', 'InstanceType',
'LazyType', 'LazyType',
'ObjectType', 'ObjectType',
'InputObjectType', 'InputObjectType',

View File

@ -0,0 +1,15 @@
from django.db import models
try:
UUIDField = models.UUIDField
except AttributeError:
# Improved compatibility for Django 1.6
class UUIDField(object):
pass
try:
from django.db.models.related import RelatedObject
except:
# Improved compatibility for Django 1.6
class RelatedObject(object):
pass

View File

@ -1,14 +1,10 @@
from django.db import models from django.db import models
from singledispatch import singledispatch
from ...core.types.scalars import ID, Boolean, Float, Int, String from ...core.types.scalars import ID, Boolean, Float, Int, String
from .compat import RelatedObject, UUIDField
from .utils import get_related_model, import_single_dispatch
try: singledispatch = import_single_dispatch()
UUIDField = models.UUIDField
except AttributeError:
# Improved compatibility for Django 1.6
class UUIDField(object):
pass
@singledispatch @singledispatch
@ -24,6 +20,7 @@ def convert_django_field(field):
@convert_django_field.register(models.EmailField) @convert_django_field.register(models.EmailField)
@convert_django_field.register(models.SlugField) @convert_django_field.register(models.SlugField)
@convert_django_field.register(models.URLField) @convert_django_field.register(models.URLField)
@convert_django_field.register(models.GenericIPAddressField)
@convert_django_field.register(UUIDField) @convert_django_field.register(UUIDField)
def convert_field_to_string(field): def convert_field_to_string(field):
return String(description=field.help_text) return String(description=field.help_text)
@ -63,7 +60,15 @@ def convert_field_to_float(field):
@convert_django_field.register(models.ManyToOneRel) @convert_django_field.register(models.ManyToOneRel)
def convert_field_to_list_or_connection(field): def convert_field_to_list_or_connection(field):
from .fields import DjangoModelField, ConnectionOrListField from .fields import DjangoModelField, ConnectionOrListField
model_field = DjangoModelField(field.related_model) model_field = DjangoModelField(get_related_model(field))
return ConnectionOrListField(model_field)
# For Django 1.6
@convert_django_field.register(RelatedObject)
def convert_relatedfield_to_djangomodel(field):
from .fields import DjangoModelField, ConnectionOrListField
model_field = DjangoModelField(field.model)
return ConnectionOrListField(model_field) return ConnectionOrListField(model_field)
@ -71,4 +76,4 @@ def convert_field_to_list_or_connection(field):
@convert_django_field.register(models.ForeignKey) @convert_django_field.register(models.ForeignKey)
def convert_field_to_djangomodel(field): def convert_field_to_djangomodel(field):
from .fields import DjangoModelField from .fields import DjangoModelField
return DjangoModelField(field.related_model, description=field.help_text) return DjangoModelField(get_related_model(field), description=field.help_text)

View File

@ -0,0 +1,4 @@
from .plugin import DjangoDebugPlugin
from .types import DjangoDebug
__all__ = ['DjangoDebugPlugin', 'DjangoDebug']

View File

@ -0,0 +1,79 @@
from contextlib import contextmanager
from django.db import connections
from ....core.schema import GraphQLSchema
from ....core.types import Field
from ....plugins import Plugin
from .sql.tracking import unwrap_cursor, wrap_cursor
from .sql.types import DjangoDebugSQL
from .types import DjangoDebug
class WrappedRoot(object):
def __init__(self, root):
self._recorded = []
self._root = root
def record(self, **log):
self._recorded.append(DjangoDebugSQL(**log))
def debug(self):
return DjangoDebug(sql=self._recorded)
class WrapRoot(object):
@property
def _root(self):
return self._wrapped_root.root
@_root.setter
def _root(self, value):
self._wrapped_root = value
def resolve_debug(self, args, info):
return self._wrapped_root.debug()
def debug_objecttype(objecttype):
return type(
'Debug{}'.format(objecttype._meta.type_name),
(WrapRoot, objecttype),
{'debug': Field(DjangoDebug, name='__debug')})
class DjangoDebugPlugin(Plugin):
def enable_instrumentation(self, wrapped_root):
# This is thread-safe because database connections are thread-local.
for connection in connections.all():
wrap_cursor(connection, wrapped_root)
def disable_instrumentation(self):
for connection in connections.all():
unwrap_cursor(connection)
def wrap_schema(self, schema_type):
query = schema_type._query
if query:
class_type = self.schema.objecttype(schema_type.get_query_type())
assert class_type, 'The query in schema is not constructed with graphene'
_type = debug_objecttype(class_type)
self.schema.register(_type, force=True)
return GraphQLSchema(
self.schema,
self.schema.T(_type),
schema_type.get_mutation_type(),
schema_type.get_subscription_type()
)
return schema_type
@contextmanager
def context_execution(self, executor):
executor['root'] = WrappedRoot(root=executor['root'])
executor['schema'] = self.wrap_schema(executor['schema'])
self.enable_instrumentation(executor['root'])
yield executor
self.disable_instrumentation()

View File

@ -0,0 +1,165 @@
# Code obtained from django-debug-toolbar sql panel tracking
from __future__ import absolute_import, unicode_literals
import json
from threading import local
from time import time
from django.utils import six
from django.utils.encoding import force_text
class SQLQueryTriggered(Exception):
"""Thrown when template panel triggers a query"""
class ThreadLocalState(local):
def __init__(self):
self.enabled = True
@property
def Wrapper(self):
if self.enabled:
return NormalCursorWrapper
return ExceptionCursorWrapper
def recording(self, v):
self.enabled = v
state = ThreadLocalState()
recording = state.recording # export function
def wrap_cursor(connection, panel):
if not hasattr(connection, '_djdt_cursor'):
connection._djdt_cursor = connection.cursor
def cursor():
return state.Wrapper(connection._djdt_cursor(), connection, panel)
connection.cursor = cursor
return cursor
def unwrap_cursor(connection):
if hasattr(connection, '_djdt_cursor'):
del connection._djdt_cursor
del connection.cursor
class ExceptionCursorWrapper(object):
"""
Wraps a cursor and raises an exception on any operation.
Used in Templates panel.
"""
def __init__(self, cursor, db, logger):
pass
def __getattr__(self, attr):
raise SQLQueryTriggered()
class NormalCursorWrapper(object):
"""
Wraps a cursor and logs queries.
"""
def __init__(self, cursor, db, logger):
self.cursor = cursor
# Instance of a BaseDatabaseWrapper subclass
self.db = db
# logger must implement a ``record`` method
self.logger = logger
def _quote_expr(self, element):
if isinstance(element, six.string_types):
return "'%s'" % force_text(element).replace("'", "''")
else:
return repr(element)
def _quote_params(self, params):
if not params:
return params
if isinstance(params, dict):
return dict((key, self._quote_expr(value))
for key, value in params.items())
return list(map(self._quote_expr, params))
def _decode(self, param):
try:
return force_text(param, strings_only=True)
except UnicodeDecodeError:
return '(encoded string)'
def _record(self, method, sql, params):
start_time = time()
try:
return method(sql, params)
finally:
stop_time = time()
duration = (stop_time - start_time)
_params = ''
try:
_params = json.dumps(list(map(self._decode, params)))
except Exception:
pass # object not JSON serializable
alias = getattr(self.db, 'alias', 'default')
conn = self.db.connection
vendor = getattr(conn, 'vendor', 'unknown')
params = {
'vendor': vendor,
'alias': alias,
'sql': self.db.ops.last_executed_query(
self.cursor, sql, self._quote_params(params)),
'duration': duration,
'raw_sql': sql,
'params': _params,
'start_time': start_time,
'stop_time': stop_time,
'is_slow': duration > 10,
'is_select': sql.lower().strip().startswith('select'),
}
if vendor == 'postgresql':
# If an erroneous query was ran on the connection, it might
# be in a state where checking isolation_level raises an
# exception.
try:
iso_level = conn.isolation_level
except conn.InternalError:
iso_level = 'unknown'
params.update({
'trans_id': self.logger.get_transaction_id(alias),
'trans_status': conn.get_transaction_status(),
'iso_level': iso_level,
'encoding': conn.encoding,
})
# We keep `sql` to maintain backwards compatibility
self.logger.record(**params)
def callproc(self, procname, params=()):
return self._record(self.cursor.callproc, procname, params)
def execute(self, sql, params=()):
return self._record(self.cursor.execute, sql, params)
def executemany(self, sql, param_list):
return self._record(self.cursor.executemany, sql, param_list)
def __getattr__(self, attr):
return getattr(self.cursor, attr)
def __iter__(self):
return iter(self.cursor)
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.close()

View File

@ -0,0 +1,19 @@
from .....core import Boolean, Float, ObjectType, String
class DjangoDebugSQL(ObjectType):
vendor = String()
alias = String()
sql = String()
duration = Float()
raw_sql = String()
params = String()
start_time = Float()
stop_time = Float()
is_slow = Boolean()
is_select = Boolean()
trans_id = String()
trans_status = String()
iso_level = String()
encoding = String()

View File

@ -0,0 +1,70 @@
import pytest
import graphene
from graphene.contrib.django import DjangoObjectType
from ...tests.models import Reporter
from ..plugin import DjangoDebugPlugin
# from examples.starwars_django.models import Character
pytestmark = pytest.mark.django_db
def test_should_query_well():
r1 = Reporter(last_name='ABA')
r1.save()
r2 = Reporter(last_name='Griffin')
r2.save()
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType)
all_reporters = ReporterType.List()
def resolve_all_reporters(self, *args, **kwargs):
return Reporter.objects.all()
def resolve_reporter(self, *args, **kwargs):
return Reporter.objects.first()
query = '''
query ReporterQuery {
reporter {
lastName
}
allReporters {
lastName
}
__debug {
sql {
rawSql
}
}
}
'''
expected = {
'reporter': {
'lastName': 'ABA',
},
'allReporters': [{
'lastName': 'ABA',
}, {
'lastName': 'Griffin',
}],
'__debug': {
'sql': [{
'rawSql': str(Reporter.objects.order_by('pk')[:1].query)
}, {
'rawSql': str(Reporter.objects.all().query)
}]
}
}
schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()])
result = schema.execute(query)
assert not result.errors
assert result.data == expected

View File

@ -0,0 +1,7 @@
from ....core.classtypes.objecttype import ObjectType
from ....core.types import Field
from .sql.types import DjangoDebugSQL
class DjangoDebug(ObjectType):
sql = Field(DjangoDebugSQL.List())

View File

@ -1,12 +1,13 @@
import warnings import warnings
from .utils import get_type_for_model
from ...core.exceptions import SkipField from ...core.exceptions import SkipField
from ...core.fields import Field from ...core.fields import Field
from ...core.types.base import FieldType from ...core.types.base import FieldType
from ...core.types.definitions import List from ...core.types.definitions import List
from ...relay import ConnectionField from ...relay import ConnectionField
from ...relay.utils import is_node from ...relay.utils import is_node
from .filter.fields import DjangoFilterConnectionField
from .utils import get_type_for_model
class DjangoConnectionField(ConnectionField): class DjangoConnectionField(ConnectionField):
@ -27,7 +28,10 @@ class ConnectionOrListField(Field):
if not field_object_type: if not field_object_type:
raise SkipField() raise SkipField()
if is_node(field_object_type): if is_node(field_object_type):
field = DjangoConnectionField(field_object_type) if field_object_type._meta.filter_fields:
field = DjangoFilterConnectionField(field_object_type)
else:
field = ConnectionField(field_object_type)
else: else:
field = Field(List(field_object_type)) field = Field(List(field_object_type))
field.contribute_to_class(self.object_type, self.attname) field.contribute_to_class(self.object_type, self.attname)

View File

@ -1,3 +1,11 @@
from graphene.contrib.django.utils import DJANGO_FILTER_INSTALLED
if not DJANGO_FILTER_INSTALLED:
raise Exception(
"Use of django filtering requires the django-filter package "
"be installed. You can do so using `pip install django-filter`"
)
from .fields import DjangoFilterConnectionField from .fields import DjangoFilterConnectionField
from .filterset import GrapheneFilterSet, GlobalIDFilter, GlobalIDMultipleChoiceFilter from .filterset import GrapheneFilterSet, GlobalIDFilter, GlobalIDMultipleChoiceFilter
from .resolvers import FilterConnectionResolver from .resolvers import FilterConnectionResolver

View File

@ -1,9 +1,9 @@
from graphene.contrib.django import DjangoConnectionField
from graphene.contrib.django.filter.resolvers import FilterConnectionResolver from graphene.contrib.django.filter.resolvers import FilterConnectionResolver
from graphene.contrib.django.utils import get_filtering_args_from_filterset from graphene.contrib.django.utils import get_filtering_args_from_filterset
from graphene.relay import ConnectionField
class DjangoFilterConnectionField(DjangoConnectionField): class DjangoFilterConnectionField(ConnectionField):
def __init__(self, type, on=None, fields=None, order_by=None, def __init__(self, type, on=None, fields=None, order_by=None,
extra_filter_meta=None, filterset_class=None, resolver=None, extra_filter_meta=None, filterset_class=None, resolver=None,

View File

@ -2,11 +2,12 @@ import six
from django.conf import settings from django.conf import settings
from django.db import models from django.db import models
from django.utils.text import capfirst from django.utils.text import capfirst
from django_filters import Filter, MultipleChoiceFilter
from django_filters.filterset import FilterSetMetaclass, FilterSet
from graphql_relay.node.node import from_global_id
from graphene.contrib.django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField from django_filters import Filter, MultipleChoiceFilter
from django_filters.filterset import FilterSet, FilterSetMetaclass
from graphene.contrib.django.forms import (GlobalIDFormField,
GlobalIDMultipleChoiceField)
from graphql_relay.node.node import from_global_id
class GlobalIDFilter(Filter): class GlobalIDFilter(Filter):
@ -25,7 +26,7 @@ class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids) return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids)
ORDER_BY_FIELD = getattr(settings, 'GRAPHENE_ORDER_BY_FIELD', 'order') ORDER_BY_FIELD = getattr(settings, 'GRAPHENE_ORDER_BY_FIELD', 'order_by')
GRAPHENE_FILTER_SET_OVERRIDES = { GRAPHENE_FILTER_SET_OVERRIDES = {
@ -45,6 +46,7 @@ GRAPHENE_FILTER_SET_OVERRIDES = {
class GrapheneFilterSetMetaclass(FilterSetMetaclass): class GrapheneFilterSetMetaclass(FilterSetMetaclass):
def __new__(cls, name, bases, attrs): def __new__(cls, name, bases, attrs):
new_class = super(GrapheneFilterSetMetaclass, cls).__new__(cls, name, bases, attrs) new_class = super(GrapheneFilterSetMetaclass, cls).__new__(cls, name, bases, attrs)
# Customise the filter_overrides for Graphene # Customise the filter_overrides for Graphene
@ -84,7 +86,6 @@ class GrapheneFilterSet(six.with_metaclass(GrapheneFilterSetMetaclass, GrapheneF
DjangoFilterConnectionField will wrap FilterSets with this class as DjangoFilterConnectionField will wrap FilterSets with this class as
necessary necessary
""" """
pass
def setup_filterset(filterset_class): def setup_filterset(filterset_class):

View File

@ -1,6 +1,7 @@
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from graphene.contrib.django.filter.filterset import setup_filterset, custom_filterset_factory from graphene.contrib.django.filter.filterset import (custom_filterset_factory,
setup_filterset)
from graphene.contrib.django.resolvers import BaseQuerySetConnectionResolver from graphene.contrib.django.resolvers import BaseQuerySetConnectionResolver
@ -10,8 +11,8 @@ class FilterConnectionResolver(BaseQuerySetConnectionResolver):
def __init__(self, node, on=None, filterset_class=None, def __init__(self, node, on=None, filterset_class=None,
fields=None, order_by=None, extra_filter_meta=None): fields=None, order_by=None, extra_filter_meta=None):
self.filterset_class = filterset_class self.filterset_class = filterset_class
self.fields = fields self.fields = fields or node._meta.filter_fields
self.order_by = order_by self.order_by = order_by or node._meta.filter_order_by
self.extra_filter_meta = extra_filter_meta or {} self.extra_filter_meta = extra_filter_meta or {}
self._filterset_class = None self._filterset_class = None
super(FilterConnectionResolver, self).__init__(node, on) super(FilterConnectionResolver, self).__init__(node, on)

View File

@ -1,7 +1,5 @@
import django_filters import django_filters
from graphene.contrib.django.tests.models import Article, Pet, Reporter
from graphene.contrib.django.tests.models import Reporter
from graphene.contrib.django.tests.models import Article, Pet
class ArticleFilter(django_filters.FilterSet): class ArticleFilter(django_filters.FilterSet):

View File

@ -1,39 +1,50 @@
import pytest import pytest
try: from graphene import ObjectType, Schema
from graphene.contrib.django import DjangoNode
from graphene.contrib.django.forms import (GlobalIDFormField,
GlobalIDMultipleChoiceField)
from graphene.contrib.django.tests.models import Article, Pet, Reporter
from graphene.contrib.django.utils import DJANGO_FILTER_INSTALLED
from graphene.relay import NodeField
pytestmark = []
if DJANGO_FILTER_INSTALLED:
import django_filters import django_filters
except ImportError:
pytestmark = pytest.mark.skipif(True, reason='django_filters not installed')
else:
from graphene.contrib.django.filter import (GlobalIDFilter, DjangoFilterConnectionField, from graphene.contrib.django.filter import (GlobalIDFilter, DjangoFilterConnectionField,
GlobalIDMultipleChoiceFilter) GlobalIDMultipleChoiceFilter)
from graphene.contrib.django.tests.filter.filters import ArticleFilter, PetFilter from graphene.contrib.django.filter.tests.filters import ArticleFilter, PetFilter
else:
pytestmark.append(pytest.mark.skipif(True, reason='django_filters not installed'))
from graphene.contrib.django import DjangoNode pytestmark.append(pytest.mark.django_db)
from graphene.contrib.django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField
from graphene.contrib.django.tests.models import Article, Pet, Reporter
class ArticleNode(DjangoNode): class ArticleNode(DjangoNode):
class Meta: class Meta:
model = Article model = Article
class ReporterNode(DjangoNode): class ReporterNode(DjangoNode):
class Meta: class Meta:
model = Reporter model = Reporter
class PetNode(DjangoNode): class PetNode(DjangoNode):
class Meta: class Meta:
model = Pet model = Pet
schema = Schema()
def assert_arguments(field, *arguments): def assert_arguments(field, *arguments):
ignore = ('after', 'before', 'first', 'last', 'order') ignore = ('after', 'before', 'first', 'last', 'orderBy')
actual = [ actual = [
name name
for name in field.arguments.arguments.keys() for name in schema.T(field.arguments)
if name not in ignore and not name.startswith('_') if name not in ignore and not name.startswith('_')
] ]
assert set(arguments) == set(actual), \ assert set(arguments) == set(actual), \
@ -44,12 +55,12 @@ def assert_arguments(field, *arguments):
def assert_orderable(field): def assert_orderable(field):
assert 'order' in field.arguments.arguments.keys(), \ assert 'orderBy' in schema.T(field.arguments), \
'Field cannot be ordered' 'Field cannot be ordered'
def assert_not_orderable(field): def assert_not_orderable(field):
assert 'order' in field.arguments.arguments.keys(), \ assert 'orderBy' not in schema.T(field.arguments), \
'Field can be ordered' 'Field can be ordered'
@ -103,11 +114,52 @@ def test_filter_explicit_filterset_not_orderable():
def test_filter_shortcut_filterset_extra_meta(): def test_filter_shortcut_filterset_extra_meta():
field = DjangoFilterConnectionField(ArticleNode, extra_filter_meta={ field = DjangoFilterConnectionField(ArticleNode, extra_filter_meta={
'ordering': True 'order_by': True
}) })
assert_orderable(field) assert_orderable(field)
def test_filter_filterset_information_on_meta():
class ReporterFilterNode(DjangoNode):
class Meta:
model = Reporter
filter_fields = ['first_name', 'articles']
filter_order_by = True
field = DjangoFilterConnectionField(ReporterFilterNode)
assert_arguments(field, 'firstName', 'articles')
assert_orderable(field)
def test_filter_filterset_information_on_meta_related():
class ReporterFilterNode(DjangoNode):
class Meta:
model = Reporter
filter_fields = ['first_name', 'articles']
filter_order_by = True
class ArticleFilterNode(DjangoNode):
class Meta:
model = Article
filter_fields = ['headline', 'reporter']
filter_order_by = True
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
all_articles = DjangoFilterConnectionField(ArticleFilterNode)
reporter = NodeField(ReporterFilterNode)
article = NodeField(ArticleFilterNode)
schema = Schema(query=Query)
schema.schema # Trigger the schema loading
articles_field = schema.get_type('ReporterFilterNode')._meta.fields_map['articles']
assert_arguments(articles_field, 'headline', 'reporter')
assert_orderable(articles_field)
def test_global_id_field_implicit(): def test_global_id_field_implicit():
field = DjangoFilterConnectionField(ArticleNode, fields=['id']) field = DjangoFilterConnectionField(ArticleNode, fields=['id'])
filterset_class = field.resolver_fn.get_filterset_class() filterset_class = field.resolver_fn.get_filterset_class()
@ -118,6 +170,7 @@ def test_global_id_field_implicit():
def test_global_id_field_explicit(): def test_global_id_field_explicit():
class ArticleIdFilter(django_filters.FilterSet): class ArticleIdFilter(django_filters.FilterSet):
class Meta: class Meta:
model = Article model = Article
fields = ['id'] fields = ['id']
@ -147,6 +200,7 @@ def test_global_id_multiple_field_implicit():
def test_global_id_multiple_field_explicit(): def test_global_id_multiple_field_explicit():
class ReporterPetsFilter(django_filters.FilterSet): class ReporterPetsFilter(django_filters.FilterSet):
class Meta: class Meta:
model = Reporter model = Reporter
fields = ['pets'] fields = ['pets']
@ -158,9 +212,6 @@ def test_global_id_multiple_field_explicit():
assert multiple_filter.field_class == GlobalIDMultipleChoiceField assert multiple_filter.field_class == GlobalIDMultipleChoiceField
@pytest.mark.skipif(True, reason="Trying to test GrapheneFilterSetMixin.filter_for_reverse_field"
"but django has not loaded the models, so the test fails as "
"reverse relations are not ready yet")
def test_global_id_multiple_field_implicit_reverse(): def test_global_id_multiple_field_implicit_reverse():
field = DjangoFilterConnectionField(ReporterNode, fields=['articles']) field = DjangoFilterConnectionField(ReporterNode, fields=['articles'])
filterset_class = field.resolver_fn.get_filterset_class() filterset_class = field.resolver_fn.get_filterset_class()
@ -169,11 +220,9 @@ def test_global_id_multiple_field_implicit_reverse():
assert multiple_filter.field_class == GlobalIDMultipleChoiceField assert multiple_filter.field_class == GlobalIDMultipleChoiceField
@pytest.mark.skipif(True, reason="Trying to test GrapheneFilterSetMixin.filter_for_reverse_field"
"but django has not loaded the models, so the test fails as "
"reverse relations are not ready yet")
def test_global_id_multiple_field_explicit_reverse(): def test_global_id_multiple_field_explicit_reverse():
class ReporterPetsFilter(django_filters.FilterSet): class ReporterPetsFilter(django_filters.FilterSet):
class Meta: class Meta:
model = Reporter model = Reporter
fields = ['articles'] fields = ['articles']

View File

@ -1,16 +1,16 @@
import pytest import pytest
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
try: from graphene.contrib.django.tests.models import Article, Reporter
import django_filters # noqa from graphene.contrib.django.tests.test_resolvers import (ArticleNode,
except ImportError: ReporterNode)
pytestmark = pytest.mark.skipif(True, reason='django_filters not installed') from graphene.contrib.django.utils import DJANGO_FILTER_INSTALLED
else:
from graphene.contrib.django.filter.resolvers import FilterConnectionResolver
from graphene.contrib.django.tests.filter.filters import ReporterFilter, ArticleFilter
from graphene.contrib.django.tests.models import Reporter, Article if DJANGO_FILTER_INSTALLED:
from graphene.contrib.django.tests.test_resolvers import ReporterNode, ArticleNode from graphene.contrib.django.filter.resolvers import FilterConnectionResolver
from graphene.contrib.django.filter.tests.filters import ArticleFilter, ReporterFilter
else:
pytestmark = pytest.mark.skipif(True, reason='django_filters not installed')
def test_filter_get_filterset_class_explicit(): def test_filter_get_filterset_class_explicit():
@ -64,7 +64,7 @@ def test_filter_order():
resolver = FilterConnectionResolver(ArticleNode, resolver = FilterConnectionResolver(ArticleNode,
filterset_class=ArticleFilter) filterset_class=ArticleFilter)
resolved = resolver(inst=article, args={ resolved = resolver(inst=article, args={
'order': 'headline' 'order_by': 'headline'
}, info=None) }, info=None)
assert 'WHERE' not in str(resolved.query) assert 'WHERE' not in str(resolved.query)
assert 'ORDER BY' in str(resolved.query) assert 'ORDER BY' in str(resolved.query)
@ -76,7 +76,7 @@ def test_filter_order_not_available():
resolver = FilterConnectionResolver(ReporterNode, resolver = FilterConnectionResolver(ReporterNode,
filterset_class=ReporterFilter) filterset_class=ReporterFilter)
resolved = resolver(inst=reporter, args={ resolved = resolver(inst=reporter, args={
'order': 'last_name' 'order_by': 'last_name'
}, info=None) }, info=None)
assert 'WHERE' not in str(resolved.query) assert 'WHERE' not in str(resolved.query)
assert 'ORDER BY' not in str(resolved.query) assert 'ORDER BY' not in str(resolved.query)

View File

@ -1,11 +1,14 @@
from django import forms from django import forms
from django.forms.fields import BaseTemporalField from django.forms.fields import BaseTemporalField
from singledispatch import singledispatch
from graphene import String, Int, Boolean, Float, ID from graphene import ID, Boolean, Float, Int, String
from graphene.contrib.django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField from graphene.contrib.django.forms import (GlobalIDFormField,
GlobalIDMultipleChoiceField)
from graphene.contrib.django.utils import import_single_dispatch
from graphene.core.types.definitions import List from graphene.core.types.definitions import List
singledispatch = import_single_dispatch()
try: try:
UUIDField = forms.UUIDField UUIDField = forms.UUIDField
except AttributeError: except AttributeError:
@ -60,11 +63,11 @@ def convert_form_field_to_float(field):
@convert_form_field.register(forms.ModelMultipleChoiceField) @convert_form_field.register(forms.ModelMultipleChoiceField)
@convert_form_field.register(GlobalIDMultipleChoiceField) @convert_form_field.register(GlobalIDMultipleChoiceField)
def convert_form_field_to_list_or_connection(field): def convert_form_field_to_list(field):
return List(ID()) return List(ID())
@convert_form_field.register(forms.ModelChoiceField) @convert_form_field.register(forms.ModelChoiceField)
@convert_form_field.register(GlobalIDFormField) @convert_form_field.register(GlobalIDFormField)
def convert_form_field_to_djangomodel(field): def convert_form_field_to_id(field):
return ID() return ID()

View File

@ -1,7 +1,7 @@
import binascii import binascii
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.forms import Field, IntegerField, CharField, MultipleChoiceField from django.forms import CharField, Field, IntegerField, MultipleChoiceField
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from graphql_relay import from_global_id from graphql_relay import from_global_id

View File

@ -0,0 +1,38 @@
from django.core.management.base import BaseCommand, CommandError
import importlib
import json
class Command(BaseCommand):
help = 'Dump Graphene schema JSON to file'
can_import_settings = True
def add_arguments(self, parser):
from django.conf import settings
parser.add_argument(
'--schema',
type=str,
dest='schema',
default=getattr(settings, 'GRAPHENE_SCHEMA', ''),
help='Django app containing schema to dump, e.g. myproject.core.schema')
parser.add_argument(
'--out',
type=str,
dest='out',
default=getattr(settings, 'GRAPHENE_SCHEMA_OUTPUT', 'schema.json'),
help='Output file (default: schema.json)')
def handle(self, *args, **options):
schema_module = options['schema']
if schema_module == '':
raise CommandError('Specify schema on GRAPHENE_SCHEMA setting or by using --schema')
i = importlib.import_module(schema_module)
schema_dict = {'data': i.schema.introspect()}
with open(options['out'], 'w') as outfile:
json.dump(schema_dict, outfile)
self.stdout.write(self.style.SUCCESS('Successfully dumped GraphQL schema to %s' % options['out']))

View File

@ -1,9 +1,13 @@
from ...core.classtypes.objecttype import ObjectTypeOptions from ...core.classtypes.objecttype import ObjectTypeOptions
from ...relay.types import Node from ...relay.types import Node
from ...relay.utils import is_node from ...relay.utils import is_node
from .utils import DJANGO_FILTER_INSTALLED
VALID_ATTRS = ('model', 'only_fields', 'exclude_fields') VALID_ATTRS = ('model', 'only_fields', 'exclude_fields')
if DJANGO_FILTER_INSTALLED:
VALID_ATTRS += ('filter_fields', 'filter_order_by')
class DjangoOptions(ObjectTypeOptions): class DjangoOptions(ObjectTypeOptions):
@ -13,6 +17,8 @@ class DjangoOptions(ObjectTypeOptions):
self.valid_attrs += VALID_ATTRS self.valid_attrs += VALID_ATTRS
self.only_fields = None self.only_fields = None
self.exclude_fields = [] self.exclude_fields = []
self.filter_fields = None
self.filter_order_by = None
def contribute_to_class(self, cls, name): def contribute_to_class(self, cls, name):
super(DjangoOptions, self).contribute_to_class(cls, name) super(DjangoOptions, self).contribute_to_class(cls, name)

View File

@ -36,8 +36,8 @@ class SimpleQuerySetConnectionResolver(BaseQuerySetConnectionResolver):
return query return query
def get_filter_kwargs(self): def get_filter_kwargs(self):
ignore = ['first', 'last', 'before', 'after', 'order'] ignore = ['first', 'last', 'before', 'after', 'order_by']
return {k: v for k, v in self.args.items() if k not in ignore} return {k: v for k, v in self.args.items() if k not in ignore}
def get_order(self): def get_order(self):
return self.args.get('order', None) return self.args.get('order_by', None)

View File

@ -16,9 +16,6 @@ class Reporter(models.Model):
def __str__(self): # __unicode__ on Python 2 def __str__(self): # __unicode__ on Python 2
return "%s %s" % (self.first_name, self.last_name) return "%s %s" % (self.first_name, self.last_name)
class Meta:
app_label = 'contrib_django'
class Article(models.Model): class Article(models.Model):
headline = models.CharField(max_length=100) headline = models.CharField(max_length=100)
@ -30,4 +27,3 @@ class Article(models.Model):
class Meta: class Meta:
ordering = ('headline',) ordering = ('headline',)
app_label = 'contrib_django'

View File

@ -9,8 +9,8 @@ from graphene.contrib.django.fields import (ConnectionOrListField,
from .models import Article, Reporter from .models import Article, Reporter
def assert_conversion(django_field, graphene_field, *args): def assert_conversion(django_field, graphene_field, *args, **kwargs):
field = django_field(*args, help_text='Custom Help Text') field = django_field(help_text='Custom Help Text', *args, **kwargs)
graphene_type = convert_django_field(field) graphene_type = convert_django_field(field)
assert isinstance(graphene_type, graphene_field) assert isinstance(graphene_type, graphene_field)
field = graphene_type.as_field() field = graphene_type.as_field()
@ -48,8 +48,12 @@ def test_should_url_convert_string():
assert_conversion(models.URLField, graphene.String) assert_conversion(models.URLField, graphene.String)
def test_should_ipaddress_convert_string():
assert_conversion(models.GenericIPAddressField, graphene.String)
def test_should_auto_convert_id(): def test_should_auto_convert_id():
assert_conversion(models.AutoField, graphene.ID) assert_conversion(models.AutoField, graphene.ID, primary_key=True)
def test_should_positive_integer_convert_int(): def test_should_positive_integer_convert_int():
@ -94,7 +98,10 @@ def test_should_manytomany_convert_connectionorlist():
def test_should_manytoone_convert_connectionorlist(): def test_should_manytoone_convert_connectionorlist():
graphene_type = convert_django_field(Reporter.articles.related) # Django 1.9 uses 'rel', <1.9 uses 'related
related = getattr(Reporter.articles, 'rel', None) or \
getattr(Reporter.articles, 'related')
graphene_type = convert_django_field(related)
assert isinstance(graphene_type, ConnectionOrListField) assert isinstance(graphene_type, ConnectionOrListField)
assert isinstance(graphene_type.type, DjangoModelField) assert isinstance(graphene_type.type, DjangoModelField)
assert graphene_type.type.model == Article assert graphene_type.type.model == Article

View File

@ -1,10 +1,9 @@
from django import forms from django import forms
from graphene.core.types import List, ID
from py.test import raises from py.test import raises
import graphene import graphene
from graphene.contrib.django.form_converter import convert_form_field from graphene.contrib.django.form_converter import convert_form_field
from graphene.core.types import ID, List
from .models import Reporter from .models import Reporter

View File

@ -1,3 +1,4 @@
import pytest
from py.test import raises from py.test import raises
import graphene import graphene
@ -6,6 +7,8 @@ from graphene.contrib.django import DjangoNode, DjangoObjectType
from .models import Article, Reporter from .models import Article, Reporter
pytestmark = pytest.mark.django_db
def test_should_query_only_fields(): def test_should_query_only_fields():
with raises(Exception): with raises(Exception):

View File

@ -3,15 +3,17 @@ from django.db.models.query import QuerySet
from graphene.contrib.django import DjangoNode from graphene.contrib.django import DjangoNode
from graphene.contrib.django.resolvers import SimpleQuerySetConnectionResolver from graphene.contrib.django.resolvers import SimpleQuerySetConnectionResolver
from graphene.contrib.django.tests.models import Reporter, Article from graphene.contrib.django.tests.models import Article, Reporter
class ReporterNode(DjangoNode): class ReporterNode(DjangoNode):
class Meta: class Meta:
model = Reporter model = Reporter
class ArticleNode(DjangoNode): class ArticleNode(DjangoNode):
class Meta: class Meta:
model = Article model = Article
@ -34,7 +36,7 @@ def test_simple_get_manager_all():
reporter = Reporter(id=1, first_name='Cookie Monster') reporter = Reporter(id=1, first_name='Cookie Monster')
resolver = SimpleQuerySetConnectionResolver(ReporterNode) resolver = SimpleQuerySetConnectionResolver(ReporterNode)
resolver(inst=reporter, args={}, info=None) resolver(inst=reporter, args={}, info=None)
assert type(resolver.get_manager()) == Manager, 'Resolver did not return a Manager' assert isinstance(resolver.get_manager(), Manager), 'Resolver did not return a Manager'
def test_simple_filter(): def test_simple_filter():
@ -51,7 +53,7 @@ def test_simple_order():
reporter = Reporter(id=1, first_name='Cookie Monster') reporter = Reporter(id=1, first_name='Cookie Monster')
resolver = SimpleQuerySetConnectionResolver(ReporterNode) resolver = SimpleQuerySetConnectionResolver(ReporterNode)
resolved = resolver(inst=reporter, args={ resolved = resolver(inst=reporter, args={
'order': 'last_name' 'order_by': 'last_name'
}, info=None) }, info=None)
assert 'WHERE' not in str(resolved.query) assert 'WHERE' not in str(resolved.query)
assert 'ORDER BY' in str(resolved.query) assert 'ORDER BY' in str(resolved.query)

View File

@ -29,7 +29,15 @@ class Human(DjangoNode):
def get_node(self, id): def get_node(self, id):
pass pass
schema = Schema(query=Human)
class Query(graphene.ObjectType):
human = graphene.Field(Human)
def resolve_human(self, args, info):
return Human()
schema = Schema(query=Query)
urlpatterns = [ urlpatterns = [

View File

@ -7,45 +7,51 @@ def format_response(response):
def test_client_get_good_query(settings, client): def test_client_get_good_query(settings, client):
settings.ROOT_URLCONF = 'graphene.contrib.django.tests.test_urls' settings.ROOT_URLCONF = 'graphene.contrib.django.tests.test_urls'
response = client.get('/graphql', {'query': '{ headline }'}) response = client.get('/graphql', {'query': '{ human { headline } }'})
json_response = format_response(response) json_response = format_response(response)
expected_json = { expected_json = {
'data': { 'data': {
'human': {
'headline': None 'headline': None
} }
} }
}
assert json_response == expected_json assert json_response == expected_json
def test_client_get_good_query_with_raise(settings, client): def test_client_get_good_query_with_raise(settings, client):
settings.ROOT_URLCONF = 'graphene.contrib.django.tests.test_urls' settings.ROOT_URLCONF = 'graphene.contrib.django.tests.test_urls'
response = client.get('/graphql', {'query': '{ raises }'}) response = client.get('/graphql', {'query': '{ human { raises } }'})
json_response = format_response(response) json_response = format_response(response)
assert json_response['errors'][0]['message'] == 'This field should raise exception' assert json_response['errors'][0]['message'] == 'This field should raise exception'
assert json_response['data']['raises'] is None assert json_response['data']['human']['raises'] is None
def test_client_post_good_query_json(settings, client): def test_client_post_good_query_json(settings, client):
settings.ROOT_URLCONF = 'graphene.contrib.django.tests.test_urls' settings.ROOT_URLCONF = 'graphene.contrib.django.tests.test_urls'
response = client.post( response = client.post(
'/graphql', json.dumps({'query': '{ headline }'}), 'application/json') '/graphql', json.dumps({'query': '{ human { headline } }'}), 'application/json')
json_response = format_response(response) json_response = format_response(response)
expected_json = { expected_json = {
'data': { 'data': {
'human': {
'headline': None 'headline': None
} }
} }
}
assert json_response == expected_json assert json_response == expected_json
def test_client_post_good_query_graphql(settings, client): def test_client_post_good_query_graphql(settings, client):
settings.ROOT_URLCONF = 'graphene.contrib.django.tests.test_urls' settings.ROOT_URLCONF = 'graphene.contrib.django.tests.test_urls'
response = client.post( response = client.post(
'/graphql', '{ headline }', 'application/graphql') '/graphql', '{ human { headline } }', 'application/graphql')
json_response = format_response(response) json_response = format_response(response)
expected_json = { expected_json = {
'data': { 'data': {
'human': {
'headline': None 'headline': None
} }
} }
}
assert json_response == expected_json assert json_response == expected_json

View File

@ -1,9 +1,18 @@
import six import six
from django.db import models from django.db import models
from django.db.models.manager import Manager from django.db.models.manager import Manager
from django.db.models.query import QuerySet
from graphene import Argument, String from graphene import Argument, String
from graphene.contrib.django.form_converter import convert_form_field from graphene.utils import LazyList
from .compat import RelatedObject
try:
import django_filters # noqa
DJANGO_FILTER_INSTALLED = True
except ImportError:
DJANGO_FILTER_INSTALLED = False
def get_type_for_model(schema, model): def get_type_for_model(schema, model):
@ -18,14 +27,32 @@ def get_type_for_model(schema, model):
def get_reverse_fields(model): def get_reverse_fields(model):
for name, attr in model.__dict__.items(): for name, attr in model.__dict__.items():
related = getattr(attr, 'related', None) # Django =>1.9 uses 'rel', django <1.9 uses 'related'
if isinstance(related, models.ManyToOneRel): related = getattr(attr, 'rel', None) or \
getattr(attr, 'related', None)
if isinstance(related, RelatedObject):
# Hack for making it compatible with Django 1.6
new_related = RelatedObject(related.parent_model, related.model, related.field)
new_related.name = name
yield new_related
elif isinstance(related, models.ManyToOneRel):
yield related yield related
class WrappedQueryset(LazyList):
def __len__(self):
# Dont calculate the length using len(queryset), as this will
# evaluate the whole queryset and return it's length.
# Use .count() instead
return self._origin.count()
def maybe_queryset(value): def maybe_queryset(value):
if isinstance(value, Manager): if isinstance(value, Manager):
value = value.get_queryset() value = value.get_queryset()
if isinstance(value, QuerySet):
return WrappedQueryset(value)
return value return value
@ -34,6 +61,8 @@ def get_filtering_args_from_filterset(filterset_class, type):
a Graphene Field. These arguments will be available to a Graphene Field. These arguments will be available to
filter against in the GraphQL filter against in the GraphQL
""" """
from graphene.contrib.django.form_converter import convert_form_field
args = {} args = {}
for name, filter_field in six.iteritems(filterset_class.base_filters): for name, filter_field in six.iteritems(filterset_class.base_filters):
field_type = Argument(convert_form_field(filter_field.field)) field_type = Argument(convert_form_field(filter_field.field))
@ -42,5 +71,36 @@ def get_filtering_args_from_filterset(filterset_class, type):
args[name] = field_type args[name] = field_type
# Also add the 'order_by' field # Also add the 'order_by' field
args[filterset_class.order_by_field] = Argument(String) if filterset_class._meta.order_by:
args[filterset_class.order_by_field] = Argument(String())
return args return args
def get_related_model(field):
if hasattr(field, 'rel'):
# Django 1.6, 1.7
return field.rel.to
return field.related_model
def import_single_dispatch():
try:
from functools import singledispatch
except ImportError:
singledispatch = None
if not singledispatch:
try:
from singledispatch import singledispatch
except ImportError:
pass
if not singledispatch:
raise Exception(
"It seems your python version does not include "
"functools.singledispatch. Please install the 'singledispatch' "
"package. More information here: "
"https://pypi.python.org/pypi/singledispatch"
)
return singledispatch

View File

@ -12,5 +12,5 @@ class GraphQLView(BaseGraphQLView):
**kwargs **kwargs
) )
def get_root_value(self, request): def execute(self, *args, **kwargs):
return self.graphene_schema.query(super(GraphQLView, self).get_root_value(request)) return self.graphene_schema.execute(*args, **kwargs)

View File

@ -0,0 +1,46 @@
from .schema import (
Schema
)
from .classtypes import (
ObjectType,
InputObjectType,
Interface,
Mutation,
Scalar
)
from .types import (
InstanceType,
LazyType,
Argument,
Field,
InputField,
String,
Int,
Boolean,
ID,
Float,
List,
NonNull
)
__all__ = [
'Argument',
'String',
'Int',
'Boolean',
'Float',
'ID',
'List',
'NonNull',
'Schema',
'InstanceType',
'LazyType',
'ObjectType',
'InputObjectType',
'Interface',
'Mutation',
'Scalar',
'Field',
'InputField']

View File

@ -1,10 +1,10 @@
import copy import copy
import inspect import inspect
from collections import OrderedDict from collections import OrderedDict
from functools import partial
import six import six
from ..exceptions import SkipField
from .options import Options from .options import Options
@ -48,8 +48,8 @@ class ClassTypeMeta(type):
if not cls._meta.abstract: if not cls._meta.abstract:
from ..types import List, NonNull from ..types import List, NonNull
setattr(cls, 'NonNull', NonNull(cls)) setattr(cls, 'NonNull', partial(NonNull, cls))
setattr(cls, 'List', List(cls)) setattr(cls, 'List', partial(List, cls))
return cls return cls
@ -81,13 +81,18 @@ class FieldsOptions(Options):
def fields_map(self): def fields_map(self):
return OrderedDict([(f.attname, f) for f in self.fields]) return OrderedDict([(f.attname, f) for f in self.fields])
@property
def fields_group_type(self):
from ..types.field import FieldsGroupType
return FieldsGroupType(*self.local_fields)
class FieldsClassTypeMeta(ClassTypeMeta): class FieldsClassTypeMeta(ClassTypeMeta):
options_class = FieldsOptions options_class = FieldsOptions
def extend_fields(cls, bases): def extend_fields(cls, bases):
new_fields = cls._meta.local_fields new_fields = cls._meta.local_fields
field_names = {f.name: f for f in new_fields} field_names = {f.attname: f for f in new_fields}
for base in bases: for base in bases:
if not isinstance(base, FieldsClassTypeMeta): if not isinstance(base, FieldsClassTypeMeta):
@ -95,17 +100,17 @@ class FieldsClassTypeMeta(ClassTypeMeta):
parent_fields = base._meta.local_fields parent_fields = base._meta.local_fields
for field in parent_fields: for field in parent_fields:
if field.name in field_names and field.type.__class__ != field_names[ if field.attname in field_names and field.type.__class__ != field_names[
field.name].type.__class__: field.attname].type.__class__:
raise Exception( raise Exception(
'Local field %r in class %r (%r) clashes ' 'Local field %r in class %r (%r) clashes '
'with field with similar name from ' 'with field with similar name from '
'Interface %s (%r)' % ( 'Interface %s (%r)' % (
field.name, field.attname,
cls.__name__, cls.__name__,
field.__class__, field.__class__,
base.__name__, base.__name__,
field_names[field.name].__class__) field_names[field.attname].__class__)
) )
new_field = copy.copy(field) new_field = copy.copy(field)
cls.add_to_class(field.attname, new_field) cls.add_to_class(field.attname, new_field)
@ -123,11 +128,4 @@ class FieldsClassType(six.with_metaclass(FieldsClassTypeMeta, ClassType)):
@classmethod @classmethod
def fields_internal_types(cls, schema): def fields_internal_types(cls, schema):
fields = [] return schema.T(cls._meta.fields_group_type)
for field in cls._meta.fields:
try:
fields.append((field.name, schema.T(field)))
except SkipField:
continue
return OrderedDict(fields)

View File

@ -23,15 +23,26 @@ def test_classtype_advanced():
def test_classtype_definition_list(): def test_classtype_definition_list():
class Character(ClassType): class Character(ClassType):
'''Character description''' '''Character description'''
assert isinstance(Character.List, List) assert isinstance(Character.List(), List)
assert Character.List.of_type == Character assert Character.List().of_type == Character
def test_classtype_definition_nonnull(): def test_classtype_definition_nonnull():
class Character(ClassType): class Character(ClassType):
'''Character description''' '''Character description'''
assert isinstance(Character.NonNull, NonNull) assert isinstance(Character.NonNull(), NonNull)
assert Character.NonNull.of_type == Character assert Character.NonNull().of_type == Character
def test_fieldsclasstype_definition_order():
class Character(ClassType):
'''Character description'''
class Query(FieldsClassType):
name = String()
char = Character.NonNull()
assert list(Query._meta.fields_map.keys()) == ['name', 'char']
def test_fieldsclasstype(): def test_fieldsclasstype():

View File

@ -24,4 +24,4 @@ def test_mutation():
assert list(object_type.get_fields().keys()) == ['name'] assert list(object_type.get_fields().keys()) == ['name']
assert MyMutation._meta.fields_map['name'].object_type == MyMutation assert MyMutation._meta.fields_map['name'].object_type == MyMutation
assert isinstance(MyMutation.arguments, ArgumentsGroup) assert isinstance(MyMutation.arguments, ArgumentsGroup)
assert 'argName' in MyMutation.arguments assert 'argName' in schema.T(MyMutation.arguments)

View File

@ -10,8 +10,9 @@ from graphql.core.utils.schema_printer import print_schema
from graphene import signals from graphene import signals
from ..plugins import CamelCase, PluginManager
from .classtypes.base import ClassType from .classtypes.base import ClassType
from .types.base import BaseType from .types.base import InstanceType
class GraphQLSchema(_GraphQLSchema): class GraphQLSchema(_GraphQLSchema):
@ -25,7 +26,7 @@ class Schema(object):
_executor = None _executor = None
def __init__(self, query=None, mutation=None, subscription=None, def __init__(self, query=None, mutation=None, subscription=None,
name='Schema', executor=None): name='Schema', executor=None, plugins=None, auto_camelcase=True):
self._types_names = {} self._types_names = {}
self._types = {} self._types = {}
self.mutation = mutation self.mutation = mutation
@ -33,27 +34,34 @@ class Schema(object):
self.subscription = subscription self.subscription = subscription
self.name = name self.name = name
self.executor = executor self.executor = executor
plugins = plugins or []
if auto_camelcase:
plugins.append(CamelCase())
self.plugins = PluginManager(self, plugins)
signals.init_schema.send(self) signals.init_schema.send(self)
def __repr__(self): def __repr__(self):
return '<Schema: %s (%s)>' % (str(self.name), hash(self)) return '<Schema: %s (%s)>' % (str(self.name), hash(self))
def T(self, object_type): def __getattr__(self, name):
if not object_type: if name in self.plugins:
return getattr(self.plugins, name)
return super(Schema, self).__getattr__(name)
def T(self, _type):
if not _type:
return return
if inspect.isclass(object_type) and issubclass( is_classtype = inspect.isclass(_type) and issubclass(_type, ClassType)
object_type, (BaseType, ClassType)) or isinstance( is_instancetype = isinstance(_type, InstanceType)
object_type, BaseType): if is_classtype or is_instancetype:
if object_type not in self._types: if _type not in self._types:
internal_type = object_type.internal_type(self) internal_type = _type.internal_type(self)
self._types[object_type] = internal_type self._types[_type] = internal_type
is_objecttype = inspect.isclass( if is_classtype:
object_type) and issubclass(object_type, ClassType) self.register(_type)
if is_objecttype: return self._types[_type]
self.register(object_type)
return self._types[object_type]
else: else:
return object_type return _type
@property @property
def executor(self): def executor(self):
@ -76,9 +84,9 @@ class Schema(object):
mutation=self.T(self.mutation), mutation=self.T(self.mutation),
subscription=self.T(self.subscription)) subscription=self.T(self.subscription))
def register(self, object_type): def register(self, object_type, force=False):
type_name = object_type._meta.type_name type_name = object_type._meta.type_name
registered_object_type = self._types_names.get(type_name, None) registered_object_type = not force and self._types_names.get(type_name, None)
if registered_object_type: if registered_object_type:
assert registered_object_type == object_type, 'Type {} already registered with other object type'.format( assert registered_object_type == object_type, 'Type {} already registered with other object type'.format(
type_name) type_name)
@ -110,17 +118,10 @@ class Schema(object):
def types(self): def types(self):
return self._types_names return self._types_names
def execute(self, request='', root=None, vars=None, def execute(self, request='', root=None, args=None, **kwargs):
operation_name=None, **kwargs): kwargs = dict(kwargs, request=request, root=root, args=args, schema=self.schema)
root = root or object() with self.plugins.context_execution(**kwargs) as execute_kwargs:
return self.executor.execute( return self.executor.execute(**execute_kwargs)
self.schema,
request,
root=root,
args=vars,
operation_name=operation_name,
**kwargs
)
def introspect(self): def introspect(self):
return self.execute(introspection_query).data return self.execute(introspection_query).data

View File

@ -34,10 +34,11 @@ def test_field_type():
assert schema.T(f).type == GraphQLString assert schema.T(f).type == GraphQLString
def test_field_name_automatic_camelcase(): def test_field_name():
f = Field(GraphQLString) f = Field(GraphQLString)
f.contribute_to_class(MyOt, 'field_name') f.contribute_to_class(MyOt, 'field_name')
assert f.name == 'fieldName' assert f.name is None
assert f.attname == 'field_name'
def test_field_name_use_name_if_exists(): def test_field_name_use_name_if_exists():

View File

@ -1,14 +1,14 @@
from .base import BaseType, LazyType, OrderedType from .base import InstanceType, LazyType, OrderedType
from .argument import Argument, ArgumentsGroup, to_arguments from .argument import Argument, ArgumentsGroup, to_arguments
from .definitions import List, NonNull from .definitions import List, NonNull
# Compatibility import # Compatibility import
from .objecttype import Interface, ObjectType, Mutation, InputObjectType from .objecttype import Interface, ObjectType, Mutation, InputObjectType
from .scalars import String, ID, Boolean, Int, Float, Scalar from .scalars import String, ID, Boolean, Int, Float
from .field import Field, InputField from .field import Field, InputField
__all__ = [ __all__ = [
'BaseType', 'InstanceType',
'LazyType', 'LazyType',
'OrderedType', 'OrderedType',
'Argument', 'Argument',
@ -26,5 +26,4 @@ __all__ = [
'ID', 'ID',
'Boolean', 'Boolean',
'Int', 'Int',
'Float', 'Float']
'Scalar']

View File

@ -1,19 +1,17 @@
from collections import OrderedDict
from functools import wraps from functools import wraps
from itertools import chain from itertools import chain
from graphql.core.type import GraphQLArgument from graphql.core.type import GraphQLArgument
from ...utils import ProxySnakeDict, to_camel_case from ...utils import ProxySnakeDict
from .base import ArgumentType, BaseType, OrderedType from .base import ArgumentType, GroupNamedType, NamedType, OrderedType
class Argument(OrderedType): class Argument(NamedType, OrderedType):
def __init__(self, type, description=None, default=None, def __init__(self, type, description=None, default=None,
name=None, _creation_counter=None): name=None, _creation_counter=None):
super(Argument, self).__init__(_creation_counter=_creation_counter) super(Argument, self).__init__(name=name, _creation_counter=_creation_counter)
self.name = name
self.type = type self.type = type
self.description = description self.description = description
self.default = default self.default = default
@ -27,47 +25,32 @@ class Argument(OrderedType):
return self.name return self.name
class ArgumentsGroup(BaseType): class ArgumentsGroup(GroupNamedType):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
arguments = to_arguments(*args, **kwargs) arguments = to_arguments(*args, **kwargs)
self.arguments = OrderedDict([(arg.name, arg) for arg in arguments]) super(ArgumentsGroup, self).__init__(*arguments)
def internal_type(self, schema):
return OrderedDict([(arg.name, schema.T(arg))
for arg in self.arguments.values()])
def __len__(self):
return len(self.arguments)
def __iter__(self):
return iter(self.arguments)
def __contains__(self, *args):
return self.arguments.__contains__(*args)
def __getitem__(self, *args):
return self.arguments.__getitem__(*args)
def to_arguments(*args, **kwargs): def to_arguments(*args, **kwargs):
arguments = {} arguments = {}
iter_arguments = chain(kwargs.items(), [(None, a) for a in args]) iter_arguments = chain(kwargs.items(), [(None, a) for a in args])
for name, arg in iter_arguments: for default_name, arg in iter_arguments:
if isinstance(arg, Argument): if isinstance(arg, Argument):
argument = arg argument = arg
elif isinstance(arg, ArgumentType): elif isinstance(arg, ArgumentType):
argument = arg.as_argument() argument = arg.as_argument()
else: else:
raise ValueError('Unknown argument %s=%r' % (name, arg)) raise ValueError('Unknown argument %s=%r' % (default_name, arg))
if name: if default_name:
argument.name = to_camel_case(name) argument.default_name = default_name
assert argument.name, 'Argument in field must have a name'
assert argument.name not in arguments, 'Found more than one Argument with same name {}'.format( name = argument.name or argument.default_name
argument.name) assert name, 'Argument in field must have a name'
arguments[argument.name] = argument assert name not in arguments, 'Found more than one Argument with same name {}'.format(name)
arguments[name] = argument
return sorted(arguments.values()) return sorted(arguments.values())

View File

@ -1,16 +1,16 @@
from functools import total_ordering from collections import OrderedDict
from functools import partial, total_ordering
import six import six
class BaseType(object): class InstanceType(object):
@classmethod def internal_type(self, schema):
def internal_type(cls, schema): raise NotImplementedError("internal_type for type {} is not implemented".format(self.__class__.__name__))
return getattr(cls, 'T', None)
class MountType(BaseType): class MountType(InstanceType):
parent = None parent = None
def mount(self, cls): def mount(self, cls):
@ -126,3 +126,39 @@ class FieldType(MirroredType):
class MountedType(FieldType, ArgumentType): class MountedType(FieldType, ArgumentType):
pass pass
class NamedType(InstanceType):
def __init__(self, name=None, default_name=None, *args, **kwargs):
self.name = name
self.default_name = None
super(NamedType, self).__init__(*args, **kwargs)
class GroupNamedType(InstanceType):
def __init__(self, *types):
self.types = types
def get_named_type(self, schema, type):
name = type.name or schema.get_default_namedtype_name(type.default_name)
return name, schema.T(type)
def iter_types(self, schema):
return map(partial(self.get_named_type, schema), self.types)
def internal_type(self, schema):
return OrderedDict(self.iter_types(schema))
def __len__(self):
return len(self.types)
def __iter__(self):
return iter(self.types)
def __contains__(self, *args):
return self.types.__contains__(*args)
def __getitem__(self, *args):
return self.types.__getitem__(*args)

View File

@ -4,23 +4,26 @@ from functools import wraps
import six import six
from graphql.core.type import GraphQLField, GraphQLInputObjectField from graphql.core.type import GraphQLField, GraphQLInputObjectField
from ...utils import to_camel_case
from ..classtypes.base import FieldsClassType from ..classtypes.base import FieldsClassType
from ..classtypes.inputobjecttype import InputObjectType from ..classtypes.inputobjecttype import InputObjectType
from ..classtypes.mutation import Mutation from ..classtypes.mutation import Mutation
from .argument import ArgumentsGroup, snake_case_args from ..exceptions import SkipField
from .base import LazyType, MountType, OrderedType from .argument import Argument, ArgumentsGroup, snake_case_args
from .base import (ArgumentType, GroupNamedType, LazyType, MountType,
NamedType, OrderedType)
from .definitions import NonNull from .definitions import NonNull
class Field(OrderedType): class Field(NamedType, OrderedType):
def __init__( def __init__(
self, type, description=None, args=None, name=None, resolver=None, self, type, description=None, args=None, name=None, resolver=None,
required=False, default=None, *args_list, **kwargs): required=False, default=None, *args_list, **kwargs):
_creation_counter = kwargs.pop('_creation_counter', None) _creation_counter = kwargs.pop('_creation_counter', None)
super(Field, self).__init__(_creation_counter=_creation_counter) if isinstance(name, (Argument, ArgumentType)):
self.name = name kwargs['name'] = name
name = None
super(Field, self).__init__(name=name, _creation_counter=_creation_counter)
if isinstance(type, six.string_types): if isinstance(type, six.string_types):
type = LazyType(type) type = LazyType(type)
self.required = required self.required = required
@ -36,9 +39,8 @@ class Field(OrderedType):
assert issubclass( assert issubclass(
cls, (FieldsClassType)), 'Field {} cannot be mounted in {}'.format( cls, (FieldsClassType)), 'Field {} cannot be mounted in {}'.format(
self, cls) self, cls)
if not self.name:
self.name = to_camel_case(attname)
self.attname = attname self.attname = attname
self.default_name = attname
self.object_type = cls self.object_type = cls
self.mount(cls) self.mount(cls)
if isinstance(self.type, MountType): if isinstance(self.type, MountType):
@ -63,6 +65,9 @@ class Field(OrderedType):
return NonNull(self.type) return NonNull(self.type)
return self.type return self.type
def decorate_resolver(self, resolver):
return snake_case_args(resolver)
def internal_type(self, schema): def internal_type(self, schema):
resolver = self.resolver resolver = self.resolver
description = self.description description = self.description
@ -85,9 +90,9 @@ class Field(OrderedType):
return my_resolver(instance, args, info) return my_resolver(instance, args, info)
resolver = wrapped_func resolver = wrapped_func
resolver = snake_case_args(resolver)
assert type, 'Internal type for field %s is None' % str(self) assert type, 'Internal type for field %s is None' % str(self)
return GraphQLField(type, args=schema.T(arguments), resolver=resolver, return GraphQLField(type, args=schema.T(arguments),
resolver=self.decorate_resolver(resolver),
description=description,) description=description,)
def __repr__(self): def __repr__(self):
@ -114,12 +119,11 @@ class Field(OrderedType):
return hash((self.creation_counter, self.object_type)) return hash((self.creation_counter, self.object_type))
class InputField(OrderedType): class InputField(NamedType, OrderedType):
def __init__(self, type, description=None, default=None, def __init__(self, type, description=None, default=None,
name=None, _creation_counter=None, required=False): name=None, _creation_counter=None, required=False):
super(InputField, self).__init__(_creation_counter=_creation_counter) super(InputField, self).__init__(_creation_counter=_creation_counter)
self.name = name
if required: if required:
type = NonNull(type) type = NonNull(type)
self.type = type self.type = type
@ -130,9 +134,8 @@ class InputField(OrderedType):
assert issubclass( assert issubclass(
cls, (InputObjectType)), 'InputField {} cannot be mounted in {}'.format( cls, (InputObjectType)), 'InputField {} cannot be mounted in {}'.format(
self, cls) self, cls)
if not self.name:
self.name = to_camel_case(attname)
self.attname = attname self.attname = attname
self.default_name = attname
self.object_type = cls self.object_type = cls
self.mount(cls) self.mount(cls)
if isinstance(self.type, MountType): if isinstance(self.type, MountType):
@ -143,3 +146,13 @@ class InputField(OrderedType):
return GraphQLInputObjectField( return GraphQLInputObjectField(
schema.T(self.type), schema.T(self.type),
default_value=self.default, description=self.description) default_value=self.default, description=self.description)
class FieldsGroupType(GroupNamedType):
def iter_types(self, schema):
for field in sorted(self.types):
try:
yield self.get_named_type(schema, field)
except SkipField:
continue

View File

@ -1,41 +1,30 @@
from graphql.core.type import (GraphQLBoolean, GraphQLFloat, GraphQLID, from graphql.core.type import (GraphQLBoolean, GraphQLFloat, GraphQLID,
GraphQLInt, GraphQLScalarType, GraphQLString) GraphQLInt, GraphQLString)
from .base import MountedType from .base import MountedType
class String(MountedType): class ScalarType(MountedType):
T = GraphQLString
def internal_type(self, schema):
return self._internal_type
class Int(MountedType): class String(ScalarType):
T = GraphQLInt _internal_type = GraphQLString
class Boolean(MountedType): class Int(ScalarType):
T = GraphQLBoolean _internal_type = GraphQLInt
class ID(MountedType): class Boolean(ScalarType):
T = GraphQLID _internal_type = GraphQLBoolean
class Float(MountedType): class ID(ScalarType):
T = GraphQLFloat _internal_type = GraphQLID
class Scalar(MountedType): class Float(ScalarType):
_internal_type = GraphQLFloat
@classmethod
def internal_type(cls, schema):
serialize = getattr(cls, 'serialize')
parse_literal = getattr(cls, 'parse_literal')
parse_value = getattr(cls, 'parse_value')
return GraphQLScalarType(
name=cls.__name__,
description=cls.__doc__,
serialize=serialize,
parse_value=parse_value,
parse_literal=parse_literal
)

View File

@ -27,8 +27,8 @@ def test_to_arguments():
other_kwarg=String(), other_kwarg=String(),
) )
assert [a.name for a in arguments] == [ assert [a.name or a.default_name for a in arguments] == [
'myArg', 'otherArg', 'myKwarg', 'otherKwarg'] 'myArg', 'otherArg', 'my_kwarg', 'other_kwarg']
def test_to_arguments_no_name(): def test_to_arguments_no_name():

View File

@ -13,14 +13,14 @@ from ..scalars import String
def test_field_internal_type(): def test_field_internal_type():
resolver = lambda *args: 'RESOLVED' resolver = lambda *args: 'RESOLVED'
field = Field(String, description='My argument', resolver=resolver) field = Field(String(), description='My argument', resolver=resolver)
class Query(ObjectType): class Query(ObjectType):
my_field = field my_field = field
schema = Schema(query=Query) schema = Schema(query=Query)
type = schema.T(field) type = schema.T(field)
assert field.name == 'myField' assert field.name is None
assert field.attname == 'my_field' assert field.attname == 'my_field'
assert isinstance(type, GraphQLField) assert isinstance(type, GraphQLField)
assert type.description == 'My argument' assert type.description == 'My argument'
@ -98,9 +98,18 @@ def test_field_string_reference():
def test_field_custom_arguments(): def test_field_custom_arguments():
field = Field(None, name='my_customName', p=String()) field = Field(None, name='my_customName', p=String())
schema = Schema()
args = field.arguments args = field.arguments
assert 'p' in args assert 'p' in schema.T(args)
def test_field_name_as_argument():
field = Field(None, name=String())
schema = Schema()
args = field.arguments
assert 'name' in schema.T(args)
def test_inputfield_internal_type(): def test_inputfield_internal_type():
@ -115,8 +124,43 @@ def test_inputfield_internal_type():
schema = Schema(query=MyObjectType) schema = Schema(query=MyObjectType)
type = schema.T(field) type = schema.T(field)
assert field.name == 'myField' assert field.name is None
assert field.attname == 'my_field' assert field.attname == 'my_field'
assert isinstance(type, GraphQLInputObjectField) assert isinstance(type, GraphQLInputObjectField)
assert type.description == 'My input field' assert type.description == 'My input field'
assert type.default_value == '3' assert type.default_value == '3'
def test_field_resolve_argument():
resolver = lambda instance, args, info: args.get('first_name')
field = Field(String(), first_name=String(), description='My argument', resolver=resolver)
class Query(ObjectType):
my_field = field
schema = Schema(query=Query)
type = schema.T(field)
assert type.resolver(None, {'firstName': 'Peter'}, None) == 'Peter'
def test_field_resolve_vars():
class Query(ObjectType):
hello = String(first_name=String())
def resolve_hello(self, args, info):
return 'Hello ' + args.get('first_name')
schema = Schema(query=Query)
result = schema.execute("""
query foo($firstName:String)
{
hello(firstName:$firstName)
}
""", args={"firstName": "Serkan"})
expected = {
'hello': 'Hello Serkan'
}
assert result.data == expected

View File

@ -1,9 +1,9 @@
from graphql.core.type import (GraphQLBoolean, GraphQLFloat, GraphQLID, from graphql.core.type import (GraphQLBoolean, GraphQLFloat, GraphQLID,
GraphQLInt, GraphQLScalarType, GraphQLString) GraphQLInt, GraphQLString)
from graphene.core.schema import Schema from graphene.core.schema import Schema
from ..scalars import ID, Boolean, Float, Int, Scalar, String from ..scalars import ID, Boolean, Float, Int, String
schema = Schema() schema = Schema()
@ -26,29 +26,3 @@ def test_id_scalar():
def test_float_scalar(): def test_float_scalar():
assert schema.T(Float()) == GraphQLFloat assert schema.T(Float()) == GraphQLFloat
def test_custom_scalar():
import datetime
from graphql.core.language import ast
class DateTimeScalar(Scalar):
'''DateTimeScalar Documentation'''
@staticmethod
def serialize(dt):
return dt.isoformat()
@staticmethod
def parse_literal(node):
if isinstance(node, ast.StringValue):
return datetime.datetime.strptime(
node.value, "%Y-%m-%dT%H:%M:%S.%f")
@staticmethod
def parse_value(value):
return datetime.datetime.strptime(value, "%Y-%m-%dT%H:%M:%S.%f")
scalar_type = schema.T(DateTimeScalar)
assert isinstance(scalar_type, GraphQLScalarType)
assert scalar_type.name == 'DateTimeScalar'
assert scalar_type.description == 'DateTimeScalar Documentation'

View File

@ -0,0 +1,6 @@
from .base import Plugin, PluginManager
from .camel_case import CamelCase
__all__ = [
'Plugin', 'PluginManager', 'CamelCase'
]

53
graphene/plugins/base.py Normal file
View File

@ -0,0 +1,53 @@
from contextlib import contextmanager
from functools import partial, reduce
class Plugin(object):
def contribute_to_schema(self, schema):
self.schema = schema
def apply_function(a, b):
return b(a)
class PluginManager(object):
PLUGIN_FUNCTIONS = ('get_default_namedtype_name', )
def __init__(self, schema, plugins=[]):
self.schema = schema
self.plugins = []
for plugin in plugins:
self.add_plugin(plugin)
def add_plugin(self, plugin):
if hasattr(plugin, 'contribute_to_schema'):
plugin.contribute_to_schema(self.schema)
self.plugins.append(plugin)
def get_plugin_functions(self, function):
for plugin in self.plugins:
if not hasattr(plugin, function):
continue
yield getattr(plugin, function)
def __getattr__(self, name):
functions = self.get_plugin_functions(name)
return partial(reduce, apply_function, functions)
def __contains__(self, name):
return name in self.PLUGIN_FUNCTIONS
@contextmanager
def context_execution(self, **executor):
contexts = []
functions = self.get_plugin_functions('context_execution')
for f in functions:
context = f(executor)
executor = context.__enter__()
contexts.append((context, executor))
yield executor
for context, value in contexts[::-1]:
context.__exit__(None, None, None)

View File

@ -0,0 +1,7 @@
from ..utils import to_camel_case
class CamelCase(object):
def get_default_namedtype_name(self, value):
return to_camel_case(value)

View File

@ -34,8 +34,7 @@ schema = Schema(query=Query, mutation=MyResultMutation)
def test_mutation_arguments(): def test_mutation_arguments():
assert ChangeNumber.arguments assert ChangeNumber.arguments
assert list(ChangeNumber.arguments) == ['input'] assert 'input' in schema.T(ChangeNumber.arguments)
assert 'input' in ChangeNumber.arguments
inner_type = ChangeNumber.input_type inner_type = ChangeNumber.input_type
client_mutation_id_field = inner_type._meta.fields_map[ client_mutation_id_field = inner_type._meta.fields_map[
'client_mutation_id'] 'client_mutation_id']

View File

@ -4,6 +4,7 @@ from collections import Iterable
from functools import wraps from functools import wraps
import six import six
from graphql_relay.connection.arrayconnection import connection_from_list from graphql_relay.connection.arrayconnection import connection_from_list
from graphql_relay.node.node import to_global_id from graphql_relay.node.node import to_global_id

View File

@ -3,8 +3,9 @@ from .proxy_snake_dict import ProxySnakeDict
from .caching import cached_property, memoize from .caching import cached_property, memoize
from .misc import enum_to_graphql_enum from .misc import enum_to_graphql_enum
from .resolve_only_args import resolve_only_args from .resolve_only_args import resolve_only_args
from .lazylist import LazyList
__all__ = ['to_camel_case', 'to_snake_case', 'ProxySnakeDict', __all__ = ['to_camel_case', 'to_snake_case', 'ProxySnakeDict',
'cached_property', 'memoize', 'enum_to_graphql_enum', 'cached_property', 'memoize', 'enum_to_graphql_enum',
'resolve_only_args'] 'resolve_only_args', 'LazyList']

View File

@ -0,0 +1,43 @@
class LazyList(object):
def __init__(self, origin, state=None):
self._origin = origin
self._state = state or []
self._origin_iter = None
self._finished = False
def __iter__(self):
return self if not self._finished else iter(self._state)
def iter(self):
return self.__iter__()
def __len__(self):
return self._origin.__len__()
def __next__(self):
try:
if not self._origin_iter:
self._origin_iter = self._origin.__iter__()
n = next(self._origin_iter)
except StopIteration as e:
self._finished = True
raise e
else:
self._state.append(n)
return n
def next(self):
return self.__next__()
def __getitem__(self, key):
item = self._origin[key]
if isinstance(key, slice):
return self.__class__(item)
return item
def __getattr__(self, name):
return getattr(self._origin, name)
def __repr__(self):
return "<{} {}>".format(self.__class__.__name__, repr(self._origin))

View File

@ -0,0 +1,23 @@
from py.test import raises
from ..lazylist import LazyList
def test_lazymap():
data = list(range(10))
lm = LazyList(data)
assert len(lm) == 10
assert lm[1] == 1
assert isinstance(lm[1:4], LazyList)
assert lm.append == data.append
assert repr(lm) == '<LazyList [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]>'
def test_lazymap_iter():
data = list(range(2))
lm = LazyList(data)
iter_lm = iter(lm)
assert iter_lm.next() == 0
assert iter_lm.next() == 1
with raises(StopIteration):
iter_lm.next()

View File

@ -24,9 +24,9 @@ class PyTest(TestCommand):
setup( setup(
name='graphene', name='graphene',
version='0.4.2', version='0.5.0',
description='Graphene: Python DSL for GraphQL', description='GraphQL Framework for Python',
long_description=open('README.rst').read(), long_description=open('README.rst').read(),
url='https://github.com/graphql-python/graphene', url='https://github.com/graphql-python/graphene',
@ -66,9 +66,9 @@ setup(
], ],
extras_require={ extras_require={
'django': [ 'django': [
'Django>=1.6.0,<1.9', 'Django>=1.6.0',
'singledispatch>=3.4.0.3', 'singledispatch>=3.4.0.3',
'graphql-django-view>=1.0.0', 'graphql-django-view>=1.1.0',
], ],
}, },

View File

@ -1,6 +1,7 @@
SECRET_KEY = 1 SECRET_KEY = 1
INSTALLED_APPS = [ INSTALLED_APPS = [
'graphene.contrib.django.tests',
'examples.starwars_django', 'examples.starwars_django',
] ]