Reformatted files using black

This commit is contained in:
Syrus Akbary 2018-07-19 16:51:33 -07:00
parent 96789b291f
commit 54ef52e1c6
41 changed files with 1478 additions and 1539 deletions

View File

@ -1,14 +1,6 @@
from .types import (
DjangoObjectType,
)
from .fields import (
DjangoConnectionField,
)
from .types import DjangoObjectType
from .fields import DjangoConnectionField
__version__ = '2.1rc1'
__version__ = "2.1rc1"
__all__ = [
'__version__',
'DjangoObjectType',
'DjangoConnectionField'
]
__all__ = ["__version__", "DjangoObjectType", "DjangoConnectionField"]

View File

@ -7,7 +7,7 @@ try:
# and we cannot have psycopg2 on PyPy
from django.contrib.postgres.fields import ArrayField, HStoreField, RangeField
except ImportError:
ArrayField, HStoreField, JSONField, RangeField = (MissingType, ) * 4
ArrayField, HStoreField, JSONField, RangeField = (MissingType,) * 4
try:

View File

@ -1,8 +1,22 @@
from django.db import models
from django.utils.encoding import force_text
from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List,
NonNull, String, UUID, DateTime, Date, Time)
from graphene import (
ID,
Boolean,
Dynamic,
Enum,
Field,
Float,
Int,
List,
NonNull,
String,
UUID,
DateTime,
Date,
Time,
)
from graphene.types.json import JSONString
from graphene.utils.str_converters import to_camel_case, to_const
from graphql import assert_valid_name
@ -32,7 +46,7 @@ def get_choices(choices):
else:
name = convert_choice_name(value)
while name in converted_names:
name += '_' + str(len(converted_names))
name += "_" + str(len(converted_names))
converted_names.append(name)
description = help_text
yield name, value, description
@ -43,16 +57,15 @@ def convert_django_field_with_choices(field, registry=None):
converted = registry.get_converted_field(field)
if converted:
return converted
choices = getattr(field, 'choices', None)
choices = getattr(field, "choices", None)
if choices:
meta = field.model._meta
name = to_camel_case('{}_{}'.format(meta.object_name, field.name))
name = to_camel_case("{}_{}".format(meta.object_name, field.name))
choices = list(get_choices(choices))
named_choices = [(c[0], c[1]) for c in choices]
named_choices_descriptions = {c[0]: c[2] for c in choices}
class EnumWithDescriptionsType(object):
@property
def description(self):
return named_choices_descriptions[self.name]
@ -69,8 +82,8 @@ def convert_django_field_with_choices(field, registry=None):
@singledispatch
def convert_django_field(field, registry=None):
raise Exception(
"Don't know how to convert the Django field %s (%s)" %
(field, field.__class__))
"Don't know how to convert the Django field %s (%s)" % (field, field.__class__)
)
@convert_django_field.register(models.CharField)
@ -147,7 +160,7 @@ def convert_onetoone_field_to_djangomodel(field, registry=None):
# We do this for a bug in Django 1.8, where null attr
# is not available in the OneToOneRel instance
null = getattr(field, 'null', True)
null = getattr(field, "null", True)
return Field(_type, required=not null)
return Dynamic(dynamic_type)
@ -171,6 +184,7 @@ def convert_field_to_list_or_connection(field, registry=None):
# defined filter_fields in the DjangoObjectType Meta
if _type._meta.filter_fields:
from .filter.fields import DjangoFilterConnectionField
return DjangoFilterConnectionField(_type)
return DjangoConnectionField(_type)

View File

@ -1,4 +1,4 @@
from .middleware import DjangoDebugMiddleware
from .types import DjangoDebug
__all__ = ['DjangoDebugMiddleware', 'DjangoDebug']
__all__ = ["DjangoDebugMiddleware", "DjangoDebug"]

View File

@ -7,7 +7,6 @@ from .types import DjangoDebug
class DjangoDebugContext(object):
def __init__(self):
self.debug_promise = None
self.promises = []
@ -38,20 +37,21 @@ class DjangoDebugContext(object):
class DjangoDebugMiddleware(object):
def resolve(self, next, root, info, **args):
context = info.context
django_debug = getattr(context, 'django_debug', None)
django_debug = getattr(context, "django_debug", None)
if not django_debug:
if context is None:
raise Exception('DjangoDebug cannot be executed in None contexts')
raise Exception("DjangoDebug cannot be executed in None contexts")
try:
context.django_debug = DjangoDebugContext()
except Exception:
raise Exception('DjangoDebug need the context to be writable, context received: {}.'.format(
context.__class__.__name__
))
if info.schema.get_type('DjangoDebug') == info.return_type:
raise Exception(
"DjangoDebug need the context to be writable, context received: {}.".format(
context.__class__.__name__
)
)
if info.schema.get_type("DjangoDebug") == info.return_type:
return context.django_debug.get_debug_promise()
promise = next(root, info, **args)
context.django_debug.add_promise(promise)

View File

@ -16,7 +16,6 @@ class SQLQueryTriggered(Exception):
class ThreadLocalState(local):
def __init__(self):
self.enabled = True
@ -35,7 +34,7 @@ recording = state.recording # export function
def wrap_cursor(connection, panel):
if not hasattr(connection, '_graphene_cursor'):
if not hasattr(connection, "_graphene_cursor"):
connection._graphene_cursor = connection.cursor
def cursor():
@ -46,7 +45,7 @@ def wrap_cursor(connection, panel):
def unwrap_cursor(connection):
if hasattr(connection, '_graphene_cursor'):
if hasattr(connection, "_graphene_cursor"):
previous_cursor = connection._graphene_cursor
connection.cursor = previous_cursor
del connection._graphene_cursor
@ -87,15 +86,14 @@ class NormalCursorWrapper(object):
if not params:
return params
if isinstance(params, dict):
return dict((key, self._quote_expr(value))
for key, value in params.items())
return dict((key, self._quote_expr(value)) for key, value in params.items())
return list(map(self._quote_expr, params))
def _decode(self, param):
try:
return force_text(param, strings_only=True)
except UnicodeDecodeError:
return '(encoded string)'
return "(encoded string)"
def _record(self, method, sql, params):
start_time = time()
@ -103,45 +101,48 @@ class NormalCursorWrapper(object):
return method(sql, params)
finally:
stop_time = time()
duration = (stop_time - start_time)
_params = ''
duration = stop_time - start_time
_params = ""
try:
_params = json.dumps(list(map(self._decode, params)))
except Exception:
pass # object not JSON serializable
alias = getattr(self.db, 'alias', 'default')
alias = getattr(self.db, "alias", "default")
conn = self.db.connection
vendor = getattr(conn, 'vendor', 'unknown')
vendor = getattr(conn, "vendor", "unknown")
params = {
'vendor': vendor,
'alias': alias,
'sql': self.db.ops.last_executed_query(
self.cursor, sql, self._quote_params(params)),
'duration': duration,
'raw_sql': sql,
'params': _params,
'start_time': start_time,
'stop_time': stop_time,
'is_slow': duration > 10,
'is_select': sql.lower().strip().startswith('select'),
"vendor": vendor,
"alias": alias,
"sql": self.db.ops.last_executed_query(
self.cursor, sql, self._quote_params(params)
),
"duration": duration,
"raw_sql": sql,
"params": _params,
"start_time": start_time,
"stop_time": stop_time,
"is_slow": duration > 10,
"is_select": sql.lower().strip().startswith("select"),
}
if vendor == 'postgresql':
if vendor == "postgresql":
# If an erroneous query was ran on the connection, it might
# be in a state where checking isolation_level raises an
# exception.
try:
iso_level = conn.isolation_level
except conn.InternalError:
iso_level = 'unknown'
params.update({
'trans_id': self.logger.get_transaction_id(alias),
'trans_status': conn.get_transaction_status(),
'iso_level': iso_level,
'encoding': conn.encoding,
})
iso_level = "unknown"
params.update(
{
"trans_id": self.logger.get_transaction_id(alias),
"trans_status": conn.get_transaction_status(),
"iso_level": iso_level,
"encoding": conn.encoding,
}
)
_sql = DjangoDebugSQL(**params)
# We keep `sql` to maintain backwards compatibility

View File

