Reformatted files using black

This commit is contained in:
Syrus Akbary 2018-07-19 16:51:33 -07:00
parent 96789b291f
commit 54ef52e1c6
41 changed files with 1478 additions and 1539 deletions

View File

@ -1,14 +1,6 @@
from .types import ( from .types import DjangoObjectType
DjangoObjectType, from .fields import DjangoConnectionField
)
from .fields import (
DjangoConnectionField,
)
__version__ = '2.1rc1' __version__ = "2.1rc1"
__all__ = [ __all__ = ["__version__", "DjangoObjectType", "DjangoConnectionField"]
'__version__',
'DjangoObjectType',
'DjangoConnectionField'
]

View File

@ -7,7 +7,7 @@ try:
# and we cannot have psycopg2 on PyPy # and we cannot have psycopg2 on PyPy
from django.contrib.postgres.fields import ArrayField, HStoreField, RangeField from django.contrib.postgres.fields import ArrayField, HStoreField, RangeField
except ImportError: except ImportError:
ArrayField, HStoreField, JSONField, RangeField = (MissingType, ) * 4 ArrayField, HStoreField, JSONField, RangeField = (MissingType,) * 4
try: try:

View File

@ -1,8 +1,22 @@
from django.db import models from django.db import models
from django.utils.encoding import force_text from django.utils.encoding import force_text
from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List, from graphene import (
NonNull, String, UUID, DateTime, Date, Time) ID,
Boolean,
Dynamic,
Enum,
Field,
Float,
Int,
List,
NonNull,
String,
UUID,
DateTime,
Date,
Time,
)
from graphene.types.json import JSONString from graphene.types.json import JSONString
from graphene.utils.str_converters import to_camel_case, to_const from graphene.utils.str_converters import to_camel_case, to_const
from graphql import assert_valid_name from graphql import assert_valid_name
@ -32,7 +46,7 @@ def get_choices(choices):
else: else:
name = convert_choice_name(value) name = convert_choice_name(value)
while name in converted_names: while name in converted_names:
name += '_' + str(len(converted_names)) name += "_" + str(len(converted_names))
converted_names.append(name) converted_names.append(name)
description = help_text description = help_text
yield name, value, description yield name, value, description
@ -43,16 +57,15 @@ def convert_django_field_with_choices(field, registry=None):
converted = registry.get_converted_field(field) converted = registry.get_converted_field(field)
if converted: if converted:
return converted return converted
choices = getattr(field, 'choices', None) choices = getattr(field, "choices", None)
if choices: if choices:
meta = field.model._meta meta = field.model._meta
name = to_camel_case('{}_{}'.format(meta.object_name, field.name)) name = to_camel_case("{}_{}".format(meta.object_name, field.name))
choices = list(get_choices(choices)) choices = list(get_choices(choices))
named_choices = [(c[0], c[1]) for c in choices] named_choices = [(c[0], c[1]) for c in choices]
named_choices_descriptions = {c[0]: c[2] for c in choices} named_choices_descriptions = {c[0]: c[2] for c in choices}
class EnumWithDescriptionsType(object): class EnumWithDescriptionsType(object):
@property @property
def description(self): def description(self):
return named_choices_descriptions[self.name] return named_choices_descriptions[self.name]
@ -69,8 +82,8 @@ def convert_django_field_with_choices(field, registry=None):
@singledispatch @singledispatch
def convert_django_field(field, registry=None): def convert_django_field(field, registry=None):
raise Exception( raise Exception(
"Don't know how to convert the Django field %s (%s)" % "Don't know how to convert the Django field %s (%s)" % (field, field.__class__)
(field, field.__class__)) )
@convert_django_field.register(models.CharField) @convert_django_field.register(models.CharField)
@ -147,7 +160,7 @@ def convert_onetoone_field_to_djangomodel(field, registry=None):
# We do this for a bug in Django 1.8, where null attr # We do this for a bug in Django 1.8, where null attr
# is not available in the OneToOneRel instance # is not available in the OneToOneRel instance
null = getattr(field, 'null', True) null = getattr(field, "null", True)
return Field(_type, required=not null) return Field(_type, required=not null)
return Dynamic(dynamic_type) return Dynamic(dynamic_type)
@ -171,6 +184,7 @@ def convert_field_to_list_or_connection(field, registry=None):
# defined filter_fields in the DjangoObjectType Meta # defined filter_fields in the DjangoObjectType Meta
if _type._meta.filter_fields: if _type._meta.filter_fields:
from .filter.fields import DjangoFilterConnectionField from .filter.fields import DjangoFilterConnectionField
return DjangoFilterConnectionField(_type) return DjangoFilterConnectionField(_type)
return DjangoConnectionField(_type) return DjangoConnectionField(_type)

View File

@ -1,4 +1,4 @@
from .middleware import DjangoDebugMiddleware from .middleware import DjangoDebugMiddleware
from .types import DjangoDebug from .types import DjangoDebug
__all__ = ['DjangoDebugMiddleware', 'DjangoDebug'] __all__ = ["DjangoDebugMiddleware", "DjangoDebug"]

View File

@ -7,7 +7,6 @@ from .types import DjangoDebug
class DjangoDebugContext(object): class DjangoDebugContext(object):
def __init__(self): def __init__(self):
self.debug_promise = None self.debug_promise = None
self.promises = [] self.promises = []
@ -38,20 +37,21 @@ class DjangoDebugContext(object):
class DjangoDebugMiddleware(object): class DjangoDebugMiddleware(object):
def resolve(self, next, root, info, **args): def resolve(self, next, root, info, **args):
context = info.context context = info.context
django_debug = getattr(context, 'django_debug', None) django_debug = getattr(context, "django_debug", None)
if not django_debug: if not django_debug:
if context is None: if context is None:
raise Exception('DjangoDebug cannot be executed in None contexts') raise Exception("DjangoDebug cannot be executed in None contexts")
try: try:
context.django_debug = DjangoDebugContext() context.django_debug = DjangoDebugContext()
except Exception: except Exception:
raise Exception('DjangoDebug need the context to be writable, context received: {}.'.format( raise Exception(
context.__class__.__name__ "DjangoDebug need the context to be writable, context received: {}.".format(
)) context.__class__.__name__
if info.schema.get_type('DjangoDebug') == info.return_type: )
)
if info.schema.get_type("DjangoDebug") == info.return_type:
return context.django_debug.get_debug_promise() return context.django_debug.get_debug_promise()
promise = next(root, info, **args) promise = next(root, info, **args)
context.django_debug.add_promise(promise) context.django_debug.add_promise(promise)

View File

@ -16,7 +16,6 @@ class SQLQueryTriggered(Exception):
class ThreadLocalState(local): class ThreadLocalState(local):
def __init__(self): def __init__(self):
self.enabled = True self.enabled = True
@ -35,7 +34,7 @@ recording = state.recording # export function
def wrap_cursor(connection, panel): def wrap_cursor(connection, panel):
if not hasattr(connection, '_graphene_cursor'): if not hasattr(connection, "_graphene_cursor"):
connection._graphene_cursor = connection.cursor connection._graphene_cursor = connection.cursor
def cursor(): def cursor():
@ -46,7 +45,7 @@ def wrap_cursor(connection, panel):
def unwrap_cursor(connection): def unwrap_cursor(connection):
if hasattr(connection, '_graphene_cursor'): if hasattr(connection, "_graphene_cursor"):
previous_cursor = connection._graphene_cursor previous_cursor = connection._graphene_cursor
connection.cursor = previous_cursor connection.cursor = previous_cursor
del connection._graphene_cursor del connection._graphene_cursor
@ -87,15 +86,14 @@ class NormalCursorWrapper(object):
if not params: if not params:
return params return params
if isinstance(params, dict): if isinstance(params, dict):
return dict((key, self._quote_expr(value)) return dict((key, self._quote_expr(value)) for key, value in params.items())
for key, value in params.items())
return list(map(self._quote_expr, params)) return list(map(self._quote_expr, params))
def _decode(self, param): def _decode(self, param):
try: try:
return force_text(param, strings_only=True) return force_text(param, strings_only=True)
except UnicodeDecodeError: except UnicodeDecodeError:
return '(encoded string)' return "(encoded string)"
def _record(self, method, sql, params): def _record(self, method, sql, params):
start_time = time() start_time = time()
@ -103,45 +101,48 @@ class NormalCursorWrapper(object):
return method(sql, params) return method(sql, params)
finally: finally:
stop_time = time() stop_time = time()
duration = (stop_time - start_time) duration = stop_time - start_time
_params = '' _params = ""
try: try:
_params = json.dumps(list(map(self._decode, params))) _params = json.dumps(list(map(self._decode, params)))
except Exception: except Exception:
pass # object not JSON serializable pass # object not JSON serializable
alias = getattr(self.db, 'alias', 'default') alias = getattr(self.db, "alias", "default")
conn = self.db.connection conn = self.db.connection
vendor = getattr(conn, 'vendor', 'unknown') vendor = getattr(conn, "vendor", "unknown")
params = { params = {
'vendor': vendor, "vendor": vendor,
'alias': alias, "alias": alias,
'sql': self.db.ops.last_executed_query( "sql": self.db.ops.last_executed_query(
self.cursor, sql, self._quote_params(params)), self.cursor, sql, self._quote_params(params)
'duration': duration, ),
'raw_sql': sql, "duration": duration,
'params': _params, "raw_sql": sql,
'start_time': start_time, "params": _params,
'stop_time': stop_time, "start_time": start_time,
'is_slow': duration > 10, "stop_time": stop_time,
'is_select': sql.lower().strip().startswith('select'), "is_slow": duration > 10,
"is_select": sql.lower().strip().startswith("select"),
} }
if vendor == 'postgresql': if vendor == "postgresql":
# If an erroneous query was ran on the connection, it might # If an erroneous query was ran on the connection, it might
# be in a state where checking isolation_level raises an # be in a state where checking isolation_level raises an
# exception. # exception.
try: try:
iso_level = conn.isolation_level iso_level = conn.isolation_level
except conn.InternalError: except conn.InternalError:
iso_level = 'unknown' iso_level = "unknown"
params.update({ params.update(
'trans_id': self.logger.get_transaction_id(alias), {
'trans_status': conn.get_transaction_status(), "trans_id": self.logger.get_transaction_id(alias),
'iso_level': iso_level, "trans_status": conn.get_transaction_status(),
'encoding': conn.encoding, "iso_level": iso_level,
}) "encoding": conn.encoding,
}
)
_sql = DjangoDebugSQL(**params) _sql = DjangoDebugSQL(**params)
# We keep `sql` to maintain backwards compatibility # We keep `sql` to maintain backwards compatibility

View File