@ -12,31 +12,31 @@ from ..types import DjangoDebug
class context(object):
pass
# from examples.starwars_django.models import Character
pytestmark = pytest.mark.django_db
def test_should_query_field():
r1 = Reporter(last_name='ABA')
r1 = Reporter(last_name="ABA")
r1.save()
r2 = Reporter(last_name='Griffin')
r2 = Reporter(last_name="Griffin")
r2.save()
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
interfaces = (Node,)
class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug')
debug = graphene.Field(DjangoDebug, name="__debug")
def resolve_reporter(self, info, **args):
return Reporter.objects.first()
query = '''
query = """
query ReporterQuery {
reporter {
lastName
@ -47,43 +47,40 @@ def test_should_query_field():
}
}
}
'''
"""
expected = {
'reporter': {
'lastName': 'ABA',
"reporter": {"lastName": "ABA"},
"__debug": {
"sql": [{"rawSql": str(Reporter.objects.order_by("pk")[:1].query)}]
},
'__debug': {
'sql': [{
'rawSql': str(Reporter.objects.order_by('pk')[:1].query)
}]
}
}
schema = graphene.Schema(query=Query)
result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()])
result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
)
assert not result.errors
assert result.data == expected
def test_should_query_list():
r1 = Reporter(last_name='ABA')
r1 = Reporter(last_name="ABA")
r1.save()
r2 = Reporter(last_name='Griffin')
r2 = Reporter(last_name="Griffin")
r2.save()
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
interfaces = (Node,)
class Query(graphene.ObjectType):
all_reporters = graphene.List(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug')
debug = graphene.Field(DjangoDebug, name="__debug")
def resolve_all_reporters(self, info, **args):
return Reporter.objects.all()
query = '''
query = """
query ReporterQuery {
allReporters {
lastName
@ -94,45 +91,38 @@ def test_should_query_list():
}
}
}
'''
"""
expected = {
'allReporters': [{
'lastName': 'ABA',
}, {
'lastName': 'Griffin',
}],
'__debug': {
'sql': [{
'rawSql': str(Reporter.objects.all().query)
}]
}
"allReporters": [{"lastName": "ABA"}, {"lastName": "Griffin"}],
"__debug": {"sql": [{"rawSql": str(Reporter.objects.all().query)}]},
}
schema = graphene.Schema(query=Query)
result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()])
result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
)
assert not result.errors
assert result.data == expected
def test_should_query_connection():
r1 = Reporter(last_name='ABA')
r1 = Reporter(last_name="ABA")
r1.save()
r2 = Reporter(last_name='Griffin')
r2 = Reporter(last_name="Griffin")
r2.save()
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
interfaces = (Node,)
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug')
debug = graphene.Field(DjangoDebug, name="__debug")
def resolve_all_reporters(self, info, **args):
return Reporter.objects.all()
query = '''
query = """
query ReporterQuery {
allReporters(first:1) {
edges {
@ -147,48 +137,41 @@ def test_should_query_connection():
}
}
}
'''
expected = {
'allReporters': {
'edges': [{
'node': {
'lastName': 'ABA',
}
}]
},
}
"""
expected = {"allReporters": {"edges": [{"node": {"lastName": "ABA"}}]}}
schema = graphene.Schema(query=Query)
result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()])
result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
)
assert not result.errors
assert result.data['allReporters'] == expected['allReporters']
assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql']
assert result.data["allReporters"] == expected["allReporters"]
assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query)
assert result.data['__debug']['sql'][1]['rawSql'] == query
assert result.data["__debug"]["sql"][1]["rawSql"] == query
def test_should_query_connectionfilter():
from ...filter import DjangoFilterConnectionField
r1 = Reporter(last_name='ABA')
r1 = Reporter(last_name="ABA")
r1.save()
r2 = Reporter(last_name='Griffin')
r2 = Reporter(last_name="Griffin")
r2.save()
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
interfaces = (Node,)
class Query(graphene.ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterType, fields=['last_name'])
all_reporters = DjangoFilterConnectionField(ReporterType, fields=["last_name"])
s = graphene.String(resolver=lambda *_: "S")
debug = graphene.Field(DjangoDebug, name='__debug')
debug = graphene.Field(DjangoDebug, name="__debug")
def resolve_all_reporters(self, info, **args):
return Reporter.objects.all()
query = '''
query = """
query ReporterQuery {
allReporters(first:1) {
edges {
@ -203,20 +186,14 @@ def test_should_query_connectionfilter():
}
}
}
'''
expected = {
'allReporters': {
'edges': [{
'node': {
'lastName': 'ABA',
}
}]
},
}
"""
expected = {"allReporters": {"edges": [{"node": {"lastName": "ABA"}}]}}
schema = graphene.Schema(query=Query)
result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()])
result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
)
assert not result.errors
assert result.data['allReporters'] == expected['allReporters']
assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql']
assert result.data["allReporters"] == expected["allReporters"]
assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query)
assert result.data['__debug']['sql'][1]['rawSql'] == query
assert result.data["__debug"]["sql"][1]["rawSql"] == query

View File

@ -13,7 +13,6 @@ from .utils import maybe_queryset
class DjangoListField(Field):
def __init__(self, _type, *args, **kwargs):
super(DjangoListField, self).__init__(List(_type), *args, **kwargs)
@ -30,25 +29,28 @@ class DjangoListField(Field):
class DjangoConnectionField(ConnectionField):
def __init__(self, *args, **kwargs):
self.on = kwargs.pop('on', False)
self.on = kwargs.pop("on", False)
self.max_limit = kwargs.pop(
'max_limit',
graphene_settings.RELAY_CONNECTION_MAX_LIMIT
"max_limit", graphene_settings.RELAY_CONNECTION_MAX_LIMIT
)
self.enforce_first_or_last = kwargs.pop(
'enforce_first_or_last',
graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST
"enforce_first_or_last",
graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST,
)
super(DjangoConnectionField, self).__init__(*args, **kwargs)
@property
def type(self):
from .types import DjangoObjectType
_type = super(ConnectionField, self).type
assert issubclass(_type, DjangoObjectType), "DjangoConnectionField only accepts DjangoObjectType types"
assert _type._meta.connection, "The type {} doesn't have a connection".format(_type.__name__)
assert issubclass(
_type, DjangoObjectType
), "DjangoConnectionField only accepts DjangoObjectType types"
assert _type._meta.connection, "The type {} doesn't have a connection".format(
_type.__name__
)
return _type._meta.connection
@property
@ -100,28 +102,37 @@ class DjangoConnectionField(ConnectionField):
return connection
@classmethod
def connection_resolver(cls, resolver, connection, default_manager, max_limit,
enforce_first_or_last, root, info, **args):
first = args.get('first')
last = args.get('last')
def connection_resolver(
cls,
resolver,
connection,
default_manager,
max_limit,
enforce_first_or_last,
root,
info,
**args
):
first = args.get("first")
last = args.get("last")
if enforce_first_or_last:
assert first or last, (
'You must provide a `first` or `last` value to properly paginate the `{}` connection.'
"You must provide a `first` or `last` value to properly paginate the `{}` connection."
).format(info.field_name)
if max_limit:
if first:
assert first <= max_limit, (
'Requesting {} records on the `{}` connection exceeds the `first` limit of {} records.'
"Requesting {} records on the `{}` connection exceeds the `first` limit of {} records."
).format(first, info.field_name, max_limit)
args['first'] = min(first, max_limit)
args["first"] = min(first, max_limit)
if last:
assert last <= max_limit, (
'Requesting {} records on the `{}` connection exceeds the `last` limit of {} records.'
"Requesting {} records on the `{}` connection exceeds the `last` limit of {} records."
).format(last, info.field_name, max_limit)
args['last'] = min(last, max_limit)
args["last"] = min(last, max_limit)
iterable = resolver(root, info, **args)
on_resolve = partial(cls.resolve_connection, connection, default_manager, args)
@ -138,5 +149,5 @@ class DjangoConnectionField(ConnectionField):
self.type,
self.get_manager(),
self.max_limit,
self.enforce_first_or_last
self.enforce_first_or_last,
)

View File

@ -4,11 +4,15 @@ from ..utils import DJANGO_FILTER_INSTALLED
if not DJANGO_FILTER_INSTALLED:
warnings.warn(
"Use of django filtering requires the django-filter package "
"be installed. You can do so using `pip install django-filter`", ImportWarning
"be installed. You can do so using `pip install django-filter`",
ImportWarning,
)
else:
from .fields import DjangoFilterConnectionField
from .filterset import GlobalIDFilter, GlobalIDMultipleChoiceFilter
__all__ = ['DjangoFilterConnectionField',
'GlobalIDFilter', 'GlobalIDMultipleChoiceFilter']
__all__ = [
"DjangoFilterConnectionField",
"GlobalIDFilter",
"GlobalIDMultipleChoiceFilter",
]

View File

@ -7,10 +7,16 @@ from .utils import get_filtering_args_from_filterset, get_filterset_class
class DjangoFilterConnectionField(DjangoConnectionField):
def __init__(self, type, fields=None, order_by=None,
extra_filter_meta=None, filterset_class=None,
*args, **kwargs):
def __init__(
self,
type,
fields=None,
order_by=None,
extra_filter_meta=None,
filterset_class=None,
*args,
**kwargs
):
self._fields = fields
self._provided_filterset_class = filterset_class
self._filterset_class = None
@ -30,12 +36,13 @@ class DjangoFilterConnectionField(DjangoConnectionField):
def filterset_class(self):
if not self._filterset_class:
fields = self._fields or self.node_type._meta.filter_fields
meta = dict(model=self.model,
fields=fields)
meta = dict(model=self.model, fields=fields)
if self._extra_filter_meta:
meta.update(self._extra_filter_meta)
self._filterset_class = get_filterset_class(self._provided_filterset_class, **meta)
self._filterset_class = get_filterset_class(
self._provided_filterset_class, **meta
)
return self._filterset_class
@ -52,28 +59,40 @@ class DjangoFilterConnectionField(DjangoConnectionField):
# See related PR: https://github.com/graphql-python/graphene-django/pull/126
assert not (default_queryset.query.low_mark and queryset.query.low_mark), (
'Received two sliced querysets (low mark) in the connection, please slice only in one.'
)
assert not (default_queryset.query.high_mark and queryset.query.high_mark), (
'Received two sliced querysets (high mark) in the connection, please slice only in one.'
)
assert not (
default_queryset.query.low_mark and queryset.query.low_mark
), "Received two sliced querysets (low mark) in the connection, please slice only in one."
assert not (
default_queryset.query.high_mark and queryset.query.high_mark
), "Received two sliced querysets (high mark) in the connection, please slice only in one."
low = default_queryset.query.low_mark or queryset.query.low_mark
high = default_queryset.query.high_mark or queryset.query.high_mark
default_queryset.query.clear_limits()
queryset = super(DjangoFilterConnectionField, cls).merge_querysets(default_queryset, queryset)
queryset = super(DjangoFilterConnectionField, cls).merge_querysets(
default_queryset, queryset
)
queryset.query.set_limits(low, high)
return queryset
@classmethod
def connection_resolver(cls, resolver, connection, default_manager, max_limit,
enforce_first_or_last, filterset_class, filtering_args,
root, info, **args):
def connection_resolver(
cls,
resolver,
connection,
default_manager,
max_limit,
enforce_first_or_last,
filterset_class,
filtering_args,
root,
info,
**args
):
filter_kwargs = {k: v for k, v in args.items() if k in filtering_args}
qs = filterset_class(
data=filter_kwargs,
queryset=default_manager.get_queryset(),
request=info.context
request=info.context,
).qs
return super(DjangoFilterConnectionField, cls).connection_resolver(
@ -96,5 +115,5 @@ class DjangoFilterConnectionField(DjangoConnectionField):
self.max_limit,
self.enforce_first_or_last,
self.filterset_class,
self.filtering_args
self.filtering_args,
)

View File

@ -28,26 +28,19 @@ class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
GRAPHENE_FILTER_SET_OVERRIDES = {
models.AutoField: {
'filter_class': GlobalIDFilter,
},
models.OneToOneField: {
'filter_class': GlobalIDFilter,
},
models.ForeignKey: {
'filter_class': GlobalIDFilter,
},
models.ManyToManyField: {
'filter_class': GlobalIDMultipleChoiceFilter,
}
models.AutoField: {"filter_class": GlobalIDFilter},
models.OneToOneField: {"filter_class": GlobalIDFilter},
models.ForeignKey: {"filter_class": GlobalIDFilter},
models.ManyToManyField: {"filter_class": GlobalIDMultipleChoiceFilter},
}
class GrapheneFilterSetMixin(BaseFilterSet):
FILTER_DEFAULTS = dict(itertools.chain(
FILTER_FOR_DBFIELD_DEFAULTS.items(),
GRAPHENE_FILTER_SET_OVERRIDES.items()
))
FILTER_DEFAULTS = dict(
itertools.chain(
FILTER_FOR_DBFIELD_DEFAULTS.items(), GRAPHENE_FILTER_SET_OVERRIDES.items()
)
)
@classmethod
def filter_for_reverse_field(cls, f, name):
@ -62,10 +55,7 @@ class GrapheneFilterSetMixin(BaseFilterSet):
except AttributeError:
rel = f.field.rel
default = {
'name': name,
'label': capfirst(rel.related_name)
}
default = {"name": name, "label": capfirst(rel.related_name)}
if rel.multiple:
# For to-many relationships
return GlobalIDMultipleChoiceFilter(**default)
@ -78,25 +68,20 @@ def setup_filterset(filterset_class):
""" Wrap a provided filterset in Graphene-specific functionality
"""
return type(
'Graphene{}'.format(filterset_class.__name__),
"Graphene{}".format(filterset_class.__name__),
(filterset_class, GrapheneFilterSetMixin),
{},
)
def custom_filterset_factory(model, filterset_base_class=FilterSet,
**meta):
def custom_filterset_factory(model, filterset_base_class=FilterSet, **meta):
""" Create a filterset for the given model using the provided meta data
"""
meta.update({
'model': model,
})
meta_class = type(str('Meta'), (object,), meta)
meta.update({"model": model})
meta_class = type(str("Meta"), (object,), meta)
filterset = type(
str('%sFilterSet' % model._meta.object_name),
str("%sFilterSet" % model._meta.object_name),
(filterset_base_class, GrapheneFilterSetMixin),
{
'Meta': meta_class
}
{"Meta": meta_class},
)
return filterset

View File

@ -5,29 +5,26 @@ from graphene_django.tests.models import Article, Pet, Reporter
class ArticleFilter(django_filters.FilterSet):
class Meta:
model = Article
fields = {
'headline': ['exact', 'icontains'],
'pub_date': ['gt', 'lt', 'exact'],
'reporter': ['exact'],
"headline": ["exact", "icontains"],
"pub_date": ["gt", "lt", "exact"],
"reporter": ["exact"],
}
order_by = OrderingFilter(fields=('pub_date',))
order_by = OrderingFilter(fields=("pub_date",))
class ReporterFilter(django_filters.FilterSet):
class Meta:
model = Reporter
fields = ['first_name', 'last_name', 'email', 'pets']
fields = ["first_name", "last_name", "email", "pets"]
order_by = OrderingFilter(fields=('pub_date',))
order_by = OrderingFilter(fields=("pub_date",))
class PetFilter(django_filters.FilterSet):
class Meta:
model = Pet
fields = ['name']
fields = ["name"]

View File

@ -5,8 +5,7 @@ import pytest
from graphene import Field, ObjectType, Schema, Argument, Float, Boolean, String
from graphene.relay import Node
from graphene_django import DjangoObjectType
from graphene_django.forms import (GlobalIDFormField,
GlobalIDMultipleChoiceField)
from graphene_django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField
from graphene_django.tests.models import Article, Pet, Reporter
from graphene_django.utils import DJANGO_FILTER_INSTALLED
@ -20,36 +19,43 @@ if DJANGO_FILTER_INSTALLED:
import django_filters
from django_filters import FilterSet, NumberFilter
from graphene_django.filter import (GlobalIDFilter, DjangoFilterConnectionField,
GlobalIDMultipleChoiceFilter)
from graphene_django.filter.tests.filters import ArticleFilter, PetFilter, ReporterFilter
from graphene_django.filter import (
GlobalIDFilter,
DjangoFilterConnectionField,
GlobalIDMultipleChoiceFilter,
)
from graphene_django.filter.tests.filters import (
ArticleFilter,
PetFilter,
ReporterFilter,
)
else:
pytestmark.append(pytest.mark.skipif(True, reason='django_filters not installed or not compatible'))
pytestmark.append(
pytest.mark.skipif(
True, reason="django_filters not installed or not compatible"
)
)
pytestmark.append(pytest.mark.django_db)
if DJANGO_FILTER_INSTALLED:
class ArticleNode(DjangoObjectType):
class ArticleNode(DjangoObjectType):
class Meta:
model = Article
interfaces = (Node, )
filter_fields = ('headline', )
interfaces = (Node,)
filter_fields = ("headline",)
class ReporterNode(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
interfaces = (Node,)
class PetNode(DjangoObjectType):
class Meta:
model = Pet
interfaces = (Node, )
interfaces = (Node,)
# schema = Schema()
@ -59,58 +65,47 @@ def get_args(field):
def assert_arguments(field, *arguments):
ignore = ('after', 'before', 'first', 'last', 'order_by')
ignore = ("after", "before", "first", "last", "order_by")
args = get_args(field)
actual = [
name
for name in args
if name not in ignore and not name.startswith('_')
]
assert set(arguments) == set(actual), \
'Expected arguments ({}) did not match actual ({})'.format(
arguments,
actual
)
actual = [name for name in args if name not in ignore and not name.startswith("_")]
assert set(arguments) == set(
actual
), "Expected arguments ({}) did not match actual ({})".format(arguments, actual)
def assert_orderable(field):
args = get_args(field)
assert 'order_by' in args, \
'Field cannot be ordered'
assert "order_by" in args, "Field cannot be ordered"
def assert_not_orderable(field):
args = get_args(field)
assert 'order_by' not in args, \
'Field can be ordered'
assert "order_by" not in args, "Field can be ordered"
def test_filter_explicit_filterset_arguments():
field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleFilter)
assert_arguments(field,
'headline', 'headline__icontains',
'pub_date', 'pub_date__gt', 'pub_date__lt',
'reporter',
)
assert_arguments(
field,
"headline",
"headline__icontains",
"pub_date",
"pub_date__gt",
"pub_date__lt",
"reporter",
)
def test_filter_shortcut_filterset_arguments_list():
field = DjangoFilterConnectionField(ArticleNode, fields=['pub_date', 'reporter'])
assert_arguments(field,
'pub_date',
'reporter',
)
field = DjangoFilterConnectionField(ArticleNode, fields=["pub_date", "reporter"])
assert_arguments(field, "pub_date", "reporter")
def test_filter_shortcut_filterset_arguments_dict():
field = DjangoFilterConnectionField(ArticleNode, fields={
'headline': ['exact', 'icontains'],
'reporter': ['exact'],
})
assert_arguments(field,
'headline', 'headline__icontains',
'reporter',
)
field = DjangoFilterConnectionField(
ArticleNode, fields={"headline": ["exact", "icontains"], "reporter": ["exact"]}
)
assert_arguments(field, "headline", "headline__icontains", "reporter")
def test_filter_explicit_filterset_orderable():
@ -134,15 +129,14 @@ def test_filter_explicit_filterset_not_orderable():
def test_filter_shortcut_filterset_extra_meta():
field = DjangoFilterConnectionField(ArticleNode, extra_filter_meta={
'exclude': ('headline', )
})
assert 'headline' not in field.filterset_class.get_fields()
field = DjangoFilterConnectionField(
ArticleNode, extra_filter_meta={"exclude": ("headline",)}
)
assert "headline" not in field.filterset_class.get_fields()
def test_filter_shortcut_filterset_context():
class ArticleContextFilter(django_filters.FilterSet):
class Meta:
model = Article
exclude = set()
@ -153,17 +147,31 @@ def test_filter_shortcut_filterset_context():
return qs.filter(reporter=self.request.reporter)
class Query(ObjectType):
context_articles = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleContextFilter)
context_articles = DjangoFilterConnectionField(
ArticleNode, filterset_class=ArticleContextFilter
)
r1 = Reporter.objects.create(first_name='r1', last_name='r1', email='r1@test.com')
r2 = Reporter.objects.create(first_name='r2', last_name='r2', email='r2@test.com')
Article.objects.create(headline='a1', pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r1, editor=r1)
Article.objects.create(headline='a2', pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r2, editor=r2)
r1 = Reporter.objects.create(first_name="r1", last_name="r1", email="r1@test.com")
r2 = Reporter.objects.create(first_name="r2", last_name="r2", email="r2@test.com")
Article.objects.create(
headline="a1",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r1,
editor=r1,
)
Article.objects.create(
headline="a2",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r2,
editor=r2,
)
class context(object):
reporter = r2
query = '''
query = """
query {
contextArticles {
edges {
@ -173,42 +181,39 @@ def test_filter_shortcut_filterset_context():
}
}
}
'''
"""
schema = Schema(query=Query)
result = schema.execute(query, context_value=context())
assert not result.errors
assert len(result.data['contextArticles']['edges']) == 1
assert result.data['contextArticles']['edges'][0]['node']['headline'] == 'a2'
assert len(result.data["contextArticles"]["edges"]) == 1
assert result.data["contextArticles"]["edges"][0]["node"]["headline"] == "a2"
def test_filter_filterset_information_on_meta():
class ReporterFilterNode(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
filter_fields = ['first_name', 'articles']
interfaces = (Node,)
filter_fields = ["first_name", "articles"]
field = DjangoFilterConnectionField(ReporterFilterNode)
assert_arguments(field, 'first_name', 'articles')
assert_arguments(field, "first_name", "articles")
assert_not_orderable(field)
def test_filter_filterset_information_on_meta_related():
class ReporterFilterNode(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
filter_fields = ['first_name', 'articles']
interfaces = (Node,)
filter_fields = ["first_name", "articles"]
class ArticleFilterNode(DjangoObjectType):
class Meta:
model = Article
interfaces = (Node, )
filter_fields = ['headline', 'reporter']
interfaces = (Node,)
filter_fields = ["headline", "reporter"]
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
@ -217,25 +222,23 @@ def test_filter_filterset_information_on_meta_related():
article = Field(ArticleFilterNode)
schema = Schema(query=Query)
articles_field = ReporterFilterNode._meta.fields['articles'].get_type()
assert_arguments(articles_field, 'headline', 'reporter')
articles_field = ReporterFilterNode._meta.fields["articles"].get_type()
assert_arguments(articles_field, "headline", "reporter")
assert_not_orderable(articles_field)
def test_filter_filterset_related_results():
class ReporterFilterNode(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
filter_fields = ['first_name', 'articles']
interfaces = (Node,)
filter_fields = ["first_name", "articles"]
class ArticleFilterNode(DjangoObjectType):
class Meta:
interfaces = (Node, )
interfaces = (Node,)
model = Article
filter_fields = ['headline', 'reporter']
filter_fields = ["headline", "reporter"]
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
@ -243,12 +246,22 @@ def test_filter_filterset_related_results():
reporter = Field(ReporterFilterNode)
article = Field(ArticleFilterNode)
r1 = Reporter.objects.create(first_name='r1', last_name='r1', email='r1@test.com')
r2 = Reporter.objects.create(first_name='r2', last_name='r2', email='r2@test.com')
Article.objects.create(headline='a1', pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r1)
Article.objects.create(headline='a2', pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r2)
r1 = Reporter.objects.create(first_name="r1", last_name="r1", email="r1@test.com")
r2 = Reporter.objects.create(first_name="r2", last_name="r2", email="r2@test.com")
Article.objects.create(
headline="a1",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r1,
)
Article.objects.create(
headline="a2",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r2,
)
query = '''
query = """
query {
allReporters {
edges {
@ -264,123 +277,134 @@ def test_filter_filterset_related_results():
}
}
}
'''
"""
schema = Schema(query=Query)
result = schema.execute(query)
assert not result.errors
# We should only get back a single article for each reporter
assert len(result.data['allReporters']['edges'][0]['node']['articles']['edges']) == 1
assert len(result.data['allReporters']['edges'][1]['node']['articles']['edges']) == 1
assert (
len(result.data["allReporters"]["edges"][0]["node"]["articles"]["edges"]) == 1
)
assert (
len(result.data["allReporters"]["edges"][1]["node"]["articles"]["edges"]) == 1
)
def test_global_id_field_implicit():
field = DjangoFilterConnectionField(ArticleNode, fields=['id'])
field = DjangoFilterConnectionField(ArticleNode, fields=["id"])
filterset_class = field.filterset_class
id_filter = filterset_class.base_filters['id']
id_filter = filterset_class.base_filters["id"]
assert isinstance(id_filter, GlobalIDFilter)
assert id_filter.field_class == GlobalIDFormField
def test_global_id_field_explicit():
class ArticleIdFilter(django_filters.FilterSet):
class Meta:
model = Article
fields = ['id']
fields = ["id"]
field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter)
filterset_class = field.filterset_class
id_filter = filterset_class.base_filters['id']
id_filter = filterset_class.base_filters["id"]
assert isinstance(id_filter, GlobalIDFilter)
assert id_filter.field_class == GlobalIDFormField
def test_filterset_descriptions():
class ArticleIdFilter(django_filters.FilterSet):
class Meta:
model = Article
fields = ['id']
fields = ["id"]
max_time = django_filters.NumberFilter(method='filter_max_time', label="The maximum time")
max_time = django_filters.NumberFilter(
method="filter_max_time", label="The maximum time"
)
field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter)
max_time = field.args['max_time']
max_time = field.args["max_time"]
assert isinstance(max_time, Argument)
assert max_time.type == Float
assert max_time.description == 'The maximum time'
assert max_time.description == "The maximum time"
def test_global_id_field_relation():
field = DjangoFilterConnectionField(ArticleNode, fields=['reporter'])
field = DjangoFilterConnectionField(ArticleNode, fields=["reporter"])
filterset_class = field.filterset_class
id_filter = filterset_class.base_filters['reporter']
id_filter = filterset_class.base_filters["reporter"]
assert isinstance(id_filter, GlobalIDFilter)
assert id_filter.field_class == GlobalIDFormField
def test_global_id_multiple_field_implicit():
field = DjangoFilterConnectionField(ReporterNode, fields=['pets'])
field = DjangoFilterConnectionField(ReporterNode, fields=["pets"])
filterset_class = field.filterset_class
multiple_filter = filterset_class.base_filters['pets']
multiple_filter = filterset_class.base_filters["pets"]
assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter)
assert multiple_filter.field_class == GlobalIDMultipleChoiceField
def test_global_id_multiple_field_explicit():
class ReporterPetsFilter(django_filters.FilterSet):
class Meta:
model = Reporter
fields = ['pets']
fields = ["pets"]
field = DjangoFilterConnectionField(ReporterNode, filterset_class=ReporterPetsFilter)
field = DjangoFilterConnectionField(
ReporterNode, filterset_class=ReporterPetsFilter
)
filterset_class = field.filterset_class
multiple_filter = filterset_class.base_filters['pets']
multiple_filter = filterset_class.base_filters["pets"]
assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter)
assert multiple_filter.field_class == GlobalIDMultipleChoiceField
def test_global_id_multiple_field_implicit_reverse():
field = DjangoFilterConnectionField(ReporterNode, fields=['articles'])
field = DjangoFilterConnectionField(ReporterNode, fields=["articles"])
filterset_class = field.filterset_class
multiple_filter = filterset_class.base_filters['articles']
multiple_filter = filterset_class.base_filters["articles"]
assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter)
assert multiple_filter.field_class == GlobalIDMultipleChoiceField
def test_global_id_multiple_field_explicit_reverse():
class ReporterPetsFilter(django_filters.FilterSet):
class Meta:
model = Reporter
fields = ['articles']
fields = ["articles"]
field = DjangoFilterConnectionField(ReporterNode, filterset_class=ReporterPetsFilter)
field = DjangoFilterConnectionField(
ReporterNode, filterset_class=ReporterPetsFilter
)
filterset_class = field.filterset_class
multiple_filter = filterset_class.base_filters['articles']
multiple_filter = filterset_class.base_filters["articles"]
assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter)
assert multiple_filter.field_class == GlobalIDMultipleChoiceField
def test_filter_filterset_related_results():
class ReporterFilterNode(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
filter_fields = {
'first_name': ['icontains']
}
interfaces = (Node,)
filter_fields = {"first_name": ["icontains"]}
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
r1 = Reporter.objects.create(first_name='A test user', last_name='Last Name', email='test1@test.com')
r2 = Reporter.objects.create(first_name='Other test user', last_name='Other Last Name', email='test2@test.com')
r3 = Reporter.objects.create(first_name='Random', last_name='RandomLast', email='random@test.com')
r1 = Reporter.objects.create(
first_name="A test user", last_name="Last Name", email="test1@test.com"
)
r2 = Reporter.objects.create(
first_name="Other test user",
last_name="Other Last Name",
email="test2@test.com",
)
r3 = Reporter.objects.create(
first_name="Random", last_name="RandomLast", email="random@test.com"
)
query = '''
query = """
query {
allReporters(firstName_Icontains: "test") {
edges {
@ -390,12 +414,12 @@ def test_filter_filterset_related_results():
}
}
}
'''
"""
schema = Schema(query=Query)
result = schema.execute(query)
assert not result.errors
# We should only get two reporters
assert len(result.data['allReporters']['edges']) == 2
assert len(result.data["allReporters"]["edges"]) == 2
def test_recursive_filter_connection():
@ -407,79 +431,73 @@ def test_recursive_filter_connection():
class Meta:
model = Reporter
interfaces = (Node, )
interfaces = (Node,)
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
assert ReporterFilterNode._meta.fields['child_reporters'].node_type == ReporterFilterNode
assert (
ReporterFilterNode._meta.fields["child_reporters"].node_type
== ReporterFilterNode
)
def test_should_query_filter_node_limit():
class ReporterFilter(FilterSet):
limit = NumberFilter(method='filter_limit')
limit = NumberFilter(method="filter_limit")
def filter_limit(self, queryset, name, value):
return queryset[:value]
class Meta:
model = Reporter
fields = ['first_name', ]
fields = ["first_name"]
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
interfaces = (Node,)
class ArticleType(DjangoObjectType):
class Meta:
model = Article
interfaces = (Node, )
filter_fields = ('lang', )
interfaces = (Node,)
filter_fields = ("lang",)
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(
ReporterType,
filterset_class=ReporterFilter
ReporterType, filterset_class=ReporterFilter
)
def resolve_all_reporters(self, info, **args):
return Reporter.objects.order_by('a_choice')
return Reporter.objects.order_by("a_choice")
Reporter.objects.create(
first_name='Bob',
last_name='Doe',
email='bobdoe@example.com',
a_choice=2
first_name="Bob", last_name="Doe", email="bobdoe@example.com", a_choice=2
)
r = Reporter.objects.create(
first_name='John',
last_name='Doe',
email='johndoe@example.com',
a_choice=1
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
)
Article.objects.create(
headline='Article Node 1',
headline="Article Node 1",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r,
editor=r,
lang='es'
lang="es",
)
Article.objects.create(
headline='Article Node 2',
headline="Article Node 2",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r,
editor=r,
lang='en'
lang="en",
)
schema = Schema(query=Query)
query = '''
query = """
query NodeFilteringQuery {
allReporters(limit: 1) {
edges {
@ -498,24 +516,23 @@ def test_should_query_filter_node_limit():
}
}
}
'''
"""
expected = {
'allReporters': {
'edges': [{
'node': {
'id': 'UmVwb3J0ZXJUeXBlOjI=',
'firstName': 'John',
'articles': {
'edges': [{
'node': {
'id': 'QXJ0aWNsZVR5cGU6MQ==',
'lang': 'ES'
}
}]
"allReporters": {
"edges": [
{
"node": {
"id": "UmVwb3J0ZXJUeXBlOjI=",
"firstName": "John",
"articles": {
"edges": [
{"node": {"id": "QXJ0aWNsZVR5cGU6MQ==", "lang": "ES"}}
]
},
}
}
}]
]
}
}
@ -526,45 +543,37 @@ def test_should_query_filter_node_limit():
def test_should_query_filter_node_double_limit_raises():
class ReporterFilter(FilterSet):
limit = NumberFilter(method='filter_limit')
limit = NumberFilter(method="filter_limit")
def filter_limit(self, queryset, name, value):
return queryset[:value]
class Meta:
model = Reporter
fields = ['first_name', ]
fields = ["first_name"]
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
interfaces = (Node,)
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(
ReporterType,
filterset_class=ReporterFilter
ReporterType, filterset_class=ReporterFilter
)
def resolve_all_reporters(self, info, **args):
return Reporter.objects.order_by('a_choice')[:2]
return Reporter.objects.order_by("a_choice")[:2]
Reporter.objects.create(
first_name='Bob',
last_name='Doe',
email='bobdoe@example.com',
a_choice=2
first_name="Bob", last_name="Doe", email="bobdoe@example.com", a_choice=2
)
r = Reporter.objects.create(
first_name='John',
last_name='Doe',
email='johndoe@example.com',
a_choice=1
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
)
schema = Schema(query=Query)
query = '''
query = """
query NodeFilteringQuery {
allReporters(limit: 1) {
edges {
@ -575,41 +584,40 @@ def test_should_query_filter_node_double_limit_raises():
}
}
}
'''
"""
result = schema.execute(query)
assert len(result.errors) == 1
assert str(result.errors[0]) == (
'Received two sliced querysets (high mark) in the connection, please slice only in one.'
"Received two sliced querysets (high mark) in the connection, please slice only in one."
)
def test_order_by_is_perserved():
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
interfaces = (Node,)
filter_fields = ()
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterType, reverse_order=Boolean())
all_reporters = DjangoFilterConnectionField(
ReporterType, reverse_order=Boolean()
)
def resolve_all_reporters(self, info, reverse_order=False, **args):
reporters = Reporter.objects.order_by('first_name')
reporters = Reporter.objects.order_by("first_name")
if reverse_order:
return reporters.reverse()
return reporters
Reporter.objects.create(
first_name='b',
)
r = Reporter.objects.create(
first_name='a',
)
Reporter.objects.create(first_name="b")
r = Reporter.objects.create(first_name="a")
schema = Schema(query=Query)
query = '''
query = """
query NodeFilteringQuery {
allReporters(first: 1) {
edges {
@ -619,23 +627,14 @@ def test_order_by_is_perserved():
}
}
}
'''
expected = {
'allReporters': {
'edges': [{
'node': {
'firstName': 'a',
}
}]
}
}
"""
expected = {"allReporters": {"edges": [{"node": {"firstName": "a"}}]}}
result = schema.execute(query)
assert not result.errors
assert result.data == expected
reverse_query = '''
reverse_query = """
query NodeFilteringQuery {
allReporters(first: 1, reverseOrder: true) {
edges {
@ -645,33 +644,26 @@ def test_order_by_is_perserved():
}
}
}
'''
"""
reverse_expected = {
'allReporters': {
'edges': [{
'node': {
'firstName': 'b',
}
}]
}
}
reverse_expected = {"allReporters": {"edges": [{"node": {"firstName": "b"}}]}}
reverse_result = schema.execute(reverse_query)
assert not reverse_result.errors
assert reverse_result.data == reverse_expected
def test_annotation_is_perserved():
class ReporterType(DjangoObjectType):
full_name = String()
def resolve_full_name(instance, info, **args):
return instance.full_name
class Meta:
model = Reporter
interfaces = (Node, )
interfaces = (Node,)
filter_fields = ()
class Query(ObjectType):
@ -679,17 +671,16 @@ def test_annotation_is_perserved():
def resolve_all_reporters(self, info, **args):
return Reporter.objects.annotate(
full_name=Concat('first_name', Value(' '), 'last_name', output_field=TextField())
full_name=Concat(
"first_name", Value(" "), "last_name", output_field=TextField()
)
)
Reporter.objects.create(
first_name='John',
last_name='Doe',
)
Reporter.objects.create(first_name="John", last_name="Doe")
schema = Schema(query=Query)
query = '''
query = """
query NodeFilteringQuery {
allReporters(first: 1) {
edges {
@ -699,16 +690,8 @@ def test_annotation_is_perserved():
}
}
}
'''
expected = {
'allReporters': {
'edges': [{
'node': {
'fullName': 'John Doe',
}
}]
}
}
"""
expected = {"allReporters": {"edges": [{"node": {"fullName": "John Doe"}}]}}
result = schema.execute(query)