@ -12,31 +12,31 @@ from ..types import DjangoDebug
class context(object): class context(object):
pass pass
# from examples.starwars_django.models import Character # from examples.starwars_django.models import Character
pytestmark = pytest.mark.django_db pytestmark = pytest.mark.django_db
def test_should_query_field(): def test_should_query_field():
r1 = Reporter(last_name='ABA') r1 = Reporter(last_name="ABA")
r1.save() r1.save()
r2 = Reporter(last_name='Griffin') r2 = Reporter(last_name="Griffin")
r2.save() r2.save()
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType) reporter = graphene.Field(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug') debug = graphene.Field(DjangoDebug, name="__debug")
def resolve_reporter(self, info, **args): def resolve_reporter(self, info, **args):
return Reporter.objects.first() return Reporter.objects.first()
query = ''' query = """
query ReporterQuery { query ReporterQuery {
reporter { reporter {
lastName lastName
@ -47,43 +47,40 @@ def test_should_query_field():
} }
} }
} }
''' """
expected = { expected = {
'reporter': { "reporter": {"lastName": "ABA"},
'lastName': 'ABA', "__debug": {
"sql": [{"rawSql": str(Reporter.objects.order_by("pk")[:1].query)}]
}, },
'__debug': {
'sql': [{
'rawSql': str(Reporter.objects.order_by('pk')[:1].query)
}]
}
} }
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()]) result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
)
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
def test_should_query_list(): def test_should_query_list():
r1 = Reporter(last_name='ABA') r1 = Reporter(last_name="ABA")
r1.save() r1.save()
r2 = Reporter(last_name='Griffin') r2 = Reporter(last_name="Griffin")
r2.save() r2.save()
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = graphene.List(ReporterType) all_reporters = graphene.List(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug') debug = graphene.Field(DjangoDebug, name="__debug")
def resolve_all_reporters(self, info, **args): def resolve_all_reporters(self, info, **args):
return Reporter.objects.all() return Reporter.objects.all()
query = ''' query = """
query ReporterQuery { query ReporterQuery {
allReporters { allReporters {
lastName lastName
@ -94,45 +91,38 @@ def test_should_query_list():
} }
} }
} }
''' """
expected = { expected = {
'allReporters': [{ "allReporters": [{"lastName": "ABA"}, {"lastName": "Griffin"}],
'lastName': 'ABA', "__debug": {"sql": [{"rawSql": str(Reporter.objects.all().query)}]},
}, {
'lastName': 'Griffin',
}],
'__debug': {
'sql': [{
'rawSql': str(Reporter.objects.all().query)
}]
}
} }
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()]) result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
)
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
def test_should_query_connection(): def test_should_query_connection():
r1 = Reporter(last_name='ABA') r1 = Reporter(last_name="ABA")
r1.save() r1.save()
r2 = Reporter(last_name='Griffin') r2 = Reporter(last_name="Griffin")
r2.save() r2.save()
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug') debug = graphene.Field(DjangoDebug, name="__debug")
def resolve_all_reporters(self, info, **args): def resolve_all_reporters(self, info, **args):
return Reporter.objects.all() return Reporter.objects.all()
query = ''' query = """
query ReporterQuery { query ReporterQuery {
allReporters(first:1) { allReporters(first:1) {
edges { edges {
@ -147,48 +137,41 @@ def test_should_query_connection():
} }
} }
} }
''' """
expected = { expected = {"allReporters": {"edges": [{"node": {"lastName": "ABA"}}]}}
'allReporters': {
'edges': [{
'node': {
'lastName': 'ABA',
}
}]
},
}
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()]) result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
)
assert not result.errors assert not result.errors
assert result.data['allReporters'] == expected['allReporters'] assert result.data["allReporters"] == expected["allReporters"]
assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql'] assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query) query = str(Reporter.objects.all()[:1].query)
assert result.data['__debug']['sql'][1]['rawSql'] == query assert result.data["__debug"]["sql"][1]["rawSql"] == query
def test_should_query_connectionfilter(): def test_should_query_connectionfilter():
from ...filter import DjangoFilterConnectionField from ...filter import DjangoFilterConnectionField
r1 = Reporter(last_name='ABA') r1 = Reporter(last_name="ABA")
r1.save() r1.save()
r2 = Reporter(last_name='Griffin') r2 = Reporter(last_name="Griffin")
r2.save() r2.save()
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterType, fields=['last_name']) all_reporters = DjangoFilterConnectionField(ReporterType, fields=["last_name"])
s = graphene.String(resolver=lambda *_: "S") s = graphene.String(resolver=lambda *_: "S")
debug = graphene.Field(DjangoDebug, name='__debug') debug = graphene.Field(DjangoDebug, name="__debug")
def resolve_all_reporters(self, info, **args): def resolve_all_reporters(self, info, **args):
return Reporter.objects.all() return Reporter.objects.all()
query = ''' query = """
query ReporterQuery { query ReporterQuery {
allReporters(first:1) { allReporters(first:1) {
edges { edges {
@ -203,20 +186,14 @@ def test_should_query_connectionfilter():
} }
} }
} }
''' """
expected = { expected = {"allReporters": {"edges": [{"node": {"lastName": "ABA"}}]}}
'allReporters': {
'edges': [{
'node': {
'lastName': 'ABA',
}
}]
},
}
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()]) result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
)
assert not result.errors assert not result.errors
assert result.data['allReporters'] == expected['allReporters'] assert result.data["allReporters"] == expected["allReporters"]
assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql'] assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query) query = str(Reporter.objects.all()[:1].query)
assert result.data['__debug']['sql'][1]['rawSql'] == query assert result.data["__debug"]["sql"][1]["rawSql"] == query

View File

@ -13,7 +13,6 @@ from .utils import maybe_queryset
class DjangoListField(Field): class DjangoListField(Field):
def __init__(self, _type, *args, **kwargs): def __init__(self, _type, *args, **kwargs):
super(DjangoListField, self).__init__(List(_type), *args, **kwargs) super(DjangoListField, self).__init__(List(_type), *args, **kwargs)
@ -30,25 +29,28 @@ class DjangoListField(Field):
class DjangoConnectionField(ConnectionField): class DjangoConnectionField(ConnectionField):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.on = kwargs.pop('on', False) self.on = kwargs.pop("on", False)
self.max_limit = kwargs.pop( self.max_limit = kwargs.pop(
'max_limit', "max_limit", graphene_settings.RELAY_CONNECTION_MAX_LIMIT
graphene_settings.RELAY_CONNECTION_MAX_LIMIT
) )
self.enforce_first_or_last = kwargs.pop( self.enforce_first_or_last = kwargs.pop(
'enforce_first_or_last', "enforce_first_or_last",
graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST,
) )
super(DjangoConnectionField, self).__init__(*args, **kwargs) super(DjangoConnectionField, self).__init__(*args, **kwargs)
@property @property
def type(self): def type(self):
from .types import DjangoObjectType from .types import DjangoObjectType
_type = super(ConnectionField, self).type _type = super(ConnectionField, self).type
assert issubclass(_type, DjangoObjectType), "DjangoConnectionField only accepts DjangoObjectType types" assert issubclass(
assert _type._meta.connection, "The type {} doesn't have a connection".format(_type.__name__) _type, DjangoObjectType
), "DjangoConnectionField only accepts DjangoObjectType types"
assert _type._meta.connection, "The type {} doesn't have a connection".format(
_type.__name__
)
return _type._meta.connection return _type._meta.connection
@property @property
@ -100,28 +102,37 @@ class DjangoConnectionField(ConnectionField):
return connection return connection
@classmethod @classmethod
def connection_resolver(cls, resolver, connection, default_manager, max_limit, def connection_resolver(
enforce_first_or_last, root, info, **args): cls,
first = args.get('first') resolver,
last = args.get('last') connection,
default_manager,
max_limit,
enforce_first_or_last,
root,
info,
**args
):
first = args.get("first")
last = args.get("last")
if enforce_first_or_last: if enforce_first_or_last:
assert first or last, ( assert first or last, (
'You must provide a `first` or `last` value to properly paginate the `{}` connection.' "You must provide a `first` or `last` value to properly paginate the `{}` connection."
).format(info.field_name) ).format(info.field_name)
if max_limit: if max_limit:
if first: if first:
assert first <= max_limit, ( assert first <= max_limit, (
'Requesting {} records on the `{}` connection exceeds the `first` limit of {} records.' "Requesting {} records on the `{}` connection exceeds the `first` limit of {} records."
).format(first, info.field_name, max_limit) ).format(first, info.field_name, max_limit)
args['first'] = min(first, max_limit) args["first"] = min(first, max_limit)
if last: if last:
assert last <= max_limit, ( assert last <= max_limit, (
'Requesting {} records on the `{}` connection exceeds the `last` limit of {} records.' "Requesting {} records on the `{}` connection exceeds the `last` limit of {} records."
).format(last, info.field_name, max_limit) ).format(last, info.field_name, max_limit)
args['last'] = min(last, max_limit) args["last"] = min(last, max_limit)
iterable = resolver(root, info, **args) iterable = resolver(root, info, **args)
on_resolve = partial(cls.resolve_connection, connection, default_manager, args) on_resolve = partial(cls.resolve_connection, connection, default_manager, args)
@ -138,5 +149,5 @@ class DjangoConnectionField(ConnectionField):
self.type, self.type,
self.get_manager(), self.get_manager(),
self.max_limit, self.max_limit,
self.enforce_first_or_last self.enforce_first_or_last,
) )

View File

@ -4,11 +4,15 @@ from ..utils import DJANGO_FILTER_INSTALLED
if not DJANGO_FILTER_INSTALLED: if not DJANGO_FILTER_INSTALLED:
warnings.warn( warnings.warn(
"Use of django filtering requires the django-filter package " "Use of django filtering requires the django-filter package "
"be installed. You can do so using `pip install django-filter`", ImportWarning "be installed. You can do so using `pip install django-filter`",
ImportWarning,
) )
else: else:
from .fields import DjangoFilterConnectionField from .fields import DjangoFilterConnectionField
from .filterset import GlobalIDFilter, GlobalIDMultipleChoiceFilter from .filterset import GlobalIDFilter, GlobalIDMultipleChoiceFilter
__all__ = ['DjangoFilterConnectionField', __all__ = [
'GlobalIDFilter', 'GlobalIDMultipleChoiceFilter'] "DjangoFilterConnectionField",
"GlobalIDFilter",
"GlobalIDMultipleChoiceFilter",
]

View File

@ -7,10 +7,16 @@ from .utils import get_filtering_args_from_filterset, get_filterset_class
class DjangoFilterConnectionField(DjangoConnectionField): class DjangoFilterConnectionField(DjangoConnectionField):
def __init__(
def __init__(self, type, fields=None, order_by=None, self,
extra_filter_meta=None, filterset_class=None, type,
*args, **kwargs): fields=None,
order_by=None,
extra_filter_meta=None,
filterset_class=None,
*args,
**kwargs
):
self._fields = fields self._fields = fields
self._provided_filterset_class = filterset_class self._provided_filterset_class = filterset_class
self._filterset_class = None self._filterset_class = None
@ -30,12 +36,13 @@ class DjangoFilterConnectionField(DjangoConnectionField):
def filterset_class(self): def filterset_class(self):
if not self._filterset_class: if not self._filterset_class:
fields = self._fields or self.node_type._meta.filter_fields fields = self._fields or self.node_type._meta.filter_fields
meta = dict(model=self.model, meta = dict(model=self.model, fields=fields)
fields=fields)
if self._extra_filter_meta: if self._extra_filter_meta:
meta.update(self._extra_filter_meta) meta.update(self._extra_filter_meta)
self._filterset_class = get_filterset_class(self._provided_filterset_class, **meta) self._filterset_class = get_filterset_class(
self._provided_filterset_class, **meta
)
return self._filterset_class return self._filterset_class
@ -52,28 +59,40 @@ class DjangoFilterConnectionField(DjangoConnectionField):
# See related PR: https://github.com/graphql-python/graphene-django/pull/126 # See related PR: https://github.com/graphql-python/graphene-django/pull/126
assert not (default_queryset.query.low_mark and queryset.query.low_mark), ( assert not (
'Received two sliced querysets (low mark) in the connection, please slice only in one.' default_queryset.query.low_mark and queryset.query.low_mark
) ), "Received two sliced querysets (low mark) in the connection, please slice only in one."
assert not (default_queryset.query.high_mark and queryset.query.high_mark), ( assert not (
'Received two sliced querysets (high mark) in the connection, please slice only in one.' default_queryset.query.high_mark and queryset.query.high_mark
) ), "Received two sliced querysets (high mark) in the connection, please slice only in one."
low = default_queryset.query.low_mark or queryset.query.low_mark low = default_queryset.query.low_mark or queryset.query.low_mark
high = default_queryset.query.high_mark or queryset.query.high_mark high = default_queryset.query.high_mark or queryset.query.high_mark
default_queryset.query.clear_limits() default_queryset.query.clear_limits()
queryset = super(DjangoFilterConnectionField, cls).merge_querysets(default_queryset, queryset) queryset = super(DjangoFilterConnectionField, cls).merge_querysets(
default_queryset, queryset
)
queryset.query.set_limits(low, high) queryset.query.set_limits(low, high)
return queryset return queryset
@classmethod @classmethod
def connection_resolver(cls, resolver, connection, default_manager, max_limit, def connection_resolver(
enforce_first_or_last, filterset_class, filtering_args, cls,
root, info, **args): resolver,
connection,
default_manager,
max_limit,
enforce_first_or_last,
filterset_class,
filtering_args,
root,
info,
**args
):
filter_kwargs = {k: v for k, v in args.items() if k in filtering_args} filter_kwargs = {k: v for k, v in args.items() if k in filtering_args}
qs = filterset_class( qs = filterset_class(
data=filter_kwargs, data=filter_kwargs,
queryset=default_manager.get_queryset(), queryset=default_manager.get_queryset(),
request=info.context request=info.context,
).qs ).qs
return super(DjangoFilterConnectionField, cls).connection_resolver( return super(DjangoFilterConnectionField, cls).connection_resolver(
@ -96,5 +115,5 @@ class DjangoFilterConnectionField(DjangoConnectionField):
self.max_limit, self.max_limit,
self.enforce_first_or_last, self.enforce_first_or_last,
self.filterset_class, self.filterset_class,
self.filtering_args self.filtering_args,
) )

View File

@ -28,26 +28,19 @@ class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
GRAPHENE_FILTER_SET_OVERRIDES = { GRAPHENE_FILTER_SET_OVERRIDES = {
models.AutoField: { models.AutoField: {"filter_class": GlobalIDFilter},
'filter_class': GlobalIDFilter, models.OneToOneField: {"filter_class": GlobalIDFilter},
}, models.ForeignKey: {"filter_class": GlobalIDFilter},
models.OneToOneField: { models.ManyToManyField: {"filter_class": GlobalIDMultipleChoiceFilter},
'filter_class': GlobalIDFilter,
},
models.ForeignKey: {
'filter_class': GlobalIDFilter,
},
models.ManyToManyField: {
'filter_class': GlobalIDMultipleChoiceFilter,
}
} }
class GrapheneFilterSetMixin(BaseFilterSet): class GrapheneFilterSetMixin(BaseFilterSet):
FILTER_DEFAULTS = dict(itertools.chain( FILTER_DEFAULTS = dict(
FILTER_FOR_DBFIELD_DEFAULTS.items(), itertools.chain(
GRAPHENE_FILTER_SET_OVERRIDES.items() FILTER_FOR_DBFIELD_DEFAULTS.items(), GRAPHENE_FILTER_SET_OVERRIDES.items()
)) )
)
@classmethod @classmethod
def filter_for_reverse_field(cls, f, name): def filter_for_reverse_field(cls, f, name):
@ -62,10 +55,7 @@ class GrapheneFilterSetMixin(BaseFilterSet):
except AttributeError: except AttributeError:
rel = f.field.rel rel = f.field.rel
default = { default = {"name": name, "label": capfirst(rel.related_name)}
'name': name,
'label': capfirst(rel.related_name)
}
if rel.multiple: if rel.multiple:
# For to-many relationships # For to-many relationships
return GlobalIDMultipleChoiceFilter(**default) return GlobalIDMultipleChoiceFilter(**default)
@ -78,25 +68,20 @@ def setup_filterset(filterset_class):
""" Wrap a provided filterset in Graphene-specific functionality """ Wrap a provided filterset in Graphene-specific functionality
""" """
return type( return type(
'Graphene{}'.format(filterset_class.__name__), "Graphene{}".format(filterset_class.__name__),
(filterset_class, GrapheneFilterSetMixin), (filterset_class, GrapheneFilterSetMixin),
{}, {},
) )
def custom_filterset_factory(model, filterset_base_class=FilterSet, def custom_filterset_factory(model, filterset_base_class=FilterSet, **meta):
**meta):
""" Create a filterset for the given model using the provided meta data """ Create a filterset for the given model using the provided meta data
""" """
meta.update({ meta.update({"model": model})
'model': model, meta_class = type(str("Meta"), (object,), meta)
})
meta_class = type(str('Meta'), (object,), meta)
filterset = type( filterset = type(
str('%sFilterSet' % model._meta.object_name), str("%sFilterSet" % model._meta.object_name),
(filterset_base_class, GrapheneFilterSetMixin), (filterset_base_class, GrapheneFilterSetMixin),
{ {"Meta": meta_class},
'Meta': meta_class
}
) )
return filterset return filterset

View File

@ -5,29 +5,26 @@ from graphene_django.tests.models import Article, Pet, Reporter
class ArticleFilter(django_filters.FilterSet): class ArticleFilter(django_filters.FilterSet):
class Meta: class Meta:
model = Article model = Article
fields = { fields = {
'headline': ['exact', 'icontains'], "headline": ["exact", "icontains"],
'pub_date': ['gt', 'lt', 'exact'], "pub_date": ["gt", "lt", "exact"],
'reporter': ['exact'], "reporter": ["exact"],
} }
order_by = OrderingFilter(fields=('pub_date',)) order_by = OrderingFilter(fields=("pub_date",))
class ReporterFilter(django_filters.FilterSet): class ReporterFilter(django_filters.FilterSet):
class Meta: class Meta:
model = Reporter model = Reporter
fields = ['first_name', 'last_name', 'email', 'pets'] fields = ["first_name", "last_name", "email", "pets"]
order_by = OrderingFilter(fields=('pub_date',)) order_by = OrderingFilter(fields=("pub_date",))
class PetFilter(django_filters.FilterSet): class PetFilter(django_filters.FilterSet):
class Meta: class Meta:
model = Pet model = Pet
fields = ['name'] fields = ["name"]

View File

@ -5,8 +5,7 @@ import pytest
from graphene import Field, ObjectType, Schema, Argument, Float, Boolean, String from graphene import Field, ObjectType, Schema, Argument, Float, Boolean, String
from graphene.relay import Node from graphene.relay import Node
from graphene_django import DjangoObjectType from graphene_django import DjangoObjectType
from graphene_django.forms import (GlobalIDFormField, from graphene_django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField
GlobalIDMultipleChoiceField)
from graphene_django.tests.models import Article, Pet, Reporter from graphene_django.tests.models import Article, Pet, Reporter
from graphene_django.utils import DJANGO_FILTER_INSTALLED from graphene_django.utils import DJANGO_FILTER_INSTALLED
@ -20,36 +19,43 @@ if DJANGO_FILTER_INSTALLED:
import django_filters import django_filters
from django_filters import FilterSet, NumberFilter from django_filters import FilterSet, NumberFilter
from graphene_django.filter import (GlobalIDFilter, DjangoFilterConnectionField, from graphene_django.filter import (
GlobalIDMultipleChoiceFilter) GlobalIDFilter,
from graphene_django.filter.tests.filters import ArticleFilter, PetFilter, ReporterFilter DjangoFilterConnectionField,
GlobalIDMultipleChoiceFilter,
)
from graphene_django.filter.tests.filters import (
ArticleFilter,
PetFilter,
ReporterFilter,
)
else: else:
pytestmark.append(pytest.mark.skipif(True, reason='django_filters not installed or not compatible')) pytestmark.append(
pytest.mark.skipif(
True, reason="django_filters not installed or not compatible"
)
)
pytestmark.append(pytest.mark.django_db) pytestmark.append(pytest.mark.django_db)
if DJANGO_FILTER_INSTALLED: if DJANGO_FILTER_INSTALLED:
class ArticleNode(DjangoObjectType):
class ArticleNode(DjangoObjectType):
class Meta: class Meta:
model = Article model = Article
interfaces = (Node, ) interfaces = (Node,)
filter_fields = ('headline', ) filter_fields = ("headline",)
class ReporterNode(DjangoObjectType): class ReporterNode(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
class PetNode(DjangoObjectType): class PetNode(DjangoObjectType):
class Meta: class Meta:
model = Pet model = Pet
interfaces = (Node, ) interfaces = (Node,)
# schema = Schema() # schema = Schema()
@ -59,58 +65,47 @@ def get_args(field):
def assert_arguments(field, *arguments): def assert_arguments(field, *arguments):
ignore = ('after', 'before', 'first', 'last', 'order_by') ignore = ("after", "before", "first", "last", "order_by")
args = get_args(field) args = get_args(field)
actual = [ actual = [name for name in args if name not in ignore and not name.startswith("_")]
name assert set(arguments) == set(
for name in args actual
if name not in ignore and not name.startswith('_') ), "Expected arguments ({}) did not match actual ({})".format(arguments, actual)
]
assert set(arguments) == set(actual), \
'Expected arguments ({}) did not match actual ({})'.format(
arguments,
actual
)
def assert_orderable(field): def assert_orderable(field):
args = get_args(field) args = get_args(field)
assert 'order_by' in args, \ assert "order_by" in args, "Field cannot be ordered"
'Field cannot be ordered'
def assert_not_orderable(field): def assert_not_orderable(field):
args = get_args(field) args = get_args(field)
assert 'order_by' not in args, \ assert "order_by" not in args, "Field can be ordered"
'Field can be ordered'
def test_filter_explicit_filterset_arguments(): def test_filter_explicit_filterset_arguments():
field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleFilter) field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleFilter)
assert_arguments(field, assert_arguments(
'headline', 'headline__icontains', field,
'pub_date', 'pub_date__gt', 'pub_date__lt', "headline",
'reporter', "headline__icontains",
) "pub_date",
"pub_date__gt",
"pub_date__lt",
"reporter",
)
def test_filter_shortcut_filterset_arguments_list(): def test_filter_shortcut_filterset_arguments_list():
field = DjangoFilterConnectionField(ArticleNode, fields=['pub_date', 'reporter']) field = DjangoFilterConnectionField(ArticleNode, fields=["pub_date", "reporter"])
assert_arguments(field, assert_arguments(field, "pub_date", "reporter")
'pub_date',
'reporter',
)
def test_filter_shortcut_filterset_arguments_dict(): def test_filter_shortcut_filterset_arguments_dict():
field = DjangoFilterConnectionField(ArticleNode, fields={ field = DjangoFilterConnectionField(
'headline': ['exact', 'icontains'], ArticleNode, fields={"headline": ["exact", "icontains"], "reporter": ["exact"]}
'reporter': ['exact'], )
}) assert_arguments(field, "headline", "headline__icontains", "reporter")
assert_arguments(field,
'headline', 'headline__icontains',
'reporter',
)
def test_filter_explicit_filterset_orderable(): def test_filter_explicit_filterset_orderable():
@ -134,15 +129,14 @@ def test_filter_explicit_filterset_not_orderable():
def test_filter_shortcut_filterset_extra_meta(): def test_filter_shortcut_filterset_extra_meta():
field = DjangoFilterConnectionField(ArticleNode, extra_filter_meta={ field = DjangoFilterConnectionField(
'exclude': ('headline', ) ArticleNode, extra_filter_meta={"exclude": ("headline",)}
}) )
assert 'headline' not in field.filterset_class.get_fields() assert "headline" not in field.filterset_class.get_fields()
def test_filter_shortcut_filterset_context(): def test_filter_shortcut_filterset_context():
class ArticleContextFilter(django_filters.FilterSet): class ArticleContextFilter(django_filters.FilterSet):
class Meta: class Meta:
model = Article model = Article
exclude = set() exclude = set()
@ -153,17 +147,31 @@ def test_filter_shortcut_filterset_context():
return qs.filter(reporter=self.request.reporter) return qs.filter(reporter=self.request.reporter)
class Query(ObjectType): class Query(ObjectType):
context_articles = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleContextFilter) context_articles = DjangoFilterConnectionField(
ArticleNode, filterset_class=ArticleContextFilter
)
r1 = Reporter.objects.create(first_name='r1', last_name='r1', email='r1@test.com') r1 = Reporter.objects.create(first_name="r1", last_name="r1", email="r1@test.com")
r2 = Reporter.objects.create(first_name='r2', last_name='r2', email='r2@test.com') r2 = Reporter.objects.create(first_name="r2", last_name="r2", email="r2@test.com")
Article.objects.create(headline='a1', pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r1, editor=r1) Article.objects.create(
Article.objects.create(headline='a2', pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r2, editor=r2) headline="a1",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r1,
editor=r1,
)
Article.objects.create(
headline="a2",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r2,
editor=r2,
)
class context(object): class context(object):
reporter = r2 reporter = r2
query = ''' query = """
query { query {
contextArticles { contextArticles {
edges { edges {
@ -173,42 +181,39 @@ def test_filter_shortcut_filterset_context():
} }
} }
} }
''' """
schema = Schema(query=Query) schema = Schema(query=Query)
result = schema.execute(query, context_value=context()) result = schema.execute(query, context_value=context())
assert not result.errors assert not result.errors
assert len(result.data['contextArticles']['edges']) == 1 assert len(result.data["contextArticles"]["edges"]) == 1
assert result.data['contextArticles']['edges'][0]['node']['headline'] == 'a2' assert result.data["contextArticles"]["edges"][0]["node"]["headline"] == "a2"
def test_filter_filterset_information_on_meta(): def test_filter_filterset_information_on_meta():
class ReporterFilterNode(DjangoObjectType): class ReporterFilterNode(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
filter_fields = ['first_name', 'articles'] filter_fields = ["first_name", "articles"]
field = DjangoFilterConnectionField(ReporterFilterNode) field = DjangoFilterConnectionField(ReporterFilterNode)
assert_arguments(field, 'first_name', 'articles') assert_arguments(field, "first_name", "articles")
assert_not_orderable(field) assert_not_orderable(field)
def test_filter_filterset_information_on_meta_related(): def test_filter_filterset_information_on_meta_related():
class ReporterFilterNode(DjangoObjectType): class ReporterFilterNode(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
filter_fields = ['first_name', 'articles'] filter_fields = ["first_name", "articles"]
class ArticleFilterNode(DjangoObjectType): class ArticleFilterNode(DjangoObjectType):
class Meta: class Meta:
model = Article model = Article
interfaces = (Node, ) interfaces = (Node,)
filter_fields = ['headline', 'reporter'] filter_fields = ["headline", "reporter"]
class Query(ObjectType): class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode) all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
@ -217,25 +222,23 @@ def test_filter_filterset_information_on_meta_related():
article = Field(ArticleFilterNode) article = Field(ArticleFilterNode)
schema = Schema(query=Query) schema = Schema(query=Query)
articles_field = ReporterFilterNode._meta.fields['articles'].get_type() articles_field = ReporterFilterNode._meta.fields["articles"].get_type()
assert_arguments(articles_field, 'headline', 'reporter') assert_arguments(articles_field, "headline", "reporter")
assert_not_orderable(articles_field) assert_not_orderable(articles_field)
def test_filter_filterset_related_results(): def test_filter_filterset_related_results():
class ReporterFilterNode(DjangoObjectType): class ReporterFilterNode(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
filter_fields = ['first_name', 'articles'] filter_fields = ["first_name", "articles"]
class ArticleFilterNode(DjangoObjectType): class ArticleFilterNode(DjangoObjectType):
class Meta: class Meta:
interfaces = (Node, ) interfaces = (Node,)
model = Article model = Article
filter_fields = ['headline', 'reporter'] filter_fields = ["headline", "reporter"]
class Query(ObjectType): class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode) all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
@ -243,12 +246,22 @@ def test_filter_filterset_related_results():
reporter = Field(ReporterFilterNode) reporter = Field(ReporterFilterNode)
article = Field(ArticleFilterNode) article = Field(ArticleFilterNode)
r1 = Reporter.objects.create(first_name='r1', last_name='r1', email='r1@test.com') r1 = Reporter.objects.create(first_name="r1", last_name="r1", email="r1@test.com")
r2 = Reporter.objects.create(first_name='r2', last_name='r2', email='r2@test.com') r2 = Reporter.objects.create(first_name="r2", last_name="r2", email="r2@test.com")
Article.objects.create(headline='a1', pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r1) Article.objects.create(
Article.objects.create(headline='a2', pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r2) headline="a1",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r1,
)
Article.objects.create(
headline="a2",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r2,
)
query = ''' query = """
query { query {
allReporters { allReporters {
edges { edges {
@ -264,123 +277,134 @@ def test_filter_filterset_related_results():
} }
} }
} }
''' """
schema = Schema(query=Query) schema = Schema(query=Query)
result = schema.execute(query) result = schema.execute(query)
assert not result.errors assert not result.errors
# We should only get back a single article for each reporter # We should only get back a single article for each reporter
assert len(result.data['allReporters']['edges'][0]['node']['articles']['edges']) == 1 assert (
assert len(result.data['allReporters']['edges'][1]['node']['articles']['edges']) == 1 len(result.data["allReporters"]["edges"][0]["node"]["articles"]["edges"]) == 1
)
assert (
len(result.data["allReporters"]["edges"][1]["node"]["articles"]["edges"]) == 1
)
def test_global_id_field_implicit(): def test_global_id_field_implicit():
field = DjangoFilterConnectionField(ArticleNode, fields=['id']) field = DjangoFilterConnectionField(ArticleNode, fields=["id"])
filterset_class = field.filterset_class filterset_class = field.filterset_class
id_filter = filterset_class.base_filters['id'] id_filter = filterset_class.base_filters["id"]
assert isinstance(id_filter, GlobalIDFilter) assert isinstance(id_filter, GlobalIDFilter)
assert id_filter.field_class == GlobalIDFormField assert id_filter.field_class == GlobalIDFormField
def test_global_id_field_explicit(): def test_global_id_field_explicit():
class ArticleIdFilter(django_filters.FilterSet): class ArticleIdFilter(django_filters.FilterSet):
class Meta: class Meta:
model = Article model = Article
fields = ['id'] fields = ["id"]
field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter) field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter)
filterset_class = field.filterset_class filterset_class = field.filterset_class
id_filter = filterset_class.base_filters['id'] id_filter = filterset_class.base_filters["id"]
assert isinstance(id_filter, GlobalIDFilter) assert isinstance(id_filter, GlobalIDFilter)
assert id_filter.field_class == GlobalIDFormField assert id_filter.field_class == GlobalIDFormField
def test_filterset_descriptions(): def test_filterset_descriptions():
class ArticleIdFilter(django_filters.FilterSet): class ArticleIdFilter(django_filters.FilterSet):
class Meta: class Meta:
model = Article model = Article
fields = ['id'] fields = ["id"]
max_time = django_filters.NumberFilter(method='filter_max_time', label="The maximum time") max_time = django_filters.NumberFilter(
method="filter_max_time", label="The maximum time"
)
field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter) field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter)
max_time = field.args['max_time'] max_time = field.args["max_time"]
assert isinstance(max_time, Argument) assert isinstance(max_time, Argument)
assert max_time.type == Float assert max_time.type == Float
assert max_time.description == 'The maximum time' assert max_time.description == "The maximum time"
def test_global_id_field_relation(): def test_global_id_field_relation():
field = DjangoFilterConnectionField(ArticleNode, fields=['reporter']) field = DjangoFilterConnectionField(ArticleNode, fields=["reporter"])
filterset_class = field.filterset_class filterset_class = field.filterset_class
id_filter = filterset_class.base_filters['reporter'] id_filter = filterset_class.base_filters["reporter"]
assert isinstance(id_filter, GlobalIDFilter) assert isinstance(id_filter, GlobalIDFilter)
assert id_filter.field_class == GlobalIDFormField assert id_filter.field_class == GlobalIDFormField
def test_global_id_multiple_field_implicit(): def test_global_id_multiple_field_implicit():
field = DjangoFilterConnectionField(ReporterNode, fields=['pets']) field = DjangoFilterConnectionField(ReporterNode, fields=["pets"])
filterset_class = field.filterset_class filterset_class = field.filterset_class
multiple_filter = filterset_class.base_filters['pets'] multiple_filter = filterset_class.base_filters["pets"]
assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter) assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter)
assert multiple_filter.field_class == GlobalIDMultipleChoiceField assert multiple_filter.field_class == GlobalIDMultipleChoiceField
def test_global_id_multiple_field_explicit(): def test_global_id_multiple_field_explicit():
class ReporterPetsFilter(django_filters.FilterSet): class ReporterPetsFilter(django_filters.FilterSet):
class Meta: class Meta:
model = Reporter model = Reporter
fields = ['pets'] fields = ["pets"]
field = DjangoFilterConnectionField(ReporterNode, filterset_class=ReporterPetsFilter) field = DjangoFilterConnectionField(
ReporterNode, filterset_class=ReporterPetsFilter
)
filterset_class = field.filterset_class filterset_class = field.filterset_class
multiple_filter = filterset_class.base_filters['pets'] multiple_filter = filterset_class.base_filters["pets"]
assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter) assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter)
assert multiple_filter.field_class == GlobalIDMultipleChoiceField assert multiple_filter.field_class == GlobalIDMultipleChoiceField
def test_global_id_multiple_field_implicit_reverse(): def test_global_id_multiple_field_implicit_reverse():
field = DjangoFilterConnectionField(ReporterNode, fields=['articles']) field = DjangoFilterConnectionField(ReporterNode, fields=["articles"])
filterset_class = field.filterset_class filterset_class = field.filterset_class
multiple_filter = filterset_class.base_filters['articles'] multiple_filter = filterset_class.base_filters["articles"]
assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter) assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter)
assert multiple_filter.field_class == GlobalIDMultipleChoiceField assert multiple_filter.field_class == GlobalIDMultipleChoiceField
def test_global_id_multiple_field_explicit_reverse(): def test_global_id_multiple_field_explicit_reverse():
class ReporterPetsFilter(django_filters.FilterSet): class ReporterPetsFilter(django_filters.FilterSet):
class Meta: class Meta:
model = Reporter model = Reporter
fields = ['articles'] fields = ["articles"]
field = DjangoFilterConnectionField(ReporterNode, filterset_class=ReporterPetsFilter) field = DjangoFilterConnectionField(
ReporterNode, filterset_class=ReporterPetsFilter
)
filterset_class = field.filterset_class filterset_class = field.filterset_class
multiple_filter = filterset_class.base_filters['articles'] multiple_filter = filterset_class.base_filters["articles"]
assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter) assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter)
assert multiple_filter.field_class == GlobalIDMultipleChoiceField assert multiple_filter.field_class == GlobalIDMultipleChoiceField
def test_filter_filterset_related_results(): def test_filter_filterset_related_results():
class ReporterFilterNode(DjangoObjectType): class ReporterFilterNode(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
filter_fields = { filter_fields = {"first_name": ["icontains"]}
'first_name': ['icontains']
}
class Query(ObjectType): class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode) all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
r1 = Reporter.objects.create(first_name='A test user', last_name='Last Name', email='test1@test.com') r1 = Reporter.objects.create(
r2 = Reporter.objects.create(first_name='Other test user', last_name='Other Last Name', email='test2@test.com') first_name="A test user", last_name="Last Name", email="test1@test.com"
r3 = Reporter.objects.create(first_name='Random', last_name='RandomLast', email='random@test.com') )
r2 = Reporter.objects.create(
first_name="Other test user",
last_name="Other Last Name",
email="test2@test.com",
)
r3 = Reporter.objects.create(
first_name="Random", last_name="RandomLast", email="random@test.com"
)
query = ''' query = """
query { query {
allReporters(firstName_Icontains: "test") { allReporters(firstName_Icontains: "test") {
edges { edges {
@ -390,12 +414,12 @@ def test_filter_filterset_related_results():
} }
} }
} }
''' """
schema = Schema(query=Query) schema = Schema(query=Query)
result = schema.execute(query) result = schema.execute(query)
assert not result.errors assert not result.errors
# We should only get two reporters # We should only get two reporters
assert len(result.data['allReporters']['edges']) == 2 assert len(result.data["allReporters"]["edges"]) == 2
def test_recursive_filter_connection(): def test_recursive_filter_connection():
@ -407,79 +431,73 @@ def test_recursive_filter_connection():
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
class Query(ObjectType): class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode) all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
assert ReporterFilterNode._meta.fields['child_reporters'].node_type == ReporterFilterNode assert (
ReporterFilterNode._meta.fields["child_reporters"].node_type
== ReporterFilterNode
)
def test_should_query_filter_node_limit(): def test_should_query_filter_node_limit():
class ReporterFilter(FilterSet): class ReporterFilter(FilterSet):
limit = NumberFilter(method='filter_limit') limit = NumberFilter(method="filter_limit")
def filter_limit(self, queryset, name, value): def filter_limit(self, queryset, name, value):
return queryset[:value] return queryset[:value]
class Meta: class Meta:
model = Reporter model = Reporter
fields = ['first_name', ] fields = ["first_name"]
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
class ArticleType(DjangoObjectType): class ArticleType(DjangoObjectType):
class Meta: class Meta:
model = Article model = Article
interfaces = (Node, ) interfaces = (Node,)
filter_fields = ('lang', ) filter_fields = ("lang",)
class Query(ObjectType): class Query(ObjectType):
all_reporters = DjangoFilterConnectionField( all_reporters = DjangoFilterConnectionField(
ReporterType, ReporterType, filterset_class=ReporterFilter
filterset_class=ReporterFilter
) )
def resolve_all_reporters(self, info, **args): def resolve_all_reporters(self, info, **args):
return Reporter.objects.order_by('a_choice') return Reporter.objects.order_by("a_choice")
Reporter.objects.create( Reporter.objects.create(
first_name='Bob', first_name="Bob", last_name="Doe", email="bobdoe@example.com", a_choice=2
last_name='Doe',
email='bobdoe@example.com',
a_choice=2
) )
r = Reporter.objects.create( r = Reporter.objects.create(
first_name='John', first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
last_name='Doe',
email='johndoe@example.com',
a_choice=1
) )
Article.objects.create( Article.objects.create(
headline='Article Node 1', headline="Article Node 1",
pub_date=datetime.now(), pub_date=datetime.now(),
pub_date_time=datetime.now(), pub_date_time=datetime.now(),
reporter=r, reporter=r,
editor=r, editor=r,
lang='es' lang="es",
) )
Article.objects.create( Article.objects.create(
headline='Article Node 2', headline="Article Node 2",
pub_date=datetime.now(), pub_date=datetime.now(),
pub_date_time=datetime.now(), pub_date_time=datetime.now(),
reporter=r, reporter=r,
editor=r, editor=r,
lang='en' lang="en",
) )
schema = Schema(query=Query) schema = Schema(query=Query)
query = ''' query = """
query NodeFilteringQuery { query NodeFilteringQuery {
allReporters(limit: 1) { allReporters(limit: 1) {
edges { edges {
@ -498,24 +516,23 @@ def test_should_query_filter_node_limit():
} }
} }
} }
''' """
expected = { expected = {
'allReporters': { "allReporters": {
'edges': [{ "edges": [
'node': { {
'id': 'UmVwb3J0ZXJUeXBlOjI=', "node": {
'firstName': 'John', "id": "UmVwb3J0ZXJUeXBlOjI=",
'articles': { "firstName": "John",
'edges': [{ "articles": {
'node': { "edges": [
'id': 'QXJ0aWNsZVR5cGU6MQ==', {"node": {"id": "QXJ0aWNsZVR5cGU6MQ==", "lang": "ES"}}
'lang': 'ES' ]
} },
}]
} }
} }
}] ]
} }
} }
@ -526,45 +543,37 @@ def test_should_query_filter_node_limit():
def test_should_query_filter_node_double_limit_raises(): def test_should_query_filter_node_double_limit_raises():
class ReporterFilter(FilterSet): class ReporterFilter(FilterSet):
limit = NumberFilter(method='filter_limit') limit = NumberFilter(method="filter_limit")
def filter_limit(self, queryset, name, value): def filter_limit(self, queryset, name, value):
return queryset[:value] return queryset[:value]
class Meta: class Meta:
model = Reporter model = Reporter
fields = ['first_name', ] fields = ["first_name"]
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
class Query(ObjectType): class Query(ObjectType):
all_reporters = DjangoFilterConnectionField( all_reporters = DjangoFilterConnectionField(
ReporterType, ReporterType, filterset_class=ReporterFilter
filterset_class=ReporterFilter
) )
def resolve_all_reporters(self, info, **args): def resolve_all_reporters(self, info, **args):
return Reporter.objects.order_by('a_choice')[:2] return Reporter.objects.order_by("a_choice")[:2]
Reporter.objects.create( Reporter.objects.create(
first_name='Bob', first_name="Bob", last_name="Doe", email="bobdoe@example.com", a_choice=2
last_name='Doe',
email='bobdoe@example.com',
a_choice=2
) )
r = Reporter.objects.create( r = Reporter.objects.create(
first_name='John', first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
last_name='Doe',
email='johndoe@example.com',
a_choice=1
) )
schema = Schema(query=Query) schema = Schema(query=Query)
query = ''' query = """
query NodeFilteringQuery { query NodeFilteringQuery {
allReporters(limit: 1) { allReporters(limit: 1) {
edges { edges {
@ -575,41 +584,40 @@ def test_should_query_filter_node_double_limit_raises():
} }
} }
} }
''' """
result = schema.execute(query) result = schema.execute(query)
assert len(result.errors) == 1 assert len(result.errors) == 1
assert str(result.errors[0]) == ( assert str(result.errors[0]) == (
'Received two sliced querysets (high mark) in the connection, please slice only in one.' "Received two sliced querysets (high mark) in the connection, please slice only in one."
) )
def test_order_by_is_perserved(): def test_order_by_is_perserved():
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
filter_fields = () filter_fields = ()
class Query(ObjectType): class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterType, reverse_order=Boolean()) all_reporters = DjangoFilterConnectionField(
ReporterType, reverse_order=Boolean()
)
def resolve_all_reporters(self, info, reverse_order=False, **args): def resolve_all_reporters(self, info, reverse_order=False, **args):
reporters = Reporter.objects.order_by('first_name') reporters = Reporter.objects.order_by("first_name")
if reverse_order: if reverse_order:
return reporters.reverse() return reporters.reverse()
return reporters return reporters
Reporter.objects.create( Reporter.objects.create(first_name="b")
first_name='b', r = Reporter.objects.create(first_name="a")
)
r = Reporter.objects.create(
first_name='a',
)
schema = Schema(query=Query) schema = Schema(query=Query)
query = ''' query = """
query NodeFilteringQuery { query NodeFilteringQuery {
allReporters(first: 1) { allReporters(first: 1) {
edges { edges {
@ -619,23 +627,14 @@ def test_order_by_is_perserved():
} }
} }
} }
''' """
expected = { expected = {"allReporters": {"edges": [{"node": {"firstName": "a"}}]}}
'allReporters': {
'edges': [{
'node': {
'firstName': 'a',
}
}]
}
}
result = schema.execute(query) result = schema.execute(query)
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
reverse_query = """
reverse_query = '''
query NodeFilteringQuery { query NodeFilteringQuery {
allReporters(first: 1, reverseOrder: true) { allReporters(first: 1, reverseOrder: true) {
edges { edges {
@ -645,33 +644,26 @@ def test_order_by_is_perserved():
} }
} }
} }
''' """
reverse_expected = { reverse_expected = {"allReporters": {"edges": [{"node": {"firstName": "b"}}]}}
'allReporters': {
'edges': [{
'node': {
'firstName': 'b',
}
}]
}
}
reverse_result = schema.execute(reverse_query) reverse_result = schema.execute(reverse_query)
assert not reverse_result.errors assert not reverse_result.errors
assert reverse_result.data == reverse_expected assert reverse_result.data == reverse_expected
def test_annotation_is_perserved(): def test_annotation_is_perserved():
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
full_name = String() full_name = String()
def resolve_full_name(instance, info, **args): def resolve_full_name(instance, info, **args):
return instance.full_name return instance.full_name
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
filter_fields = () filter_fields = ()
class Query(ObjectType): class Query(ObjectType):
@ -679,17 +671,16 @@ def test_annotation_is_perserved():
def resolve_all_reporters(self, info, **args): def resolve_all_reporters(self, info, **args):
return Reporter.objects.annotate( return Reporter.objects.annotate(
full_name=Concat('first_name', Value(' '), 'last_name', output_field=TextField()) full_name=Concat(
"first_name", Value(" "), "last_name", output_field=TextField()
)
) )
Reporter.objects.create( Reporter.objects.create(first_name="John", last_name="Doe")
first_name='John',
last_name='Doe',
)
schema = Schema(query=Query) schema = Schema(query=Query)
query = ''' query = """
query NodeFilteringQuery { query NodeFilteringQuery {
allReporters(first: 1) { allReporters(first: 1) {
edges { edges {
@ -699,16 +690,8 @@ def test_annotation_is_perserved():
} }
} }
} }
''' """
expected = { expected = {"allReporters": {"edges": [{"node": {"fullName": "John Doe"}}]}}
'allReporters': {
'edges': [{
'node': {
'fullName': 'John Doe',
}
}]
}
}
result = schema.execute(query) result = schema.execute(query)

View File

@ -14,8 +14,7 @@ singledispatch = import_single_dispatch()
def convert_form_field(field): def convert_form_field(field):
raise ImproperlyConfigured( raise ImproperlyConfigured(
"Don't know how to convert the Django form field %s (%s) " "Don't know how to convert the Django form field %s (%s) "
"to Graphene type" % "to Graphene type" % (field, field.__class__)
(field, field.__class__)
) )

View File

@ -8,9 +8,7 @@ from graphql_relay import from_global_id
class GlobalIDFormField(Field): class GlobalIDFormField(Field):
default_error_messages = { default_error_messages = {"invalid": _("Invalid ID specified.")}
'invalid': _('Invalid ID specified.'),
}
def clean(self, value): def clean(self, value):
if not value and not self.required: if not value and not self.required:
@ -19,21 +17,21 @@ class GlobalIDFormField(Field):
try: try:
_type, _id = from_global_id(value) _type, _id = from_global_id(value)
except (TypeError, ValueError, UnicodeDecodeError, binascii.Error): except (TypeError, ValueError, UnicodeDecodeError, binascii.Error):
raise ValidationError(self.error_messages['invalid']) raise ValidationError(self.error_messages["invalid"])
try: try:
CharField().clean(_id) CharField().clean(_id)
CharField().clean(_type) CharField().clean(_type)
except ValidationError: except ValidationError:
raise ValidationError(self.error_messages['invalid']) raise ValidationError(self.error_messages["invalid"])
return value return value
class GlobalIDMultipleChoiceField(MultipleChoiceField): class GlobalIDMultipleChoiceField(MultipleChoiceField):
default_error_messages = { default_error_messages = {
'invalid_choice': _('One of the specified IDs was invalid (%(value)s).'), "invalid_choice": _("One of the specified IDs was invalid (%(value)s)."),
'invalid_list': _('Enter a list of values.'), "invalid_list": _("Enter a list of values."),
} }
def valid_value(self, value): def valid_value(self, value):

View File

@ -5,6 +5,7 @@ import graphene
from graphene import Field, InputField from graphene import Field, InputField
from graphene.relay.mutation import ClientIDMutation from graphene.relay.mutation import ClientIDMutation
from graphene.types.mutation import MutationOptions from graphene.types.mutation import MutationOptions
# from graphene.types.inputobjecttype import ( # from graphene.types.inputobjecttype import (
# InputObjectTypeOptions, # InputObjectTypeOptions,
# InputObjectType, # InputObjectType,
@ -21,7 +22,8 @@ def fields_for_form(form, only_fields, exclude_fields):
for name, field in form.fields.items(): for name, field in form.fields.items():
is_not_in_only = only_fields and name not in only_fields is_not_in_only = only_fields and name not in only_fields
is_excluded = ( is_excluded = (
name in exclude_fields # or name
in exclude_fields # or
# name in already_created_fields # name in already_created_fields
) )
@ -57,12 +59,12 @@ class BaseDjangoFormMutation(ClientIDMutation):
@classmethod @classmethod
def get_form_kwargs(cls, root, info, **input): def get_form_kwargs(cls, root, info, **input):
kwargs = {'data': input} kwargs = {"data": input}
pk = input.pop('id', None) pk = input.pop("id", None)
if pk: if pk:
instance = cls._meta.model._default_manager.get(pk=pk) instance = cls._meta.model._default_manager.get(pk=pk)
kwargs['instance'] = instance kwargs["instance"] = instance
return kwargs return kwargs
@ -100,11 +102,12 @@ class DjangoFormMutation(BaseDjangoFormMutation):
errors = graphene.List(ErrorType) errors = graphene.List(ErrorType)
@classmethod @classmethod
def __init_subclass_with_meta__(cls, form_class=None, def __init_subclass_with_meta__(
only_fields=(), exclude_fields=(), **options): cls, form_class=None, only_fields=(), exclude_fields=(), **options
):
if not form_class: if not form_class:
raise Exception('form_class is required for DjangoFormMutation') raise Exception("form_class is required for DjangoFormMutation")
form = form_class() form = form_class()
input_fields = fields_for_form(form, only_fields, exclude_fields) input_fields = fields_for_form(form, only_fields, exclude_fields)
@ -112,16 +115,12 @@ class DjangoFormMutation(BaseDjangoFormMutation):
_meta = DjangoFormMutationOptions(cls) _meta = DjangoFormMutationOptions(cls)
_meta.form_class = form_class _meta.form_class = form_class
_meta.fields = yank_fields_from_attrs( _meta.fields = yank_fields_from_attrs(output_fields, _as=Field)
output_fields,
_as=Field,
)
input_fields = yank_fields_from_attrs( input_fields = yank_fields_from_attrs(input_fields, _as=InputField)
input_fields, super(DjangoFormMutation, cls).__init_subclass_with_meta__(
_as=InputField, _meta=_meta, input_fields=input_fields, **options
) )
super(DjangoFormMutation, cls).__init_subclass_with_meta__(_meta=_meta, input_fields=input_fields, **options)
@classmethod @classmethod
def perform_mutate(cls, form, info): def perform_mutate(cls, form, info):
@ -141,21 +140,28 @@ class DjangoModelFormMutation(BaseDjangoFormMutation):
errors = graphene.List(ErrorType) errors = graphene.List(ErrorType)
@classmethod @classmethod
def __init_subclass_with_meta__(cls, form_class=None, model=None, return_field_name=None, def __init_subclass_with_meta__(
only_fields=(), exclude_fields=(), **options): cls,
form_class=None,
model=None,
return_field_name=None,
only_fields=(),
exclude_fields=(),
**options
):
if not form_class: if not form_class:
raise Exception('form_class is required for DjangoModelFormMutation') raise Exception("form_class is required for DjangoModelFormMutation")
if not model: if not model:
model = form_class._meta.model model = form_class._meta.model
if not model: if not model:
raise Exception('model is required for DjangoModelFormMutation') raise Exception("model is required for DjangoModelFormMutation")
form = form_class() form = form_class()
input_fields = fields_for_form(form, only_fields, exclude_fields) input_fields = fields_for_form(form, only_fields, exclude_fields)
input_fields['id'] = graphene.ID() input_fields["id"] = graphene.ID()
registry = get_global_registry() registry = get_global_registry()
model_type = registry.get_type_for_model(model) model_type = registry.get_type_for_model(model)
@ -171,19 +177,11 @@ class DjangoModelFormMutation(BaseDjangoFormMutation):
_meta.form_class = form_class _meta.form_class = form_class
_meta.model = model _meta.model = model
_meta.return_field_name = return_field_name _meta.return_field_name = return_field_name
_meta.fields = yank_fields_from_attrs( _meta.fields = yank_fields_from_attrs(output_fields, _as=Field)
output_fields,
_as=Field,
)
input_fields = yank_fields_from_attrs( input_fields = yank_fields_from_attrs(input_fields, _as=InputField)
input_fields,
_as=InputField,
)
super(DjangoModelFormMutation, cls).__init_subclass_with_meta__( super(DjangoModelFormMutation, cls).__init_subclass_with_meta__(
_meta=_meta, _meta=_meta, input_fields=input_fields, **options
input_fields=input_fields,
**options
) )
@classmethod @classmethod

View File

@ -2,24 +2,36 @@ from django import forms
from py.test import raises from py.test import raises
import graphene import graphene
from graphene import String, Int, Boolean, Float, ID, UUID, List, NonNull, DateTime, Date, Time from graphene import (
String,
Int,
Boolean,
Float,
ID,
UUID,
List,
NonNull,
DateTime,
Date,
Time,
)
from ..converter import convert_form_field from ..converter import convert_form_field
def assert_conversion(django_field, graphene_field, *args): def assert_conversion(django_field, graphene_field, *args):
field = django_field(*args, help_text='Custom Help Text') field = django_field(*args, help_text="Custom Help Text")
graphene_type = convert_form_field(field) graphene_type = convert_form_field(field)
assert isinstance(graphene_type, graphene_field) assert isinstance(graphene_type, graphene_field)
field = graphene_type.Field() field = graphene_type.Field()
assert field.description == 'Custom Help Text' assert field.description == "Custom Help Text"
return field return field
def test_should_unknown_django_field_raise_exception(): def test_should_unknown_django_field_raise_exception():
with raises(Exception) as excinfo: with raises(Exception) as excinfo:
convert_form_field(None) convert_form_field(None)
assert 'Don\'t know how to convert the Django form field' in str(excinfo.value) assert "Don't know how to convert the Django form field" in str(excinfo.value)
def test_should_date_convert_date(): def test_should_date_convert_date():
@ -59,11 +71,11 @@ def test_should_base_field_convert_string():
def test_should_regex_convert_string(): def test_should_regex_convert_string():
assert_conversion(forms.RegexField, String, '[0-9]+') assert_conversion(forms.RegexField, String, "[0-9]+")
def test_should_uuid_convert_string(): def test_should_uuid_convert_string():
if hasattr(forms, 'UUIDField'): if hasattr(forms, "UUIDField"):
assert_conversion(forms.UUIDField, UUID) assert_conversion(forms.UUIDField, UUID)

View File

@ -11,18 +11,18 @@ class MyForm(forms.Form):
class PetForm(forms.ModelForm): class PetForm(forms.ModelForm):
class Meta: class Meta:
model = Pet model = Pet
fields = ('name',) fields = ("name",)
def test_needs_form_class(): def test_needs_form_class():
with raises(Exception) as exc: with raises(Exception) as exc:
class MyMutation(DjangoFormMutation): class MyMutation(DjangoFormMutation):
pass pass
assert exc.value.args[0] == 'form_class is required for DjangoFormMutation' assert exc.value.args[0] == "form_class is required for DjangoFormMutation"
def test_has_output_fields(): def test_has_output_fields():
@ -30,7 +30,7 @@ def test_has_output_fields():
class Meta: class Meta:
form_class = MyForm form_class = MyForm
assert 'errors' in MyMutation._meta.fields assert "errors" in MyMutation._meta.fields
def test_has_input_fields(): def test_has_input_fields():
@ -38,19 +38,18 @@ def test_has_input_fields():
class Meta: class Meta:
form_class = MyForm form_class = MyForm
assert 'text' in MyMutation.Input._meta.fields assert "text" in MyMutation.Input._meta.fields
class ModelFormMutationTests(TestCase): class ModelFormMutationTests(TestCase):
def test_default_meta_fields(self): def test_default_meta_fields(self):
class PetMutation(DjangoModelFormMutation): class PetMutation(DjangoModelFormMutation):
class Meta: class Meta:
form_class = PetForm form_class = PetForm
self.assertEqual(PetMutation._meta.model, Pet) self.assertEqual(PetMutation._meta.model, Pet)
self.assertEqual(PetMutation._meta.return_field_name, 'pet') self.assertEqual(PetMutation._meta.return_field_name, "pet")
self.assertIn('pet', PetMutation._meta.fields) self.assertIn("pet", PetMutation._meta.fields)
def test_return_field_name_is_camelcased(self): def test_return_field_name_is_camelcased(self):
class PetMutation(DjangoModelFormMutation): class PetMutation(DjangoModelFormMutation):
@ -59,31 +58,31 @@ class ModelFormMutationTests(TestCase):
model = FilmDetails model = FilmDetails
self.assertEqual(PetMutation._meta.model, FilmDetails) self.assertEqual(PetMutation._meta.model, FilmDetails)
self.assertEqual(PetMutation._meta.return_field_name, 'filmDetails') self.assertEqual(PetMutation._meta.return_field_name, "filmDetails")
def test_custom_return_field_name(self): def test_custom_return_field_name(self):
class PetMutation(DjangoModelFormMutation): class PetMutation(DjangoModelFormMutation):
class Meta: class Meta:
form_class = PetForm form_class = PetForm
model = Film model = Film
return_field_name = 'animal' return_field_name = "animal"
self.assertEqual(PetMutation._meta.model, Film) self.assertEqual(PetMutation._meta.model, Film)
self.assertEqual(PetMutation._meta.return_field_name, 'animal') self.assertEqual(PetMutation._meta.return_field_name, "animal")
self.assertIn('animal', PetMutation._meta.fields) self.assertIn("animal", PetMutation._meta.fields)
def test_model_form_mutation_mutate(self): def test_model_form_mutation_mutate(self):
class PetMutation(DjangoModelFormMutation): class PetMutation(DjangoModelFormMutation):
class Meta: class Meta:
form_class = PetForm form_class = PetForm
pet = Pet.objects.create(name='Axel') pet = Pet.objects.create(name="Axel")
result = PetMutation.mutate_and_get_payload(None, None, id=pet.pk, name='Mia') result = PetMutation.mutate_and_get_payload(None, None, id=pet.pk, name="Mia")
self.assertEqual(Pet.objects.count(), 1) self.assertEqual(Pet.objects.count(), 1)
pet.refresh_from_db() pet.refresh_from_db()
self.assertEqual(pet.name, 'Mia') self.assertEqual(pet.name, "Mia")
self.assertEqual(result.errors, []) self.assertEqual(result.errors, [])
def test_model_form_mutation_updates_existing_(self): def test_model_form_mutation_updates_existing_(self):
@ -91,11 +90,11 @@ class ModelFormMutationTests(TestCase):
class Meta: class Meta:
form_class = PetForm form_class = PetForm
result = PetMutation.mutate_and_get_payload(None, None, name='Mia') result = PetMutation.mutate_and_get_payload(None, None, name="Mia")
self.assertEqual(Pet.objects.count(), 1) self.assertEqual(Pet.objects.count(), 1)
pet = Pet.objects.get() pet = Pet.objects.get()
self.assertEqual(pet.name, 'Mia') self.assertEqual(pet.name, "Mia")
self.assertEqual(result.errors, []) self.assertEqual(result.errors, [])
def test_model_form_mutation_mutate_invalid_form(self): def test_model_form_mutation_mutate_invalid_form(self):
@ -109,5 +108,5 @@ class ModelFormMutationTests(TestCase):
self.assertEqual(Pet.objects.count(), 0) self.assertEqual(Pet.objects.count(), 0)
self.assertEqual(len(result.errors), 1) self.assertEqual(len(result.errors), 1)
self.assertEqual(result.errors[0].field, 'name') self.assertEqual(result.errors[0].field, "name")
self.assertEqual(result.errors[0].messages, ['This field is required.']) self.assertEqual(result.errors[0].messages, ["This field is required."])

View File

@ -7,43 +7,45 @@ from graphene_django.settings import graphene_settings
class CommandArguments(BaseCommand): class CommandArguments(BaseCommand):
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument( parser.add_argument(
'--schema', "--schema",
type=str, type=str,
dest='schema', dest="schema",
default=graphene_settings.SCHEMA, default=graphene_settings.SCHEMA,
help='Django app containing schema to dump, e.g. myproject.core.schema.schema') help="Django app containing schema to dump, e.g. myproject.core.schema.schema",
)
parser.add_argument( parser.add_argument(
'--out', "--out",
type=str, type=str,
dest='out', dest="out",
default=graphene_settings.SCHEMA_OUTPUT, default=graphene_settings.SCHEMA_OUTPUT,
help='Output file (default: schema.json)') help="Output file (default: schema.json)",
)
parser.add_argument( parser.add_argument(
'--indent', "--indent",
type=int, type=int,
dest='indent', dest="indent",
default=graphene_settings.SCHEMA_INDENT, default=graphene_settings.SCHEMA_INDENT,
help='Output file indent (default: None)') help="Output file indent (default: None)",
)
class Command(CommandArguments): class Command(CommandArguments):
help = 'Dump Graphene schema JSON to file' help = "Dump Graphene schema JSON to file"
can_import_settings = True can_import_settings = True
def save_file(self, out, schema_dict, indent): def save_file(self, out, schema_dict, indent):
with open(out, 'w') as outfile: with open(out, "w") as outfile:
json.dump(schema_dict, outfile, indent=indent) json.dump(schema_dict, outfile, indent=indent)
def handle(self, *args, **options): def handle(self, *args, **options):
options_schema = options.get('schema') options_schema = options.get("schema")
if options_schema and type(options_schema) is str: if options_schema and type(options_schema) is str:
module_str, schema_name = options_schema.rsplit('.', 1) module_str, schema_name = options_schema.rsplit(".", 1)
mod = importlib.import_module(module_str) mod = importlib.import_module(module_str)
schema = getattr(mod, schema_name) schema = getattr(mod, schema_name)
@ -53,16 +55,18 @@ class Command(CommandArguments):
else: else:
schema = graphene_settings.SCHEMA schema = graphene_settings.SCHEMA
out = options.get('out') or graphene_settings.SCHEMA_OUTPUT out = options.get("out") or graphene_settings.SCHEMA_OUTPUT
if not schema: if not schema:
raise CommandError('Specify schema on GRAPHENE.SCHEMA setting or by using --schema') raise CommandError(
"Specify schema on GRAPHENE.SCHEMA setting or by using --schema"
)
indent = options.get('indent') indent = options.get("indent")
schema_dict = {'data': schema.introspect()} schema_dict = {"data": schema.introspect()}
self.save_file(out, schema_dict, indent) self.save_file(out, schema_dict, indent)
style = getattr(self, 'style', None) style = getattr(self, "style", None)
success = getattr(style, 'SUCCESS', lambda x: x) success = getattr(style, "SUCCESS", lambda x: x)
self.stdout.write(success('Successfully dumped GraphQL schema to %s' % out)) self.stdout.write(success("Successfully dumped GraphQL schema to %s" % out))

View File

@ -1,20 +1,21 @@
class Registry(object): class Registry(object):
def __init__(self): def __init__(self):
self._registry = {} self._registry = {}
self._field_registry = {} self._field_registry = {}
def register(self, cls): def register(self, cls):
from .types import DjangoObjectType from .types import DjangoObjectType
assert issubclass( assert issubclass(
cls, DjangoObjectType), 'Only DjangoObjectTypes can be registered, received "{}"'.format( cls, DjangoObjectType
cls.__name__) ), 'Only DjangoObjectTypes can be registered, received "{}"'.format(
assert cls._meta.registry == self, 'Registry for a Model have to match.' cls.__name__
)
assert cls._meta.registry == self, "Registry for a Model have to match."
# assert self.get_type_for_model(cls._meta.model) == cls, ( # assert self.get_type_for_model(cls._meta.model) == cls, (
# 'Multiple DjangoObjectTypes registered for "{}"'.format(cls._meta.model) # 'Multiple DjangoObjectTypes registered for "{}"'.format(cls._meta.model)
# ) # )
if not getattr(cls._meta, 'skip_registry', False): if not getattr(cls._meta, "skip_registry", False):
self._registry[cls._meta.model] = cls self._registry[cls._meta.model] = cls
def get_type_for_model(self, model): def get_type_for_model(self, model):

View File

@ -6,20 +6,16 @@ import graphene
from graphene.types import Field, InputField from graphene.types import Field, InputField
from graphene.types.mutation import MutationOptions from graphene.types.mutation import MutationOptions
from graphene.relay.mutation import ClientIDMutation from graphene.relay.mutation import ClientIDMutation
from graphene.types.objecttype import ( from graphene.types.objecttype import yank_fields_from_attrs
yank_fields_from_attrs
)
from .serializer_converter import ( from .serializer_converter import convert_serializer_field
convert_serializer_field
)
from .types import ErrorType from .types import ErrorType
class SerializerMutationOptions(MutationOptions): class SerializerMutationOptions(MutationOptions):
lookup_field = None lookup_field = None
model_class = None model_class = None
model_operations = ['create', 'update'] model_operations = ["create", "update"]
serializer_class = None serializer_class = None
@ -28,7 +24,8 @@ def fields_for_serializer(serializer, only_fields, exclude_fields, is_input=Fals
for name, field in serializer.fields.items(): for name, field in serializer.fields.items():
is_not_in_only = only_fields and name not in only_fields is_not_in_only = only_fields and name not in only_fields
is_excluded = ( is_excluded = (
name in exclude_fields # or name
in exclude_fields # or
# name in already_created_fields # name in already_created_fields
) )
@ -44,49 +41,54 @@ class SerializerMutation(ClientIDMutation):
abstract = True abstract = True
errors = graphene.List( errors = graphene.List(
ErrorType, ErrorType, description="May contain more than one error for same field."
description='May contain more than one error for same field.'
) )
@classmethod @classmethod
def __init_subclass_with_meta__(cls, lookup_field=None, def __init_subclass_with_meta__(
serializer_class=None, model_class=None, cls,
model_operations=['create', 'update'], lookup_field=None,
only_fields=(), exclude_fields=(), **options): serializer_class=None,
model_class=None,
model_operations=["create", "update"],
only_fields=(),
exclude_fields=(),
**options
):
if not serializer_class: if not serializer_class:
raise Exception('serializer_class is required for the SerializerMutation') raise Exception("serializer_class is required for the SerializerMutation")
if 'update' not in model_operations and 'create' not in model_operations: if "update" not in model_operations and "create" not in model_operations:
raise Exception('model_operations must contain "create" and/or "update"') raise Exception('model_operations must contain "create" and/or "update"')
serializer = serializer_class() serializer = serializer_class()
if model_class is None: if model_class is None:
serializer_meta = getattr(serializer_class, 'Meta', None) serializer_meta = getattr(serializer_class, "Meta", None)
if serializer_meta: if serializer_meta:
model_class = getattr(serializer_meta, 'model', None) model_class = getattr(serializer_meta, "model", None)
if lookup_field is None and model_class: if lookup_field is None and model_class:
lookup_field = model_class._meta.pk.name lookup_field = model_class._meta.pk.name
input_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=True) input_fields = fields_for_serializer(
output_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=False) serializer, only_fields, exclude_fields, is_input=True
)
output_fields = fields_for_serializer(
serializer, only_fields, exclude_fields, is_input=False
)
_meta = SerializerMutationOptions(cls) _meta = SerializerMutationOptions(cls)
_meta.lookup_field = lookup_field _meta.lookup_field = lookup_field
_meta.model_operations = model_operations _meta.model_operations = model_operations
_meta.serializer_class = serializer_class _meta.serializer_class = serializer_class
_meta.model_class = model_class _meta.model_class = model_class
_meta.fields = yank_fields_from_attrs( _meta.fields = yank_fields_from_attrs(output_fields, _as=Field)
output_fields,
_as=Field,
)
input_fields = yank_fields_from_attrs( input_fields = yank_fields_from_attrs(input_fields, _as=InputField)
input_fields, super(SerializerMutation, cls).__init_subclass_with_meta__(
_as=InputField, _meta=_meta, input_fields=input_fields, **options
) )
super(SerializerMutation, cls).__init_subclass_with_meta__(_meta=_meta, input_fields=input_fields, **options)
@classmethod @classmethod
def get_serializer_kwargs(cls, root, info, **input): def get_serializer_kwargs(cls, root, info, **input):
@ -94,24 +96,26 @@ class SerializerMutation(ClientIDMutation):
model_class = cls._meta.model_class model_class = cls._meta.model_class
if model_class: if model_class:
if 'update' in cls._meta.model_operations and lookup_field in input: if "update" in cls._meta.model_operations and lookup_field in input:
instance = get_object_or_404(model_class, **{ instance = get_object_or_404(
lookup_field: input[lookup_field]}) model_class, **{lookup_field: input[lookup_field]}
elif 'create' in cls._meta.model_operations: )
elif "create" in cls._meta.model_operations:
instance = None instance = None
else: else:
raise Exception( raise Exception(
'Invalid update operation. Input parameter "{}" required.'.format( 'Invalid update operation. Input parameter "{}" required.'.format(
lookup_field lookup_field
)) )
)
return { return {
'instance': instance, "instance": instance,
'data': input, "data": input,
'context': {'request': info.context} "context": {"request": info.context},
} }
return {'data': input, 'context': {'request': info.context}} return {"data": input, "context": {"request": info.context}}
@classmethod @classmethod
def mutate_and_get_payload(cls, root, info, **input): def mutate_and_get_payload(cls, root, info, **input):

View File

@ -28,15 +28,12 @@ def convert_serializer_field(field, is_input=True):
graphql_type = get_graphene_type_from_serializer_field(field) graphql_type = get_graphene_type_from_serializer_field(field)
args = [] args = []
kwargs = { kwargs = {"description": field.help_text, "required": is_input and field.required}
'description': field.help_text,
'required': is_input and field.required,
}
# if it is a tuple or a list it means that we are returning # if it is a tuple or a list it means that we are returning
# the graphql type and the child type # the graphql type and the child type
if isinstance(graphql_type, (list, tuple)): if isinstance(graphql_type, (list, tuple)):
kwargs['of_type'] = graphql_type[1] kwargs["of_type"] = graphql_type[1]
graphql_type = graphql_type[0] graphql_type = graphql_type[0]
if isinstance(field, serializers.ModelSerializer): if isinstance(field, serializers.ModelSerializer):
@ -49,9 +46,9 @@ def convert_serializer_field(field, is_input=True):
elif isinstance(field, serializers.ListSerializer): elif isinstance(field, serializers.ListSerializer):
field = field.child field = field.child
if is_input: if is_input:
kwargs['of_type'] = convert_serializer_to_input_type(field.__class__) kwargs["of_type"] = convert_serializer_to_input_type(field.__class__)
else: else:
del kwargs['of_type'] del kwargs["of_type"]
global_registry = get_global_registry() global_registry = get_global_registry()
field_model = field.Meta.model field_model = field.Meta.model
args = [global_registry.get_type_for_model(field_model)] args = [global_registry.get_type_for_model(field_model)]
@ -68,9 +65,9 @@ def convert_serializer_to_input_type(serializer_class):
} }
return type( return type(
'{}Input'.format(serializer.__class__.__name__), "{}Input".format(serializer.__class__.__name__),
(graphene.InputObjectType,), (graphene.InputObjectType,),
items items,
) )

View File

@ -16,8 +16,8 @@ def _get_type(rest_framework_field, is_input=True, **kwargs):
# Remove `source=` from the field declaration. # Remove `source=` from the field declaration.
# since we are reusing the same child in when testing the required attribute # since we are reusing the same child in when testing the required attribute
if 'child' in kwargs: if "child" in kwargs:
kwargs['child'] = copy.deepcopy(kwargs['child']) kwargs["child"] = copy.deepcopy(kwargs["child"])
field = rest_framework_field(**kwargs) field = rest_framework_field(**kwargs)
@ -25,11 +25,13 @@ def _get_type(rest_framework_field, is_input=True, **kwargs):
def assert_conversion(rest_framework_field, graphene_field, **kwargs): def assert_conversion(rest_framework_field, graphene_field, **kwargs):
graphene_type = _get_type(rest_framework_field, help_text='Custom Help Text', **kwargs) graphene_type = _get_type(
rest_framework_field, help_text="Custom Help Text", **kwargs
)
assert isinstance(graphene_type, graphene_field) assert isinstance(graphene_type, graphene_field)
graphene_type_required = _get_type( graphene_type_required = _get_type(
rest_framework_field, help_text='Custom Help Text', required=True, **kwargs rest_framework_field, help_text="Custom Help Text", required=True, **kwargs
) )
assert isinstance(graphene_type_required, graphene_field) assert isinstance(graphene_type_required, graphene_field)
@ -39,7 +41,7 @@ def assert_conversion(rest_framework_field, graphene_field, **kwargs):
def test_should_unknown_rest_framework_field_raise_exception(): def test_should_unknown_rest_framework_field_raise_exception():
with raises(Exception) as excinfo: with raises(Exception) as excinfo:
convert_serializer_field(None) convert_serializer_field(None)
assert 'Don\'t know how to convert the serializer field' in str(excinfo.value) assert "Don't know how to convert the serializer field" in str(excinfo.value)
def test_should_char_convert_string(): def test_should_char_convert_string():
@ -67,11 +69,11 @@ def test_should_base_field_convert_string():
def test_should_regex_convert_string(): def test_should_regex_convert_string():
assert_conversion(serializers.RegexField, graphene.String, regex='[0-9]+') assert_conversion(serializers.RegexField, graphene.String, regex="[0-9]+")
def test_should_uuid_convert_string(): def test_should_uuid_convert_string():
if hasattr(serializers, 'UUIDField'): if hasattr(serializers, "UUIDField"):
assert_conversion(serializers.UUIDField, graphene.String) assert_conversion(serializers.UUIDField, graphene.String)
@ -79,7 +81,7 @@ def test_should_model_convert_field():
class MyModelSerializer(serializers.ModelSerializer): class MyModelSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = None model = None
fields = '__all__' fields = "__all__"
assert_conversion(MyModelSerializer, graphene.Field, is_input=False) assert_conversion(MyModelSerializer, graphene.Field, is_input=False)
@ -109,7 +111,9 @@ def test_should_float_convert_float():
def test_should_decimal_convert_float(): def test_should_decimal_convert_float():
assert_conversion(serializers.DecimalField, graphene.Float, max_digits=4, decimal_places=2) assert_conversion(
serializers.DecimalField, graphene.Float, max_digits=4, decimal_places=2
)
def test_should_list_convert_to_list(): def test_should_list_convert_to_list():
@ -119,7 +123,7 @@ def test_should_list_convert_to_list():
field_a = assert_conversion( field_a = assert_conversion(
serializers.ListField, serializers.ListField,
graphene.List, graphene.List,
child=serializers.IntegerField(min_value=0, max_value=100) child=serializers.IntegerField(min_value=0, max_value=100),
) )
assert field_a.of_type == graphene.Int assert field_a.of_type == graphene.Int
@ -136,19 +140,23 @@ def test_should_list_serializer_convert_to_list():
class ChildSerializer(serializers.ModelSerializer): class ChildSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = FooModel model = FooModel
fields = '__all__' fields = "__all__"
class ParentSerializer(serializers.ModelSerializer): class ParentSerializer(serializers.ModelSerializer):
child = ChildSerializer(many=True) child = ChildSerializer(many=True)
class Meta: class Meta:
model = FooModel model = FooModel
fields = '__all__' fields = "__all__"
converted_type = convert_serializer_field(ParentSerializer().get_fields()['child'], is_input=True) converted_type = convert_serializer_field(
ParentSerializer().get_fields()["child"], is_input=True
)
assert isinstance(converted_type, graphene.List) assert isinstance(converted_type, graphene.List)
converted_type = convert_serializer_field(ParentSerializer().get_fields()['child'], is_input=False) converted_type = convert_serializer_field(
ParentSerializer().get_fields()["child"], is_input=False
)
assert isinstance(converted_type, graphene.List) assert isinstance(converted_type, graphene.List)
assert converted_type.of_type is None assert converted_type.of_type is None
@ -166,7 +174,7 @@ def test_should_file_convert_string():
def test_should_filepath_convert_string(): def test_should_filepath_convert_string():
assert_conversion(serializers.FilePathField, graphene.String, path='/') assert_conversion(serializers.FilePathField, graphene.String, path="/")
def test_should_ip_convert_string(): def test_should_ip_convert_string():
@ -182,6 +190,8 @@ def test_should_json_convert_jsonstring():
def test_should_multiplechoicefield_convert_to_list_of_string(): def test_should_multiplechoicefield_convert_to_list_of_string():
field = assert_conversion(serializers.MultipleChoiceField, graphene.List, choices=[1, 2, 3]) field = assert_conversion(
serializers.MultipleChoiceField, graphene.List, choices=[1, 2, 3]
)
assert field.of_type == graphene.String assert field.of_type == graphene.String

View File

@ -10,30 +10,33 @@ from ...types import DjangoObjectType
from ..models import MyFakeModel from ..models import MyFakeModel
from ..mutation import SerializerMutation from ..mutation import SerializerMutation
def mock_info(): def mock_info():
return ResolveInfo( return ResolveInfo(
None, None,
None, None,
None, None,
None, None,
schema=None, schema=None,
fragments=None, fragments=None,
root_value=None, root_value=None,
operation=None, operation=None,
variable_values=None, variable_values=None,
context=None context=None,
) )
class MyModelSerializer(serializers.ModelSerializer): class MyModelSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = MyFakeModel model = MyFakeModel
fields = '__all__' fields = "__all__"
class MyModelMutation(SerializerMutation): class MyModelMutation(SerializerMutation):
class Meta: class Meta:
serializer_class = MyModelSerializer serializer_class = MyModelSerializer
class MySerializer(serializers.Serializer): class MySerializer(serializers.Serializer):
text = serializers.CharField() text = serializers.CharField()
model = MyModelSerializer() model = MyModelSerializer()
@ -44,10 +47,11 @@ class MySerializer(serializers.Serializer):
def test_needs_serializer_class(): def test_needs_serializer_class():
with raises(Exception) as exc: with raises(Exception) as exc:
class MyMutation(SerializerMutation): class MyMutation(SerializerMutation):
pass pass
assert str(exc.value) == 'serializer_class is required for the SerializerMutation' assert str(exc.value) == "serializer_class is required for the SerializerMutation"
def test_has_fields(): def test_has_fields():
@ -55,9 +59,9 @@ def test_has_fields():
class Meta: class Meta:
serializer_class = MySerializer serializer_class = MySerializer
assert 'text' in MyMutation._meta.fields assert "text" in MyMutation._meta.fields
assert 'model' in MyMutation._meta.fields assert "model" in MyMutation._meta.fields
assert 'errors' in MyMutation._meta.fields assert "errors" in MyMutation._meta.fields
def test_has_input_fields(): def test_has_input_fields():
@ -65,25 +69,24 @@ def test_has_input_fields():
class Meta: class Meta:
serializer_class = MySerializer serializer_class = MySerializer
assert 'text' in MyMutation.Input._meta.fields assert "text" in MyMutation.Input._meta.fields
assert 'model' in MyMutation.Input._meta.fields assert "model" in MyMutation.Input._meta.fields
def test_exclude_fields(): def test_exclude_fields():
class MyMutation(SerializerMutation): class MyMutation(SerializerMutation):
class Meta: class Meta:
serializer_class = MyModelSerializer serializer_class = MyModelSerializer
exclude_fields = ['created'] exclude_fields = ["created"]
assert 'cool_name' in MyMutation._meta.fields assert "cool_name" in MyMutation._meta.fields
assert 'created' not in MyMutation._meta.fields assert "created" not in MyMutation._meta.fields
assert 'errors' in MyMutation._meta.fields assert "errors" in MyMutation._meta.fields
assert 'cool_name' in MyMutation.Input._meta.fields assert "cool_name" in MyMutation.Input._meta.fields
assert 'created' not in MyMutation.Input._meta.fields assert "created" not in MyMutation.Input._meta.fields
def test_nested_model(): def test_nested_model():
class MyFakeModelGrapheneType(DjangoObjectType): class MyFakeModelGrapheneType(DjangoObjectType):
class Meta: class Meta:
model = MyFakeModel model = MyFakeModel
@ -92,67 +95,64 @@ def test_nested_model():
class Meta: class Meta:
serializer_class = MySerializer serializer_class = MySerializer
model_field = MyMutation._meta.fields['model'] model_field = MyMutation._meta.fields["model"]
assert isinstance(model_field, Field) assert isinstance(model_field, Field)
assert model_field.type == MyFakeModelGrapheneType assert model_field.type == MyFakeModelGrapheneType
model_input = MyMutation.Input._meta.fields['model'] model_input = MyMutation.Input._meta.fields["model"]
model_input_type = model_input._type.of_type model_input_type = model_input._type.of_type
assert issubclass(model_input_type, InputObjectType) assert issubclass(model_input_type, InputObjectType)
assert 'cool_name' in model_input_type._meta.fields assert "cool_name" in model_input_type._meta.fields
assert 'created' in model_input_type._meta.fields assert "created" in model_input_type._meta.fields
def test_mutate_and_get_payload_success(): def test_mutate_and_get_payload_success():
class MyMutation(SerializerMutation): class MyMutation(SerializerMutation):
class Meta: class Meta:
serializer_class = MySerializer serializer_class = MySerializer
result = MyMutation.mutate_and_get_payload(None, mock_info(), **{ result = MyMutation.mutate_and_get_payload(
'text': 'value', None, mock_info(), **{"text": "value", "model": {"cool_name": "other_value"}}
'model': { )
'cool_name': 'other_value'
}
})
assert result.errors is None assert result.errors is None
@mark.django_db @mark.django_db
def test_model_add_mutate_and_get_payload_success(): def test_model_add_mutate_and_get_payload_success():
result = MyModelMutation.mutate_and_get_payload(None, mock_info(), **{ result = MyModelMutation.mutate_and_get_payload(
'cool_name': 'Narf', None, mock_info(), **{"cool_name": "Narf"}
}) )
assert result.errors is None assert result.errors is None
assert result.cool_name == 'Narf' assert result.cool_name == "Narf"
assert isinstance(result.created, datetime.datetime) assert isinstance(result.created, datetime.datetime)
@mark.django_db @mark.django_db
def test_model_update_mutate_and_get_payload_success(): def test_model_update_mutate_and_get_payload_success():
instance = MyFakeModel.objects.create(cool_name="Narf") instance = MyFakeModel.objects.create(cool_name="Narf")
result = MyModelMutation.mutate_and_get_payload(None, mock_info(), **{ result = MyModelMutation.mutate_and_get_payload(
'id': instance.id, None, mock_info(), **{"id": instance.id, "cool_name": "New Narf"}
'cool_name': 'New Narf', )
})
assert result.errors is None assert result.errors is None
assert result.cool_name == 'New Narf' assert result.cool_name == "New Narf"
@mark.django_db @mark.django_db
def test_model_invalid_update_mutate_and_get_payload_success(): def test_model_invalid_update_mutate_and_get_payload_success():
class InvalidModelMutation(SerializerMutation): class InvalidModelMutation(SerializerMutation):
class Meta: class Meta:
serializer_class = MyModelSerializer serializer_class = MyModelSerializer
model_operations = ['update'] model_operations = ["update"]
with raises(Exception) as exc: with raises(Exception) as exc:
result = InvalidModelMutation.mutate_and_get_payload(None, mock_info(), **{ result = InvalidModelMutation.mutate_and_get_payload(
'cool_name': 'Narf', None, mock_info(), **{"cool_name": "Narf"}
}) )
assert '"id" required' in str(exc.value) assert '"id" required' in str(exc.value)
def test_mutate_and_get_payload_error():
def test_mutate_and_get_payload_error():
class MyMutation(SerializerMutation): class MyMutation(SerializerMutation):
class Meta: class Meta:
serializer_class = MySerializer serializer_class = MySerializer
@ -161,16 +161,19 @@ def test_mutate_and_get_payload_error():
result = MyMutation.mutate_and_get_payload(None, mock_info(), **{}) result = MyMutation.mutate_and_get_payload(None, mock_info(), **{})
assert len(result.errors) > 0 assert len(result.errors) > 0
def test_model_mutate_and_get_payload_error(): def test_model_mutate_and_get_payload_error():
# missing required fields # missing required fields
result = MyModelMutation.mutate_and_get_payload(None, mock_info(), **{}) result = MyModelMutation.mutate_and_get_payload(None, mock_info(), **{})
assert len(result.errors) > 0 assert len(result.errors) > 0
def test_invalid_serializer_operations(): def test_invalid_serializer_operations():
with raises(Exception) as exc: with raises(Exception) as exc:
class MyModelMutation(SerializerMutation): class MyModelMutation(SerializerMutation):
class Meta: class Meta:
serializer_class = MyModelSerializer serializer_class = MyModelSerializer
model_operations = ['Add'] model_operations = ["Add"]
assert 'model_operations' in str(exc.value) assert "model_operations" in str(exc.value)

View File

@ -26,27 +26,22 @@ except ImportError:
# Copied shamelessly from Django REST Framework # Copied shamelessly from Django REST Framework
DEFAULTS = { DEFAULTS = {
'SCHEMA': None, "SCHEMA": None,
'SCHEMA_OUTPUT': 'schema.json', "SCHEMA_OUTPUT": "schema.json",
'SCHEMA_INDENT': None, "SCHEMA_INDENT": None,
'MIDDLEWARE': (), "MIDDLEWARE": (),
# Set to True if the connection fields must have # Set to True if the connection fields must have
# either the first or last argument # either the first or last argument
'RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST': False, "RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST": False,
# Max items returned in ConnectionFields / FilterConnectionFields # Max items returned in ConnectionFields / FilterConnectionFields
'RELAY_CONNECTION_MAX_LIMIT': 100, "RELAY_CONNECTION_MAX_LIMIT": 100,
} }
if settings.DEBUG: if settings.DEBUG:
DEFAULTS['MIDDLEWARE'] += ( DEFAULTS["MIDDLEWARE"] += ("graphene_django.debug.DjangoDebugMiddleware",)
'graphene_django.debug.DjangoDebugMiddleware',
)
# List of settings that may be in string import notation. # List of settings that may be in string import notation.
IMPORT_STRINGS = ( IMPORT_STRINGS = ("MIDDLEWARE", "SCHEMA")
'MIDDLEWARE',
'SCHEMA',
)
def perform_import(val, setting_name): def perform_import(val, setting_name):
@ -69,12 +64,17 @@ def import_from_string(val, setting_name):
""" """
try: try:
# Nod to tastypie's use of importlib. # Nod to tastypie's use of importlib.
parts = val.split('.') parts = val.split(".")
module_path, class_name = '.'.join(parts[:-1]), parts[-1] module_path, class_name = ".".join(parts[:-1]), parts[-1]
module = importlib.import_module(module_path) module = importlib.import_module(module_path)
return getattr(module, class_name) return getattr(module, class_name)
except (ImportError, AttributeError) as e: except (ImportError, AttributeError) as e:
msg = "Could not import '%s' for Graphene setting '%s'. %s: %s." % (val, setting_name, e.__class__.__name__, e) msg = "Could not import '%s' for Graphene setting '%s'. %s: %s." % (
val,
setting_name,
e.__class__.__name__,
e,
)
raise ImportError(msg) raise ImportError(msg)
@ -96,8 +96,8 @@ class GrapheneSettings(object):
@property @property
def user_settings(self): def user_settings(self):
if not hasattr(self, '_user_settings'): if not hasattr(self, "_user_settings"):
self._user_settings = getattr(settings, 'GRAPHENE', {}) self._user_settings = getattr(settings, "GRAPHENE", {})
return self._user_settings return self._user_settings
def __getattr__(self, attr): def __getattr__(self, attr):
@ -125,8 +125,8 @@ graphene_settings = GrapheneSettings(None, DEFAULTS, IMPORT_STRINGS)
def reload_graphene_settings(*args, **kwargs): def reload_graphene_settings(*args, **kwargs):
global graphene_settings global graphene_settings
setting, value = kwargs['setting'], kwargs['value'] setting, value = kwargs["setting"], kwargs["value"]
if setting == 'GRAPHENE': if setting == "GRAPHENE":
graphene_settings = GrapheneSettings(value, DEFAULTS, IMPORT_STRINGS) graphene_settings = GrapheneSettings(value, DEFAULTS, IMPORT_STRINGS)

View File

@ -3,10 +3,7 @@ from __future__ import absolute_import
from django.db import models from django.db import models
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
CHOICES = ( CHOICES = ((1, "this"), (2, _("that")))
(1, 'this'),
(2, _('that'))
)
class Pet(models.Model): class Pet(models.Model):
@ -15,38 +12,43 @@ class Pet(models.Model):
class FilmDetails(models.Model): class FilmDetails(models.Model):
location = models.CharField(max_length=30) location = models.CharField(max_length=30)
film = models.OneToOneField('Film', on_delete=models.CASCADE, related_name='details') film = models.OneToOneField(
"Film", on_delete=models.CASCADE, related_name="details"
)
class Film(models.Model): class Film(models.Model):
genre = models.CharField(max_length=2, help_text='Genre', choices=[ genre = models.CharField(
('do', 'Documentary'), max_length=2,
('ot', 'Other') help_text="Genre",
], default='ot') choices=[("do", "Documentary"), ("ot", "Other")],
reporters = models.ManyToManyField('Reporter', default="ot",
related_name='films') )
reporters = models.ManyToManyField("Reporter", related_name="films")
class DoeReporterManager(models.Manager): class DoeReporterManager(models.Manager):
def get_queryset(self): def get_queryset(self):
return super(DoeReporterManager, self).get_queryset().filter(last_name="Doe") return super(DoeReporterManager, self).get_queryset().filter(last_name="Doe")
class Reporter(models.Model): class Reporter(models.Model):
first_name = models.CharField(max_length=30) first_name = models.CharField(max_length=30)
last_name = models.CharField(max_length=30) last_name = models.CharField(max_length=30)
email = models.EmailField() email = models.EmailField()
pets = models.ManyToManyField('self') pets = models.ManyToManyField("self")
a_choice = models.CharField(max_length=30, choices=CHOICES) a_choice = models.CharField(max_length=30, choices=CHOICES)
objects = models.Manager() objects = models.Manager()
doe_objects = DoeReporterManager() doe_objects = DoeReporterManager()
reporter_type = models.IntegerField( reporter_type = models.IntegerField(
'Reporter Type', "Reporter Type",
null=True, null=True,
blank=True, blank=True,
choices=[(1, u'Regular'), (2, u'CNN Reporter')] choices=[(1, u"Regular"), (2, u"CNN Reporter")],
) )
def __str__(self): # __unicode__ on Python 2 def __str__(self): # __unicode__ on Python 2
return "%s %s" % (self.first_name, self.last_name) return "%s %s" % (self.first_name, self.last_name)
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -61,11 +63,13 @@ class Reporter(models.Model):
if self.reporter_type == 2: # quick and dirty way without enums if self.reporter_type == 2: # quick and dirty way without enums
self.__class__ = CNNReporter self.__class__ = CNNReporter
class CNNReporter(Reporter): class CNNReporter(Reporter):
""" """
This class is a proxy model for Reporter, used for testing This class is a proxy model for Reporter, used for testing
proxy model support proxy model support
""" """
class Meta: class Meta:
proxy = True proxy = True
@ -74,17 +78,27 @@ class Article(models.Model):
headline = models.CharField(max_length=100) headline = models.CharField(max_length=100)
pub_date = models.DateField() pub_date = models.DateField()
pub_date_time = models.DateTimeField() pub_date_time = models.DateTimeField()
reporter = models.ForeignKey(Reporter, on_delete=models.CASCADE, related_name='articles') reporter = models.ForeignKey(
editor = models.ForeignKey(Reporter, on_delete=models.CASCADE, related_name='edited_articles_+') Reporter, on_delete=models.CASCADE, related_name="articles"
lang = models.CharField(max_length=2, help_text='Language', choices=[ )
('es', 'Spanish'), editor = models.ForeignKey(
('en', 'English') Reporter, on_delete=models.CASCADE, related_name="edited_articles_+"
], default='es') )
importance = models.IntegerField('Importance', null=True, blank=True, lang = models.CharField(
choices=[(1, u'Very important'), (2, u'Not as important')]) max_length=2,
help_text="Language",
choices=[("es", "Spanish"), ("en", "English")],
default="es",
)
importance = models.IntegerField(
"Importance",
null=True,
blank=True,
choices=[(1, u"Very important"), (2, u"Not as important")],
)
def __str__(self): # __unicode__ on Python 2 def __str__(self): # __unicode__ on Python 2
return self.headline return self.headline
class Meta: class Meta:
ordering = ('headline',) ordering = ("headline",)

View File

@ -6,10 +6,9 @@ from .models import Article, Reporter
class Character(DjangoObjectType): class Character(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (relay.Node, ) interfaces = (relay.Node,)
def get_node(self, info, id): def get_node(self, info, id):
pass pass
@ -20,7 +19,7 @@ class Human(DjangoObjectType):
class Meta: class Meta:
model = Article model = Article
interfaces = (relay.Node, ) interfaces = (relay.Node,)
def resolve_raises(self, info): def resolve_raises(self, info):
raise Exception("This field should raise exception") raise Exception("This field should raise exception")

View File

@ -12,10 +12,10 @@ class QueryRoot(ObjectType):
raise Exception("Throws!") raise Exception("Throws!")
def resolve_request(self, info): def resolve_request(self, info):
return info.context.GET.get('q') return info.context.GET.get("q")
def resolve_test(self, info, who=None): def resolve_test(self, info, who=None):
return 'Hello %s' % (who or 'World') return "Hello %s" % (who or "World")
class MutationRoot(ObjectType): class MutationRoot(ObjectType):

View File

@ -3,8 +3,8 @@ from mock import patch
from six import StringIO from six import StringIO
@patch('graphene_django.management.commands.graphql_schema.Command.save_file') @patch("graphene_django.management.commands.graphql_schema.Command.save_file")
def test_generate_file_on_call_graphql_schema(savefile_mock, settings): def test_generate_file_on_call_graphql_schema(savefile_mock, settings):
out = StringIO() out = StringIO()
management.call_command('graphql_schema', schema='', stdout=out) management.call_command("graphql_schema", schema="", stdout=out)
assert "Successfully dumped GraphQL schema to schema.json" in out.getvalue() assert "Successfully dumped GraphQL schema to schema.json" in out.getvalue()

View File

@ -19,11 +19,11 @@ from .models import Article, Film, FilmDetails, Reporter
def assert_conversion(django_field, graphene_field, *args, **kwargs): def assert_conversion(django_field, graphene_field, *args, **kwargs):
field = django_field(help_text='Custom Help Text', null=True, *args, **kwargs) field = django_field(help_text="Custom Help Text", null=True, *args, **kwargs)
graphene_type = convert_django_field(field) graphene_type = convert_django_field(field)
assert isinstance(graphene_type, graphene_field) assert isinstance(graphene_type, graphene_field)
field = graphene_type.Field() field = graphene_type.Field()
assert field.description == 'Custom Help Text' assert field.description == "Custom Help Text"
nonnull_field = django_field(null=False, *args, **kwargs) nonnull_field = django_field(null=False, *args, **kwargs)
if not nonnull_field.null: if not nonnull_field.null:
nonnull_graphene_type = convert_django_field(nonnull_field) nonnull_graphene_type = convert_django_field(nonnull_field)
@ -36,7 +36,8 @@ def assert_conversion(django_field, graphene_field, *args, **kwargs):
def test_should_unknown_django_field_raise_exception(): def test_should_unknown_django_field_raise_exception():
with raises(Exception) as excinfo: with raises(Exception) as excinfo:
convert_django_field(None) convert_django_field(None)
assert 'Don\'t know how to convert the Django field' in str(excinfo.value) assert "Don't know how to convert the Django field" in str(excinfo.value)
def test_should_date_time_convert_string(): def test_should_date_time_convert_string():
assert_conversion(models.DateTimeField, DateTime) assert_conversion(models.DateTimeField, DateTime)
@ -128,70 +129,69 @@ def test_should_nullboolean_convert_boolean():
def test_field_with_choices_convert_enum(): def test_field_with_choices_convert_enum():
field = models.CharField(help_text='Language', choices=( field = models.CharField(
('es', 'Spanish'), help_text="Language", choices=(("es", "Spanish"), ("en", "English"))
('en', 'English') )
))
class TranslatedModel(models.Model): class TranslatedModel(models.Model):
language = field language = field
class Meta: class Meta:
app_label = 'test' app_label = "test"
graphene_type = convert_django_field_with_choices(field) graphene_type = convert_django_field_with_choices(field)
assert isinstance(graphene_type, graphene.Enum) assert isinstance(graphene_type, graphene.Enum)
assert graphene_type._meta.name == 'TranslatedModelLanguage' assert graphene_type._meta.name == "TranslatedModelLanguage"
assert graphene_type._meta.enum.__members__['ES'].value == 'es' assert graphene_type._meta.enum.__members__["ES"].value == "es"
assert graphene_type._meta.enum.__members__['ES'].description == 'Spanish' assert graphene_type._meta.enum.__members__["ES"].description == "Spanish"
assert graphene_type._meta.enum.__members__['EN'].value == 'en' assert graphene_type._meta.enum.__members__["EN"].value == "en"
assert graphene_type._meta.enum.__members__['EN'].description == 'English' assert graphene_type._meta.enum.__members__["EN"].description == "English"
def test_field_with_grouped_choices(): def test_field_with_grouped_choices():
field = models.CharField(help_text='Language', choices=( field = models.CharField(
('Europe', ( help_text="Language",
('es', 'Spanish'), choices=(("Europe", (("es", "Spanish"), ("en", "English"))),),
('en', 'English'), )
)),
))
class GroupedChoicesModel(models.Model): class GroupedChoicesModel(models.Model):
language = field language = field
class Meta: class Meta:
app_label = 'test' app_label = "test"
convert_django_field_with_choices(field) convert_django_field_with_choices(field)
def test_field_with_choices_gettext(): def test_field_with_choices_gettext():
field = models.CharField(help_text='Language', choices=( field = models.CharField(
('es', _('Spanish')), help_text="Language", choices=(("es", _("Spanish")), ("en", _("English")))
('en', _('English')) )
))
class TranslatedChoicesModel(models.Model): class TranslatedChoicesModel(models.Model):
language = field language = field
class Meta: class Meta:
app_label = 'test' app_label = "test"
convert_django_field_with_choices(field) convert_django_field_with_choices(field)
def test_field_with_choices_collision(): def test_field_with_choices_collision():
field = models.CharField(help_text='Timezone', choices=( field = models.CharField(
('Etc/GMT+1+2', 'Fake choice to produce double collision'), help_text="Timezone",
('Etc/GMT+1', 'Greenwich Mean Time +1'), choices=(
('Etc/GMT-1', 'Greenwich Mean Time -1'), ("Etc/GMT+1+2", "Fake choice to produce double collision"),
)) ("Etc/GMT+1", "Greenwich Mean Time +1"),
("Etc/GMT-1", "Greenwich Mean Time -1"),
),
)
class CollisionChoicesModel(models.Model): class CollisionChoicesModel(models.Model):
timezone = field timezone = field
class Meta: class Meta:
app_label = 'test' app_label = "test"
convert_django_field_with_choices(field) convert_django_field_with_choices(field)
@ -208,11 +208,12 @@ def test_should_manytomany_convert_connectionorlist():
def test_should_manytomany_convert_connectionorlist_list(): def test_should_manytomany_convert_connectionorlist_list():
class A(DjangoObjectType): class A(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
graphene_field = convert_django_field(Reporter._meta.local_many_to_many[0], A._meta.registry) graphene_field = convert_django_field(
Reporter._meta.local_many_to_many[0], A._meta.registry
)
assert isinstance(graphene_field, graphene.Dynamic) assert isinstance(graphene_field, graphene.Dynamic)
dynamic_field = graphene_field.get_type() dynamic_field = graphene_field.get_type()
assert isinstance(dynamic_field, graphene.Field) assert isinstance(dynamic_field, graphene.Field)
@ -222,12 +223,13 @@ def test_should_manytomany_convert_connectionorlist_list():
def test_should_manytomany_convert_connectionorlist_connection(): def test_should_manytomany_convert_connectionorlist_connection():
class A(DjangoObjectType): class A(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
graphene_field = convert_django_field(Reporter._meta.local_many_to_many[0], A._meta.registry) graphene_field = convert_django_field(
Reporter._meta.local_many_to_many[0], A._meta.registry
)
assert isinstance(graphene_field, graphene.Dynamic) assert isinstance(graphene_field, graphene.Dynamic)
dynamic_field = graphene_field.get_type() dynamic_field = graphene_field.get_type()
assert isinstance(dynamic_field, ConnectionField) assert isinstance(dynamic_field, ConnectionField)
@ -236,11 +238,11 @@ def test_should_manytomany_convert_connectionorlist_connection():
def test_should_manytoone_convert_connectionorlist(): def test_should_manytoone_convert_connectionorlist():
# Django 1.9 uses 'rel', <1.9 uses 'related # Django 1.9 uses 'rel', <1.9 uses 'related
related = getattr(Reporter.articles, 'rel', None) or \ related = getattr(Reporter.articles, "rel", None) or getattr(
getattr(Reporter.articles, 'related') Reporter.articles, "related"
)
class A(DjangoObjectType): class A(DjangoObjectType):
class Meta: class Meta:
model = Article model = Article
@ -254,11 +256,9 @@ def test_should_manytoone_convert_connectionorlist():
def test_should_onetoone_reverse_convert_model(): def test_should_onetoone_reverse_convert_model():
# Django 1.9 uses 'rel', <1.9 uses 'related # Django 1.9 uses 'rel', <1.9 uses 'related
related = getattr(Film.details, 'rel', None) or \ related = getattr(Film.details, "rel", None) or getattr(Film.details, "related")
getattr(Film.details, 'related')
class A(DjangoObjectType): class A(DjangoObjectType):
class Meta: class Meta:
model = FilmDetails model = FilmDetails
@ -269,41 +269,41 @@ def test_should_onetoone_reverse_convert_model():
assert dynamic_field.type == A assert dynamic_field.type == A
@pytest.mark.skipif(ArrayField is MissingType, @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
reason="ArrayField should exist")
def test_should_postgres_array_convert_list(): def test_should_postgres_array_convert_list():
field = assert_conversion(ArrayField, graphene.List, models.CharField(max_length=100)) field = assert_conversion(
ArrayField, graphene.List, models.CharField(max_length=100)
)
assert isinstance(field.type, graphene.NonNull) assert isinstance(field.type, graphene.NonNull)
assert isinstance(field.type.of_type, graphene.List) assert isinstance(field.type.of_type, graphene.List)
assert field.type.of_type.of_type == graphene.String assert field.type.of_type.of_type == graphene.String
@pytest.mark.skipif(ArrayField is MissingType, @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
reason="ArrayField should exist")
def test_should_postgres_array_multiple_convert_list(): def test_should_postgres_array_multiple_convert_list():
field = assert_conversion(ArrayField, graphene.List, ArrayField(models.CharField(max_length=100))) field = assert_conversion(
ArrayField, graphene.List, ArrayField(models.CharField(max_length=100))
)
assert isinstance(field.type, graphene.NonNull) assert isinstance(field.type, graphene.NonNull)
assert isinstance(field.type.of_type, graphene.List) assert isinstance(field.type.of_type, graphene.List)
assert isinstance(field.type.of_type.of_type, graphene.List) assert isinstance(field.type.of_type.of_type, graphene.List)
assert field.type.of_type.of_type.of_type == graphene.String assert field.type.of_type.of_type.of_type == graphene.String
@pytest.mark.skipif(HStoreField is MissingType, @pytest.mark.skipif(HStoreField is MissingType, reason="HStoreField should exist")
reason="HStoreField should exist")
def test_should_postgres_hstore_convert_string(): def test_should_postgres_hstore_convert_string():
assert_conversion(HStoreField, JSONString) assert_conversion(HStoreField, JSONString)
@pytest.mark.skipif(JSONField is MissingType, @pytest.mark.skipif(JSONField is MissingType, reason="JSONField should exist")
reason="JSONField should exist")
def test_should_postgres_json_convert_string(): def test_should_postgres_json_convert_string():
assert_conversion(JSONField, JSONString) assert_conversion(JSONField, JSONString)
@pytest.mark.skipif(RangeField is MissingType, @pytest.mark.skipif(RangeField is MissingType, reason="RangeField should exist")
reason="RangeField should exist")
def test_should_postgres_range_convert_list(): def test_should_postgres_range_convert_list():
from django.contrib.postgres.fields import IntegerRangeField from django.contrib.postgres.fields import IntegerRangeField
field = assert_conversion(IntegerRangeField, graphene.List) field = assert_conversion(IntegerRangeField, graphene.List)
assert isinstance(field.type, graphene.NonNull) assert isinstance(field.type, graphene.NonNull)
assert isinstance(field.type.of_type, graphene.List) assert isinstance(field.type.of_type, graphene.List)

View File

@ -1,7 +1,7 @@
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from py.test import raises from py.test import raises
from ..forms import GlobalIDFormField,GlobalIDMultipleChoiceField from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
# 'TXlUeXBlOmFiYw==' -> 'MyType', 'abc' # 'TXlUeXBlOmFiYw==' -> 'MyType', 'abc'
@ -9,24 +9,24 @@ from ..forms import GlobalIDFormField,GlobalIDMultipleChoiceField
def test_global_id_valid(): def test_global_id_valid():
field = GlobalIDFormField() field = GlobalIDFormField()
field.clean('TXlUeXBlOmFiYw==') field.clean("TXlUeXBlOmFiYw==")
def test_global_id_invalid(): def test_global_id_invalid():
field = GlobalIDFormField() field = GlobalIDFormField()
with raises(ValidationError): with raises(ValidationError):
field.clean('badvalue') field.clean("badvalue")
def test_global_id_multiple_valid(): def test_global_id_multiple_valid():
field = GlobalIDMultipleChoiceField() field = GlobalIDMultipleChoiceField()
field.clean(['TXlUeXBlOmFiYw==', 'TXlUeXBlOmFiYw==']) field.clean(["TXlUeXBlOmFiYw==", "TXlUeXBlOmFiYw=="])
def test_global_id_multiple_invalid(): def test_global_id_multiple_invalid():
field = GlobalIDMultipleChoiceField() field = GlobalIDMultipleChoiceField()
with raises(ValidationError): with raises(ValidationError):
field.clean(['badvalue', 'another bad avue']) field.clean(["badvalue", "another bad avue"])
def test_global_id_none(): def test_global_id_none():

File diff suppressed because it is too large Load Diff

View File

@ -7,48 +7,47 @@ from .models import Reporter
def test_should_raise_if_no_model(): def test_should_raise_if_no_model():
with raises(Exception) as excinfo: with raises(Exception) as excinfo:
class Character1(DjangoObjectType): class Character1(DjangoObjectType):
pass pass
assert 'valid Django Model' in str(excinfo.value)
assert "valid Django Model" in str(excinfo.value)
def test_should_raise_if_model_is_invalid(): def test_should_raise_if_model_is_invalid():
with raises(Exception) as excinfo: with raises(Exception) as excinfo:
class Character2(DjangoObjectType):
class Character2(DjangoObjectType):
class Meta: class Meta:
model = 1 model = 1
assert 'valid Django Model' in str(excinfo.value)
assert "valid Django Model" in str(excinfo.value)
def test_should_map_fields_correctly(): def test_should_map_fields_correctly():
class ReporterType2(DjangoObjectType): class ReporterType2(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
registry = Registry() registry = Registry()
fields = list(ReporterType2._meta.fields.keys()) fields = list(ReporterType2._meta.fields.keys())
assert fields[:-2] == [ assert fields[:-2] == [
'id', "id",
'first_name', "first_name",
'last_name', "last_name",
'email', "email",
'pets', "pets",
'a_choice', "a_choice",
'reporter_type' "reporter_type",
] ]
assert sorted(fields[-2:]) == [ assert sorted(fields[-2:]) == ["articles", "films"]
'articles',
'films',
]
def test_should_map_only_few_fields(): def test_should_map_only_few_fields():
class Reporter2(DjangoObjectType): class Reporter2(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
only_fields = ('id', 'email') only_fields = ("id", "email")
assert list(Reporter2._meta.fields.keys()) == ['id', 'email'] assert list(Reporter2._meta.fields.keys()) == ["id", "email"]

View File

@ -12,27 +12,30 @@ registry.reset_global_registry()
class Reporter(DjangoObjectType): class Reporter(DjangoObjectType):
'''Reporter description''' """Reporter description"""
class Meta: class Meta:
model = ReporterModel model = ReporterModel
class ArticleConnection(Connection): class ArticleConnection(Connection):
'''Article Connection''' """Article Connection"""
test = String() test = String()
def resolve_test(): def resolve_test():
return 'test' return "test"
class Meta: class Meta:
abstract = True abstract = True
class Article(DjangoObjectType): class Article(DjangoObjectType):
'''Article description''' """Article description"""
class Meta: class Meta:
model = ArticleModel model = ArticleModel
interfaces = (Node, ) interfaces = (Node,)
connection_class = ArticleConnection connection_class = ArticleConnection
@ -48,7 +51,7 @@ def test_django_interface():
assert issubclass(Node, Node) assert issubclass(Node, Node)
@patch('graphene_django.tests.models.Article.objects.get', return_value=Article(id=1)) @patch("graphene_django.tests.models.Article.objects.get", return_value=Article(id=1))
def test_django_get_node(get): def test_django_get_node(get):
article = Article.get_node(None, 1) article = Article.get_node(None, 1)
get.assert_called_with(pk=1) get.assert_called_with(pk=1)
@ -58,18 +61,35 @@ def test_django_get_node(get):
def test_django_objecttype_map_correct_fields(): def test_django_objecttype_map_correct_fields():
fields = Reporter._meta.fields fields = Reporter._meta.fields
fields = list(fields.keys()) fields = list(fields.keys())
assert fields[:-2] == ['id', 'first_name', 'last_name', 'email', 'pets', 'a_choice', 'reporter_type'] assert fields[:-2] == [
assert sorted(fields[-2:]) == ['articles', 'films'] "id",
"first_name",
"last_name",
"email",
"pets",
"a_choice",
"reporter_type",
]
assert sorted(fields[-2:]) == ["articles", "films"]
def test_django_objecttype_with_node_have_correct_fields(): def test_django_objecttype_with_node_have_correct_fields():
fields = Article._meta.fields fields = Article._meta.fields
assert list(fields.keys()) == ['id', 'headline', 'pub_date', 'pub_date_time', 'reporter', 'editor', 'lang', 'importance'] assert list(fields.keys()) == [
"id",
"headline",
"pub_date",
"pub_date_time",
"reporter",
"editor",
"lang",
"importance",
]
def test_django_objecttype_with_custom_meta(): def test_django_objecttype_with_custom_meta():
class ArticleTypeOptions(DjangoObjectTypeOptions): class ArticleTypeOptions(DjangoObjectTypeOptions):
'''Article Type Options''' """Article Type Options"""
class ArticleType(DjangoObjectType): class ArticleType(DjangoObjectType):
class Meta: class Meta:
@ -77,7 +97,7 @@ def test_django_objecttype_with_custom_meta():
@classmethod @classmethod
def __init_subclass_with_meta__(cls, **options): def __init_subclass_with_meta__(cls, **options):
options.setdefault('_meta', ArticleTypeOptions(cls)) options.setdefault("_meta", ArticleTypeOptions(cls))
super(ArticleType, cls).__init_subclass_with_meta__(**options) super(ArticleType, cls).__init_subclass_with_meta__(**options)
class Article(ArticleType): class Article(ArticleType):
@ -180,6 +200,7 @@ def with_local_registry(func):
else: else:
registry.registry = old registry.registry = old
return retval return retval
return inner return inner
@ -188,11 +209,10 @@ def test_django_objecttype_only_fields():
class Reporter(DjangoObjectType): class Reporter(DjangoObjectType):
class Meta: class Meta:
model = ReporterModel model = ReporterModel
only_fields = ('id', 'email', 'films') only_fields = ("id", "email", "films")
fields = list(Reporter._meta.fields.keys()) fields = list(Reporter._meta.fields.keys())
assert fields == ['id', 'email', 'films'] assert fields == ["id", "email", "films"]
@with_local_registry @with_local_registry
@ -200,8 +220,7 @@ def test_django_objecttype_exclude_fields():
class Reporter(DjangoObjectType): class Reporter(DjangoObjectType):
class Meta: class Meta:
model = ReporterModel model = ReporterModel
exclude_fields = ('email') exclude_fields = "email"
fields = list(Reporter._meta.fields.keys()) fields = list(Reporter._meta.fields.keys())
assert 'email' not in fields assert "email" not in fields

View File

@ -8,15 +8,15 @@ except ImportError:
from urllib.parse import urlencode from urllib.parse import urlencode
def url_string(string='/graphql', **url_params): def url_string(string="/graphql", **url_params):
if url_params: if url_params:
string += '?' + urlencode(url_params) string += "?" + urlencode(url_params)
return string return string
def batch_url_string(**url_params): def batch_url_string(**url_params):
return url_string('/graphql/batch', **url_params) return url_string("/graphql/batch", **url_params)
def response_json(response): def response_json(response):
@ -28,441 +28,446 @@ jl = lambda **kwargs: json.dumps([kwargs])
def test_graphiql_is_enabled(client): def test_graphiql_is_enabled(client):
response = client.get(url_string(), HTTP_ACCEPT='text/html') response = client.get(url_string(), HTTP_ACCEPT="text/html")
assert response.status_code == 200 assert response.status_code == 200
assert response['Content-Type'].split(';')[0] == 'text/html' assert response["Content-Type"].split(";")[0] == "text/html"
def test_qfactor_graphiql(client): def test_qfactor_graphiql(client):
response = client.get(url_string(query='{test}'), HTTP_ACCEPT='application/json;q=0.8, text/html;q=0.9') response = client.get(
url_string(query="{test}"),
HTTP_ACCEPT="application/json;q=0.8, text/html;q=0.9",
)
assert response.status_code == 200 assert response.status_code == 200
assert response['Content-Type'].split(';')[0] == 'text/html' assert response["Content-Type"].split(";")[0] == "text/html"
def test_qfactor_json(client): def test_qfactor_json(client):
response = client.get(url_string(query='{test}'), HTTP_ACCEPT='text/html;q=0.8, application/json;q=0.9') response = client.get(
url_string(query="{test}"),
HTTP_ACCEPT="text/html;q=0.8, application/json;q=0.9",
)
assert response.status_code == 200 assert response.status_code == 200
assert response['Content-Type'].split(';')[0] == 'application/json' assert response["Content-Type"].split(";")[0] == "application/json"
assert response_json(response) == { assert response_json(response) == {"data": {"test": "Hello World"}}
'data': {'test': "Hello World"}
}
def test_allows_get_with_query_param(client): def test_allows_get_with_query_param(client):
response = client.get(url_string(query='{test}')) response = client.get(url_string(query="{test}"))
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {"data": {"test": "Hello World"}}
'data': {'test': "Hello World"}
}
def test_allows_get_with_variable_values(client): def test_allows_get_with_variable_values(client):
response = client.get(url_string( response = client.get(
query='query helloWho($who: String){ test(who: $who) }', url_string(
variables=json.dumps({'who': "Dolly"}) query="query helloWho($who: String){ test(who: $who) }",
)) variables=json.dumps({"who": "Dolly"}),
)
)
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {"data": {"test": "Hello Dolly"}}
'data': {'test': "Hello Dolly"}
}
def test_allows_get_with_operation_name(client): def test_allows_get_with_operation_name(client):
response = client.get(url_string( response = client.get(
query=''' url_string(
query="""
query helloYou { test(who: "You"), ...shared } query helloYou { test(who: "You"), ...shared }
query helloWorld { test(who: "World"), ...shared } query helloWorld { test(who: "World"), ...shared }
query helloDolly { test(who: "Dolly"), ...shared } query helloDolly { test(who: "Dolly"), ...shared }
fragment shared on QueryRoot { fragment shared on QueryRoot {
shared: test(who: "Everyone") shared: test(who: "Everyone")
} }
''', """,
operationName='helloWorld' operationName="helloWorld",
)) )
)
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {
'data': { "data": {"test": "Hello World", "shared": "Hello Everyone"}
'test': 'Hello World',
'shared': 'Hello Everyone'
}
} }
def test_reports_validation_errors(client): def test_reports_validation_errors(client):
response = client.get(url_string( response = client.get(url_string(query="{ test, unknownOne, unknownTwo }"))
query='{ test, unknownOne, unknownTwo }'
))
assert response.status_code == 400 assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
'errors': [ "errors": [
{ {
'message': 'Cannot query field "unknownOne" on type "QueryRoot".', "message": 'Cannot query field "unknownOne" on type "QueryRoot".',
'locations': [{'line': 1, 'column': 9}] "locations": [{"line": 1, "column": 9}],
}, },
{ {
'message': 'Cannot query field "unknownTwo" on type "QueryRoot".', "message": 'Cannot query field "unknownTwo" on type "QueryRoot".',
'locations': [{'line': 1, 'column': 21}] "locations": [{"line": 1, "column": 21}],
} },
] ]
} }
def test_errors_when_missing_operation_name(client): def test_errors_when_missing_operation_name(client):
response = client.get(url_string( response = client.get(
query=''' url_string(
query="""
query TestQuery { test } query TestQuery { test }
mutation TestMutation { writeTest { test } } mutation TestMutation { writeTest { test } }
''' """
)) )
)
assert response.status_code == 400 assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
'errors': [ "errors": [
{ {
'message': 'Must provide operation name if query contains multiple operations.' "message": "Must provide operation name if query contains multiple operations."
} }
] ]
} }
def test_errors_when_sending_a_mutation_via_get(client): def test_errors_when_sending_a_mutation_via_get(client):
response = client.get(url_string( response = client.get(
query=''' url_string(
query="""
mutation TestMutation { writeTest { test } } mutation TestMutation { writeTest { test } }
''' """
)) )
)
assert response.status_code == 405 assert response.status_code == 405
assert response_json(response) == { assert response_json(response) == {
'errors': [ "errors": [
{ {"message": "Can only perform a mutation operation from a POST request."}
'message': 'Can only perform a mutation operation from a POST request.'
}
] ]
} }
def test_errors_when_selecting_a_mutation_within_a_get(client): def test_errors_when_selecting_a_mutation_within_a_get(client):
response = client.get(url_string( response = client.get(
query=''' url_string(
query="""
query TestQuery { test } query TestQuery { test }
mutation TestMutation { writeTest { test } } mutation TestMutation { writeTest { test } }
''', """,
operationName='TestMutation' operationName="TestMutation",
)) )
)
assert response.status_code == 405 assert response.status_code == 405
assert response_json(response) == { assert response_json(response) == {
'errors': [ "errors": [
{ {"message": "Can only perform a mutation operation from a POST request."}
'message': 'Can only perform a mutation operation from a POST request.'
}
] ]
} }
def test_allows_mutation_to_exist_within_a_get(client): def test_allows_mutation_to_exist_within_a_get(client):
response = client.get(url_string( response = client.get(
query=''' url_string(
query="""
query TestQuery { test } query TestQuery { test }
mutation TestMutation { writeTest { test } } mutation TestMutation { writeTest { test } }
''', """,
operationName='TestQuery' operationName="TestQuery",
)) )
)
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {"data": {"test": "Hello World"}}
'data': {'test': "Hello World"}
}
def test_allows_post_with_json_encoding(client): def test_allows_post_with_json_encoding(client):
response = client.post(url_string(), j(query='{test}'), 'application/json') response = client.post(url_string(), j(query="{test}"), "application/json")
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {"data": {"test": "Hello World"}}
'data': {'test': "Hello World"}
}
def test_batch_allows_post_with_json_encoding(client): def test_batch_allows_post_with_json_encoding(client):
response = client.post(batch_url_string(), jl(id=1, query='{test}'), 'application/json') response = client.post(
batch_url_string(), jl(id=1, query="{test}"), "application/json"
)
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == [{ assert response_json(response) == [
'id': 1, {"id": 1, "data": {"test": "Hello World"}, "status": 200}
'data': {'test': "Hello World"}, ]
'status': 200,
}]
def test_batch_fails_if_is_empty(client): def test_batch_fails_if_is_empty(client):
response = client.post(batch_url_string(), '[]', 'application/json') response = client.post(batch_url_string(), "[]", "application/json")
assert response.status_code == 400 assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
'errors': [{'message': 'Received an empty list in the batch request.'}] "errors": [{"message": "Received an empty list in the batch request."}]
} }
def test_allows_sending_a_mutation_via_post(client): def test_allows_sending_a_mutation_via_post(client):
response = client.post(url_string(), j(query='mutation TestMutation { writeTest { test } }'), 'application/json') response = client.post(
url_string(),
j(query="mutation TestMutation { writeTest { test } }"),
"application/json",
)
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {"data": {"writeTest": {"test": "Hello World"}}}
'data': {'writeTest': {'test': 'Hello World'}}
}
def test_allows_post_with_url_encoding(client): def test_allows_post_with_url_encoding(client):
response = client.post(url_string(), urlencode(dict(query='{test}')), 'application/x-www-form-urlencoded') response = client.post(
url_string(),
urlencode(dict(query="{test}")),
"application/x-www-form-urlencoded",
)
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {"data": {"test": "Hello World"}}
'data': {'test': "Hello World"}
}
def test_supports_post_json_query_with_string_variables(client): def test_supports_post_json_query_with_string_variables(client):
response = client.post(url_string(), j( response = client.post(
query='query helloWho($who: String){ test(who: $who) }', url_string(),
variables=json.dumps({'who': "Dolly"}) j(
), 'application/json') query="query helloWho($who: String){ test(who: $who) }",
variables=json.dumps({"who": "Dolly"}),
),
"application/json",
)
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {"data": {"test": "Hello Dolly"}}
'data': {'test': "Hello Dolly"}
}
def test_batch_supports_post_json_query_with_string_variables(client): def test_batch_supports_post_json_query_with_string_variables(client):
response = client.post(batch_url_string(), jl( response = client.post(
id=1, batch_url_string(),
query='query helloWho($who: String){ test(who: $who) }', jl(
variables=json.dumps({'who': "Dolly"}) id=1,
), 'application/json') query="query helloWho($who: String){ test(who: $who) }",
variables=json.dumps({"who": "Dolly"}),
),
"application/json",
)
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == [{ assert response_json(response) == [
'id': 1, {"id": 1, "data": {"test": "Hello Dolly"}, "status": 200}
'data': {'test': "Hello Dolly"}, ]
'status': 200,
}]
def test_supports_post_json_query_with_json_variables(client): def test_supports_post_json_query_with_json_variables(client):
response = client.post(url_string(), j( response = client.post(
query='query helloWho($who: String){ test(who: $who) }', url_string(),
variables={'who': "Dolly"} j(
), 'application/json') query="query helloWho($who: String){ test(who: $who) }",
variables={"who": "Dolly"},
),
"application/json",
)
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {"data": {"test": "Hello Dolly"}}
'data': {'test': "Hello Dolly"}
}
def test_batch_supports_post_json_query_with_json_variables(client): def test_batch_supports_post_json_query_with_json_variables(client):
response = client.post(batch_url_string(), jl( response = client.post(
id=1, batch_url_string(),
query='query helloWho($who: String){ test(who: $who) }', jl(
variables={'who': "Dolly"} id=1,
), 'application/json') query="query helloWho($who: String){ test(who: $who) }",
variables={"who": "Dolly"},
),
"application/json",
)
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == [{ assert response_json(response) == [
'id': 1, {"id": 1, "data": {"test": "Hello Dolly"}, "status": 200}
'data': {'test': "Hello Dolly"}, ]
'status': 200,
}]
def test_supports_post_url_encoded_query_with_string_variables(client): def test_supports_post_url_encoded_query_with_string_variables(client):
response = client.post(url_string(), urlencode(dict( response = client.post(
query='query helloWho($who: String){ test(who: $who) }', url_string(),
variables=json.dumps({'who': "Dolly"}) urlencode(
)), 'application/x-www-form-urlencoded') dict(
query="query helloWho($who: String){ test(who: $who) }",
variables=json.dumps({"who": "Dolly"}),
)
),
"application/x-www-form-urlencoded",
)
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {"data": {"test": "Hello Dolly"}}
'data': {'test': "Hello Dolly"}
}
def test_supports_post_json_quey_with_get_variable_values(client): def test_supports_post_json_quey_with_get_variable_values(client):
response = client.post(url_string( response = client.post(
variables=json.dumps({'who': "Dolly"}) url_string(variables=json.dumps({"who": "Dolly"})),
), j( j(query="query helloWho($who: String){ test(who: $who) }"),
query='query helloWho($who: String){ test(who: $who) }', "application/json",
), 'application/json') )
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {"data": {"test": "Hello Dolly"}}
'data': {'test': "Hello Dolly"}
}
def test_post_url_encoded_query_with_get_variable_values(client): def test_post_url_encoded_query_with_get_variable_values(client):
response = client.post(url_string( response = client.post(
variables=json.dumps({'who': "Dolly"}) url_string(variables=json.dumps({"who": "Dolly"})),
), urlencode(dict( urlencode(dict(query="query helloWho($who: String){ test(who: $who) }")),
query='query helloWho($who: String){ test(who: $who) }', "application/x-www-form-urlencoded",
)), 'application/x-www-form-urlencoded') )
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {"data": {"test": "Hello Dolly"}}
'data': {'test': "Hello Dolly"}
}
def test_supports_post_raw_text_query_with_get_variable_values(client): def test_supports_post_raw_text_query_with_get_variable_values(client):
response = client.post(url_string( response = client.post(
variables=json.dumps({'who': "Dolly"}) url_string(variables=json.dumps({"who": "Dolly"})),
), "query helloWho($who: String){ test(who: $who) }",
'query helloWho($who: String){ test(who: $who) }', "application/graphql",
'application/graphql' )
assert response.status_code == 200
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
def test_allows_post_with_operation_name(client):
response = client.post(
url_string(),
j(
query="""
query helloYou { test(who: "You"), ...shared }
query helloWorld { test(who: "World"), ...shared }
query helloDolly { test(who: "Dolly"), ...shared }
fragment shared on QueryRoot {
shared: test(who: "Everyone")
}
""",
operationName="helloWorld",
),
"application/json",
) )
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {
'data': {'test': "Hello Dolly"} "data": {"test": "Hello World", "shared": "Hello Everyone"}
}
def test_allows_post_with_operation_name(client):
response = client.post(url_string(), j(
query='''
query helloYou { test(who: "You"), ...shared }
query helloWorld { test(who: "World"), ...shared }
query helloDolly { test(who: "Dolly"), ...shared }
fragment shared on QueryRoot {
shared: test(who: "Everyone")
}
''',
operationName='helloWorld'
), 'application/json')
assert response.status_code == 200
assert response_json(response) == {
'data': {
'test': 'Hello World',
'shared': 'Hello Everyone'
}
} }
def test_batch_allows_post_with_operation_name(client): def test_batch_allows_post_with_operation_name(client):
response = client.post(batch_url_string(), jl( response = client.post(
id=1, batch_url_string(),
query=''' jl(
id=1,
query="""
query helloYou { test(who: "You"), ...shared } query helloYou { test(who: "You"), ...shared }
query helloWorld { test(who: "World"), ...shared } query helloWorld { test(who: "World"), ...shared }
query helloDolly { test(who: "Dolly"), ...shared } query helloDolly { test(who: "Dolly"), ...shared }
fragment shared on QueryRoot { fragment shared on QueryRoot {
shared: test(who: "Everyone") shared: test(who: "Everyone")
} }
''', """,
operationName='helloWorld' operationName="helloWorld",
), 'application/json') ),
"application/json",
)
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == [{ assert response_json(response) == [
'id': 1, {
'data': { "id": 1,
'test': 'Hello World', "data": {"test": "Hello World", "shared": "Hello Everyone"},
'shared': 'Hello Everyone' "status": 200,
}, }
'status': 200, ]
}]
def test_allows_post_with_get_operation_name(client): def test_allows_post_with_get_operation_name(client):
response = client.post(url_string( response = client.post(
operationName='helloWorld' url_string(operationName="helloWorld"),
), ''' """
query helloYou { test(who: "You"), ...shared } query helloYou { test(who: "You"), ...shared }
query helloWorld { test(who: "World"), ...shared } query helloWorld { test(who: "World"), ...shared }
query helloDolly { test(who: "Dolly"), ...shared } query helloDolly { test(who: "Dolly"), ...shared }
fragment shared on QueryRoot { fragment shared on QueryRoot {
shared: test(who: "Everyone") shared: test(who: "Everyone")
} }
''', """,
'application/graphql') "application/graphql",
)
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {
'data': { "data": {"test": "Hello World", "shared": "Hello Everyone"}
'test': 'Hello World',
'shared': 'Hello Everyone'
}
} }
@pytest.mark.urls('graphene_django.tests.urls_inherited') @pytest.mark.urls("graphene_django.tests.urls_inherited")
def test_inherited_class_with_attributes_works(client): def test_inherited_class_with_attributes_works(client):
inherited_url = '/graphql/inherited/' inherited_url = "/graphql/inherited/"
# Check schema and pretty attributes work # Check schema and pretty attributes work
response = client.post(url_string(inherited_url, query='{test}')) response = client.post(url_string(inherited_url, query="{test}"))
assert response.content.decode() == ( assert response.content.decode() == (
'{\n' "{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}"
' "data": {\n'
' "test": "Hello World"\n'
' }\n'
'}'
) )
# Check graphiql works # Check graphiql works
response = client.get(url_string(inherited_url), HTTP_ACCEPT='text/html') response = client.get(url_string(inherited_url), HTTP_ACCEPT="text/html")
assert response.status_code == 200 assert response.status_code == 200
@pytest.mark.urls('graphene_django.tests.urls_pretty') @pytest.mark.urls("graphene_django.tests.urls_pretty")
def test_supports_pretty_printing(client): def test_supports_pretty_printing(client):
response = client.get(url_string(query='{test}')) response = client.get(url_string(query="{test}"))
assert response.content.decode() == ( assert response.content.decode() == (
'{\n' "{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}"
' "data": {\n'
' "test": "Hello World"\n'
' }\n'
'}'
) )
def test_supports_pretty_printing_by_request(client): def test_supports_pretty_printing_by_request(client):
response = client.get(url_string(query='{test}', pretty='1')) response = client.get(url_string(query="{test}", pretty="1"))
assert response.content.decode() == ( assert response.content.decode() == (
'{\n' "{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}"
' "data": {\n'
' "test": "Hello World"\n'
' }\n'
'}'
) )
def test_handles_field_errors_caught_by_graphql(client): def test_handles_field_errors_caught_by_graphql(client):
response = client.get(url_string(query='{thrower}')) response = client.get(url_string(query="{thrower}"))
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {
'data': None, "data": None,
'errors': [{ "errors": [
'locations': [{'column': 2, 'line': 1}], {
'path': ['thrower'], "locations": [{"column": 2, "line": 1}],
'message': 'Throws!', "path": ["thrower"],
}] "message": "Throws!",
}
],
} }
def test_handles_syntax_errors_caught_by_graphql(client): def test_handles_syntax_errors_caught_by_graphql(client):
response = client.get(url_string(query='syntaxerror')) response = client.get(url_string(query="syntaxerror"))
assert response.status_code == 400 assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
'errors': [{'locations': [{'column': 1, 'line': 1}], "errors": [
'message': 'Syntax Error GraphQL (1:1) ' {
'Unexpected Name "syntaxerror"\n\n1: syntaxerror\n ^\n'}] "locations": [{"column": 1, "line": 1}],
"message": "Syntax Error GraphQL (1:1) "
'Unexpected Name "syntaxerror"\n\n1: syntaxerror\n ^\n',
}
]
} }
@ -471,25 +476,25 @@ def test_handles_errors_caused_by_a_lack_of_query(client):
assert response.status_code == 400 assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
'errors': [{'message': 'Must provide query string.'}] "errors": [{"message": "Must provide query string."}]
} }
def test_handles_not_expected_json_bodies(client): def test_handles_not_expected_json_bodies(client):
response = client.post(url_string(), '[]', 'application/json') response = client.post(url_string(), "[]", "application/json")
assert response.status_code == 400 assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
'errors': [{'message': 'The received data is not a valid JSON query.'}] "errors": [{"message": "The received data is not a valid JSON query."}]
} }
def test_handles_invalid_json_bodies(client): def test_handles_invalid_json_bodies(client):
response = client.post(url_string(), '[oh}', 'application/json') response = client.post(url_string(), "[oh}", "application/json")
assert response.status_code == 400 assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
'errors': [{'message': 'POST body sent invalid JSON.'}] "errors": [{"message": "POST body sent invalid JSON."}]
} }
@ -499,63 +504,57 @@ def test_handles_django_request_error(client, monkeypatch):
monkeypatch.setattr("django.http.request.HttpRequest.read", mocked_read) monkeypatch.setattr("django.http.request.HttpRequest.read", mocked_read)
valid_json = json.dumps(dict(foo='bar')) valid_json = json.dumps(dict(foo="bar"))
response = client.post(url_string(), valid_json, 'application/json') response = client.post(url_string(), valid_json, "application/json")
assert response.status_code == 400 assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {"errors": [{"message": "foo-bar"}]}
'errors': [{'message': 'foo-bar'}]
}
def test_handles_incomplete_json_bodies(client): def test_handles_incomplete_json_bodies(client):
response = client.post(url_string(), '{"query":', 'application/json') response = client.post(url_string(), '{"query":', "application/json")
assert response.status_code == 400 assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
'errors': [{'message': 'POST body sent invalid JSON.'}] "errors": [{"message": "POST body sent invalid JSON."}]
} }
def test_handles_plain_post_text(client): def test_handles_plain_post_text(client):
response = client.post(url_string( response = client.post(
variables=json.dumps({'who': "Dolly"}) url_string(variables=json.dumps({"who": "Dolly"})),
), "query helloWho($who: String){ test(who: $who) }",
'query helloWho($who: String){ test(who: $who) }', "text/plain",
'text/plain'
) )
assert response.status_code == 400 assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
'errors': [{'message': 'Must provide query string.'}] "errors": [{"message": "Must provide query string."}]
} }
def test_handles_poorly_formed_variables(client): def test_handles_poorly_formed_variables(client):
response = client.get(url_string( response = client.get(
query='query helloWho($who: String){ test(who: $who) }', url_string(
variables='who:You' query="query helloWho($who: String){ test(who: $who) }", variables="who:You"
)) )
)
assert response.status_code == 400 assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
'errors': [{'message': 'Variables are invalid JSON.'}] "errors": [{"message": "Variables are invalid JSON."}]
} }
def test_handles_unsupported_http_methods(client): def test_handles_unsupported_http_methods(client):
response = client.put(url_string(query='{test}')) response = client.put(url_string(query="{test}"))
assert response.status_code == 405 assert response.status_code == 405
assert response['Allow'] == 'GET, POST' assert response["Allow"] == "GET, POST"
assert response_json(response) == { assert response_json(response) == {
'errors': [{'message': 'GraphQL only supports GET and POST requests.'}] "errors": [{"message": "GraphQL only supports GET and POST requests."}]
} }
def test_passes_request_into_context_request(client): def test_passes_request_into_context_request(client):
response = client.get(url_string(query='{request}', q='testing')) response = client.get(url_string(query="{request}", q="testing"))
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {"data": {"request": "testing"}}
'data': {
'request': 'testing'
}
}

View File

@ -3,6 +3,6 @@ from django.conf.urls import url
from ..views import GraphQLView from ..views import GraphQLView
urlpatterns = [ urlpatterns = [
url(r'^graphql/batch', GraphQLView.as_view(batch=True)), url(r"^graphql/batch", GraphQLView.as_view(batch=True)),
url(r'^graphql', GraphQLView.as_view(graphiql=True)), url(r"^graphql", GraphQLView.as_view(graphiql=True)),
] ]

View File

@ -3,12 +3,11 @@ from django.conf.urls import url
from ..views import GraphQLView from ..views import GraphQLView
from .schema_view import schema from .schema_view import schema
class CustomGraphQLView(GraphQLView): class CustomGraphQLView(GraphQLView):
schema = schema schema = schema
graphiql = True graphiql = True
pretty = True pretty = True
urlpatterns = [ urlpatterns = [url(r"^graphql/inherited/$", CustomGraphQLView.as_view())]
url(r'^graphql/inherited/$', CustomGraphQLView.as_view()),
]

View File

@ -3,6 +3,4 @@ from django.conf.urls import url
from ..views import GraphQLView from ..views import GraphQLView
from .schema_view import schema from .schema_view import schema
urlpatterns = [ urlpatterns = [url(r"^graphql", GraphQLView.as_view(schema=schema, pretty=True))]
url(r'^graphql', GraphQLView.as_view(schema=schema, pretty=True)),
]