View File

@ -14,8 +14,7 @@ singledispatch = import_single_dispatch()
def convert_form_field(field):
raise ImproperlyConfigured(
"Don't know how to convert the Django form field %s (%s) "
"to Graphene type" %
(field, field.__class__)
"to Graphene type" % (field, field.__class__)
)

View File

@ -8,9 +8,7 @@ from graphql_relay import from_global_id
class GlobalIDFormField(Field):
default_error_messages = {
'invalid': _('Invalid ID specified.'),
}
default_error_messages = {"invalid": _("Invalid ID specified.")}
def clean(self, value):
if not value and not self.required:
@ -19,21 +17,21 @@ class GlobalIDFormField(Field):
try:
_type, _id = from_global_id(value)
except (TypeError, ValueError, UnicodeDecodeError, binascii.Error):
raise ValidationError(self.error_messages['invalid'])
raise ValidationError(self.error_messages["invalid"])
try:
CharField().clean(_id)
CharField().clean(_type)
except ValidationError:
raise ValidationError(self.error_messages['invalid'])
raise ValidationError(self.error_messages["invalid"])
return value
class GlobalIDMultipleChoiceField(MultipleChoiceField):
default_error_messages = {
'invalid_choice': _('One of the specified IDs was invalid (%(value)s).'),
'invalid_list': _('Enter a list of values.'),
"invalid_choice": _("One of the specified IDs was invalid (%(value)s)."),
"invalid_list": _("Enter a list of values."),
}
def valid_value(self, value):

View File