View File

@ -8,8 +8,7 @@ from graphene.types.utils import yank_fields_from_attrs
from .converter import convert_django_field_with_choices from .converter import convert_django_field_with_choices
from .registry import Registry, get_global_registry from .registry import Registry, get_global_registry
from .utils import (DJANGO_FILTER_INSTALLED, get_model_fields, from .utils import DJANGO_FILTER_INSTALLED, get_model_fields, is_valid_django_model
is_valid_django_model)
def construct_fields(model, registry, only_fields, exclude_fields): def construct_fields(model, registry, only_fields, exclude_fields):
@ -21,7 +20,7 @@ def construct_fields(model, registry, only_fields, exclude_fields):
# is_already_created = name in options.fields # is_already_created = name in options.fields
is_excluded = name in exclude_fields # or is_already_created is_excluded = name in exclude_fields # or is_already_created
# https://docs.djangoproject.com/en/1.10/ref/models/fields/#django.db.models.ForeignKey.related_query_name # https://docs.djangoproject.com/en/1.10/ref/models/fields/#django.db.models.ForeignKey.related_query_name
is_no_backref = str(name).endswith('+') is_no_backref = str(name).endswith("+")
if is_not_in_only or is_excluded or is_no_backref: if is_not_in_only or is_excluded or is_no_backref:
# We skip this field if we specify only_fields and is not # We skip this field if we specify only_fields and is not
# in there. Or when we exclude this field in exclude_fields. # in there. Or when we exclude this field in exclude_fields.
@ -43,9 +42,21 @@ class DjangoObjectTypeOptions(ObjectTypeOptions):
class DjangoObjectType(ObjectType): class DjangoObjectType(ObjectType):
@classmethod @classmethod
def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=False, def __init_subclass_with_meta__(
only_fields=(), exclude_fields=(), filter_fields=None, connection=None, cls,
connection_class=None, use_connection=None, interfaces=(), _meta=None, **options): model=None,
registry=None,
skip_registry=False,
only_fields=(),
exclude_fields=(),
filter_fields=None,
connection=None,
connection_class=None,
use_connection=None,
interfaces=(),
_meta=None,
**options
):
assert is_valid_django_model(model), ( assert is_valid_django_model(model), (
'You need to pass a valid Django Model in {}.Meta, received "{}".' 'You need to pass a valid Django Model in {}.Meta, received "{}".'
).format(cls.__name__, model) ).format(cls.__name__, model)
@ -54,7 +65,7 @@ class DjangoObjectType(ObjectType):
registry = get_global_registry() registry = get_global_registry()
assert isinstance(registry, Registry), ( assert isinstance(registry, Registry), (
'The attribute registry in {} needs to be an instance of ' "The attribute registry in {} needs to be an instance of "
'Registry, received "{}".' 'Registry, received "{}".'
).format(cls.__name__, registry) ).format(cls.__name__, registry)
@ -62,12 +73,13 @@ class DjangoObjectType(ObjectType):
raise Exception("Can only set filter_fields if Django-Filter is installed") raise Exception("Can only set filter_fields if Django-Filter is installed")
django_fields = yank_fields_from_attrs( django_fields = yank_fields_from_attrs(
construct_fields(model, registry, only_fields, exclude_fields), construct_fields(model, registry, only_fields, exclude_fields), _as=Field
_as=Field,
) )
if use_connection is None and interfaces: if use_connection is None and interfaces:
use_connection = any((issubclass(interface, Node) for interface in interfaces)) use_connection = any(
(issubclass(interface, Node) for interface in interfaces)
)
if use_connection and not connection: if use_connection and not connection:
# We create the connection automatically # We create the connection automatically
@ -75,7 +87,8 @@ class DjangoObjectType(ObjectType):
connection_class = Connection connection_class = Connection
connection = connection_class.create_type( connection = connection_class.create_type(
'{}Connection'.format(cls.__name__), node=cls) "{}Connection".format(cls.__name__), node=cls
)
if connection is not None: if connection is not None:
assert issubclass(connection, Connection), ( assert issubclass(connection, Connection), (
@ -91,7 +104,9 @@ class DjangoObjectType(ObjectType):
_meta.fields = django_fields _meta.fields = django_fields
_meta.connection = connection _meta.connection = connection
super(DjangoObjectType, cls).__init_subclass_with_meta__(_meta=_meta, interfaces=interfaces, **options) super(DjangoObjectType, cls).__init_subclass_with_meta__(
_meta=_meta, interfaces=interfaces, **options
)
if not skip_registry: if not skip_registry:
registry.register(cls) registry.register(cls)
@ -107,9 +122,7 @@ class DjangoObjectType(ObjectType):
if isinstance(root, cls): if isinstance(root, cls):
return True return True
if not is_valid_django_model(type(root)): if not is_valid_django_model(type(root)):
raise Exception(( raise Exception(('Received incompatible instance "{}".').format(root))
'Received incompatible instance "{}".'
).format(root))
model = root._meta.model._meta.concrete_model model = root._meta.model._meta.concrete_model
return model == cls._meta.model return model == cls._meta.model

View File

@ -13,6 +13,7 @@ class LazyList(object):
try: try:
import django_filters # noqa import django_filters # noqa
DJANGO_FILTER_INSTALLED = True DJANGO_FILTER_INSTALLED = True
except ImportError: except ImportError:
DJANGO_FILTER_INSTALLED = False DJANGO_FILTER_INSTALLED = False
@ -25,8 +26,7 @@ def get_reverse_fields(model, local_field_names):
continue continue
# Django =>1.9 uses 'rel', django <1.9 uses 'related' # Django =>1.9 uses 'rel', django <1.9 uses 'related'
related = getattr(attr, 'rel', None) or \ related = getattr(attr, "rel", None) or getattr(attr, "related", None)
getattr(attr, 'related', None)
if isinstance(related, models.ManyToOneRel): if isinstance(related, models.ManyToOneRel):
yield (name, related) yield (name, related)
elif isinstance(related, models.ManyToManyRel) and not related.symmetrical: elif isinstance(related, models.ManyToManyRel) and not related.symmetrical:
@ -42,9 +42,9 @@ def maybe_queryset(value):
def get_model_fields(model): def get_model_fields(model):
local_fields = [ local_fields = [
(field.name, field) (field.name, field)
for field for field in sorted(
in sorted(list(model._meta.fields) + list(model._meta.fields) + list(model._meta.local_many_to_many)
list(model._meta.local_many_to_many)) )
] ]
# Make sure we don't duplicate local fields with "reverse" version # Make sure we don't duplicate local fields with "reverse" version

View File

@ -20,7 +20,6 @@ from .settings import graphene_settings
class HttpError(Exception): class HttpError(Exception):
def __init__(self, response, message=None, *args, **kwargs): def __init__(self, response, message=None, *args, **kwargs):
self.response = response self.response = response
self.message = message = message or response.content.decode() self.message = message = message or response.content.decode()
@ -29,18 +28,18 @@ class HttpError(Exception):
def get_accepted_content_types(request): def get_accepted_content_types(request):
def qualify(x): def qualify(x):
parts = x.split(';', 1) parts = x.split(";", 1)
if len(parts) == 2: if len(parts) == 2:
match = re.match(r'(^|;)q=(0(\.\d{,3})?|1(\.0{,3})?)(;|$)', match = re.match(r"(^|;)q=(0(\.\d{,3})?|1(\.0{,3})?)(;|$)", parts[1])
parts[1])
if match: if match:
return parts[0].strip(), float(match.group(2)) return parts[0].strip(), float(match.group(2))
return parts[0].strip(), 1 return parts[0].strip(), 1
raw_content_types = request.META.get('HTTP_ACCEPT', '*/*').split(',') raw_content_types = request.META.get("HTTP_ACCEPT", "*/*").split(",")
qualified_content_types = map(qualify, raw_content_types) qualified_content_types = map(qualify, raw_content_types)
return list(x[0] for x in sorted(qualified_content_types, return list(
key=lambda x: x[1], reverse=True)) x[0] for x in sorted(qualified_content_types, key=lambda x: x[1], reverse=True)
)
def instantiate_middleware(middlewares): def instantiate_middleware(middlewares):
@ -52,8 +51,8 @@ def instantiate_middleware(middlewares):
class GraphQLView(View): class GraphQLView(View):
graphiql_version = '0.11.10' graphiql_version = "0.11.10"
graphiql_template = 'graphene/graphiql.html' graphiql_template = "graphene/graphiql.html"
schema = None schema = None
graphiql = False graphiql = False
@ -64,8 +63,17 @@ class GraphQLView(View):
pretty = False pretty = False
batch = False batch = False
def __init__(self, schema=None, executor=None, middleware=None, root_value=None, graphiql=False, pretty=False, def __init__(
batch=False, backend=None): self,
schema=None,
executor=None,
middleware=None,
root_value=None,
graphiql=False,
pretty=False,
batch=False,
backend=None,
):
if not schema: if not schema:
schema = graphene_settings.SCHEMA schema = graphene_settings.SCHEMA
@ -86,9 +94,9 @@ class GraphQLView(View):
self.backend = backend self.backend = backend
assert isinstance( assert isinstance(
self.schema, GraphQLSchema), 'A Schema is required to be provided to GraphQLView.' self.schema, GraphQLSchema
assert not all((graphiql, batch) ), "A Schema is required to be provided to GraphQLView."
), 'Use either graphiql or batch processing' assert not all((graphiql, batch)), "Use either graphiql or batch processing"
# noinspection PyUnusedLocal # noinspection PyUnusedLocal
def get_root_value(self, request): def get_root_value(self, request):
@ -106,59 +114,59 @@ class GraphQLView(View):
@method_decorator(ensure_csrf_cookie) @method_decorator(ensure_csrf_cookie)
def dispatch(self, request, *args, **kwargs): def dispatch(self, request, *args, **kwargs):
try: try:
if request.method.lower() not in ('get', 'post'): if request.method.lower() not in ("get", "post"):
raise HttpError(HttpResponseNotAllowed( raise HttpError(
['GET', 'POST'], 'GraphQL only supports GET and POST requests.')) HttpResponseNotAllowed(
["GET", "POST"], "GraphQL only supports GET and POST requests."
)
)
data = self.parse_body(request) data = self.parse_body(request)
show_graphiql = self.graphiql and self.can_display_graphiql( show_graphiql = self.graphiql and self.can_display_graphiql(request, data)
request, data)
if self.batch: if self.batch:
responses = [self.get_response(request, entry) for entry in data] responses = [self.get_response(request, entry) for entry in data]
result = '[{}]'.format(','.join([response[0] for response in responses])) result = "[{}]".format(
status_code = responses and max(responses, key=lambda response: response[1])[1] or 200 ",".join([response[0] for response in responses])
)
status_code = (
responses
and max(responses, key=lambda response: response[1])[1]
or 200
)
else: else:
result, status_code = self.get_response( result, status_code = self.get_response(request, data, show_graphiql)
request, data, show_graphiql)
if show_graphiql: if show_graphiql:
query, variables, operation_name, id = self.get_graphql_params( query, variables, operation_name, id = self.get_graphql_params(
request, data) request, data
)
return self.render_graphiql( return self.render_graphiql(
request, request,
graphiql_version=self.graphiql_version, graphiql_version=self.graphiql_version,
query=query or '', query=query or "",
variables=json.dumps(variables) or '', variables=json.dumps(variables) or "",
operation_name=operation_name or '', operation_name=operation_name or "",
result=result or '' result=result or "",
) )
return HttpResponse( return HttpResponse(
status=status_code, status=status_code, content=result, content_type="application/json"
content=result,
content_type='application/json'
) )
except HttpError as e: except HttpError as e:
response = e.response response = e.response
response['Content-Type'] = 'application/json' response["Content-Type"] = "application/json"
response.content = self.json_encode(request, { response.content = self.json_encode(
'errors': [self.format_error(e)] request, {"errors": [self.format_error(e)]}
}) )
return response return response
def get_response(self, request, data, show_graphiql=False): def get_response(self, request, data, show_graphiql=False):
query, variables, operation_name, id = self.get_graphql_params( query, variables, operation_name, id = self.get_graphql_params(request, data)
request, data)
execution_result = self.execute_graphql_request( execution_result = self.execute_graphql_request(
request, request, data, query, variables, operation_name, show_graphiql
data,
query,
variables,
operation_name,
show_graphiql
) )
status_code = 200 status_code = 200
@ -166,17 +174,18 @@ class GraphQLView(View):
response = {} response = {}
if execution_result.errors: if execution_result.errors:
response['errors'] = [self.format_error( response["errors"] = [
e) for e in execution_result.errors] self.format_error(e) for e in execution_result.errors
]
if execution_result.invalid: if execution_result.invalid:
status_code = 400 status_code = 400
else: else:
response['data'] = execution_result.data response["data"] = execution_result.data
if self.batch: if self.batch:
response['id'] = id response["id"] = id
response['status'] = status_code response["status"] = status_code
result = self.json_encode(request, response, pretty=show_graphiql) result = self.json_encode(request, response, pretty=show_graphiql)
else: else:
@ -188,22 +197,21 @@ class GraphQLView(View):
return render(request, self.graphiql_template, data) return render(request, self.graphiql_template, data)
def json_encode(self, request, d, pretty=False): def json_encode(self, request, d, pretty=False):
if not (self.pretty or pretty) and not request.GET.get('pretty'): if not (self.pretty or pretty) and not request.GET.get("pretty"):
return json.dumps(d, separators=(',', ':')) return json.dumps(d, separators=(",", ":"))
return json.dumps(d, sort_keys=True, return json.dumps(d, sort_keys=True, indent=2, separators=(",", ": "))
indent=2, separators=(',', ': '))
def parse_body(self, request): def parse_body(self, request):
content_type = self.get_content_type(request) content_type = self.get_content_type(request)
if content_type == 'application/graphql': if content_type == "application/graphql":
return {'query': request.body.decode()} return {"query": request.body.decode()}
elif content_type == 'application/json': elif content_type == "application/json":
# noinspection PyBroadException # noinspection PyBroadException
try: try:
body = request.body.decode('utf-8') body = request.body.decode("utf-8")
except Exception as e: except Exception as e:
raise HttpError(HttpResponseBadRequest(str(e))) raise HttpError(HttpResponseBadRequest(str(e)))
@ -211,33 +219,36 @@ class GraphQLView(View):
request_json = json.loads(body) request_json = json.loads(body)
if self.batch: if self.batch:
assert isinstance(request_json, list), ( assert isinstance(request_json, list), (
'Batch requests should receive a list, but received {}.' "Batch requests should receive a list, but received {}."
).format(repr(request_json)) ).format(repr(request_json))
assert len(request_json) > 0, ( assert (
'Received an empty list in the batch request.' len(request_json) > 0
) ), "Received an empty list in the batch request."
else: else:
assert isinstance(request_json, dict), ( assert isinstance(
'The received data is not a valid JSON query.' request_json, dict
) ), "The received data is not a valid JSON query."
return request_json return request_json
except AssertionError as e: except AssertionError as e:
raise HttpError(HttpResponseBadRequest(str(e))) raise HttpError(HttpResponseBadRequest(str(e)))
except (TypeError, ValueError): except (TypeError, ValueError):
raise HttpError(HttpResponseBadRequest( raise HttpError(HttpResponseBadRequest("POST body sent invalid JSON."))
'POST body sent invalid JSON.'))
elif content_type in ['application/x-www-form-urlencoded', 'multipart/form-data']: elif content_type in [
"application/x-www-form-urlencoded",
"multipart/form-data",
]:
return request.POST return request.POST
return {} return {}
def execute_graphql_request(self, request, data, query, variables, operation_name, show_graphiql=False): def execute_graphql_request(
self, request, data, query, variables, operation_name, show_graphiql=False
):
if not query: if not query:
if show_graphiql: if show_graphiql:
return None return None
raise HttpError(HttpResponseBadRequest( raise HttpError(HttpResponseBadRequest("Must provide query string."))
'Must provide query string.'))
try: try:
backend = self.get_backend(request) backend = self.get_backend(request)
@ -245,23 +256,27 @@ class GraphQLView(View):
except Exception as e: except Exception as e:
return ExecutionResult(errors=[e], invalid=True) return ExecutionResult(errors=[e], invalid=True)
if request.method.lower() == 'get': if request.method.lower() == "get":
operation_type = document.get_operation_type(operation_name) operation_type = document.get_operation_type(operation_name)
if operation_type and operation_type != 'query': if operation_type and operation_type != "query":
if show_graphiql: if show_graphiql:
return None return None
raise HttpError(HttpResponseNotAllowed( raise HttpError(
['POST'], 'Can only perform a {} operation from a POST request.'.format( HttpResponseNotAllowed(
operation_type) ["POST"],
)) "Can only perform a {} operation from a POST request.".format(
operation_type
),
)
)
try: try:
extra_options = {} extra_options = {}
if self.executor: if self.executor:
# We only include it optionally since # We only include it optionally since
# executor is not a valid argument in all backends # executor is not a valid argument in all backends
extra_options['executor'] = self.executor extra_options["executor"] = self.executor
return document.execute( return document.execute(
root=self.get_root_value(request), root=self.get_root_value(request),
@ -276,7 +291,7 @@ class GraphQLView(View):
@classmethod @classmethod
def can_display_graphiql(cls, request, data): def can_display_graphiql(cls, request, data):
raw = 'raw' in request.GET or 'raw' in data raw = "raw" in request.GET or "raw" in data
return not raw and cls.request_wants_html(request) return not raw and cls.request_wants_html(request)
@classmethod @classmethod
@ -285,26 +300,32 @@ class GraphQLView(View):
accepted_length = len(accepted) accepted_length = len(accepted)
# the list will be ordered in preferred first - so we have to make # the list will be ordered in preferred first - so we have to make
# sure the most preferred gets the highest number # sure the most preferred gets the highest number
html_priority = accepted_length - accepted.index('text/html') if 'text/html' in accepted else 0 html_priority = (
json_priority = accepted_length - accepted.index('application/json') if 'application/json' in accepted else 0 accepted_length - accepted.index("text/html")
if "text/html" in accepted
else 0
)
json_priority = (
accepted_length - accepted.index("application/json")
if "application/json" in accepted
else 0
)
return html_priority > json_priority return html_priority > json_priority
@staticmethod @staticmethod
def get_graphql_params(request, data): def get_graphql_params(request, data):
query = request.GET.get('query') or data.get('query') query = request.GET.get("query") or data.get("query")
variables = request.GET.get('variables') or data.get('variables') variables = request.GET.get("variables") or data.get("variables")
id = request.GET.get('id') or data.get('id') id = request.GET.get("id") or data.get("id")
if variables and isinstance(variables, six.text_type): if variables and isinstance(variables, six.text_type):
try: try:
variables = json.loads(variables) variables = json.loads(variables)
except Exception: except Exception:
raise HttpError(HttpResponseBadRequest( raise HttpError(HttpResponseBadRequest("Variables are invalid JSON."))
'Variables are invalid JSON.'))
operation_name = request.GET.get( operation_name = request.GET.get("operationName") or data.get("operationName")
'operationName') or data.get('operationName')
if operation_name == "null": if operation_name == "null":
operation_name = None operation_name = None
@ -315,11 +336,10 @@ class GraphQLView(View):
if isinstance(error, GraphQLError): if isinstance(error, GraphQLError):
return format_graphql_error(error) return format_graphql_error(error)
return {'message': six.text_type(error)} return {"message": six.text_type(error)}
@staticmethod @staticmethod
def get_content_type(request): def get_content_type(request):
meta = request.META meta = request.META
content_type = meta.get( content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", ""))
'CONTENT_TYPE', meta.get('HTTP_CONTENT_TYPE', '')) return content_type.split(";", 1)[0].lower()
return content_type.split(';', 1)[0].lower()