@ -5,6 +5,7 @@ import graphene
from graphene import Field, InputField
from graphene.relay.mutation import ClientIDMutation
from graphene.types.mutation import MutationOptions
# from graphene.types.inputobjecttype import (
# InputObjectTypeOptions,
# InputObjectType,
@ -21,7 +22,8 @@ def fields_for_form(form, only_fields, exclude_fields):
for name, field in form.fields.items():
is_not_in_only = only_fields and name not in only_fields
is_excluded = (
name in exclude_fields # or
name
in exclude_fields # or
# name in already_created_fields
)
@ -57,12 +59,12 @@ class BaseDjangoFormMutation(ClientIDMutation):
@classmethod
def get_form_kwargs(cls, root, info, **input):
kwargs = {'data': input}
kwargs = {"data": input}
pk = input.pop('id', None)
pk = input.pop("id", None)
if pk:
instance = cls._meta.model._default_manager.get(pk=pk)
kwargs['instance'] = instance
kwargs["instance"] = instance
return kwargs
@ -100,11 +102,12 @@ class DjangoFormMutation(BaseDjangoFormMutation):
errors = graphene.List(ErrorType)
@classmethod
def __init_subclass_with_meta__(cls, form_class=None,
only_fields=(), exclude_fields=(), **options):
def __init_subclass_with_meta__(
cls, form_class=None, only_fields=(), exclude_fields=(), **options
):
if not form_class:
raise Exception('form_class is required for DjangoFormMutation')
raise Exception("form_class is required for DjangoFormMutation")
form = form_class()
input_fields = fields_for_form(form, only_fields, exclude_fields)
@ -112,16 +115,12 @@ class DjangoFormMutation(BaseDjangoFormMutation):
_meta = DjangoFormMutationOptions(cls)
_meta.form_class = form_class
_meta.fields = yank_fields_from_attrs(
output_fields,
_as=Field,
)
_meta.fields = yank_fields_from_attrs(output_fields, _as=Field)
input_fields = yank_fields_from_attrs(
input_fields,
_as=InputField,
input_fields = yank_fields_from_attrs(input_fields, _as=InputField)
super(DjangoFormMutation, cls).__init_subclass_with_meta__(
_meta=_meta, input_fields=input_fields, **options
)
super(DjangoFormMutation, cls).__init_subclass_with_meta__(_meta=_meta, input_fields=input_fields, **options)
@classmethod
def perform_mutate(cls, form, info):
@ -141,21 +140,28 @@ class DjangoModelFormMutation(BaseDjangoFormMutation):
errors = graphene.List(ErrorType)
@classmethod
def __init_subclass_with_meta__(cls, form_class=None, model=None, return_field_name=None,
only_fields=(), exclude_fields=(), **options):
def __init_subclass_with_meta__(
cls,
form_class=None,
model=None,
return_field_name=None,
only_fields=(),
exclude_fields=(),
**options
):
if not form_class:
raise Exception('form_class is required for DjangoModelFormMutation')
raise Exception("form_class is required for DjangoModelFormMutation")
if not model:
model = form_class._meta.model
if not model:
raise Exception('model is required for DjangoModelFormMutation')
raise Exception("model is required for DjangoModelFormMutation")
form = form_class()
input_fields = fields_for_form(form, only_fields, exclude_fields)
input_fields['id'] = graphene.ID()
input_fields["id"] = graphene.ID()
registry = get_global_registry()
model_type = registry.get_type_for_model(model)
@ -171,19 +177,11 @@ class DjangoModelFormMutation(BaseDjangoFormMutation):
_meta.form_class = form_class
_meta.model = model
_meta.return_field_name = return_field_name
_meta.fields = yank_fields_from_attrs(
output_fields,
_as=Field,
)
_meta.fields = yank_fields_from_attrs(output_fields, _as=Field)
input_fields = yank_fields_from_attrs(
input_fields,
_as=InputField,
)
input_fields = yank_fields_from_attrs(input_fields, _as=InputField)
super(DjangoModelFormMutation, cls).__init_subclass_with_meta__(
_meta=_meta,
input_fields=input_fields,
**options
_meta=_meta, input_fields=input_fields, **options
)
@classmethod

View File

@ -2,24 +2,36 @@ from django import forms
from py.test import raises
import graphene
from graphene import String, Int, Boolean, Float, ID, UUID, List, NonNull, DateTime, Date, Time
from graphene import (
String,
Int,
Boolean,
Float,
ID,
UUID,
List,
NonNull,
DateTime,
Date,
Time,
)
from ..converter import convert_form_field
def assert_conversion(django_field, graphene_field, *args):
field = django_field(*args, help_text='Custom Help Text')
field = django_field(*args, help_text="Custom Help Text")
graphene_type = convert_form_field(field)
assert isinstance(graphene_type, graphene_field)
field = graphene_type.Field()
assert field.description == 'Custom Help Text'
assert field.description == "Custom Help Text"
return field
def test_should_unknown_django_field_raise_exception():
with raises(Exception) as excinfo:
convert_form_field(None)
assert 'Don\'t know how to convert the Django form field' in str(excinfo.value)
assert "Don't know how to convert the Django form field" in str(excinfo.value)
def test_should_date_convert_date():
@ -59,11 +71,11 @@ def test_should_base_field_convert_string():
def test_should_regex_convert_string():
assert_conversion(forms.RegexField, String, '[0-9]+')
assert_conversion(forms.RegexField, String, "[0-9]+")
def test_should_uuid_convert_string():
if hasattr(forms, 'UUIDField'):
if hasattr(forms, "UUIDField"):
assert_conversion(forms.UUIDField, UUID)

View File

@ -11,18 +11,18 @@ class MyForm(forms.Form):
class PetForm(forms.ModelForm):
class Meta:
model = Pet
fields = ('name',)
fields = ("name",)
def test_needs_form_class():
with raises(Exception) as exc:
class MyMutation(DjangoFormMutation):
pass
assert exc.value.args[0] == 'form_class is required for DjangoFormMutation'
assert exc.value.args[0] == "form_class is required for DjangoFormMutation"
def test_has_output_fields():
@ -30,7 +30,7 @@ def test_has_output_fields():
class Meta:
form_class = MyForm
assert 'errors' in MyMutation._meta.fields
assert "errors" in MyMutation._meta.fields
def test_has_input_fields():
@ -38,19 +38,18 @@ def test_has_input_fields():
class Meta:
form_class = MyForm
assert 'text' in MyMutation.Input._meta.fields
assert "text" in MyMutation.Input._meta.fields
class ModelFormMutationTests(TestCase):
def test_default_meta_fields(self):
class PetMutation(DjangoModelFormMutation):
class Meta:
form_class = PetForm
self.assertEqual(PetMutation._meta.model, Pet)
self.assertEqual(PetMutation._meta.return_field_name, 'pet')
self.assertIn('pet', PetMutation._meta.fields)
self.assertEqual(PetMutation._meta.return_field_name, "pet")
self.assertIn("pet", PetMutation._meta.fields)
def test_return_field_name_is_camelcased(self):
class PetMutation(DjangoModelFormMutation):
@ -59,31 +58,31 @@ class ModelFormMutationTests(TestCase):
model = FilmDetails
self.assertEqual(PetMutation._meta.model, FilmDetails)
self.assertEqual(PetMutation._meta.return_field_name, 'filmDetails')
self.assertEqual(PetMutation._meta.return_field_name, "filmDetails")
def test_custom_return_field_name(self):
class PetMutation(DjangoModelFormMutation):
class Meta:
form_class = PetForm
model = Film
return_field_name = 'animal'
return_field_name = "animal"
self.assertEqual(PetMutation._meta.model, Film)
self.assertEqual(PetMutation._meta.return_field_name, 'animal')
self.assertIn('animal', PetMutation._meta.fields)
self.assertEqual(PetMutation._meta.return_field_name, "animal")
self.assertIn("animal", PetMutation._meta.fields)
def test_model_form_mutation_mutate(self):
class PetMutation(DjangoModelFormMutation):
class Meta:
form_class = PetForm
pet = Pet.objects.create(name='Axel')
pet = Pet.objects.create(name="Axel")
result = PetMutation.mutate_and_get_payload(None, None, id=pet.pk, name='Mia')
result = PetMutation.mutate_and_get_payload(None, None, id=pet.pk, name="Mia")
self.assertEqual(Pet.objects.count(), 1)
pet.refresh_from_db()
self.assertEqual(pet.name, 'Mia')
self.assertEqual(pet.name, "Mia")
self.assertEqual(result.errors, [])
def test_model_form_mutation_updates_existing_(self):
@ -91,11 +90,11 @@ class ModelFormMutationTests(TestCase):
class Meta:
form_class = PetForm
result = PetMutation.mutate_and_get_payload(None, None, name='Mia')
result = PetMutation.mutate_and_get_payload(None, None, name="Mia")
self.assertEqual(Pet.objects.count(), 1)
pet = Pet.objects.get()
self.assertEqual(pet.name, 'Mia')
self.assertEqual(pet.name, "Mia")
self.assertEqual(result.errors, [])
def test_model_form_mutation_mutate_invalid_form(self):
@ -109,5 +108,5 @@ class ModelFormMutationTests(TestCase):
self.assertEqual(Pet.objects.count(), 0)
self.assertEqual(len(result.errors), 1)
self.assertEqual(result.errors[0].field, 'name')
self.assertEqual(result.errors[0].messages, ['This field is required.'])
self.assertEqual(result.errors[0].field, "name")
self.assertEqual(result.errors[0].messages, ["This field is required."])

View File

@ -7,43 +7,45 @@ from graphene_django.settings import graphene_settings
class CommandArguments(BaseCommand):
def add_arguments(self, parser):
parser.add_argument(
'--schema',
"--schema",
type=str,
dest='schema',
dest="schema",
default=graphene_settings.SCHEMA,
help='Django app containing schema to dump, e.g. myproject.core.schema.schema')
help="Django app containing schema to dump, e.g. myproject.core.schema.schema",
)
parser.add_argument(
'--out',
"--out",
type=str,
dest='out',
dest="out",
default=graphene_settings.SCHEMA_OUTPUT,
help='Output file (default: schema.json)')
help="Output file (default: schema.json)",
)
parser.add_argument(
'--indent',
"--indent",
type=int,
dest='indent',
dest="indent",
default=graphene_settings.SCHEMA_INDENT,
help='Output file indent (default: None)')
help="Output file indent (default: None)",
)
class Command(CommandArguments):
help = 'Dump Graphene schema JSON to file'
help = "Dump Graphene schema JSON to file"
can_import_settings = True
def save_file(self, out, schema_dict, indent):
with open(out, 'w') as outfile:
with open(out, "w") as outfile:
json.dump(schema_dict, outfile, indent=indent)
def handle(self, *args, **options):
options_schema = options.get('schema')
options_schema = options.get("schema")
if options_schema and type(options_schema) is str:
module_str, schema_name = options_schema.rsplit('.', 1)
module_str, schema_name = options_schema.rsplit(".", 1)
mod = importlib.import_module(module_str)
schema = getattr(mod, schema_name)
@ -53,16 +55,18 @@ class Command(CommandArguments):
else:
schema = graphene_settings.SCHEMA
out = options.get('out') or graphene_settings.SCHEMA_OUTPUT
out = options.get("out") or graphene_settings.SCHEMA_OUTPUT
if not schema:
raise CommandError('Specify schema on GRAPHENE.SCHEMA setting or by using --schema')
raise CommandError(
"Specify schema on GRAPHENE.SCHEMA setting or by using --schema"
)
indent = options.get('indent')
schema_dict = {'data': schema.introspect()}
indent = options.get("indent")
schema_dict = {"data": schema.introspect()}
self.save_file(out, schema_dict, indent)
style = getattr(self, 'style', None)
success = getattr(style, 'SUCCESS', lambda x: x)
style = getattr(self, "style", None)
success = getattr(style, "SUCCESS", lambda x: x)
self.stdout.write(success('Successfully dumped GraphQL schema to %s' % out))
self.stdout.write(success("Successfully dumped GraphQL schema to %s" % out))

View File

@ -1,20 +1,21 @@
class Registry(object):
def __init__(self):
self._registry = {}
self._field_registry = {}
def register(self, cls):
from .types import DjangoObjectType
assert issubclass(
cls, DjangoObjectType), 'Only DjangoObjectTypes can be registered, received "{}"'.format(
cls.__name__)
assert cls._meta.registry == self, 'Registry for a Model have to match.'
cls, DjangoObjectType
), 'Only DjangoObjectTypes can be registered, received "{}"'.format(
cls.__name__
)
assert cls._meta.registry == self, "Registry for a Model have to match."
# assert self.get_type_for_model(cls._meta.model) == cls, (
# 'Multiple DjangoObjectTypes registered for "{}"'.format(cls._meta.model)
# )
if not getattr(cls._meta, 'skip_registry', False):
if not getattr(cls._meta, "skip_registry", False):
self._registry[cls._meta.model] = cls
def get_type_for_model(self, model):

View File

@ -6,20 +6,16 @@ import graphene
from graphene.types import Field, InputField
from graphene.types.mutation import MutationOptions
from graphene.relay.mutation import ClientIDMutation
from graphene.types.objecttype import (
yank_fields_from_attrs
)
from graphene.types.objecttype import yank_fields_from_attrs
from .serializer_converter import (
convert_serializer_field
)
from .serializer_converter import convert_serializer_field
from .types import ErrorType
class SerializerMutationOptions(MutationOptions):
lookup_field = None
model_class = None
model_operations = ['create', 'update']
model_operations = ["create", "update"]
serializer_class = None
@ -28,7 +24,8 @@ def fields_for_serializer(serializer, only_fields, exclude_fields, is_input=Fals
for name, field in serializer.fields.items():
is_not_in_only = only_fields and name not in only_fields
is_excluded = (
name in exclude_fields # or
name
in exclude_fields # or
# name in already_created_fields
)
@ -44,49 +41,54 @@ class SerializerMutation(ClientIDMutation):
abstract = True
errors = graphene.List(
ErrorType,
description='May contain more than one error for same field.'
ErrorType, description="May contain more than one error for same field."
)
@classmethod
def __init_subclass_with_meta__(cls, lookup_field=None,
serializer_class=None, model_class=None,
model_operations=['create', 'update'],
only_fields=(), exclude_fields=(), **options):
def __init_subclass_with_meta__(
cls,
lookup_field=None,
serializer_class=None,
model_class=None,
model_operations=["create", "update"],
only_fields=(),
exclude_fields=(),
**options
):
if not serializer_class:
raise Exception('serializer_class is required for the SerializerMutation')
raise Exception("serializer_class is required for the SerializerMutation")
if 'update' not in model_operations and 'create' not in model_operations:
if "update" not in model_operations and "create" not in model_operations:
raise Exception('model_operations must contain "create" and/or "update"')
serializer = serializer_class()
if model_class is None:
serializer_meta = getattr(serializer_class, 'Meta', None)
serializer_meta = getattr(serializer_class, "Meta", None)
if serializer_meta:
model_class = getattr(serializer_meta, 'model', None)
model_class = getattr(serializer_meta, "model", None)
if lookup_field is None and model_class:
lookup_field = model_class._meta.pk.name
input_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=True)
output_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=False)
input_fields = fields_for_serializer(
serializer, only_fields, exclude_fields, is_input=True
)
output_fields = fields_for_serializer(
serializer, only_fields, exclude_fields, is_input=False
)
_meta = SerializerMutationOptions(cls)
_meta.lookup_field = lookup_field
_meta.model_operations = model_operations
_meta.serializer_class = serializer_class
_meta.model_class = model_class
_meta.fields = yank_fields_from_attrs(
output_fields,
_as=Field,
)
_meta.fields = yank_fields_from_attrs(output_fields, _as=Field)
input_fields = yank_fields_from_attrs(
input_fields,
_as=InputField,
input_fields = yank_fields_from_attrs(input_fields, _as=InputField)
super(SerializerMutation, cls).__init_subclass_with_meta__(
_meta=_meta, input_fields=input_fields, **options
)
super(SerializerMutation, cls).__init_subclass_with_meta__(_meta=_meta, input_fields=input_fields, **options)
@classmethod
def get_serializer_kwargs(cls, root, info, **input):
@ -94,24 +96,26 @@ class SerializerMutation(ClientIDMutation):
model_class = cls._meta.model_class
if model_class:
if 'update' in cls._meta.model_operations and lookup_field in input:
instance = get_object_or_404(model_class, **{
lookup_field: input[lookup_field]})
elif 'create' in cls._meta.model_operations:
if "update" in cls._meta.model_operations and lookup_field in input:
instance = get_object_or_404(
model_class, **{lookup_field: input[lookup_field]}
)
elif "create" in cls._meta.model_operations:
instance = None
else:
raise Exception(
'Invalid update operation. Input parameter "{}" required.'.format(
lookup_field
))
)
)
return {
'instance': instance,
'data': input,
'context': {'request': info.context}
"instance": instance,
"data": input,
"context": {"request": info.context},
}
return {'data': input, 'context': {'request': info.context}}
return {"data": input, "context": {"request": info.context}}
@classmethod
def mutate_and_get_payload(cls, root, info, **input):

View File

@ -28,15 +28,12 @@ def convert_serializer_field(field, is_input=True):
graphql_type = get_graphene_type_from_serializer_field(field)
args = []
kwargs = {
'description': field.help_text,
'required': is_input and field.required,
}
kwargs = {"description": field.help_text, "required": is_input and field.required}
# if it is a tuple or a list it means that we are returning
# the graphql type and the child type
if isinstance(graphql_type, (list, tuple)):
kwargs['of_type'] = graphql_type[1]
kwargs["of_type"] = graphql_type[1]
graphql_type = graphql_type[0]
if isinstance(field, serializers.ModelSerializer):
@ -49,9 +46,9 @@ def convert_serializer_field(field, is_input=True):
elif isinstance(field, serializers.ListSerializer):
field = field.child
if is_input:
kwargs['of_type'] = convert_serializer_to_input_type(field.__class__)
kwargs["of_type"] = convert_serializer_to_input_type(field.__class__)
else:
del kwargs['of_type']
del kwargs["of_type"]
global_registry = get_global_registry()
field_model = field.Meta.model
args = [global_registry.get_type_for_model(field_model)]
@ -68,9 +65,9 @@ def convert_serializer_to_input_type(serializer_class):
}
return type(
'{}Input'.format(serializer.__class__.__name__),
"{}Input".format(serializer.__class__.__name__),
(graphene.InputObjectType,),
items
items,
)

View File

@ -16,8 +16,8 @@ def _get_type(rest_framework_field, is_input=True, **kwargs):
# Remove `source=` from the field declaration.
# since we are reusing the same child in when testing the required attribute
if 'child' in kwargs:
kwargs['child'] = copy.deepcopy(kwargs['child'])
if "child" in kwargs:
kwargs["child"] = copy.deepcopy(kwargs["child"])
field = rest_framework_field(**kwargs)
@ -25,11 +25,13 @@ def _get_type(rest_framework_field, is_input=True, **kwargs):
def assert_conversion(rest_framework_field, graphene_field, **kwargs):
graphene_type = _get_type(rest_framework_field, help_text='Custom Help Text', **kwargs)
graphene_type = _get_type(
rest_framework_field, help_text="Custom Help Text", **kwargs
)
assert isinstance(graphene_type, graphene_field)
graphene_type_required = _get_type(
rest_framework_field, help_text='Custom Help Text', required=True, **kwargs
rest_framework_field, help_text="Custom Help Text", required=True, **kwargs
)
assert isinstance(graphene_type_required, graphene_field)
@ -39,7 +41,7 @@ def assert_conversion(rest_framework_field, graphene_field, **kwargs):
def test_should_unknown_rest_framework_field_raise_exception():
with raises(Exception) as excinfo:
convert_serializer_field(None)
assert 'Don\'t know how to convert the serializer field' in str(excinfo.value)
assert "Don't know how to convert the serializer field" in str(excinfo.value)
def test_should_char_convert_string():
@ -67,11 +69,11 @@ def test_should_base_field_convert_string():
def test_should_regex_convert_string():
assert_conversion(serializers.RegexField, graphene.String, regex='[0-9]+')
assert_conversion(serializers.RegexField, graphene.String, regex="[0-9]+")
def test_should_uuid_convert_string():
if hasattr(serializers, 'UUIDField'):
if hasattr(serializers, "UUIDField"):
assert_conversion(serializers.UUIDField, graphene.String)
@ -79,7 +81,7 @@ def test_should_model_convert_field():
class MyModelSerializer(serializers.ModelSerializer):
class Meta:
model = None
fields = '__all__'
fields = "__all__"
assert_conversion(MyModelSerializer, graphene.Field, is_input=False)
@ -109,7 +111,9 @@ def test_should_float_convert_float():
def test_should_decimal_convert_float():
assert_conversion(serializers.DecimalField, graphene.Float, max_digits=4, decimal_places=2)
assert_conversion(
serializers.DecimalField, graphene.Float, max_digits=4, decimal_places=2
)
def test_should_list_convert_to_list():
@ -119,7 +123,7 @@ def test_should_list_convert_to_list():
field_a = assert_conversion(
serializers.ListField,
graphene.List,
child=serializers.IntegerField(min_value=0, max_value=100)
child=serializers.IntegerField(min_value=0, max_value=100),
)
assert field_a.of_type == graphene.Int
@ -136,19 +140,23 @@ def test_should_list_serializer_convert_to_list():
class ChildSerializer(serializers.ModelSerializer):
class Meta:
model = FooModel
fields = '__all__'
fields = "__all__"
class ParentSerializer(serializers.ModelSerializer):
child = ChildSerializer(many=True)
class Meta:
model = FooModel
fields = '__all__'
fields = "__all__"
converted_type = convert_serializer_field(ParentSerializer().get_fields()['child'], is_input=True)
converted_type = convert_serializer_field(
ParentSerializer().get_fields()["child"], is_input=True
)
assert isinstance(converted_type, graphene.List)
converted_type = convert_serializer_field(ParentSerializer().get_fields()['child'], is_input=False)
converted_type = convert_serializer_field(
ParentSerializer().get_fields()["child"], is_input=False
)
assert isinstance(converted_type, graphene.List)
assert converted_type.of_type is None
@ -166,7 +174,7 @@ def test_should_file_convert_string():
def test_should_filepath_convert_string():
assert_conversion(serializers.FilePathField, graphene.String, path='/')
assert_conversion(serializers.FilePathField, graphene.String, path="/")
def test_should_ip_convert_string():
@ -182,6 +190,8 @@ def test_should_json_convert_jsonstring():
def test_should_multiplechoicefield_convert_to_list_of_string():
field = assert_conversion(serializers.MultipleChoiceField, graphene.List, choices=[1, 2, 3])
field = assert_conversion(
serializers.MultipleChoiceField, graphene.List, choices=[1, 2, 3]
)
assert field.of_type == graphene.String

View File

@ -10,30 +10,33 @@ from ...types import DjangoObjectType
from ..models import MyFakeModel
from ..mutation import SerializerMutation
def mock_info():
return ResolveInfo(
None,
None,
None,
None,
schema=None,
fragments=None,
root_value=None,
operation=None,
variable_values=None,
context=None
)
return ResolveInfo(
None,
None,
None,
None,
schema=None,
fragments=None,
root_value=None,
operation=None,
variable_values=None,
context=None,
)
class MyModelSerializer(serializers.ModelSerializer):
class Meta:
model = MyFakeModel
fields = '__all__'
fields = "__all__"
class MyModelMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
class MySerializer(serializers.Serializer):
text = serializers.CharField()
model = MyModelSerializer()
@ -44,10 +47,11 @@ class MySerializer(serializers.Serializer):
def test_needs_serializer_class():
with raises(Exception) as exc:
class MyMutation(SerializerMutation):
pass
assert str(exc.value) == 'serializer_class is required for the SerializerMutation'
assert str(exc.value) == "serializer_class is required for the SerializerMutation"
def test_has_fields():
@ -55,9 +59,9 @@ def test_has_fields():
class Meta:
serializer_class = MySerializer
assert 'text' in MyMutation._meta.fields
assert 'model' in MyMutation._meta.fields
assert 'errors' in MyMutation._meta.fields
assert "text" in MyMutation._meta.fields
assert "model" in MyMutation._meta.fields
assert "errors" in MyMutation._meta.fields
def test_has_input_fields():
@ -65,25 +69,24 @@ def test_has_input_fields():
class Meta:
serializer_class = MySerializer
assert 'text' in MyMutation.Input._meta.fields
assert 'model' in MyMutation.Input._meta.fields
assert "text" in MyMutation.Input._meta.fields
assert "model" in MyMutation.Input._meta.fields
def test_exclude_fields():
class MyMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
exclude_fields = ['created']
exclude_fields = ["created"]
assert 'cool_name' in MyMutation._meta.fields
assert 'created' not in MyMutation._meta.fields
assert 'errors' in MyMutation._meta.fields
assert 'cool_name' in MyMutation.Input._meta.fields
assert 'created' not in MyMutation.Input._meta.fields
assert "cool_name" in MyMutation._meta.fields
assert "created" not in MyMutation._meta.fields
assert "errors" in MyMutation._meta.fields
assert "cool_name" in MyMutation.Input._meta.fields
assert "created" not in MyMutation.Input._meta.fields
def test_nested_model():
class MyFakeModelGrapheneType(DjangoObjectType):
class Meta:
model = MyFakeModel
@ -92,67 +95,64 @@ def test_nested_model():
class Meta:
serializer_class = MySerializer
model_field = MyMutation._meta.fields['model']
model_field = MyMutation._meta.fields["model"]
assert isinstance(model_field, Field)
assert model_field.type == MyFakeModelGrapheneType
model_input = MyMutation.Input._meta.fields['model']
model_input = MyMutation.Input._meta.fields["model"]
model_input_type = model_input._type.of_type
assert issubclass(model_input_type, InputObjectType)
assert 'cool_name' in model_input_type._meta.fields
assert 'created' in model_input_type._meta.fields
assert "cool_name" in model_input_type._meta.fields
assert "created" in model_input_type._meta.fields
def test_mutate_and_get_payload_success():
class MyMutation(SerializerMutation):
class Meta:
serializer_class = MySerializer
result = MyMutation.mutate_and_get_payload(None, mock_info(), **{
'text': 'value',
'model': {
'cool_name': 'other_value'
}
})
result = MyMutation.mutate_and_get_payload(
None, mock_info(), **{"text": "value", "model": {"cool_name": "other_value"}}
)
assert result.errors is None
@mark.django_db
def test_model_add_mutate_and_get_payload_success():
result = MyModelMutation.mutate_and_get_payload(None, mock_info(), **{
'cool_name': 'Narf',
})
result = MyModelMutation.mutate_and_get_payload(
None, mock_info(), **{"cool_name": "Narf"}
)
assert result.errors is None
assert result.cool_name == 'Narf'
assert result.cool_name == "Narf"
assert isinstance(result.created, datetime.datetime)
@mark.django_db
def test_model_update_mutate_and_get_payload_success():
instance = MyFakeModel.objects.create(cool_name="Narf")
result = MyModelMutation.mutate_and_get_payload(None, mock_info(), **{
'id': instance.id,
'cool_name': 'New Narf',
})
result = MyModelMutation.mutate_and_get_payload(
None, mock_info(), **{"id": instance.id, "cool_name": "New Narf"}
)
assert result.errors is None
assert result.cool_name == 'New Narf'
assert result.cool_name == "New Narf"
@mark.django_db
def test_model_invalid_update_mutate_and_get_payload_success():
class InvalidModelMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
model_operations = ['update']
model_operations = ["update"]
with raises(Exception) as exc:
result = InvalidModelMutation.mutate_and_get_payload(None, mock_info(), **{
'cool_name': 'Narf',
})
result = InvalidModelMutation.mutate_and_get_payload(
None, mock_info(), **{"cool_name": "Narf"}
)
assert '"id" required' in str(exc.value)
def test_mutate_and_get_payload_error():
def test_mutate_and_get_payload_error():
class MyMutation(SerializerMutation):
class Meta:
serializer_class = MySerializer
@ -161,16 +161,19 @@ def test_mutate_and_get_payload_error():
result = MyMutation.mutate_and_get_payload(None, mock_info(), **{})
assert len(result.errors) > 0
def test_model_mutate_and_get_payload_error():
# missing required fields
result = MyModelMutation.mutate_and_get_payload(None, mock_info(), **{})
assert len(result.errors) > 0
def test_invalid_serializer_operations():
with raises(Exception) as exc:
class MyModelMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
model_operations = ['Add']
model_operations = ["Add"]
assert 'model_operations' in str(exc.value)
assert "model_operations" in str(exc.value)

View File

@ -26,27 +26,22 @@ except ImportError:
# Copied shamelessly from Django REST Framework
DEFAULTS = {
'SCHEMA': None,
'SCHEMA_OUTPUT': 'schema.json',
'SCHEMA_INDENT': None,
'MIDDLEWARE': (),
"SCHEMA": None,
"SCHEMA_OUTPUT": "schema.json",
"SCHEMA_INDENT": None,
"MIDDLEWARE": (),
# Set to True if the connection fields must have
# either the first or last argument
'RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST': False,
"RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST": False,
# Max items returned in ConnectionFields / FilterConnectionFields
'RELAY_CONNECTION_MAX_LIMIT': 100,
"RELAY_CONNECTION_MAX_LIMIT": 100,
}
if settings.DEBUG:
DEFAULTS['MIDDLEWARE'] += (
'graphene_django.debug.DjangoDebugMiddleware',
)
DEFAULTS["MIDDLEWARE"] += ("graphene_django.debug.DjangoDebugMiddleware",)
# List of settings that may be in string import notation.
IMPORT_STRINGS = (
'MIDDLEWARE',
'SCHEMA',
)
IMPORT_STRINGS = ("MIDDLEWARE", "SCHEMA")
def perform_import(val, setting_name):
@ -69,12 +64,17 @@ def import_from_string(val, setting_name):
"""
try:
# Nod to tastypie's use of importlib.
parts = val.split('.')
module_path, class_name = '.'.join(parts[:-1]), parts[-1]
parts = val.split(".")
module_path, class_name = ".".join(parts[:-1]), parts[-1]
module = importlib.import_module(module_path)
return getattr(module, class_name)
except (ImportError, AttributeError) as e:
msg = "Could not import '%s' for Graphene setting '%s'. %s: %s." % (val, setting_name, e.__class__.__name__, e)
msg = "Could not import '%s' for Graphene setting '%s'. %s: %s." % (
val,
setting_name,
e.__class__.__name__,
e,
)
raise ImportError(msg)
@ -96,8 +96,8 @@ class GrapheneSettings(object):
@property
def user_settings(self):
if not hasattr(self, '_user_settings'):
self._user_settings = getattr(settings, 'GRAPHENE', {})
if not hasattr(self, "_user_settings"):
self._user_settings = getattr(settings, "GRAPHENE", {})
return self._user_settings
def __getattr__(self, attr):
@ -125,8 +125,8 @@ graphene_settings = GrapheneSettings(None, DEFAULTS, IMPORT_STRINGS)
def reload_graphene_settings(*args, **kwargs):
global graphene_settings
setting, value = kwargs['setting'], kwargs['value']
if setting == 'GRAPHENE':
setting, value = kwargs["setting"], kwargs["value"]
if setting == "GRAPHENE":
graphene_settings = GrapheneSettings(value, DEFAULTS, IMPORT_STRINGS)

View File

@ -3,10 +3,7 @@ from __future__ import absolute_import
from django.db import models
from django.utils.translation import ugettext_lazy as _
CHOICES = (
(1, 'this'),
(2, _('that'))
)
CHOICES = ((1, "this"), (2, _("that")))
class Pet(models.Model):
@ -15,38 +12,43 @@ class Pet(models.Model):
class FilmDetails(models.Model):
location = models.CharField(max_length=30)
film = models.OneToOneField('Film', on_delete=models.CASCADE, related_name='details')
film = models.OneToOneField(
"Film", on_delete=models.CASCADE, related_name="details"
)
class Film(models.Model):
genre = models.CharField(max_length=2, help_text='Genre', choices=[
('do', 'Documentary'),
('ot', 'Other')
], default='ot')
reporters = models.ManyToManyField('Reporter',
related_name='films')
genre = models.CharField(
max_length=2,
help_text="Genre",
choices=[("do", "Documentary"), ("ot", "Other")],
default="ot",
)
reporters = models.ManyToManyField("Reporter", related_name="films")
class DoeReporterManager(models.Manager):
def get_queryset(self):
return super(DoeReporterManager, self).get_queryset().filter(last_name="Doe")
class Reporter(models.Model):
first_name = models.CharField(max_length=30)
last_name = models.CharField(max_length=30)
email = models.EmailField()
pets = models.ManyToManyField('self')
pets = models.ManyToManyField("self")
a_choice = models.CharField(max_length=30, choices=CHOICES)
objects = models.Manager()
doe_objects = DoeReporterManager()
reporter_type = models.IntegerField(
'Reporter Type',
"Reporter Type",
null=True,
blank=True,
choices=[(1, u'Regular'), (2, u'CNN Reporter')]
choices=[(1, u"Regular"), (2, u"CNN Reporter")],
)
def __str__(self): # __unicode__ on Python 2
def __str__(self): # __unicode__ on Python 2
return "%s %s" % (self.first_name, self.last_name)
def __init__(self, *args, **kwargs):
@ -61,11 +63,13 @@ class Reporter(models.Model):
if self.reporter_type == 2: # quick and dirty way without enums
self.__class__ = CNNReporter
class CNNReporter(Reporter):
"""
This class is a proxy model for Reporter, used for testing
proxy model support
"""
class Meta:
proxy = True
@ -74,17 +78,27 @@ class Article(models.Model):
headline = models.CharField(max_length=100)
pub_date = models.DateField()
pub_date_time = models.DateTimeField()
reporter = models.ForeignKey(Reporter, on_delete=models.CASCADE, related_name='articles')
editor = models.ForeignKey(Reporter, on_delete=models.CASCADE, related_name='edited_articles_+')
lang = models.CharField(max_length=2, help_text='Language', choices=[
('es', 'Spanish'),
('en', 'English')
], default='es')
importance = models.IntegerField('Importance', null=True, blank=True,
choices=[(1, u'Very important'), (2, u'Not as important')])
reporter = models.ForeignKey(
Reporter, on_delete=models.CASCADE, related_name="articles"
)
editor = models.ForeignKey(
Reporter, on_delete=models.CASCADE, related_name="edited_articles_+"
)
lang = models.CharField(
max_length=2,
help_text="Language",
choices=[("es", "Spanish"), ("en", "English")],
default="es",
)
importance = models.IntegerField(
"Importance",
null=True,
blank=True,
choices=[(1, u"Very important"), (2, u"Not as important")],
)
def __str__(self): # __unicode__ on Python 2
def __str__(self): # __unicode__ on Python 2
return self.headline
class Meta:
ordering = ('headline',)
ordering = ("headline",)

View File

@ -6,10 +6,9 @@ from .models import Article, Reporter
class Character(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (relay.Node, )
interfaces = (relay.Node,)
def get_node(self, info, id):
pass
@ -20,7 +19,7 @@ class Human(DjangoObjectType):
class Meta:
model = Article
interfaces = (relay.Node, )
interfaces = (relay.Node,)
def resolve_raises(self, info):
raise Exception("This field should raise exception")

View File

@ -12,10 +12,10 @@ class QueryRoot(ObjectType):
raise Exception("Throws!")
def resolve_request(self, info):
return info.context.GET.get('q')
return info.context.GET.get("q")
def resolve_test(self, info, who=None):
return 'Hello %s' % (who or 'World')
return "Hello %s" % (who or "World")
class MutationRoot(ObjectType):

View File

@ -3,8 +3,8 @@ from mock import patch
from six import StringIO
@patch('graphene_django.management.commands.graphql_schema.Command.save_file')
@patch("graphene_django.management.commands.graphql_schema.Command.save_file")
def test_generate_file_on_call_graphql_schema(savefile_mock, settings):
out = StringIO()
management.call_command('graphql_schema', schema='', stdout=out)
management.call_command("graphql_schema", schema="", stdout=out)
assert "Successfully dumped GraphQL schema to schema.json" in out.getvalue()

View File

@ -19,11 +19,11 @@ from .models import Article, Film, FilmDetails, Reporter
def assert_conversion(django_field, graphene_field, *args, **kwargs):
field = django_field(help_text='Custom Help Text', null=True, *args, **kwargs)
field = django_field(help_text="Custom Help Text", null=True, *args, **kwargs)
graphene_type = convert_django_field(field)
assert isinstance(graphene_type, graphene_field)
field = graphene_type.Field()
assert field.description == 'Custom Help Text'
assert field.description == "Custom Help Text"
nonnull_field = django_field(null=False, *args, **kwargs)
if not nonnull_field.null:
nonnull_graphene_type = convert_django_field(nonnull_field)
@ -36,7 +36,8 @@ def assert_conversion(django_field, graphene_field, *args, **kwargs):
def test_should_unknown_django_field_raise_exception():
with raises(Exception) as excinfo:
convert_django_field(None)
assert 'Don\'t know how to convert the Django field' in str(excinfo.value)
assert "Don't know how to convert the Django field" in str(excinfo.value)
def test_should_date_time_convert_string():
assert_conversion(models.DateTimeField, DateTime)
@ -128,70 +129,69 @@ def test_should_nullboolean_convert_boolean():
def test_field_with_choices_convert_enum():
field = models.CharField(help_text='Language', choices=(
('es', 'Spanish'),
('en', 'English')
))
field = models.CharField(
help_text="Language", choices=(("es", "Spanish"), ("en", "English"))
)
class TranslatedModel(models.Model):
language = field
class Meta:
app_label = 'test'
app_label = "test"
graphene_type = convert_django_field_with_choices(field)
assert isinstance(graphene_type, graphene.Enum)
assert graphene_type._meta.name == 'TranslatedModelLanguage'
assert graphene_type._meta.enum.__members__['ES'].value == 'es'
assert graphene_type._meta.enum.__members__['ES'].description == 'Spanish'
assert graphene_type._meta.enum.__members__['EN'].value == 'en'
assert graphene_type._meta.enum.__members__['EN'].description == 'English'
assert graphene_type._meta.name == "TranslatedModelLanguage"
assert graphene_type._meta.enum.__members__["ES"].value == "es"
assert graphene_type._meta.enum.__members__["ES"].description == "Spanish"
assert graphene_type._meta.enum.__members__["EN"].value == "en"
assert graphene_type._meta.enum.__members__["EN"].description == "English"
def test_field_with_grouped_choices():
field = models.CharField(help_text='Language', choices=(
('Europe', (
('es', 'Spanish'),
('en', 'English'),
)),
))
field = models.CharField(
help_text="Language",
choices=(("Europe", (("es", "Spanish"), ("en", "English"))),),
)
class GroupedChoicesModel(models.Model):
language = field
class Meta:
app_label = 'test'
app_label = "test"
convert_django_field_with_choices(field)
def test_field_with_choices_gettext():
field = models.CharField(help_text='Language', choices=(
('es', _('Spanish')),
('en', _('English'))
))
field = models.CharField(
help_text="Language", choices=(("es", _("Spanish")), ("en", _("English")))
)
class TranslatedChoicesModel(models.Model):
language = field
class Meta:
app_label = 'test'
app_label = "test"
convert_django_field_with_choices(field)
def test_field_with_choices_collision():
field = models.CharField(help_text='Timezone', choices=(
('Etc/GMT+1+2', 'Fake choice to produce double collision'),
('Etc/GMT+1', 'Greenwich Mean Time +1'),
('Etc/GMT-1', 'Greenwich Mean Time -1'),
))
field = models.CharField(
help_text="Timezone",
choices=(
("Etc/GMT+1+2", "Fake choice to produce double collision"),
("Etc/GMT+1", "Greenwich Mean Time +1"),
("Etc/GMT-1", "Greenwich Mean Time -1"),
),
)
class CollisionChoicesModel(models.Model):
timezone = field
class Meta:
app_label = 'test'
app_label = "test"
convert_django_field_with_choices(field)
@ -208,11 +208,12 @@ def test_should_manytomany_convert_connectionorlist():
def test_should_manytomany_convert_connectionorlist_list():
class A(DjangoObjectType):
class Meta:
model = Reporter
graphene_field = convert_django_field(Reporter._meta.local_many_to_many[0], A._meta.registry)
graphene_field = convert_django_field(
Reporter._meta.local_many_to_many[0], A._meta.registry
)
assert isinstance(graphene_field, graphene.Dynamic)
dynamic_field = graphene_field.get_type()
assert isinstance(dynamic_field, graphene.Field)
@ -222,12 +223,13 @@ def test_should_manytomany_convert_connectionorlist_list():
def test_should_manytomany_convert_connectionorlist_connection():
class A(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
interfaces = (Node,)
graphene_field = convert_django_field(Reporter._meta.local_many_to_many[0], A._meta.registry)
graphene_field = convert_django_field(
Reporter._meta.local_many_to_many[0], A._meta.registry
)
assert isinstance(graphene_field, graphene.Dynamic)
dynamic_field = graphene_field.get_type()
assert isinstance(dynamic_field, ConnectionField)
@ -236,11 +238,11 @@ def test_should_manytomany_convert_connectionorlist_connection():
def test_should_manytoone_convert_connectionorlist():
# Django 1.9 uses 'rel', <1.9 uses 'related
related = getattr(Reporter.articles, 'rel', None) or \
getattr(Reporter.articles, 'related')
related = getattr(Reporter.articles, "rel", None) or getattr(
Reporter.articles, "related"
)
class A(DjangoObjectType):
class Meta:
model = Article
@ -254,11 +256,9 @@ def test_should_manytoone_convert_connectionorlist():
def test_should_onetoone_reverse_convert_model():
# Django 1.9 uses 'rel', <1.9 uses 'related
related = getattr(Film.details, 'rel', None) or \
getattr(Film.details, 'related')
related = getattr(Film.details, "rel", None) or getattr(Film.details, "related")
class A(DjangoObjectType):
class Meta:
model = FilmDetails
@ -269,41 +269,41 @@ def test_should_onetoone_reverse_convert_model():
assert dynamic_field.type == A
@pytest.mark.skipif(ArrayField is MissingType,
reason="ArrayField should exist")
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_should_postgres_array_convert_list():
field = assert_conversion(ArrayField, graphene.List, models.CharField(max_length=100))
field = assert_conversion(
ArrayField, graphene.List, models.CharField(max_length=100)
)
assert isinstance(field.type, graphene.NonNull)
assert isinstance(field.type.of_type, graphene.List)
assert field.type.of_type.of_type == graphene.String
@pytest.mark.skipif(ArrayField is MissingType,
reason="ArrayField should exist")
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_should_postgres_array_multiple_convert_list():
field = assert_conversion(ArrayField, graphene.List, ArrayField(models.CharField(max_length=100)))
field = assert_conversion(
ArrayField, graphene.List, ArrayField(models.CharField(max_length=100))
)
assert isinstance(field.type, graphene.NonNull)
assert isinstance(field.type.of_type, graphene.List)
assert isinstance(field.type.of_type.of_type, graphene.List)
assert field.type.of_type.of_type.of_type == graphene.String
@pytest.mark.skipif(HStoreField is MissingType,
reason="HStoreField should exist")
@pytest.mark.skipif(HStoreField is MissingType, reason="HStoreField should exist")
def test_should_postgres_hstore_convert_string():
assert_conversion(HStoreField, JSONString)
@pytest.mark.skipif(JSONField is MissingType,
reason="JSONField should exist")
@pytest.mark.skipif(JSONField is MissingType, reason="JSONField should exist")
def test_should_postgres_json_convert_string():
assert_conversion(JSONField, JSONString)
@pytest.mark.skipif(RangeField is MissingType,
reason="RangeField should exist")
@pytest.mark.skipif(RangeField is MissingType, reason="RangeField should exist")
def test_should_postgres_range_convert_list():
from django.contrib.postgres.fields import IntegerRangeField
field = assert_conversion(IntegerRangeField, graphene.List)
assert isinstance(field.type, graphene.NonNull)
assert isinstance(field.type.of_type, graphene.List)

View File

@ -1,7 +1,7 @@
from django.core.exceptions import ValidationError
from py.test import raises
from ..forms import GlobalIDFormField,GlobalIDMultipleChoiceField
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
# 'TXlUeXBlOmFiYw==' -> 'MyType', 'abc'
@ -9,24 +9,24 @@ from ..forms import GlobalIDFormField,GlobalIDMultipleChoiceField
def test_global_id_valid():
field = GlobalIDFormField()
field.clean('TXlUeXBlOmFiYw==')
field.clean("TXlUeXBlOmFiYw==")
def test_global_id_invalid():
field = GlobalIDFormField()
with raises(ValidationError):
field.clean('badvalue')
field.clean("badvalue")
def test_global_id_multiple_valid():
field = GlobalIDMultipleChoiceField()
field.clean(['TXlUeXBlOmFiYw==', 'TXlUeXBlOmFiYw=='])
field.clean(["TXlUeXBlOmFiYw==", "TXlUeXBlOmFiYw=="])
def test_global_id_multiple_invalid():
field = GlobalIDMultipleChoiceField()
with raises(ValidationError):
field.clean(['badvalue', 'another bad avue'])
field.clean(["badvalue", "another bad avue"])
def test_global_id_none():

File diff suppressed because it is too large Load Diff

View File

@ -7,48 +7,47 @@ from .models import Reporter
def test_should_raise_if_no_model():
with raises(Exception) as excinfo:
class Character1(DjangoObjectType):
pass
assert 'valid Django Model' in str(excinfo.value)
assert "valid Django Model" in str(excinfo.value)
def test_should_raise_if_model_is_invalid():
with raises(Exception) as excinfo:
class Character2(DjangoObjectType):
class Character2(DjangoObjectType):
class Meta:
model = 1
assert 'valid Django Model' in str(excinfo.value)
assert "valid Django Model" in str(excinfo.value)
def test_should_map_fields_correctly():
class ReporterType2(DjangoObjectType):
class Meta:
model = Reporter
registry = Registry()
fields = list(ReporterType2._meta.fields.keys())
assert fields[:-2] == [
'id',
'first_name',
'last_name',
'email',
'pets',
'a_choice',
'reporter_type'
"id",
"first_name",
"last_name",
"email",
"pets",
"a_choice",
"reporter_type",
]
assert sorted(fields[-2:]) == [
'articles',
'films',
]
assert sorted(fields[-2:]) == ["articles", "films"]
def test_should_map_only_few_fields():
class Reporter2(DjangoObjectType):
class Meta:
model = Reporter
only_fields = ('id', 'email')
only_fields = ("id", "email")
assert list(Reporter2._meta.fields.keys()) == ['id', 'email']
assert list(Reporter2._meta.fields.keys()) == ["id", "email"]

View File

@ -12,27 +12,30 @@ registry.reset_global_registry()
class Reporter(DjangoObjectType):
'''Reporter description'''
"""Reporter description"""
class Meta:
model = ReporterModel
class ArticleConnection(Connection):
'''Article Connection'''
"""Article Connection"""
test = String()
def resolve_test():
return 'test'
return "test"
class Meta:
abstract = True
class Article(DjangoObjectType):
'''Article description'''
"""Article description"""
class Meta:
model = ArticleModel
interfaces = (Node, )
interfaces = (Node,)
connection_class = ArticleConnection
@ -48,7 +51,7 @@ def test_django_interface():
assert issubclass(Node, Node)
@patch('graphene_django.tests.models.Article.objects.get', return_value=Article(id=1))
@patch("graphene_django.tests.models.Article.objects.get", return_value=Article(id=1))
def test_django_get_node(get):
article = Article.get_node(None, 1)
get.assert_called_with(pk=1)
@ -58,18 +61,35 @@ def test_django_get_node(get):
def test_django_objecttype_map_correct_fields():
fields = Reporter._meta.fields
fields = list(fields.keys())
assert fields[:-2] == ['id', 'first_name', 'last_name', 'email', 'pets', 'a_choice', 'reporter_type']
assert sorted(fields[-2:]) == ['articles', 'films']
assert fields[:-2] == [
"id",
"first_name",
"last_name",
"email",
"pets",
"a_choice",
"reporter_type",
]
assert sorted(fields[-2:]) == ["articles", "films"]
def test_django_objecttype_with_node_have_correct_fields():
fields = Article._meta.fields
assert list(fields.keys()) == ['id', 'headline', 'pub_date', 'pub_date_time', 'reporter', 'editor', 'lang', 'importance']
assert list(fields.keys()) == [
"id",
"headline",
"pub_date",
"pub_date_time",
"reporter",
"editor",
"lang",
"importance",
]
def test_django_objecttype_with_custom_meta():
class ArticleTypeOptions(DjangoObjectTypeOptions):
'''Article Type Options'''
"""Article Type Options"""
class ArticleType(DjangoObjectType):
class Meta:
@ -77,7 +97,7 @@ def test_django_objecttype_with_custom_meta():
@classmethod
def __init_subclass_with_meta__(cls, **options):
options.setdefault('_meta', ArticleTypeOptions(cls))
options.setdefault("_meta", ArticleTypeOptions(cls))
super(ArticleType, cls).__init_subclass_with_meta__(**options)
class Article(ArticleType):
@ -180,6 +200,7 @@ def with_local_registry(func):
else:
registry.registry = old
return retval
return inner
@ -188,11 +209,10 @@ def test_django_objecttype_only_fields():
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
only_fields = ('id', 'email', 'films')
only_fields = ("id", "email", "films")
fields = list(Reporter._meta.fields.keys())
assert fields == ['id', 'email', 'films']
assert fields == ["id", "email", "films"]
@with_local_registry
@ -200,8 +220,7 @@ def test_django_objecttype_exclude_fields():
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
exclude_fields = ('email')
exclude_fields = "email"
fields = list(Reporter._meta.fields.keys())
assert 'email' not in fields
assert "email" not in fields

View File

@ -8,15 +8,15 @@ except ImportError:
from urllib.parse import urlencode
def url_string(string='/graphql', **url_params):
def url_string(string="/graphql", **url_params):
if url_params:
string += '?' + urlencode(url_params)
string += "?" + urlencode(url_params)
return string
def batch_url_string(**url_params):
return url_string('/graphql/batch', **url_params)
return url_string("/graphql/batch", **url_params)
def response_json(response):
@ -28,441 +28,446 @@ jl = lambda **kwargs: json.dumps([kwargs])
def test_graphiql_is_enabled(client):
response = client.get(url_string(), HTTP_ACCEPT='text/html')
response = client.get(url_string(), HTTP_ACCEPT="text/html")
assert response.status_code == 200
assert response['Content-Type'].split(';')[0] == 'text/html'
assert response["Content-Type"].split(";")[0] == "text/html"
def test_qfactor_graphiql(client):
response = client.get(url_string(query='{test}'), HTTP_ACCEPT='application/json;q=0.8, text/html;q=0.9')
response = client.get(
url_string(query="{test}"),
HTTP_ACCEPT="application/json;q=0.8, text/html;q=0.9",
)
assert response.status_code == 200
assert response['Content-Type'].split(';')[0] == 'text/html'
assert response["Content-Type"].split(";")[0] == "text/html"
def test_qfactor_json(client):
response = client.get(url_string(query='{test}'), HTTP_ACCEPT='text/html;q=0.8, application/json;q=0.9')
response = client.get(
url_string(query="{test}"),
HTTP_ACCEPT="text/html;q=0.8, application/json;q=0.9",
)
assert response.status_code == 200
assert response['Content-Type'].split(';')[0] == 'application/json'
assert response_json(response) == {
'data': {'test': "Hello World"}
}
assert response["Content-Type"].split(";")[0] == "application/json"
assert response_json(response) == {"data": {"test": "Hello World"}}
def test_allows_get_with_query_param(client):
response = client.get(url_string(query='{test}'))
response = client.get(url_string(query="{test}"))
assert response.status_code == 200
assert response_json(response) == {
'data': {'test': "Hello World"}
}
assert response_json(response) == {"data": {"test": "Hello World"}}
def test_allows_get_with_variable_values(client):
response = client.get(url_string(
query='query helloWho($who: String){ test(who: $who) }',
variables=json.dumps({'who': "Dolly"})
))
response = client.get(
url_string(
query="query helloWho($who: String){ test(who: $who) }",
variables=json.dumps({"who": "Dolly"}),
)
)
assert response.status_code == 200
assert response_json(response) == {
'data': {'test': "Hello Dolly"}
}
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
def test_allows_get_with_operation_name(client):
response = client.get(url_string(
query='''
response = client.get(
url_string(
query="""
query helloYou { test(who: "You"), ...shared }
query helloWorld { test(who: "World"), ...shared }
query helloDolly { test(who: "Dolly"), ...shared }
fragment shared on QueryRoot {
shared: test(who: "Everyone")
}
''',
operationName='helloWorld'
))
""",
operationName="helloWorld",
)
)
assert response.status_code == 200
assert response_json(response) == {
'data': {
'test': 'Hello World',
'shared': 'Hello Everyone'
}
"data": {"test": "Hello World", "shared": "Hello Everyone"}
}
def test_reports_validation_errors(client):
response = client.get(url_string(
query='{ test, unknownOne, unknownTwo }'
))
response = client.get(url_string(query="{ test, unknownOne, unknownTwo }"))
assert response.status_code == 400
assert response_json(response) == {
'errors': [
"errors": [
{
'message': 'Cannot query field "unknownOne" on type "QueryRoot".',
'locations': [{'line': 1, 'column': 9}]
"message": 'Cannot query field "unknownOne" on type "QueryRoot".',
"locations": [{"line": 1, "column": 9}],
},
{
'message': 'Cannot query field "unknownTwo" on type "QueryRoot".',
'locations': [{'line': 1, 'column': 21}]
}
"message": 'Cannot query field "unknownTwo" on type "QueryRoot".',
"locations": [{"line": 1, "column": 21}],
},
]
}
def test_errors_when_missing_operation_name(client):
response = client.get(url_string(
query='''
response = client.get(
url_string(
query="""
query TestQuery { test }
mutation TestMutation { writeTest { test } }
'''
))
"""
)
)
assert response.status_code == 400
assert response_json(response) == {
'errors': [
"errors": [
{
'message': 'Must provide operation name if query contains multiple operations.'
"message": "Must provide operation name if query contains multiple operations."
}
]
}
def test_errors_when_sending_a_mutation_via_get(client):
response = client.get(url_string(
query='''
response = client.get(
url_string(
query="""
mutation TestMutation { writeTest { test } }
'''
))
"""
)
)
assert response.status_code == 405
assert response_json(response) == {
'errors': [
{
'message': 'Can only perform a mutation operation from a POST request.'
}
"errors": [
{"message": "Can only perform a mutation operation from a POST request."}
]
}
def test_errors_when_selecting_a_mutation_within_a_get(client):
response = client.get(url_string(
query='''
response = client.get(
url_string(
query="""
query TestQuery { test }
mutation TestMutation { writeTest { test } }
''',
operationName='TestMutation'
))
""",
operationName="TestMutation",
)
)
assert response.status_code == 405
assert response_json(response) == {
'errors': [
{
'message': 'Can only perform a mutation operation from a POST request.'
}
"errors": [
{"message": "Can only perform a mutation operation from a POST request."}
]
}
def test_allows_mutation_to_exist_within_a_get(client):
response = client.get(url_string(
query='''
response = client.get(
url_string(
query="""
query TestQuery { test }
mutation TestMutation { writeTest { test } }
''',
operationName='TestQuery'
))
""",
operationName="TestQuery",
)
)
assert response.status_code == 200
assert response_json(response) == {
'data': {'test': "Hello World"}
}
assert response_json(response) == {"data": {"test": "Hello World"}}
def test_allows_post_with_json_encoding(client):
response = client.post(url_string(), j(query='{test}'), 'application/json')
response = client.post(url_string(), j(query="{test}"), "application/json")
assert response.status_code == 200
assert response_json(response) == {
'data': {'test': "Hello World"}
}
assert response_json(response) == {"data": {"test": "Hello World"}}
def test_batch_allows_post_with_json_encoding(client):
response = client.post(batch_url_string(), jl(id=1, query='{test}'), 'application/json')
response = client.post(
batch_url_string(), jl(id=1, query="{test}"), "application/json"
)
assert response.status_code == 200
assert response_json(response) == [{
'id': 1,
'data': {'test': "Hello World"},
'status': 200,
}]
assert response_json(response) == [
{"id": 1, "data": {"test": "Hello World"}, "status": 200}
]
def test_batch_fails_if_is_empty(client):
response = client.post(batch_url_string(), '[]', 'application/json')
response = client.post(batch_url_string(), "[]", "application/json")
assert response.status_code == 400
assert response_json(response) == {
'errors': [{'message': 'Received an empty list in the batch request.'}]
"errors": [{"message": "Received an empty list in the batch request."}]
}
def test_allows_sending_a_mutation_via_post(client):
response = client.post(url_string(), j(query='mutation TestMutation { writeTest { test } }'), 'application/json')
response = client.post(
url_string(),
j(query="mutation TestMutation { writeTest { test } }"),
"application/json",
)
assert response.status_code == 200
assert response_json(response) == {
'data': {'writeTest': {'test': 'Hello World'}}
}
assert response_json(response) == {"data": {"writeTest": {"test": "Hello World"}}}
def test_allows_post_with_url_encoding(client):
response = client.post(url_string(), urlencode(dict(query='{test}')), 'application/x-www-form-urlencoded')
response = client.post(
url_string(),
urlencode(dict(query="{test}")),
"application/x-www-form-urlencoded",
)
assert response.status_code == 200
assert response_json(response) == {
'data': {'test': "Hello World"}
}
assert response_json(response) == {"data": {"test": "Hello World"}}
def test_supports_post_json_query_with_string_variables(client):
response = client.post(url_string(), j(
query='query helloWho($who: String){ test(who: $who) }',
variables=json.dumps({'who': "Dolly"})
), 'application/json')
response = client.post(
url_string(),
j(
query="query helloWho($who: String){ test(who: $who) }",
variables=json.dumps({"who": "Dolly"}),
),
"application/json",
)
assert response.status_code == 200
assert response_json(response) == {
'data': {'test': "Hello Dolly"}
}
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
def test_batch_supports_post_json_query_with_string_variables(client):
response = client.post(batch_url_string(), jl(
id=1,
query='query helloWho($who: String){ test(who: $who) }',
variables=json.dumps({'who': "Dolly"})
), 'application/json')
response = client.post(
batch_url_string(),
jl(
id=1,
query="query helloWho($who: String){ test(who: $who) }",
variables=json.dumps({"who": "Dolly"}),
),
"application/json",
)
assert response.status_code == 200
assert response_json(response) == [{
'id': 1,
'data': {'test': "Hello Dolly"},
'status': 200,
}]
assert response_json(response) == [
{"id": 1, "data": {"test": "Hello Dolly"}, "status": 200}
]
def test_supports_post_json_query_with_json_variables(client):
response = client.post(url_string(), j(
query='query helloWho($who: String){ test(who: $who) }',
variables={'who': "Dolly"}
), 'application/json')
response = client.post(
url_string(),
j(
query="query helloWho($who: String){ test(who: $who) }",
variables={"who": "Dolly"},
),
"application/json",
)
assert response.status_code == 200
assert response_json(response) == {
'data': {'test': "Hello Dolly"}
}
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
def test_batch_supports_post_json_query_with_json_variables(client):
response = client.post(batch_url_string(), jl(
id=1,
query='query helloWho($who: String){ test(who: $who) }',
variables={'who': "Dolly"}
), 'application/json')
response = client.post(
batch_url_string(),
jl(
id=1,
query="query helloWho($who: String){ test(who: $who) }",
variables={"who": "Dolly"},
),
"application/json",
)
assert response.status_code == 200
assert response_json(response) == [{
'id': 1,
'data': {'test': "Hello Dolly"},
'status': 200,
}]
assert response_json(response) == [
{"id": 1, "data": {"test": "Hello Dolly"}, "status": 200}
]
def test_supports_post_url_encoded_query_with_string_variables(client):
response = client.post(url_string(), urlencode(dict(
query='query helloWho($who: String){ test(who: $who) }',
variables=json.dumps({'who': "Dolly"})
)), 'application/x-www-form-urlencoded')
response = client.post(
url_string(),
urlencode(
dict(
query="query helloWho($who: String){ test(who: $who) }",
variables=json.dumps({"who": "Dolly"}),
)
),
"application/x-www-form-urlencoded",
)
assert response.status_code == 200
assert response_json(response) == {
'data': {'test': "Hello Dolly"}
}
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
def test_supports_post_json_quey_with_get_variable_values(client):
response = client.post(url_string(
variables=json.dumps({'who': "Dolly"})
), j(
query='query helloWho($who: String){ test(who: $who) }',
), 'application/json')
response = client.post(
url_string(variables=json.dumps({"who": "Dolly"})),
j(query="query helloWho($who: String){ test(who: $who) }"),
"application/json",
)
assert response.status_code == 200
assert response_json(response) == {
'data': {'test': "Hello Dolly"}
}
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
def test_post_url_encoded_query_with_get_variable_values(client):
response = client.post(url_string(
variables=json.dumps({'who': "Dolly"})
), urlencode(dict(
query='query helloWho($who: String){ test(who: $who) }',
)), 'application/x-www-form-urlencoded')
response = client.post(
url_string(variables=json.dumps({"who": "Dolly"})),
urlencode(dict(query="query helloWho($who: String){ test(who: $who) }")),
"application/x-www-form-urlencoded",
)
assert response.status_code == 200
assert response_json(response) == {
'data': {'test': "Hello Dolly"}
}
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
def test_supports_post_raw_text_query_with_get_variable_values(client):
response = client.post(url_string(
variables=json.dumps({'who': "Dolly"})
),
'query helloWho($who: String){ test(who: $who) }',
'application/graphql'
response = client.post(
url_string(variables=json.dumps({"who": "Dolly"})),
"query helloWho($who: String){ test(who: $who) }",
"application/graphql",
)
assert response.status_code == 200
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
def test_allows_post_with_operation_name(client):
response = client.post(
url_string(),
j(
query="""
query helloYou { test(who: "You"), ...shared }
query helloWorld { test(who: "World"), ...shared }
query helloDolly { test(who: "Dolly"), ...shared }
fragment shared on QueryRoot {
shared: test(who: "Everyone")
}
""",
operationName="helloWorld",
),
"application/json",
)
assert response.status_code == 200
assert response_json(response) == {
'data': {'test': "Hello Dolly"}
}
def test_allows_post_with_operation_name(client):
response = client.post(url_string(), j(
query='''
query helloYou { test(who: "You"), ...shared }
query helloWorld { test(who: "World"), ...shared }
query helloDolly { test(who: "Dolly"), ...shared }
fragment shared on QueryRoot {
shared: test(who: "Everyone")
}
''',
operationName='helloWorld'
), 'application/json')
assert response.status_code == 200
assert response_json(response) == {
'data': {
'test': 'Hello World',
'shared': 'Hello Everyone'
}
"data": {"test": "Hello World", "shared": "Hello Everyone"}
}
def test_batch_allows_post_with_operation_name(client):
response = client.post(batch_url_string(), jl(
id=1,
query='''
response = client.post(
batch_url_string(),
jl(
id=1,
query="""
query helloYou { test(who: "You"), ...shared }
query helloWorld { test(who: "World"), ...shared }
query helloDolly { test(who: "Dolly"), ...shared }
fragment shared on QueryRoot {
shared: test(who: "Everyone")
}
''',
operationName='helloWorld'
), 'application/json')
""",
operationName="helloWorld",
),
"application/json",
)
assert response.status_code == 200
assert response_json(response) == [{
'id': 1,
'data': {
'test': 'Hello World',
'shared': 'Hello Everyone'
},
'status': 200,
}]
assert response_json(response) == [
{
"id": 1,
"data": {"test": "Hello World", "shared": "Hello Everyone"},
"status": 200,
}
]
def test_allows_post_with_get_operation_name(client):
response = client.post(url_string(
operationName='helloWorld'
), '''
response = client.post(
url_string(operationName="helloWorld"),
"""
query helloYou { test(who: "You"), ...shared }
query helloWorld { test(who: "World"), ...shared }
query helloDolly { test(who: "Dolly"), ...shared }
fragment shared on QueryRoot {
shared: test(who: "Everyone")
}
''',
'application/graphql')
""",
"application/graphql",
)
assert response.status_code == 200
assert response_json(response) == {
'data': {
'test': 'Hello World',
'shared': 'Hello Everyone'
}
"data": {"test": "Hello World", "shared": "Hello Everyone"}
}
@pytest.mark.urls('graphene_django.tests.urls_inherited')
@pytest.mark.urls("graphene_django.tests.urls_inherited")
def test_inherited_class_with_attributes_works(client):
inherited_url = '/graphql/inherited/'
inherited_url = "/graphql/inherited/"
# Check schema and pretty attributes work
response = client.post(url_string(inherited_url, query='{test}'))
response = client.post(url_string(inherited_url, query="{test}"))
assert response.content.decode() == (
'{\n'
' "data": {\n'
' "test": "Hello World"\n'
' }\n'
'}'
"{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}"
)
# Check graphiql works
response = client.get(url_string(inherited_url), HTTP_ACCEPT='text/html')
response = client.get(url_string(inherited_url), HTTP_ACCEPT="text/html")
assert response.status_code == 200
@pytest.mark.urls('graphene_django.tests.urls_pretty')
@pytest.mark.urls("graphene_django.tests.urls_pretty")
def test_supports_pretty_printing(client):
response = client.get(url_string(query='{test}'))
response = client.get(url_string(query="{test}"))
assert response.content.decode() == (
'{\n'
' "data": {\n'
' "test": "Hello World"\n'
' }\n'
'}'
"{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}"
)
def test_supports_pretty_printing_by_request(client):
response = client.get(url_string(query='{test}', pretty='1'))
response = client.get(url_string(query="{test}", pretty="1"))
assert response.content.decode() == (
'{\n'
' "data": {\n'
' "test": "Hello World"\n'
' }\n'
'}'
"{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}"
)
def test_handles_field_errors_caught_by_graphql(client):
response = client.get(url_string(query='{thrower}'))
response = client.get(url_string(query="{thrower}"))
assert response.status_code == 200
assert response_json(response) == {
'data': None,
'errors': [{
'locations': [{'column': 2, 'line': 1}],
'path': ['thrower'],
'message': 'Throws!',
}]
"data": None,
"errors": [
{
"locations": [{"column": 2, "line": 1}],
"path": ["thrower"],
"message": "Throws!",
}
],
}
def test_handles_syntax_errors_caught_by_graphql(client):
response = client.get(url_string(query='syntaxerror'))
response = client.get(url_string(query="syntaxerror"))
assert response.status_code == 400
assert response_json(response) == {
'errors': [{'locations': [{'column': 1, 'line': 1}],
'message': 'Syntax Error GraphQL (1:1) '
'Unexpected Name "syntaxerror"\n\n1: syntaxerror\n ^\n'}]
"errors": [
{
"locations": [{"column": 1, "line": 1}],
"message": "Syntax Error GraphQL (1:1) "
'Unexpected Name "syntaxerror"\n\n1: syntaxerror\n ^\n',
}
]
}
@ -471,25 +476,25 @@ def test_handles_errors_caused_by_a_lack_of_query(client):
assert response.status_code == 400
assert response_json(response) == {
'errors': [{'message': 'Must provide query string.'}]
"errors": [{"message": "Must provide query string."}]
}
def test_handles_not_expected_json_bodies(client):
response = client.post(url_string(), '[]', 'application/json')
response = client.post(url_string(), "[]", "application/json")
assert response.status_code == 400
assert response_json(response) == {
'errors': [{'message': 'The received data is not a valid JSON query.'}]
"errors": [{"message": "The received data is not a valid JSON query."}]
}
def test_handles_invalid_json_bodies(client):
response = client.post(url_string(), '[oh}', 'application/json')
response = client.post(url_string(), "[oh}", "application/json")
assert response.status_code == 400
assert response_json(response) == {
'errors': [{'message': 'POST body sent invalid JSON.'}]
"errors": [{"message": "POST body sent invalid JSON."}]
}
@ -499,63 +504,57 @@ def test_handles_django_request_error(client, monkeypatch):
monkeypatch.setattr("django.http.request.HttpRequest.read", mocked_read)
valid_json = json.dumps(dict(foo='bar'))
response = client.post(url_string(), valid_json, 'application/json')
valid_json = json.dumps(dict(foo="bar"))
response = client.post(url_string(), valid_json, "application/json")
assert response.status_code == 400
assert response_json(response) == {
'errors': [{'message': 'foo-bar'}]
}
assert response_json(response) == {"errors": [{"message": "foo-bar"}]}
def test_handles_incomplete_json_bodies(client):
response = client.post(url_string(), '{"query":', 'application/json')
response = client.post(url_string(), '{"query":', "application/json")
assert response.status_code == 400
assert response_json(response) == {
'errors': [{'message': 'POST body sent invalid JSON.'}]
"errors": [{"message": "POST body sent invalid JSON."}]
}
def test_handles_plain_post_text(client):
response = client.post(url_string(
variables=json.dumps({'who': "Dolly"})
),
'query helloWho($who: String){ test(who: $who) }',
'text/plain'
response = client.post(
url_string(variables=json.dumps({"who": "Dolly"})),
"query helloWho($who: String){ test(who: $who) }",
"text/plain",
)
assert response.status_code == 400
assert response_json(response) == {
'errors': [{'message': 'Must provide query string.'}]
"errors": [{"message": "Must provide query string."}]
}
def test_handles_poorly_formed_variables(client):
response = client.get(url_string(
query='query helloWho($who: String){ test(who: $who) }',
variables='who:You'
))
response = client.get(
url_string(
query="query helloWho($who: String){ test(who: $who) }", variables="who:You"
)
)
assert response.status_code == 400
assert response_json(response) == {
'errors': [{'message': 'Variables are invalid JSON.'}]
"errors": [{"message": "Variables are invalid JSON."}]
}
def test_handles_unsupported_http_methods(client):
response = client.put(url_string(query='{test}'))
response = client.put(url_string(query="{test}"))
assert response.status_code == 405
assert response['Allow'] == 'GET, POST'
assert response["Allow"] == "GET, POST"
assert response_json(response) == {
'errors': [{'message': 'GraphQL only supports GET and POST requests.'}]
"errors": [{"message": "GraphQL only supports GET and POST requests."}]
}
def test_passes_request_into_context_request(client):
response = client.get(url_string(query='{request}', q='testing'))
response = client.get(url_string(query="{request}", q="testing"))
assert response.status_code == 200
assert response_json(response) == {
'data': {
'request': 'testing'
}
}
assert response_json(response) == {"data": {"request": "testing"}}

View File

@ -3,6 +3,6 @@ from django.conf.urls import url
from ..views import GraphQLView
urlpatterns = [
url(r'^graphql/batch', GraphQLView.as_view(batch=True)),
url(r'^graphql', GraphQLView.as_view(graphiql=True)),
url(r"^graphql/batch", GraphQLView.as_view(batch=True)),
url(r"^graphql", GraphQLView.as_view(graphiql=True)),
]

View File

@ -3,12 +3,11 @@ from django.conf.urls import url
from ..views import GraphQLView
from .schema_view import schema
class CustomGraphQLView(GraphQLView):
schema = schema
graphiql = True
pretty = True
urlpatterns = [
url(r'^graphql/inherited/$', CustomGraphQLView.as_view()),
]
urlpatterns = [url(r"^graphql/inherited/$", CustomGraphQLView.as_view())]

View File

@ -3,6 +3,4 @@ from django.conf.urls import url
from ..views import GraphQLView
from .schema_view import schema
urlpatterns = [
url(r'^graphql', GraphQLView.as_view(schema=schema, pretty=True)),
]
urlpatterns = [url(r"^graphql", GraphQLView.as_view(schema=schema, pretty=True))]

View File

@ -8,8 +8,7 @@ from graphene.types.utils import yank_fields_from_attrs
from .converter import convert_django_field_with_choices
from .registry import Registry, get_global_registry
from .utils import (DJANGO_FILTER_INSTALLED, get_model_fields,
is_valid_django_model)
from .utils import DJANGO_FILTER_INSTALLED, get_model_fields, is_valid_django_model
def construct_fields(model, registry, only_fields, exclude_fields):
@ -21,7 +20,7 @@ def construct_fields(model, registry, only_fields, exclude_fields):
# is_already_created = name in options.fields
is_excluded = name in exclude_fields # or is_already_created
# https://docs.djangoproject.com/en/1.10/ref/models/fields/#django.db.models.ForeignKey.related_query_name
is_no_backref = str(name).endswith('+')
is_no_backref = str(name).endswith("+")
if is_not_in_only or is_excluded or is_no_backref:
# We skip this field if we specify only_fields and is not
# in there. Or when we exclude this field in exclude_fields.
@ -43,9 +42,21 @@ class DjangoObjectTypeOptions(ObjectTypeOptions):
class DjangoObjectType(ObjectType):
@classmethod
def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=False,
only_fields=(), exclude_fields=(), filter_fields=None, connection=None,
connection_class=None, use_connection=None, interfaces=(), _meta=None, **options):
def __init_subclass_with_meta__(
cls,
model=None,
registry=None,
skip_registry=False,
only_fields=(),
exclude_fields=(),
filter_fields=None,
connection=None,
connection_class=None,
use_connection=None,
interfaces=(),
_meta=None,
**options
):
assert is_valid_django_model(model), (
'You need to pass a valid Django Model in {}.Meta, received "{}".'
).format(cls.__name__, model)
@ -54,7 +65,7 @@ class DjangoObjectType(ObjectType):
registry = get_global_registry()
assert isinstance(registry, Registry), (
'The attribute registry in {} needs to be an instance of '
"The attribute registry in {} needs to be an instance of "
'Registry, received "{}".'
).format(cls.__name__, registry)
@ -62,12 +73,13 @@ class DjangoObjectType(ObjectType):
raise Exception("Can only set filter_fields if Django-Filter is installed")
django_fields = yank_fields_from_attrs(
construct_fields(model, registry, only_fields, exclude_fields),
_as=Field,
construct_fields(model, registry, only_fields, exclude_fields), _as=Field
)
if use_connection is None and interfaces:
use_connection = any((issubclass(interface, Node) for interface in interfaces))
use_connection = any(
(issubclass(interface, Node) for interface in interfaces)
)
if use_connection and not connection:
# We create the connection automatically
@ -75,7 +87,8 @@ class DjangoObjectType(ObjectType):
connection_class = Connection
connection = connection_class.create_type(
'{}Connection'.format(cls.__name__), node=cls)
"{}Connection".format(cls.__name__), node=cls
)
if connection is not None:
assert issubclass(connection, Connection), (
@ -91,7 +104,9 @@ class DjangoObjectType(ObjectType):
_meta.fields = django_fields
_meta.connection = connection
super(DjangoObjectType, cls).__init_subclass_with_meta__(_meta=_meta, interfaces=interfaces, **options)
super(DjangoObjectType, cls).__init_subclass_with_meta__(
_meta=_meta, interfaces=interfaces, **options
)
if not skip_registry:
registry.register(cls)
@ -107,9 +122,7 @@ class DjangoObjectType(ObjectType):
if isinstance(root, cls):
return True
if not is_valid_django_model(type(root)):
raise Exception((
'Received incompatible instance "{}".'
).format(root))
raise Exception(('Received incompatible instance "{}".').format(root))
model = root._meta.model._meta.concrete_model
return model == cls._meta.model

View File

@ -13,6 +13,7 @@ class LazyList(object):
try:
import django_filters # noqa
DJANGO_FILTER_INSTALLED = True
except ImportError:
DJANGO_FILTER_INSTALLED = False
@ -25,8 +26,7 @@ def get_reverse_fields(model, local_field_names):
continue
# Django =>1.9 uses 'rel', django <1.9 uses 'related'
related = getattr(attr, 'rel', None) or \
getattr(attr, 'related', None)
related = getattr(attr, "rel", None) or getattr(attr, "related", None)
if isinstance(related, models.ManyToOneRel):
yield (name, related)
elif isinstance(related, models.ManyToManyRel) and not related.symmetrical:
@ -42,9 +42,9 @@ def maybe_queryset(value):
def get_model_fields(model):
local_fields = [
(field.name, field)
for field
in sorted(list(model._meta.fields) +
list(model._meta.local_many_to_many))
for field in sorted(
list(model._meta.fields) + list(model._meta.local_many_to_many)
)
]
# Make sure we don't duplicate local fields with "reverse" version

View File

@ -20,7 +20,6 @@ from .settings import graphene_settings
class HttpError(Exception):
def __init__(self, response, message=None, *args, **kwargs):
self.response = response
self.message = message = message or response.content.decode()
@ -29,18 +28,18 @@ class HttpError(Exception):
def get_accepted_content_types(request):
def qualify(x):
parts = x.split(';', 1)
parts = x.split(";", 1)
if len(parts) == 2:
match = re.match(r'(^|;)q=(0(\.\d{,3})?|1(\.0{,3})?)(;|$)',
parts[1])
match = re.match(r"(^|;)q=(0(\.\d{,3})?|1(\.0{,3})?)(;|$)", parts[1])
if match:
return parts[0].strip(), float(match.group(2))
return parts[0].strip(), 1
raw_content_types = request.META.get('HTTP_ACCEPT', '*/*').split(',')
raw_content_types = request.META.get("HTTP_ACCEPT", "*/*").split(",")
qualified_content_types = map(qualify, raw_content_types)
return list(x[0] for x in sorted(qualified_content_types,
key=lambda x: x[1], reverse=True))
return list(
x[0] for x in sorted(qualified_content_types, key=lambda x: x[1], reverse=True)
)
def instantiate_middleware(middlewares):
@ -52,8 +51,8 @@ def instantiate_middleware(middlewares):
class GraphQLView(View):
graphiql_version = '0.11.10'
graphiql_template = 'graphene/graphiql.html'
graphiql_version = "0.11.10"
graphiql_template = "graphene/graphiql.html"
schema = None
graphiql = False
@ -64,8 +63,17 @@ class GraphQLView(View):
pretty = False
batch = False
def __init__(self, schema=None, executor=None, middleware=None, root_value=None, graphiql=False, pretty=False,
batch=False, backend=None):
def __init__(
self,
schema=None,
executor=None,
middleware=None,
root_value=None,
graphiql=False,
pretty=False,
batch=False,
backend=None,
):
if not schema:
schema = graphene_settings.SCHEMA
@ -86,9 +94,9 @@ class GraphQLView(View):
self.backend = backend
assert isinstance(
self.schema, GraphQLSchema), 'A Schema is required to be provided to GraphQLView.'
assert not all((graphiql, batch)
), 'Use either graphiql or batch processing'
self.schema, GraphQLSchema
), "A Schema is required to be provided to GraphQLView."
assert not all((graphiql, batch)), "Use either graphiql or batch processing"
# noinspection PyUnusedLocal
def get_root_value(self, request):
@ -106,59 +114,59 @@ class GraphQLView(View):
@method_decorator(ensure_csrf_cookie)
def dispatch(self, request, *args, **kwargs):
try:
if request.method.lower() not in ('get', 'post'):
raise HttpError(HttpResponseNotAllowed(
['GET', 'POST'], 'GraphQL only supports GET and POST requests.'))
if request.method.lower() not in ("get", "post"):
raise HttpError(
HttpResponseNotAllowed(
["GET", "POST"], "GraphQL only supports GET and POST requests."
)
)
data = self.parse_body(request)
show_graphiql = self.graphiql and self.can_display_graphiql(
request, data)
show_graphiql = self.graphiql and self.can_display_graphiql(request, data)
if self.batch:
responses = [self.get_response(request, entry) for entry in data]
result = '[{}]'.format(','.join([response[0] for response in responses]))
status_code = responses and max(responses, key=lambda response: response[1])[1] or 200
result = "[{}]".format(
",".join([response[0] for response in responses])
)
status_code = (
responses
and max(responses, key=lambda response: response[1])[1]
or 200
)
else:
result, status_code = self.get_response(
request, data, show_graphiql)
result, status_code = self.get_response(request, data, show_graphiql)
if show_graphiql:
query, variables, operation_name, id = self.get_graphql_params(
request, data)
request, data
)
return self.render_graphiql(
request,
graphiql_version=self.graphiql_version,
query=query or '',
variables=json.dumps(variables) or '',
operation_name=operation_name or '',
result=result or ''
query=query or "",
variables=json.dumps(variables) or "",
operation_name=operation_name or "",
result=result or "",
)
return HttpResponse(
status=status_code,
content=result,
content_type='application/json'
status=status_code, content=result, content_type="application/json"
)
except HttpError as e:
response = e.response
response['Content-Type'] = 'application/json'
response.content = self.json_encode(request, {
'errors': [self.format_error(e)]
})
response["Content-Type"] = "application/json"
response.content = self.json_encode(
request, {"errors": [self.format_error(e)]}
)
return response
def get_response(self, request, data, show_graphiql=False):
query, variables, operation_name, id = self.get_graphql_params(
request, data)
query, variables, operation_name, id = self.get_graphql_params(request, data)
execution_result = self.execute_graphql_request(
request,
data,
query,
variables,
operation_name,
show_graphiql
request, data, query, variables, operation_name, show_graphiql
)
status_code = 200
@ -166,17 +174,18 @@ class GraphQLView(View):
response = {}
if execution_result.errors:
response['errors'] = [self.format_error(
e) for e in execution_result.errors]
response["errors"] = [
self.format_error(e) for e in execution_result.errors
]
if execution_result.invalid:
status_code = 400
else:
response['data'] = execution_result.data
response["data"] = execution_result.data
if self.batch:
response['id'] = id
response['status'] = status_code
response["id"] = id
response["status"] = status_code
result = self.json_encode(request, response, pretty=show_graphiql)
else:
@ -188,22 +197,21 @@ class GraphQLView(View):
return render(request, self.graphiql_template, data)
def json_encode(self, request, d, pretty=False):
if not (self.pretty or pretty) and not request.GET.get('pretty'):
return json.dumps(d, separators=(',', ':'))
if not (self.pretty or pretty) and not request.GET.get("pretty"):
return json.dumps(d, separators=(",", ":"))
return json.dumps(d, sort_keys=True,
indent=2, separators=(',', ': '))
return json.dumps(d, sort_keys=True, indent=2, separators=(",", ": "))
def parse_body(self, request):
content_type = self.get_content_type(request)
if content_type == 'application/graphql':
return {'query': request.body.decode()}
if content_type == "application/graphql":
return {"query": request.body.decode()}
elif content_type == 'application/json':
elif content_type == "application/json":
# noinspection PyBroadException
try:
body = request.body.decode('utf-8')
body = request.body.decode("utf-8")
except Exception as e:
raise HttpError(HttpResponseBadRequest(str(e)))
@ -211,33 +219,36 @@ class GraphQLView(View):
request_json = json.loads(body)
if self.batch:
assert isinstance(request_json, list), (
'Batch requests should receive a list, but received {}.'
"Batch requests should receive a list, but received {}."
).format(repr(request_json))
assert len(request_json) > 0, (
'Received an empty list in the batch request.'
)
assert (
len(request_json) > 0
), "Received an empty list in the batch request."
else:
assert isinstance(request_json, dict), (
'The received data is not a valid JSON query.'
)
assert isinstance(
request_json, dict
), "The received data is not a valid JSON query."
return request_json
except AssertionError as e:
raise HttpError(HttpResponseBadRequest(str(e)))
except (TypeError, ValueError):
raise HttpError(HttpResponseBadRequest(
'POST body sent invalid JSON.'))
raise HttpError(HttpResponseBadRequest("POST body sent invalid JSON."))
elif content_type in ['application/x-www-form-urlencoded', 'multipart/form-data']:
elif content_type in [
"application/x-www-form-urlencoded",
"multipart/form-data",
]:
return request.POST
return {}
def execute_graphql_request(self, request, data, query, variables, operation_name, show_graphiql=False):
def execute_graphql_request(
self, request, data, query, variables, operation_name, show_graphiql=False
):
if not query:
if show_graphiql:
return None
raise HttpError(HttpResponseBadRequest(
'Must provide query string.'))
raise HttpError(HttpResponseBadRequest("Must provide query string."))
try:
backend = self.get_backend(request)
@ -245,23 +256,27 @@ class GraphQLView(View):
except Exception as e:
return ExecutionResult(errors=[e], invalid=True)
if request.method.lower() == 'get':
if request.method.lower() == "get":
operation_type = document.get_operation_type(operation_name)
if operation_type and operation_type != 'query':
if operation_type and operation_type != "query":
if show_graphiql:
return None
raise HttpError(HttpResponseNotAllowed(
['POST'], 'Can only perform a {} operation from a POST request.'.format(
operation_type)
))
raise HttpError(
HttpResponseNotAllowed(
["POST"],
"Can only perform a {} operation from a POST request.".format(
operation_type
),
)
)
try:
extra_options = {}
if self.executor:
# We only include it optionally since
# executor is not a valid argument in all backends
extra_options['executor'] = self.executor
extra_options["executor"] = self.executor
return document.execute(
root=self.get_root_value(request),
@ -276,7 +291,7 @@ class GraphQLView(View):
@classmethod
def can_display_graphiql(cls, request, data):
raw = 'raw' in request.GET or 'raw' in data
raw = "raw" in request.GET or "raw" in data
return not raw and cls.request_wants_html(request)
@classmethod
@ -285,26 +300,32 @@ class GraphQLView(View):
accepted_length = len(accepted)
# the list will be ordered in preferred first - so we have to make
# sure the most preferred gets the highest number
html_priority = accepted_length - accepted.index('text/html') if 'text/html' in accepted else 0
json_priority = accepted_length - accepted.index('application/json') if 'application/json' in accepted else 0
html_priority = (
accepted_length - accepted.index("text/html")
if "text/html" in accepted
else 0
)
json_priority = (
accepted_length - accepted.index("application/json")
if "application/json" in accepted
else 0
)
return html_priority > json_priority
@staticmethod
def get_graphql_params(request, data):
query = request.GET.get('query') or data.get('query')
variables = request.GET.get('variables') or data.get('variables')
id = request.GET.get('id') or data.get('id')
query = request.GET.get("query") or data.get("query")
variables = request.GET.get("variables") or data.get("variables")
id = request.GET.get("id") or data.get("id")
if variables and isinstance(variables, six.text_type):
try:
variables = json.loads(variables)
except Exception:
raise HttpError(HttpResponseBadRequest(
'Variables are invalid JSON.'))
raise HttpError(HttpResponseBadRequest("Variables are invalid JSON."))
operation_name = request.GET.get(
'operationName') or data.get('operationName')
operation_name = request.GET.get("operationName") or data.get("operationName")
if operation_name == "null":
operation_name = None
@ -315,11 +336,10 @@ class GraphQLView(View):
if isinstance(error, GraphQLError):
return format_graphql_error(error)
return {'message': six.text_type(error)}
return {"message": six.text_type(error)}
@staticmethod
def get_content_type(request):
meta = request.META
content_type = meta.get(
'CONTENT_TYPE', meta.get('HTTP_CONTENT_TYPE', ''))
return content_type.split(';', 1)[0].lower()
content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", ""))
return content_type.split(";", 1)[0].lower()