diff --git a/graphene_django/__init__.py b/graphene_django/__init__.py index e7fca4d..7b80f49 100644 --- a/graphene_django/__init__.py +++ b/graphene_django/__init__.py @@ -1,14 +1,6 @@ -from .types import ( - DjangoObjectType, -) -from .fields import ( - DjangoConnectionField, -) +from .types import DjangoObjectType +from .fields import DjangoConnectionField -__version__ = '2.1rc1' +__version__ = "2.1rc1" -__all__ = [ - '__version__', - 'DjangoObjectType', - 'DjangoConnectionField' -] +__all__ = ["__version__", "DjangoObjectType", "DjangoConnectionField"] diff --git a/graphene_django/compat.py b/graphene_django/compat.py index 0269e33..f43db04 100644 --- a/graphene_django/compat.py +++ b/graphene_django/compat.py @@ -7,7 +7,7 @@ try: # and we cannot have psycopg2 on PyPy from django.contrib.postgres.fields import ArrayField, HStoreField, RangeField except ImportError: - ArrayField, HStoreField, JSONField, RangeField = (MissingType, ) * 4 + ArrayField, HStoreField, JSONField, RangeField = (MissingType,) * 4 try: diff --git a/graphene_django/converter.py b/graphene_django/converter.py index da73ede..c40313d 100644 --- a/graphene_django/converter.py +++ b/graphene_django/converter.py @@ -1,8 +1,22 @@ from django.db import models from django.utils.encoding import force_text -from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List, - NonNull, String, UUID, DateTime, Date, Time) +from graphene import ( + ID, + Boolean, + Dynamic, + Enum, + Field, + Float, + Int, + List, + NonNull, + String, + UUID, + DateTime, + Date, + Time, +) from graphene.types.json import JSONString from graphene.utils.str_converters import to_camel_case, to_const from graphql import assert_valid_name @@ -32,7 +46,7 @@ def get_choices(choices): else: name = convert_choice_name(value) while name in converted_names: - name += '_' + str(len(converted_names)) + name += "_" + str(len(converted_names)) converted_names.append(name) description = help_text yield name, value, description @@ -43,16 +57,15 @@ def convert_django_field_with_choices(field, registry=None): converted = registry.get_converted_field(field) if converted: return converted - choices = getattr(field, 'choices', None) + choices = getattr(field, "choices", None) if choices: meta = field.model._meta - name = to_camel_case('{}_{}'.format(meta.object_name, field.name)) + name = to_camel_case("{}_{}".format(meta.object_name, field.name)) choices = list(get_choices(choices)) named_choices = [(c[0], c[1]) for c in choices] named_choices_descriptions = {c[0]: c[2] for c in choices} class EnumWithDescriptionsType(object): - @property def description(self): return named_choices_descriptions[self.name] @@ -69,8 +82,8 @@ def convert_django_field_with_choices(field, registry=None): @singledispatch def convert_django_field(field, registry=None): raise Exception( - "Don't know how to convert the Django field %s (%s)" % - (field, field.__class__)) + "Don't know how to convert the Django field %s (%s)" % (field, field.__class__) + ) @convert_django_field.register(models.CharField) @@ -147,7 +160,7 @@ def convert_onetoone_field_to_djangomodel(field, registry=None): # We do this for a bug in Django 1.8, where null attr # is not available in the OneToOneRel instance - null = getattr(field, 'null', True) + null = getattr(field, "null", True) return Field(_type, required=not null) return Dynamic(dynamic_type) @@ -171,6 +184,7 @@ def convert_field_to_list_or_connection(field, registry=None): # defined filter_fields in the DjangoObjectType Meta if _type._meta.filter_fields: from .filter.fields import DjangoFilterConnectionField + return DjangoFilterConnectionField(_type) return DjangoConnectionField(_type) diff --git a/graphene_django/debug/__init__.py b/graphene_django/debug/__init__.py index cd5015e..3e078da 100644 --- a/graphene_django/debug/__init__.py +++ b/graphene_django/debug/__init__.py @@ -1,4 +1,4 @@ from .middleware import DjangoDebugMiddleware from .types import DjangoDebug -__all__ = ['DjangoDebugMiddleware', 'DjangoDebug'] +__all__ = ["DjangoDebugMiddleware", "DjangoDebug"] diff --git a/graphene_django/debug/middleware.py b/graphene_django/debug/middleware.py index 2b11f7e..48d471f 100644 --- a/graphene_django/debug/middleware.py +++ b/graphene_django/debug/middleware.py @@ -7,7 +7,6 @@ from .types import DjangoDebug class DjangoDebugContext(object): - def __init__(self): self.debug_promise = None self.promises = [] @@ -38,20 +37,21 @@ class DjangoDebugContext(object): class DjangoDebugMiddleware(object): - def resolve(self, next, root, info, **args): context = info.context - django_debug = getattr(context, 'django_debug', None) + django_debug = getattr(context, "django_debug", None) if not django_debug: if context is None: - raise Exception('DjangoDebug cannot be executed in None contexts') + raise Exception("DjangoDebug cannot be executed in None contexts") try: context.django_debug = DjangoDebugContext() except Exception: - raise Exception('DjangoDebug need the context to be writable, context received: {}.'.format( - context.__class__.__name__ - )) - if info.schema.get_type('DjangoDebug') == info.return_type: + raise Exception( + "DjangoDebug need the context to be writable, context received: {}.".format( + context.__class__.__name__ + ) + ) + if info.schema.get_type("DjangoDebug") == info.return_type: return context.django_debug.get_debug_promise() promise = next(root, info, **args) context.django_debug.add_promise(promise) diff --git a/graphene_django/debug/sql/tracking.py b/graphene_django/debug/sql/tracking.py index 9d14e4b..f96583b 100644 --- a/graphene_django/debug/sql/tracking.py +++ b/graphene_django/debug/sql/tracking.py @@ -16,7 +16,6 @@ class SQLQueryTriggered(Exception): class ThreadLocalState(local): - def __init__(self): self.enabled = True @@ -35,7 +34,7 @@ recording = state.recording # export function def wrap_cursor(connection, panel): - if not hasattr(connection, '_graphene_cursor'): + if not hasattr(connection, "_graphene_cursor"): connection._graphene_cursor = connection.cursor def cursor(): @@ -46,7 +45,7 @@ def wrap_cursor(connection, panel): def unwrap_cursor(connection): - if hasattr(connection, '_graphene_cursor'): + if hasattr(connection, "_graphene_cursor"): previous_cursor = connection._graphene_cursor connection.cursor = previous_cursor del connection._graphene_cursor @@ -87,15 +86,14 @@ class NormalCursorWrapper(object): if not params: return params if isinstance(params, dict): - return dict((key, self._quote_expr(value)) - for key, value in params.items()) + return dict((key, self._quote_expr(value)) for key, value in params.items()) return list(map(self._quote_expr, params)) def _decode(self, param): try: return force_text(param, strings_only=True) except UnicodeDecodeError: - return '(encoded string)' + return "(encoded string)" def _record(self, method, sql, params): start_time = time() @@ -103,45 +101,48 @@ class NormalCursorWrapper(object): return method(sql, params) finally: stop_time = time() - duration = (stop_time - start_time) - _params = '' + duration = stop_time - start_time + _params = "" try: _params = json.dumps(list(map(self._decode, params))) except Exception: pass # object not JSON serializable - alias = getattr(self.db, 'alias', 'default') + alias = getattr(self.db, "alias", "default") conn = self.db.connection - vendor = getattr(conn, 'vendor', 'unknown') + vendor = getattr(conn, "vendor", "unknown") params = { - 'vendor': vendor, - 'alias': alias, - 'sql': self.db.ops.last_executed_query( - self.cursor, sql, self._quote_params(params)), - 'duration': duration, - 'raw_sql': sql, - 'params': _params, - 'start_time': start_time, - 'stop_time': stop_time, - 'is_slow': duration > 10, - 'is_select': sql.lower().strip().startswith('select'), + "vendor": vendor, + "alias": alias, + "sql": self.db.ops.last_executed_query( + self.cursor, sql, self._quote_params(params) + ), + "duration": duration, + "raw_sql": sql, + "params": _params, + "start_time": start_time, + "stop_time": stop_time, + "is_slow": duration > 10, + "is_select": sql.lower().strip().startswith("select"), } - if vendor == 'postgresql': + if vendor == "postgresql": # If an erroneous query was ran on the connection, it might # be in a state where checking isolation_level raises an # exception. try: iso_level = conn.isolation_level except conn.InternalError: - iso_level = 'unknown' - params.update({ - 'trans_id': self.logger.get_transaction_id(alias), - 'trans_status': conn.get_transaction_status(), - 'iso_level': iso_level, - 'encoding': conn.encoding, - }) + iso_level = "unknown" + params.update( + { + "trans_id": self.logger.get_transaction_id(alias), + "trans_status": conn.get_transaction_status(), + "iso_level": iso_level, + "encoding": conn.encoding, + } + ) _sql = DjangoDebugSQL(**params) # We keep `sql` to maintain backwards compatibility diff --git a/graphene_django/debug/tests/test_query.py b/graphene_django/debug/tests/test_query.py index 72747b2..f2ef096 100644 --- a/graphene_django/debug/tests/test_query.py +++ b/graphene_django/debug/tests/test_query.py @@ -12,31 +12,31 @@ from ..types import DjangoDebug class context(object): pass + # from examples.starwars_django.models import Character pytestmark = pytest.mark.django_db def test_should_query_field(): - r1 = Reporter(last_name='ABA') + r1 = Reporter(last_name="ABA") r1.save() - r2 = Reporter(last_name='Griffin') + r2 = Reporter(last_name="Griffin") r2.save() class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(graphene.ObjectType): reporter = graphene.Field(ReporterType) - debug = graphene.Field(DjangoDebug, name='__debug') + debug = graphene.Field(DjangoDebug, name="__debug") def resolve_reporter(self, info, **args): return Reporter.objects.first() - query = ''' + query = """ query ReporterQuery { reporter { lastName @@ -47,43 +47,40 @@ def test_should_query_field(): } } } - ''' + """ expected = { - 'reporter': { - 'lastName': 'ABA', + "reporter": {"lastName": "ABA"}, + "__debug": { + "sql": [{"rawSql": str(Reporter.objects.order_by("pk")[:1].query)}] }, - '__debug': { - 'sql': [{ - 'rawSql': str(Reporter.objects.order_by('pk')[:1].query) - }] - } } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()]) + result = schema.execute( + query, context_value=context(), middleware=[DjangoDebugMiddleware()] + ) assert not result.errors assert result.data == expected def test_should_query_list(): - r1 = Reporter(last_name='ABA') + r1 = Reporter(last_name="ABA") r1.save() - r2 = Reporter(last_name='Griffin') + r2 = Reporter(last_name="Griffin") r2.save() class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(graphene.ObjectType): all_reporters = graphene.List(ReporterType) - debug = graphene.Field(DjangoDebug, name='__debug') + debug = graphene.Field(DjangoDebug, name="__debug") def resolve_all_reporters(self, info, **args): return Reporter.objects.all() - query = ''' + query = """ query ReporterQuery { allReporters { lastName @@ -94,45 +91,38 @@ def test_should_query_list(): } } } - ''' + """ expected = { - 'allReporters': [{ - 'lastName': 'ABA', - }, { - 'lastName': 'Griffin', - }], - '__debug': { - 'sql': [{ - 'rawSql': str(Reporter.objects.all().query) - }] - } + "allReporters": [{"lastName": "ABA"}, {"lastName": "Griffin"}], + "__debug": {"sql": [{"rawSql": str(Reporter.objects.all().query)}]}, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()]) + result = schema.execute( + query, context_value=context(), middleware=[DjangoDebugMiddleware()] + ) assert not result.errors assert result.data == expected def test_should_query_connection(): - r1 = Reporter(last_name='ABA') + r1 = Reporter(last_name="ABA") r1.save() - r2 = Reporter(last_name='Griffin') + r2 = Reporter(last_name="Griffin") r2.save() class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) - debug = graphene.Field(DjangoDebug, name='__debug') + debug = graphene.Field(DjangoDebug, name="__debug") def resolve_all_reporters(self, info, **args): return Reporter.objects.all() - query = ''' + query = """ query ReporterQuery { allReporters(first:1) { edges { @@ -147,48 +137,41 @@ def test_should_query_connection(): } } } - ''' - expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'lastName': 'ABA', - } - }] - }, - } + """ + expected = {"allReporters": {"edges": [{"node": {"lastName": "ABA"}}]}} schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()]) + result = schema.execute( + query, context_value=context(), middleware=[DjangoDebugMiddleware()] + ) assert not result.errors - assert result.data['allReporters'] == expected['allReporters'] - assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql'] + assert result.data["allReporters"] == expected["allReporters"] + assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"] query = str(Reporter.objects.all()[:1].query) - assert result.data['__debug']['sql'][1]['rawSql'] == query + assert result.data["__debug"]["sql"][1]["rawSql"] == query def test_should_query_connectionfilter(): from ...filter import DjangoFilterConnectionField - r1 = Reporter(last_name='ABA') + r1 = Reporter(last_name="ABA") r1.save() - r2 = Reporter(last_name='Griffin') + r2 = Reporter(last_name="Griffin") r2.save() class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(graphene.ObjectType): - all_reporters = DjangoFilterConnectionField(ReporterType, fields=['last_name']) + all_reporters = DjangoFilterConnectionField(ReporterType, fields=["last_name"]) s = graphene.String(resolver=lambda *_: "S") - debug = graphene.Field(DjangoDebug, name='__debug') + debug = graphene.Field(DjangoDebug, name="__debug") def resolve_all_reporters(self, info, **args): return Reporter.objects.all() - query = ''' + query = """ query ReporterQuery { allReporters(first:1) { edges { @@ -203,20 +186,14 @@ def test_should_query_connectionfilter(): } } } - ''' - expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'lastName': 'ABA', - } - }] - }, - } + """ + expected = {"allReporters": {"edges": [{"node": {"lastName": "ABA"}}]}} schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()]) + result = schema.execute( + query, context_value=context(), middleware=[DjangoDebugMiddleware()] + ) assert not result.errors - assert result.data['allReporters'] == expected['allReporters'] - assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql'] + assert result.data["allReporters"] == expected["allReporters"] + assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"] query = str(Reporter.objects.all()[:1].query) - assert result.data['__debug']['sql'][1]['rawSql'] == query + assert result.data["__debug"]["sql"][1]["rawSql"] == query diff --git a/graphene_django/fields.py b/graphene_django/fields.py index 5576454..1ecce45 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -13,7 +13,6 @@ from .utils import maybe_queryset class DjangoListField(Field): - def __init__(self, _type, *args, **kwargs): super(DjangoListField, self).__init__(List(_type), *args, **kwargs) @@ -30,25 +29,28 @@ class DjangoListField(Field): class DjangoConnectionField(ConnectionField): - def __init__(self, *args, **kwargs): - self.on = kwargs.pop('on', False) + self.on = kwargs.pop("on", False) self.max_limit = kwargs.pop( - 'max_limit', - graphene_settings.RELAY_CONNECTION_MAX_LIMIT + "max_limit", graphene_settings.RELAY_CONNECTION_MAX_LIMIT ) self.enforce_first_or_last = kwargs.pop( - 'enforce_first_or_last', - graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST + "enforce_first_or_last", + graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST, ) super(DjangoConnectionField, self).__init__(*args, **kwargs) @property def type(self): from .types import DjangoObjectType + _type = super(ConnectionField, self).type - assert issubclass(_type, DjangoObjectType), "DjangoConnectionField only accepts DjangoObjectType types" - assert _type._meta.connection, "The type {} doesn't have a connection".format(_type.__name__) + assert issubclass( + _type, DjangoObjectType + ), "DjangoConnectionField only accepts DjangoObjectType types" + assert _type._meta.connection, "The type {} doesn't have a connection".format( + _type.__name__ + ) return _type._meta.connection @property @@ -100,28 +102,37 @@ class DjangoConnectionField(ConnectionField): return connection @classmethod - def connection_resolver(cls, resolver, connection, default_manager, max_limit, - enforce_first_or_last, root, info, **args): - first = args.get('first') - last = args.get('last') + def connection_resolver( + cls, + resolver, + connection, + default_manager, + max_limit, + enforce_first_or_last, + root, + info, + **args + ): + first = args.get("first") + last = args.get("last") if enforce_first_or_last: assert first or last, ( - 'You must provide a `first` or `last` value to properly paginate the `{}` connection.' + "You must provide a `first` or `last` value to properly paginate the `{}` connection." ).format(info.field_name) if max_limit: if first: assert first <= max_limit, ( - 'Requesting {} records on the `{}` connection exceeds the `first` limit of {} records.' + "Requesting {} records on the `{}` connection exceeds the `first` limit of {} records." ).format(first, info.field_name, max_limit) - args['first'] = min(first, max_limit) + args["first"] = min(first, max_limit) if last: assert last <= max_limit, ( - 'Requesting {} records on the `{}` connection exceeds the `last` limit of {} records.' + "Requesting {} records on the `{}` connection exceeds the `last` limit of {} records." ).format(last, info.field_name, max_limit) - args['last'] = min(last, max_limit) + args["last"] = min(last, max_limit) iterable = resolver(root, info, **args) on_resolve = partial(cls.resolve_connection, connection, default_manager, args) @@ -138,5 +149,5 @@ class DjangoConnectionField(ConnectionField): self.type, self.get_manager(), self.max_limit, - self.enforce_first_or_last + self.enforce_first_or_last, ) diff --git a/graphene_django/filter/__init__.py b/graphene_django/filter/__init__.py index 24fae60..daafe56 100644 --- a/graphene_django/filter/__init__.py +++ b/graphene_django/filter/__init__.py @@ -4,11 +4,15 @@ from ..utils import DJANGO_FILTER_INSTALLED if not DJANGO_FILTER_INSTALLED: warnings.warn( "Use of django filtering requires the django-filter package " - "be installed. You can do so using `pip install django-filter`", ImportWarning + "be installed. You can do so using `pip install django-filter`", + ImportWarning, ) else: from .fields import DjangoFilterConnectionField from .filterset import GlobalIDFilter, GlobalIDMultipleChoiceFilter - __all__ = ['DjangoFilterConnectionField', - 'GlobalIDFilter', 'GlobalIDMultipleChoiceFilter'] + __all__ = [ + "DjangoFilterConnectionField", + "GlobalIDFilter", + "GlobalIDMultipleChoiceFilter", + ] diff --git a/graphene_django/filter/fields.py b/graphene_django/filter/fields.py index 06e81c2..cb42543 100644 --- a/graphene_django/filter/fields.py +++ b/graphene_django/filter/fields.py @@ -7,10 +7,16 @@ from .utils import get_filtering_args_from_filterset, get_filterset_class class DjangoFilterConnectionField(DjangoConnectionField): - - def __init__(self, type, fields=None, order_by=None, - extra_filter_meta=None, filterset_class=None, - *args, **kwargs): + def __init__( + self, + type, + fields=None, + order_by=None, + extra_filter_meta=None, + filterset_class=None, + *args, + **kwargs + ): self._fields = fields self._provided_filterset_class = filterset_class self._filterset_class = None @@ -30,12 +36,13 @@ class DjangoFilterConnectionField(DjangoConnectionField): def filterset_class(self): if not self._filterset_class: fields = self._fields or self.node_type._meta.filter_fields - meta = dict(model=self.model, - fields=fields) + meta = dict(model=self.model, fields=fields) if self._extra_filter_meta: meta.update(self._extra_filter_meta) - self._filterset_class = get_filterset_class(self._provided_filterset_class, **meta) + self._filterset_class = get_filterset_class( + self._provided_filterset_class, **meta + ) return self._filterset_class @@ -52,28 +59,40 @@ class DjangoFilterConnectionField(DjangoConnectionField): # See related PR: https://github.com/graphql-python/graphene-django/pull/126 - assert not (default_queryset.query.low_mark and queryset.query.low_mark), ( - 'Received two sliced querysets (low mark) in the connection, please slice only in one.' - ) - assert not (default_queryset.query.high_mark and queryset.query.high_mark), ( - 'Received two sliced querysets (high mark) in the connection, please slice only in one.' - ) + assert not ( + default_queryset.query.low_mark and queryset.query.low_mark + ), "Received two sliced querysets (low mark) in the connection, please slice only in one." + assert not ( + default_queryset.query.high_mark and queryset.query.high_mark + ), "Received two sliced querysets (high mark) in the connection, please slice only in one." low = default_queryset.query.low_mark or queryset.query.low_mark high = default_queryset.query.high_mark or queryset.query.high_mark default_queryset.query.clear_limits() - queryset = super(DjangoFilterConnectionField, cls).merge_querysets(default_queryset, queryset) + queryset = super(DjangoFilterConnectionField, cls).merge_querysets( + default_queryset, queryset + ) queryset.query.set_limits(low, high) return queryset @classmethod - def connection_resolver(cls, resolver, connection, default_manager, max_limit, - enforce_first_or_last, filterset_class, filtering_args, - root, info, **args): + def connection_resolver( + cls, + resolver, + connection, + default_manager, + max_limit, + enforce_first_or_last, + filterset_class, + filtering_args, + root, + info, + **args + ): filter_kwargs = {k: v for k, v in args.items() if k in filtering_args} qs = filterset_class( data=filter_kwargs, queryset=default_manager.get_queryset(), - request=info.context + request=info.context, ).qs return super(DjangoFilterConnectionField, cls).connection_resolver( @@ -96,5 +115,5 @@ class DjangoFilterConnectionField(DjangoConnectionField): self.max_limit, self.enforce_first_or_last, self.filterset_class, - self.filtering_args + self.filtering_args, ) diff --git a/graphene_django/filter/filterset.py b/graphene_django/filter/filterset.py index 0e33fe6..29a275d 100644 --- a/graphene_django/filter/filterset.py +++ b/graphene_django/filter/filterset.py @@ -28,26 +28,19 @@ class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter): GRAPHENE_FILTER_SET_OVERRIDES = { - models.AutoField: { - 'filter_class': GlobalIDFilter, - }, - models.OneToOneField: { - 'filter_class': GlobalIDFilter, - }, - models.ForeignKey: { - 'filter_class': GlobalIDFilter, - }, - models.ManyToManyField: { - 'filter_class': GlobalIDMultipleChoiceFilter, - } + models.AutoField: {"filter_class": GlobalIDFilter}, + models.OneToOneField: {"filter_class": GlobalIDFilter}, + models.ForeignKey: {"filter_class": GlobalIDFilter}, + models.ManyToManyField: {"filter_class": GlobalIDMultipleChoiceFilter}, } class GrapheneFilterSetMixin(BaseFilterSet): - FILTER_DEFAULTS = dict(itertools.chain( - FILTER_FOR_DBFIELD_DEFAULTS.items(), - GRAPHENE_FILTER_SET_OVERRIDES.items() - )) + FILTER_DEFAULTS = dict( + itertools.chain( + FILTER_FOR_DBFIELD_DEFAULTS.items(), GRAPHENE_FILTER_SET_OVERRIDES.items() + ) + ) @classmethod def filter_for_reverse_field(cls, f, name): @@ -62,10 +55,7 @@ class GrapheneFilterSetMixin(BaseFilterSet): except AttributeError: rel = f.field.rel - default = { - 'name': name, - 'label': capfirst(rel.related_name) - } + default = {"name": name, "label": capfirst(rel.related_name)} if rel.multiple: # For to-many relationships return GlobalIDMultipleChoiceFilter(**default) @@ -78,25 +68,20 @@ def setup_filterset(filterset_class): """ Wrap a provided filterset in Graphene-specific functionality """ return type( - 'Graphene{}'.format(filterset_class.__name__), + "Graphene{}".format(filterset_class.__name__), (filterset_class, GrapheneFilterSetMixin), {}, ) -def custom_filterset_factory(model, filterset_base_class=FilterSet, - **meta): +def custom_filterset_factory(model, filterset_base_class=FilterSet, **meta): """ Create a filterset for the given model using the provided meta data """ - meta.update({ - 'model': model, - }) - meta_class = type(str('Meta'), (object,), meta) + meta.update({"model": model}) + meta_class = type(str("Meta"), (object,), meta) filterset = type( - str('%sFilterSet' % model._meta.object_name), + str("%sFilterSet" % model._meta.object_name), (filterset_base_class, GrapheneFilterSetMixin), - { - 'Meta': meta_class - } + {"Meta": meta_class}, ) return filterset diff --git a/graphene_django/filter/tests/filters.py b/graphene_django/filter/tests/filters.py index 4a3fbaa..359d2ba 100644 --- a/graphene_django/filter/tests/filters.py +++ b/graphene_django/filter/tests/filters.py @@ -5,29 +5,26 @@ from graphene_django.tests.models import Article, Pet, Reporter class ArticleFilter(django_filters.FilterSet): - class Meta: model = Article fields = { - 'headline': ['exact', 'icontains'], - 'pub_date': ['gt', 'lt', 'exact'], - 'reporter': ['exact'], + "headline": ["exact", "icontains"], + "pub_date": ["gt", "lt", "exact"], + "reporter": ["exact"], } - order_by = OrderingFilter(fields=('pub_date',)) + order_by = OrderingFilter(fields=("pub_date",)) class ReporterFilter(django_filters.FilterSet): - class Meta: model = Reporter - fields = ['first_name', 'last_name', 'email', 'pets'] + fields = ["first_name", "last_name", "email", "pets"] - order_by = OrderingFilter(fields=('pub_date',)) + order_by = OrderingFilter(fields=("pub_date",)) class PetFilter(django_filters.FilterSet): - class Meta: model = Pet - fields = ['name'] + fields = ["name"] diff --git a/graphene_django/filter/tests/test_fields.py b/graphene_django/filter/tests/test_fields.py index c730ef3..f9ef0ae 100644 --- a/graphene_django/filter/tests/test_fields.py +++ b/graphene_django/filter/tests/test_fields.py @@ -5,8 +5,7 @@ import pytest from graphene import Field, ObjectType, Schema, Argument, Float, Boolean, String from graphene.relay import Node from graphene_django import DjangoObjectType -from graphene_django.forms import (GlobalIDFormField, - GlobalIDMultipleChoiceField) +from graphene_django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField from graphene_django.tests.models import Article, Pet, Reporter from graphene_django.utils import DJANGO_FILTER_INSTALLED @@ -20,36 +19,43 @@ if DJANGO_FILTER_INSTALLED: import django_filters from django_filters import FilterSet, NumberFilter - from graphene_django.filter import (GlobalIDFilter, DjangoFilterConnectionField, - GlobalIDMultipleChoiceFilter) - from graphene_django.filter.tests.filters import ArticleFilter, PetFilter, ReporterFilter + from graphene_django.filter import ( + GlobalIDFilter, + DjangoFilterConnectionField, + GlobalIDMultipleChoiceFilter, + ) + from graphene_django.filter.tests.filters import ( + ArticleFilter, + PetFilter, + ReporterFilter, + ) else: - pytestmark.append(pytest.mark.skipif(True, reason='django_filters not installed or not compatible')) + pytestmark.append( + pytest.mark.skipif( + True, reason="django_filters not installed or not compatible" + ) + ) pytestmark.append(pytest.mark.django_db) if DJANGO_FILTER_INSTALLED: - class ArticleNode(DjangoObjectType): + class ArticleNode(DjangoObjectType): class Meta: model = Article - interfaces = (Node, ) - filter_fields = ('headline', ) - + interfaces = (Node,) + filter_fields = ("headline",) class ReporterNode(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) - + interfaces = (Node,) class PetNode(DjangoObjectType): - class Meta: model = Pet - interfaces = (Node, ) + interfaces = (Node,) # schema = Schema() @@ -59,58 +65,47 @@ def get_args(field): def assert_arguments(field, *arguments): - ignore = ('after', 'before', 'first', 'last', 'order_by') + ignore = ("after", "before", "first", "last", "order_by") args = get_args(field) - actual = [ - name - for name in args - if name not in ignore and not name.startswith('_') - ] - assert set(arguments) == set(actual), \ - 'Expected arguments ({}) did not match actual ({})'.format( - arguments, - actual - ) + actual = [name for name in args if name not in ignore and not name.startswith("_")] + assert set(arguments) == set( + actual + ), "Expected arguments ({}) did not match actual ({})".format(arguments, actual) def assert_orderable(field): args = get_args(field) - assert 'order_by' in args, \ - 'Field cannot be ordered' + assert "order_by" in args, "Field cannot be ordered" def assert_not_orderable(field): args = get_args(field) - assert 'order_by' not in args, \ - 'Field can be ordered' + assert "order_by" not in args, "Field can be ordered" def test_filter_explicit_filterset_arguments(): field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleFilter) - assert_arguments(field, - 'headline', 'headline__icontains', - 'pub_date', 'pub_date__gt', 'pub_date__lt', - 'reporter', - ) + assert_arguments( + field, + "headline", + "headline__icontains", + "pub_date", + "pub_date__gt", + "pub_date__lt", + "reporter", + ) def test_filter_shortcut_filterset_arguments_list(): - field = DjangoFilterConnectionField(ArticleNode, fields=['pub_date', 'reporter']) - assert_arguments(field, - 'pub_date', - 'reporter', - ) + field = DjangoFilterConnectionField(ArticleNode, fields=["pub_date", "reporter"]) + assert_arguments(field, "pub_date", "reporter") def test_filter_shortcut_filterset_arguments_dict(): - field = DjangoFilterConnectionField(ArticleNode, fields={ - 'headline': ['exact', 'icontains'], - 'reporter': ['exact'], - }) - assert_arguments(field, - 'headline', 'headline__icontains', - 'reporter', - ) + field = DjangoFilterConnectionField( + ArticleNode, fields={"headline": ["exact", "icontains"], "reporter": ["exact"]} + ) + assert_arguments(field, "headline", "headline__icontains", "reporter") def test_filter_explicit_filterset_orderable(): @@ -134,15 +129,14 @@ def test_filter_explicit_filterset_not_orderable(): def test_filter_shortcut_filterset_extra_meta(): - field = DjangoFilterConnectionField(ArticleNode, extra_filter_meta={ - 'exclude': ('headline', ) - }) - assert 'headline' not in field.filterset_class.get_fields() + field = DjangoFilterConnectionField( + ArticleNode, extra_filter_meta={"exclude": ("headline",)} + ) + assert "headline" not in field.filterset_class.get_fields() def test_filter_shortcut_filterset_context(): class ArticleContextFilter(django_filters.FilterSet): - class Meta: model = Article exclude = set() @@ -153,17 +147,31 @@ def test_filter_shortcut_filterset_context(): return qs.filter(reporter=self.request.reporter) class Query(ObjectType): - context_articles = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleContextFilter) + context_articles = DjangoFilterConnectionField( + ArticleNode, filterset_class=ArticleContextFilter + ) - r1 = Reporter.objects.create(first_name='r1', last_name='r1', email='r1@test.com') - r2 = Reporter.objects.create(first_name='r2', last_name='r2', email='r2@test.com') - Article.objects.create(headline='a1', pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r1, editor=r1) - Article.objects.create(headline='a2', pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r2, editor=r2) + r1 = Reporter.objects.create(first_name="r1", last_name="r1", email="r1@test.com") + r2 = Reporter.objects.create(first_name="r2", last_name="r2", email="r2@test.com") + Article.objects.create( + headline="a1", + pub_date=datetime.now(), + pub_date_time=datetime.now(), + reporter=r1, + editor=r1, + ) + Article.objects.create( + headline="a2", + pub_date=datetime.now(), + pub_date_time=datetime.now(), + reporter=r2, + editor=r2, + ) class context(object): reporter = r2 - query = ''' + query = """ query { contextArticles { edges { @@ -173,42 +181,39 @@ def test_filter_shortcut_filterset_context(): } } } - ''' + """ schema = Schema(query=Query) result = schema.execute(query, context_value=context()) assert not result.errors - assert len(result.data['contextArticles']['edges']) == 1 - assert result.data['contextArticles']['edges'][0]['node']['headline'] == 'a2' + assert len(result.data["contextArticles"]["edges"]) == 1 + assert result.data["contextArticles"]["edges"][0]["node"]["headline"] == "a2" def test_filter_filterset_information_on_meta(): class ReporterFilterNode(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) - filter_fields = ['first_name', 'articles'] + interfaces = (Node,) + filter_fields = ["first_name", "articles"] field = DjangoFilterConnectionField(ReporterFilterNode) - assert_arguments(field, 'first_name', 'articles') + assert_arguments(field, "first_name", "articles") assert_not_orderable(field) def test_filter_filterset_information_on_meta_related(): class ReporterFilterNode(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) - filter_fields = ['first_name', 'articles'] + interfaces = (Node,) + filter_fields = ["first_name", "articles"] class ArticleFilterNode(DjangoObjectType): - class Meta: model = Article - interfaces = (Node, ) - filter_fields = ['headline', 'reporter'] + interfaces = (Node,) + filter_fields = ["headline", "reporter"] class Query(ObjectType): all_reporters = DjangoFilterConnectionField(ReporterFilterNode) @@ -217,25 +222,23 @@ def test_filter_filterset_information_on_meta_related(): article = Field(ArticleFilterNode) schema = Schema(query=Query) - articles_field = ReporterFilterNode._meta.fields['articles'].get_type() - assert_arguments(articles_field, 'headline', 'reporter') + articles_field = ReporterFilterNode._meta.fields["articles"].get_type() + assert_arguments(articles_field, "headline", "reporter") assert_not_orderable(articles_field) def test_filter_filterset_related_results(): class ReporterFilterNode(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) - filter_fields = ['first_name', 'articles'] + interfaces = (Node,) + filter_fields = ["first_name", "articles"] class ArticleFilterNode(DjangoObjectType): - class Meta: - interfaces = (Node, ) + interfaces = (Node,) model = Article - filter_fields = ['headline', 'reporter'] + filter_fields = ["headline", "reporter"] class Query(ObjectType): all_reporters = DjangoFilterConnectionField(ReporterFilterNode) @@ -243,12 +246,22 @@ def test_filter_filterset_related_results(): reporter = Field(ReporterFilterNode) article = Field(ArticleFilterNode) - r1 = Reporter.objects.create(first_name='r1', last_name='r1', email='r1@test.com') - r2 = Reporter.objects.create(first_name='r2', last_name='r2', email='r2@test.com') - Article.objects.create(headline='a1', pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r1) - Article.objects.create(headline='a2', pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r2) + r1 = Reporter.objects.create(first_name="r1", last_name="r1", email="r1@test.com") + r2 = Reporter.objects.create(first_name="r2", last_name="r2", email="r2@test.com") + Article.objects.create( + headline="a1", + pub_date=datetime.now(), + pub_date_time=datetime.now(), + reporter=r1, + ) + Article.objects.create( + headline="a2", + pub_date=datetime.now(), + pub_date_time=datetime.now(), + reporter=r2, + ) - query = ''' + query = """ query { allReporters { edges { @@ -264,123 +277,134 @@ def test_filter_filterset_related_results(): } } } - ''' + """ schema = Schema(query=Query) result = schema.execute(query) assert not result.errors # We should only get back a single article for each reporter - assert len(result.data['allReporters']['edges'][0]['node']['articles']['edges']) == 1 - assert len(result.data['allReporters']['edges'][1]['node']['articles']['edges']) == 1 + assert ( + len(result.data["allReporters"]["edges"][0]["node"]["articles"]["edges"]) == 1 + ) + assert ( + len(result.data["allReporters"]["edges"][1]["node"]["articles"]["edges"]) == 1 + ) def test_global_id_field_implicit(): - field = DjangoFilterConnectionField(ArticleNode, fields=['id']) + field = DjangoFilterConnectionField(ArticleNode, fields=["id"]) filterset_class = field.filterset_class - id_filter = filterset_class.base_filters['id'] + id_filter = filterset_class.base_filters["id"] assert isinstance(id_filter, GlobalIDFilter) assert id_filter.field_class == GlobalIDFormField def test_global_id_field_explicit(): class ArticleIdFilter(django_filters.FilterSet): - class Meta: model = Article - fields = ['id'] + fields = ["id"] field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter) filterset_class = field.filterset_class - id_filter = filterset_class.base_filters['id'] + id_filter = filterset_class.base_filters["id"] assert isinstance(id_filter, GlobalIDFilter) assert id_filter.field_class == GlobalIDFormField def test_filterset_descriptions(): class ArticleIdFilter(django_filters.FilterSet): - class Meta: model = Article - fields = ['id'] + fields = ["id"] - max_time = django_filters.NumberFilter(method='filter_max_time', label="The maximum time") + max_time = django_filters.NumberFilter( + method="filter_max_time", label="The maximum time" + ) field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter) - max_time = field.args['max_time'] + max_time = field.args["max_time"] assert isinstance(max_time, Argument) assert max_time.type == Float - assert max_time.description == 'The maximum time' + assert max_time.description == "The maximum time" def test_global_id_field_relation(): - field = DjangoFilterConnectionField(ArticleNode, fields=['reporter']) + field = DjangoFilterConnectionField(ArticleNode, fields=["reporter"]) filterset_class = field.filterset_class - id_filter = filterset_class.base_filters['reporter'] + id_filter = filterset_class.base_filters["reporter"] assert isinstance(id_filter, GlobalIDFilter) assert id_filter.field_class == GlobalIDFormField def test_global_id_multiple_field_implicit(): - field = DjangoFilterConnectionField(ReporterNode, fields=['pets']) + field = DjangoFilterConnectionField(ReporterNode, fields=["pets"]) filterset_class = field.filterset_class - multiple_filter = filterset_class.base_filters['pets'] + multiple_filter = filterset_class.base_filters["pets"] assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter) assert multiple_filter.field_class == GlobalIDMultipleChoiceField def test_global_id_multiple_field_explicit(): class ReporterPetsFilter(django_filters.FilterSet): - class Meta: model = Reporter - fields = ['pets'] + fields = ["pets"] - field = DjangoFilterConnectionField(ReporterNode, filterset_class=ReporterPetsFilter) + field = DjangoFilterConnectionField( + ReporterNode, filterset_class=ReporterPetsFilter + ) filterset_class = field.filterset_class - multiple_filter = filterset_class.base_filters['pets'] + multiple_filter = filterset_class.base_filters["pets"] assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter) assert multiple_filter.field_class == GlobalIDMultipleChoiceField def test_global_id_multiple_field_implicit_reverse(): - field = DjangoFilterConnectionField(ReporterNode, fields=['articles']) + field = DjangoFilterConnectionField(ReporterNode, fields=["articles"]) filterset_class = field.filterset_class - multiple_filter = filterset_class.base_filters['articles'] + multiple_filter = filterset_class.base_filters["articles"] assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter) assert multiple_filter.field_class == GlobalIDMultipleChoiceField def test_global_id_multiple_field_explicit_reverse(): class ReporterPetsFilter(django_filters.FilterSet): - class Meta: model = Reporter - fields = ['articles'] + fields = ["articles"] - field = DjangoFilterConnectionField(ReporterNode, filterset_class=ReporterPetsFilter) + field = DjangoFilterConnectionField( + ReporterNode, filterset_class=ReporterPetsFilter + ) filterset_class = field.filterset_class - multiple_filter = filterset_class.base_filters['articles'] + multiple_filter = filterset_class.base_filters["articles"] assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter) assert multiple_filter.field_class == GlobalIDMultipleChoiceField def test_filter_filterset_related_results(): class ReporterFilterNode(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) - filter_fields = { - 'first_name': ['icontains'] - } + interfaces = (Node,) + filter_fields = {"first_name": ["icontains"]} class Query(ObjectType): all_reporters = DjangoFilterConnectionField(ReporterFilterNode) - r1 = Reporter.objects.create(first_name='A test user', last_name='Last Name', email='test1@test.com') - r2 = Reporter.objects.create(first_name='Other test user', last_name='Other Last Name', email='test2@test.com') - r3 = Reporter.objects.create(first_name='Random', last_name='RandomLast', email='random@test.com') + r1 = Reporter.objects.create( + first_name="A test user", last_name="Last Name", email="test1@test.com" + ) + r2 = Reporter.objects.create( + first_name="Other test user", + last_name="Other Last Name", + email="test2@test.com", + ) + r3 = Reporter.objects.create( + first_name="Random", last_name="RandomLast", email="random@test.com" + ) - query = ''' + query = """ query { allReporters(firstName_Icontains: "test") { edges { @@ -390,12 +414,12 @@ def test_filter_filterset_related_results(): } } } - ''' + """ schema = Schema(query=Query) result = schema.execute(query) assert not result.errors # We should only get two reporters - assert len(result.data['allReporters']['edges']) == 2 + assert len(result.data["allReporters"]["edges"]) == 2 def test_recursive_filter_connection(): @@ -407,79 +431,73 @@ def test_recursive_filter_connection(): class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(ObjectType): all_reporters = DjangoFilterConnectionField(ReporterFilterNode) - assert ReporterFilterNode._meta.fields['child_reporters'].node_type == ReporterFilterNode + assert ( + ReporterFilterNode._meta.fields["child_reporters"].node_type + == ReporterFilterNode + ) def test_should_query_filter_node_limit(): class ReporterFilter(FilterSet): - limit = NumberFilter(method='filter_limit') + limit = NumberFilter(method="filter_limit") def filter_limit(self, queryset, name, value): return queryset[:value] class Meta: model = Reporter - fields = ['first_name', ] + fields = ["first_name"] class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class ArticleType(DjangoObjectType): - class Meta: model = Article - interfaces = (Node, ) - filter_fields = ('lang', ) + interfaces = (Node,) + filter_fields = ("lang",) class Query(ObjectType): all_reporters = DjangoFilterConnectionField( - ReporterType, - filterset_class=ReporterFilter + ReporterType, filterset_class=ReporterFilter ) def resolve_all_reporters(self, info, **args): - return Reporter.objects.order_by('a_choice') + return Reporter.objects.order_by("a_choice") Reporter.objects.create( - first_name='Bob', - last_name='Doe', - email='bobdoe@example.com', - a_choice=2 + first_name="Bob", last_name="Doe", email="bobdoe@example.com", a_choice=2 ) r = Reporter.objects.create( - first_name='John', - last_name='Doe', - email='johndoe@example.com', - a_choice=1 + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 ) Article.objects.create( - headline='Article Node 1', + headline="Article Node 1", pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r, editor=r, - lang='es' + lang="es", ) Article.objects.create( - headline='Article Node 2', + headline="Article Node 2", pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r, editor=r, - lang='en' + lang="en", ) schema = Schema(query=Query) - query = ''' + query = """ query NodeFilteringQuery { allReporters(limit: 1) { edges { @@ -498,24 +516,23 @@ def test_should_query_filter_node_limit(): } } } - ''' + """ expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'id': 'UmVwb3J0ZXJUeXBlOjI=', - 'firstName': 'John', - 'articles': { - 'edges': [{ - 'node': { - 'id': 'QXJ0aWNsZVR5cGU6MQ==', - 'lang': 'ES' - } - }] + "allReporters": { + "edges": [ + { + "node": { + "id": "UmVwb3J0ZXJUeXBlOjI=", + "firstName": "John", + "articles": { + "edges": [ + {"node": {"id": "QXJ0aWNsZVR5cGU6MQ==", "lang": "ES"}} + ] + }, } } - }] + ] } } @@ -526,45 +543,37 @@ def test_should_query_filter_node_limit(): def test_should_query_filter_node_double_limit_raises(): class ReporterFilter(FilterSet): - limit = NumberFilter(method='filter_limit') + limit = NumberFilter(method="filter_limit") def filter_limit(self, queryset, name, value): return queryset[:value] class Meta: model = Reporter - fields = ['first_name', ] + fields = ["first_name"] class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(ObjectType): all_reporters = DjangoFilterConnectionField( - ReporterType, - filterset_class=ReporterFilter + ReporterType, filterset_class=ReporterFilter ) def resolve_all_reporters(self, info, **args): - return Reporter.objects.order_by('a_choice')[:2] + return Reporter.objects.order_by("a_choice")[:2] Reporter.objects.create( - first_name='Bob', - last_name='Doe', - email='bobdoe@example.com', - a_choice=2 + first_name="Bob", last_name="Doe", email="bobdoe@example.com", a_choice=2 ) r = Reporter.objects.create( - first_name='John', - last_name='Doe', - email='johndoe@example.com', - a_choice=1 + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 ) schema = Schema(query=Query) - query = ''' + query = """ query NodeFilteringQuery { allReporters(limit: 1) { edges { @@ -575,41 +584,40 @@ def test_should_query_filter_node_double_limit_raises(): } } } - ''' + """ result = schema.execute(query) assert len(result.errors) == 1 assert str(result.errors[0]) == ( - 'Received two sliced querysets (high mark) in the connection, please slice only in one.' + "Received two sliced querysets (high mark) in the connection, please slice only in one." ) + def test_order_by_is_perserved(): class ReporterType(DjangoObjectType): class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) filter_fields = () class Query(ObjectType): - all_reporters = DjangoFilterConnectionField(ReporterType, reverse_order=Boolean()) + all_reporters = DjangoFilterConnectionField( + ReporterType, reverse_order=Boolean() + ) def resolve_all_reporters(self, info, reverse_order=False, **args): - reporters = Reporter.objects.order_by('first_name') + reporters = Reporter.objects.order_by("first_name") if reverse_order: return reporters.reverse() - + return reporters - Reporter.objects.create( - first_name='b', - ) - r = Reporter.objects.create( - first_name='a', - ) + Reporter.objects.create(first_name="b") + r = Reporter.objects.create(first_name="a") schema = Schema(query=Query) - query = ''' + query = """ query NodeFilteringQuery { allReporters(first: 1) { edges { @@ -619,23 +627,14 @@ def test_order_by_is_perserved(): } } } - ''' - expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'firstName': 'a', - } - }] - } - } + """ + expected = {"allReporters": {"edges": [{"node": {"firstName": "a"}}]}} result = schema.execute(query) assert not result.errors assert result.data == expected - - reverse_query = ''' + reverse_query = """ query NodeFilteringQuery { allReporters(first: 1, reverseOrder: true) { edges { @@ -645,33 +644,26 @@ def test_order_by_is_perserved(): } } } - ''' + """ - reverse_expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'firstName': 'b', - } - }] - } - } + reverse_expected = {"allReporters": {"edges": [{"node": {"firstName": "b"}}]}} reverse_result = schema.execute(reverse_query) assert not reverse_result.errors assert reverse_result.data == reverse_expected + def test_annotation_is_perserved(): class ReporterType(DjangoObjectType): full_name = String() - + def resolve_full_name(instance, info, **args): return instance.full_name class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) filter_fields = () class Query(ObjectType): @@ -679,17 +671,16 @@ def test_annotation_is_perserved(): def resolve_all_reporters(self, info, **args): return Reporter.objects.annotate( - full_name=Concat('first_name', Value(' '), 'last_name', output_field=TextField()) + full_name=Concat( + "first_name", Value(" "), "last_name", output_field=TextField() + ) ) - Reporter.objects.create( - first_name='John', - last_name='Doe', - ) + Reporter.objects.create(first_name="John", last_name="Doe") schema = Schema(query=Query) - query = ''' + query = """ query NodeFilteringQuery { allReporters(first: 1) { edges { @@ -699,16 +690,8 @@ def test_annotation_is_perserved(): } } } - ''' - expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'fullName': 'John Doe', - } - }] - } - } + """ + expected = {"allReporters": {"edges": [{"node": {"fullName": "John Doe"}}]}} result = schema.execute(query) diff --git a/graphene_django/forms/converter.py b/graphene_django/forms/converter.py index 24cece7..87180b2 100644 --- a/graphene_django/forms/converter.py +++ b/graphene_django/forms/converter.py @@ -14,8 +14,7 @@ singledispatch = import_single_dispatch() def convert_form_field(field): raise ImproperlyConfigured( "Don't know how to convert the Django form field %s (%s) " - "to Graphene type" % - (field, field.__class__) + "to Graphene type" % (field, field.__class__) ) diff --git a/graphene_django/forms/forms.py b/graphene_django/forms/forms.py index a54f0a5..14e68c8 100644 --- a/graphene_django/forms/forms.py +++ b/graphene_django/forms/forms.py @@ -8,9 +8,7 @@ from graphql_relay import from_global_id class GlobalIDFormField(Field): - default_error_messages = { - 'invalid': _('Invalid ID specified.'), - } + default_error_messages = {"invalid": _("Invalid ID specified.")} def clean(self, value): if not value and not self.required: @@ -19,21 +17,21 @@ class GlobalIDFormField(Field): try: _type, _id = from_global_id(value) except (TypeError, ValueError, UnicodeDecodeError, binascii.Error): - raise ValidationError(self.error_messages['invalid']) + raise ValidationError(self.error_messages["invalid"]) try: CharField().clean(_id) CharField().clean(_type) except ValidationError: - raise ValidationError(self.error_messages['invalid']) + raise ValidationError(self.error_messages["invalid"]) return value class GlobalIDMultipleChoiceField(MultipleChoiceField): default_error_messages = { - 'invalid_choice': _('One of the specified IDs was invalid (%(value)s).'), - 'invalid_list': _('Enter a list of values.'), + "invalid_choice": _("One of the specified IDs was invalid (%(value)s)."), + "invalid_list": _("Enter a list of values."), } def valid_value(self, value): diff --git a/graphene_django/forms/mutation.py b/graphene_django/forms/mutation.py index f247546..701bac8 100644 --- a/graphene_django/forms/mutation.py +++ b/graphene_django/forms/mutation.py @@ -5,6 +5,7 @@ import graphene from graphene import Field, InputField from graphene.relay.mutation import ClientIDMutation from graphene.types.mutation import MutationOptions + # from graphene.types.inputobjecttype import ( # InputObjectTypeOptions, # InputObjectType, @@ -21,7 +22,8 @@ def fields_for_form(form, only_fields, exclude_fields): for name, field in form.fields.items(): is_not_in_only = only_fields and name not in only_fields is_excluded = ( - name in exclude_fields # or + name + in exclude_fields # or # name in already_created_fields ) @@ -57,12 +59,12 @@ class BaseDjangoFormMutation(ClientIDMutation): @classmethod def get_form_kwargs(cls, root, info, **input): - kwargs = {'data': input} + kwargs = {"data": input} - pk = input.pop('id', None) + pk = input.pop("id", None) if pk: instance = cls._meta.model._default_manager.get(pk=pk) - kwargs['instance'] = instance + kwargs["instance"] = instance return kwargs @@ -100,11 +102,12 @@ class DjangoFormMutation(BaseDjangoFormMutation): errors = graphene.List(ErrorType) @classmethod - def __init_subclass_with_meta__(cls, form_class=None, - only_fields=(), exclude_fields=(), **options): + def __init_subclass_with_meta__( + cls, form_class=None, only_fields=(), exclude_fields=(), **options + ): if not form_class: - raise Exception('form_class is required for DjangoFormMutation') + raise Exception("form_class is required for DjangoFormMutation") form = form_class() input_fields = fields_for_form(form, only_fields, exclude_fields) @@ -112,16 +115,12 @@ class DjangoFormMutation(BaseDjangoFormMutation): _meta = DjangoFormMutationOptions(cls) _meta.form_class = form_class - _meta.fields = yank_fields_from_attrs( - output_fields, - _as=Field, - ) + _meta.fields = yank_fields_from_attrs(output_fields, _as=Field) - input_fields = yank_fields_from_attrs( - input_fields, - _as=InputField, + input_fields = yank_fields_from_attrs(input_fields, _as=InputField) + super(DjangoFormMutation, cls).__init_subclass_with_meta__( + _meta=_meta, input_fields=input_fields, **options ) - super(DjangoFormMutation, cls).__init_subclass_with_meta__(_meta=_meta, input_fields=input_fields, **options) @classmethod def perform_mutate(cls, form, info): @@ -141,21 +140,28 @@ class DjangoModelFormMutation(BaseDjangoFormMutation): errors = graphene.List(ErrorType) @classmethod - def __init_subclass_with_meta__(cls, form_class=None, model=None, return_field_name=None, - only_fields=(), exclude_fields=(), **options): + def __init_subclass_with_meta__( + cls, + form_class=None, + model=None, + return_field_name=None, + only_fields=(), + exclude_fields=(), + **options + ): if not form_class: - raise Exception('form_class is required for DjangoModelFormMutation') + raise Exception("form_class is required for DjangoModelFormMutation") if not model: model = form_class._meta.model if not model: - raise Exception('model is required for DjangoModelFormMutation') + raise Exception("model is required for DjangoModelFormMutation") form = form_class() input_fields = fields_for_form(form, only_fields, exclude_fields) - input_fields['id'] = graphene.ID() + input_fields["id"] = graphene.ID() registry = get_global_registry() model_type = registry.get_type_for_model(model) @@ -171,19 +177,11 @@ class DjangoModelFormMutation(BaseDjangoFormMutation): _meta.form_class = form_class _meta.model = model _meta.return_field_name = return_field_name - _meta.fields = yank_fields_from_attrs( - output_fields, - _as=Field, - ) + _meta.fields = yank_fields_from_attrs(output_fields, _as=Field) - input_fields = yank_fields_from_attrs( - input_fields, - _as=InputField, - ) + input_fields = yank_fields_from_attrs(input_fields, _as=InputField) super(DjangoModelFormMutation, cls).__init_subclass_with_meta__( - _meta=_meta, - input_fields=input_fields, - **options + _meta=_meta, input_fields=input_fields, **options ) @classmethod diff --git a/graphene_django/forms/tests/test_converter.py b/graphene_django/forms/tests/test_converter.py index 86b5b80..955b952 100644 --- a/graphene_django/forms/tests/test_converter.py +++ b/graphene_django/forms/tests/test_converter.py @@ -2,24 +2,36 @@ from django import forms from py.test import raises import graphene -from graphene import String, Int, Boolean, Float, ID, UUID, List, NonNull, DateTime, Date, Time +from graphene import ( + String, + Int, + Boolean, + Float, + ID, + UUID, + List, + NonNull, + DateTime, + Date, + Time, +) from ..converter import convert_form_field def assert_conversion(django_field, graphene_field, *args): - field = django_field(*args, help_text='Custom Help Text') + field = django_field(*args, help_text="Custom Help Text") graphene_type = convert_form_field(field) assert isinstance(graphene_type, graphene_field) field = graphene_type.Field() - assert field.description == 'Custom Help Text' + assert field.description == "Custom Help Text" return field def test_should_unknown_django_field_raise_exception(): with raises(Exception) as excinfo: convert_form_field(None) - assert 'Don\'t know how to convert the Django form field' in str(excinfo.value) + assert "Don't know how to convert the Django form field" in str(excinfo.value) def test_should_date_convert_date(): @@ -59,11 +71,11 @@ def test_should_base_field_convert_string(): def test_should_regex_convert_string(): - assert_conversion(forms.RegexField, String, '[0-9]+') + assert_conversion(forms.RegexField, String, "[0-9]+") def test_should_uuid_convert_string(): - if hasattr(forms, 'UUIDField'): + if hasattr(forms, "UUIDField"): assert_conversion(forms.UUIDField, UUID) diff --git a/graphene_django/forms/tests/test_mutation.py b/graphene_django/forms/tests/test_mutation.py index 10a15ae..0404010 100644 --- a/graphene_django/forms/tests/test_mutation.py +++ b/graphene_django/forms/tests/test_mutation.py @@ -11,18 +11,18 @@ class MyForm(forms.Form): class PetForm(forms.ModelForm): - class Meta: model = Pet - fields = ('name',) + fields = ("name",) def test_needs_form_class(): with raises(Exception) as exc: + class MyMutation(DjangoFormMutation): pass - assert exc.value.args[0] == 'form_class is required for DjangoFormMutation' + assert exc.value.args[0] == "form_class is required for DjangoFormMutation" def test_has_output_fields(): @@ -30,7 +30,7 @@ def test_has_output_fields(): class Meta: form_class = MyForm - assert 'errors' in MyMutation._meta.fields + assert "errors" in MyMutation._meta.fields def test_has_input_fields(): @@ -38,19 +38,18 @@ def test_has_input_fields(): class Meta: form_class = MyForm - assert 'text' in MyMutation.Input._meta.fields + assert "text" in MyMutation.Input._meta.fields class ModelFormMutationTests(TestCase): - def test_default_meta_fields(self): class PetMutation(DjangoModelFormMutation): class Meta: form_class = PetForm self.assertEqual(PetMutation._meta.model, Pet) - self.assertEqual(PetMutation._meta.return_field_name, 'pet') - self.assertIn('pet', PetMutation._meta.fields) + self.assertEqual(PetMutation._meta.return_field_name, "pet") + self.assertIn("pet", PetMutation._meta.fields) def test_return_field_name_is_camelcased(self): class PetMutation(DjangoModelFormMutation): @@ -59,31 +58,31 @@ class ModelFormMutationTests(TestCase): model = FilmDetails self.assertEqual(PetMutation._meta.model, FilmDetails) - self.assertEqual(PetMutation._meta.return_field_name, 'filmDetails') + self.assertEqual(PetMutation._meta.return_field_name, "filmDetails") def test_custom_return_field_name(self): class PetMutation(DjangoModelFormMutation): class Meta: form_class = PetForm model = Film - return_field_name = 'animal' + return_field_name = "animal" self.assertEqual(PetMutation._meta.model, Film) - self.assertEqual(PetMutation._meta.return_field_name, 'animal') - self.assertIn('animal', PetMutation._meta.fields) + self.assertEqual(PetMutation._meta.return_field_name, "animal") + self.assertIn("animal", PetMutation._meta.fields) def test_model_form_mutation_mutate(self): class PetMutation(DjangoModelFormMutation): class Meta: form_class = PetForm - pet = Pet.objects.create(name='Axel') + pet = Pet.objects.create(name="Axel") - result = PetMutation.mutate_and_get_payload(None, None, id=pet.pk, name='Mia') + result = PetMutation.mutate_and_get_payload(None, None, id=pet.pk, name="Mia") self.assertEqual(Pet.objects.count(), 1) pet.refresh_from_db() - self.assertEqual(pet.name, 'Mia') + self.assertEqual(pet.name, "Mia") self.assertEqual(result.errors, []) def test_model_form_mutation_updates_existing_(self): @@ -91,11 +90,11 @@ class ModelFormMutationTests(TestCase): class Meta: form_class = PetForm - result = PetMutation.mutate_and_get_payload(None, None, name='Mia') + result = PetMutation.mutate_and_get_payload(None, None, name="Mia") self.assertEqual(Pet.objects.count(), 1) pet = Pet.objects.get() - self.assertEqual(pet.name, 'Mia') + self.assertEqual(pet.name, "Mia") self.assertEqual(result.errors, []) def test_model_form_mutation_mutate_invalid_form(self): @@ -109,5 +108,5 @@ class ModelFormMutationTests(TestCase): self.assertEqual(Pet.objects.count(), 0) self.assertEqual(len(result.errors), 1) - self.assertEqual(result.errors[0].field, 'name') - self.assertEqual(result.errors[0].messages, ['This field is required.']) + self.assertEqual(result.errors[0].field, "name") + self.assertEqual(result.errors[0].messages, ["This field is required."]) diff --git a/graphene_django/management/commands/graphql_schema.py b/graphene_django/management/commands/graphql_schema.py index 14ecf0c..099af5d 100644 --- a/graphene_django/management/commands/graphql_schema.py +++ b/graphene_django/management/commands/graphql_schema.py @@ -7,43 +7,45 @@ from graphene_django.settings import graphene_settings class CommandArguments(BaseCommand): - def add_arguments(self, parser): parser.add_argument( - '--schema', + "--schema", type=str, - dest='schema', + dest="schema", default=graphene_settings.SCHEMA, - help='Django app containing schema to dump, e.g. myproject.core.schema.schema') + help="Django app containing schema to dump, e.g. myproject.core.schema.schema", + ) parser.add_argument( - '--out', + "--out", type=str, - dest='out', + dest="out", default=graphene_settings.SCHEMA_OUTPUT, - help='Output file (default: schema.json)') + help="Output file (default: schema.json)", + ) parser.add_argument( - '--indent', + "--indent", type=int, - dest='indent', + dest="indent", default=graphene_settings.SCHEMA_INDENT, - help='Output file indent (default: None)') + help="Output file indent (default: None)", + ) class Command(CommandArguments): - help = 'Dump Graphene schema JSON to file' + help = "Dump Graphene schema JSON to file" can_import_settings = True def save_file(self, out, schema_dict, indent): - with open(out, 'w') as outfile: + with open(out, "w") as outfile: json.dump(schema_dict, outfile, indent=indent) def handle(self, *args, **options): - options_schema = options.get('schema') + options_schema = options.get("schema") if options_schema and type(options_schema) is str: - module_str, schema_name = options_schema.rsplit('.', 1) + module_str, schema_name = options_schema.rsplit(".", 1) mod = importlib.import_module(module_str) schema = getattr(mod, schema_name) @@ -53,16 +55,18 @@ class Command(CommandArguments): else: schema = graphene_settings.SCHEMA - out = options.get('out') or graphene_settings.SCHEMA_OUTPUT + out = options.get("out") or graphene_settings.SCHEMA_OUTPUT if not schema: - raise CommandError('Specify schema on GRAPHENE.SCHEMA setting or by using --schema') + raise CommandError( + "Specify schema on GRAPHENE.SCHEMA setting or by using --schema" + ) - indent = options.get('indent') - schema_dict = {'data': schema.introspect()} + indent = options.get("indent") + schema_dict = {"data": schema.introspect()} self.save_file(out, schema_dict, indent) - style = getattr(self, 'style', None) - success = getattr(style, 'SUCCESS', lambda x: x) + style = getattr(self, "style", None) + success = getattr(style, "SUCCESS", lambda x: x) - self.stdout.write(success('Successfully dumped GraphQL schema to %s' % out)) + self.stdout.write(success("Successfully dumped GraphQL schema to %s" % out)) diff --git a/graphene_django/registry.py b/graphene_django/registry.py index b28268d..50a8ae5 100644 --- a/graphene_django/registry.py +++ b/graphene_django/registry.py @@ -1,20 +1,21 @@ - class Registry(object): - def __init__(self): self._registry = {} self._field_registry = {} def register(self, cls): from .types import DjangoObjectType + assert issubclass( - cls, DjangoObjectType), 'Only DjangoObjectTypes can be registered, received "{}"'.format( - cls.__name__) - assert cls._meta.registry == self, 'Registry for a Model have to match.' + cls, DjangoObjectType + ), 'Only DjangoObjectTypes can be registered, received "{}"'.format( + cls.__name__ + ) + assert cls._meta.registry == self, "Registry for a Model have to match." # assert self.get_type_for_model(cls._meta.model) == cls, ( # 'Multiple DjangoObjectTypes registered for "{}"'.format(cls._meta.model) # ) - if not getattr(cls._meta, 'skip_registry', False): + if not getattr(cls._meta, "skip_registry", False): self._registry[cls._meta.model] = cls def get_type_for_model(self, model): diff --git a/graphene_django/rest_framework/mutation.py b/graphene_django/rest_framework/mutation.py index a694553..5e343aa 100644 --- a/graphene_django/rest_framework/mutation.py +++ b/graphene_django/rest_framework/mutation.py @@ -6,20 +6,16 @@ import graphene from graphene.types import Field, InputField from graphene.types.mutation import MutationOptions from graphene.relay.mutation import ClientIDMutation -from graphene.types.objecttype import ( - yank_fields_from_attrs -) +from graphene.types.objecttype import yank_fields_from_attrs -from .serializer_converter import ( - convert_serializer_field -) +from .serializer_converter import convert_serializer_field from .types import ErrorType class SerializerMutationOptions(MutationOptions): lookup_field = None model_class = None - model_operations = ['create', 'update'] + model_operations = ["create", "update"] serializer_class = None @@ -28,7 +24,8 @@ def fields_for_serializer(serializer, only_fields, exclude_fields, is_input=Fals for name, field in serializer.fields.items(): is_not_in_only = only_fields and name not in only_fields is_excluded = ( - name in exclude_fields # or + name + in exclude_fields # or # name in already_created_fields ) @@ -44,49 +41,54 @@ class SerializerMutation(ClientIDMutation): abstract = True errors = graphene.List( - ErrorType, - description='May contain more than one error for same field.' + ErrorType, description="May contain more than one error for same field." ) @classmethod - def __init_subclass_with_meta__(cls, lookup_field=None, - serializer_class=None, model_class=None, - model_operations=['create', 'update'], - only_fields=(), exclude_fields=(), **options): + def __init_subclass_with_meta__( + cls, + lookup_field=None, + serializer_class=None, + model_class=None, + model_operations=["create", "update"], + only_fields=(), + exclude_fields=(), + **options + ): if not serializer_class: - raise Exception('serializer_class is required for the SerializerMutation') + raise Exception("serializer_class is required for the SerializerMutation") - if 'update' not in model_operations and 'create' not in model_operations: + if "update" not in model_operations and "create" not in model_operations: raise Exception('model_operations must contain "create" and/or "update"') serializer = serializer_class() if model_class is None: - serializer_meta = getattr(serializer_class, 'Meta', None) + serializer_meta = getattr(serializer_class, "Meta", None) if serializer_meta: - model_class = getattr(serializer_meta, 'model', None) + model_class = getattr(serializer_meta, "model", None) if lookup_field is None and model_class: lookup_field = model_class._meta.pk.name - input_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=True) - output_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=False) + input_fields = fields_for_serializer( + serializer, only_fields, exclude_fields, is_input=True + ) + output_fields = fields_for_serializer( + serializer, only_fields, exclude_fields, is_input=False + ) _meta = SerializerMutationOptions(cls) _meta.lookup_field = lookup_field _meta.model_operations = model_operations _meta.serializer_class = serializer_class _meta.model_class = model_class - _meta.fields = yank_fields_from_attrs( - output_fields, - _as=Field, - ) + _meta.fields = yank_fields_from_attrs(output_fields, _as=Field) - input_fields = yank_fields_from_attrs( - input_fields, - _as=InputField, + input_fields = yank_fields_from_attrs(input_fields, _as=InputField) + super(SerializerMutation, cls).__init_subclass_with_meta__( + _meta=_meta, input_fields=input_fields, **options ) - super(SerializerMutation, cls).__init_subclass_with_meta__(_meta=_meta, input_fields=input_fields, **options) @classmethod def get_serializer_kwargs(cls, root, info, **input): @@ -94,24 +96,26 @@ class SerializerMutation(ClientIDMutation): model_class = cls._meta.model_class if model_class: - if 'update' in cls._meta.model_operations and lookup_field in input: - instance = get_object_or_404(model_class, **{ - lookup_field: input[lookup_field]}) - elif 'create' in cls._meta.model_operations: + if "update" in cls._meta.model_operations and lookup_field in input: + instance = get_object_or_404( + model_class, **{lookup_field: input[lookup_field]} + ) + elif "create" in cls._meta.model_operations: instance = None else: raise Exception( 'Invalid update operation. Input parameter "{}" required.'.format( lookup_field - )) + ) + ) return { - 'instance': instance, - 'data': input, - 'context': {'request': info.context} + "instance": instance, + "data": input, + "context": {"request": info.context}, } - return {'data': input, 'context': {'request': info.context}} + return {"data": input, "context": {"request": info.context}} @classmethod def mutate_and_get_payload(cls, root, info, **input): diff --git a/graphene_django/rest_framework/serializer_converter.py b/graphene_django/rest_framework/serializer_converter.py index 014d42a..9f8e516 100644 --- a/graphene_django/rest_framework/serializer_converter.py +++ b/graphene_django/rest_framework/serializer_converter.py @@ -28,15 +28,12 @@ def convert_serializer_field(field, is_input=True): graphql_type = get_graphene_type_from_serializer_field(field) args = [] - kwargs = { - 'description': field.help_text, - 'required': is_input and field.required, - } + kwargs = {"description": field.help_text, "required": is_input and field.required} # if it is a tuple or a list it means that we are returning # the graphql type and the child type if isinstance(graphql_type, (list, tuple)): - kwargs['of_type'] = graphql_type[1] + kwargs["of_type"] = graphql_type[1] graphql_type = graphql_type[0] if isinstance(field, serializers.ModelSerializer): @@ -49,9 +46,9 @@ def convert_serializer_field(field, is_input=True): elif isinstance(field, serializers.ListSerializer): field = field.child if is_input: - kwargs['of_type'] = convert_serializer_to_input_type(field.__class__) + kwargs["of_type"] = convert_serializer_to_input_type(field.__class__) else: - del kwargs['of_type'] + del kwargs["of_type"] global_registry = get_global_registry() field_model = field.Meta.model args = [global_registry.get_type_for_model(field_model)] @@ -68,9 +65,9 @@ def convert_serializer_to_input_type(serializer_class): } return type( - '{}Input'.format(serializer.__class__.__name__), + "{}Input".format(serializer.__class__.__name__), (graphene.InputObjectType,), - items + items, ) diff --git a/graphene_django/rest_framework/tests/test_field_converter.py b/graphene_django/rest_framework/tests/test_field_converter.py index 22a0ba9..6fa4ca8 100644 --- a/graphene_django/rest_framework/tests/test_field_converter.py +++ b/graphene_django/rest_framework/tests/test_field_converter.py @@ -16,8 +16,8 @@ def _get_type(rest_framework_field, is_input=True, **kwargs): # Remove `source=` from the field declaration. # since we are reusing the same child in when testing the required attribute - if 'child' in kwargs: - kwargs['child'] = copy.deepcopy(kwargs['child']) + if "child" in kwargs: + kwargs["child"] = copy.deepcopy(kwargs["child"]) field = rest_framework_field(**kwargs) @@ -25,11 +25,13 @@ def _get_type(rest_framework_field, is_input=True, **kwargs): def assert_conversion(rest_framework_field, graphene_field, **kwargs): - graphene_type = _get_type(rest_framework_field, help_text='Custom Help Text', **kwargs) + graphene_type = _get_type( + rest_framework_field, help_text="Custom Help Text", **kwargs + ) assert isinstance(graphene_type, graphene_field) graphene_type_required = _get_type( - rest_framework_field, help_text='Custom Help Text', required=True, **kwargs + rest_framework_field, help_text="Custom Help Text", required=True, **kwargs ) assert isinstance(graphene_type_required, graphene_field) @@ -39,7 +41,7 @@ def assert_conversion(rest_framework_field, graphene_field, **kwargs): def test_should_unknown_rest_framework_field_raise_exception(): with raises(Exception) as excinfo: convert_serializer_field(None) - assert 'Don\'t know how to convert the serializer field' in str(excinfo.value) + assert "Don't know how to convert the serializer field" in str(excinfo.value) def test_should_char_convert_string(): @@ -67,11 +69,11 @@ def test_should_base_field_convert_string(): def test_should_regex_convert_string(): - assert_conversion(serializers.RegexField, graphene.String, regex='[0-9]+') + assert_conversion(serializers.RegexField, graphene.String, regex="[0-9]+") def test_should_uuid_convert_string(): - if hasattr(serializers, 'UUIDField'): + if hasattr(serializers, "UUIDField"): assert_conversion(serializers.UUIDField, graphene.String) @@ -79,7 +81,7 @@ def test_should_model_convert_field(): class MyModelSerializer(serializers.ModelSerializer): class Meta: model = None - fields = '__all__' + fields = "__all__" assert_conversion(MyModelSerializer, graphene.Field, is_input=False) @@ -109,7 +111,9 @@ def test_should_float_convert_float(): def test_should_decimal_convert_float(): - assert_conversion(serializers.DecimalField, graphene.Float, max_digits=4, decimal_places=2) + assert_conversion( + serializers.DecimalField, graphene.Float, max_digits=4, decimal_places=2 + ) def test_should_list_convert_to_list(): @@ -119,7 +123,7 @@ def test_should_list_convert_to_list(): field_a = assert_conversion( serializers.ListField, graphene.List, - child=serializers.IntegerField(min_value=0, max_value=100) + child=serializers.IntegerField(min_value=0, max_value=100), ) assert field_a.of_type == graphene.Int @@ -136,19 +140,23 @@ def test_should_list_serializer_convert_to_list(): class ChildSerializer(serializers.ModelSerializer): class Meta: model = FooModel - fields = '__all__' + fields = "__all__" class ParentSerializer(serializers.ModelSerializer): child = ChildSerializer(many=True) class Meta: model = FooModel - fields = '__all__' + fields = "__all__" - converted_type = convert_serializer_field(ParentSerializer().get_fields()['child'], is_input=True) + converted_type = convert_serializer_field( + ParentSerializer().get_fields()["child"], is_input=True + ) assert isinstance(converted_type, graphene.List) - converted_type = convert_serializer_field(ParentSerializer().get_fields()['child'], is_input=False) + converted_type = convert_serializer_field( + ParentSerializer().get_fields()["child"], is_input=False + ) assert isinstance(converted_type, graphene.List) assert converted_type.of_type is None @@ -166,7 +174,7 @@ def test_should_file_convert_string(): def test_should_filepath_convert_string(): - assert_conversion(serializers.FilePathField, graphene.String, path='/') + assert_conversion(serializers.FilePathField, graphene.String, path="/") def test_should_ip_convert_string(): @@ -182,6 +190,8 @@ def test_should_json_convert_jsonstring(): def test_should_multiplechoicefield_convert_to_list_of_string(): - field = assert_conversion(serializers.MultipleChoiceField, graphene.List, choices=[1, 2, 3]) + field = assert_conversion( + serializers.MultipleChoiceField, graphene.List, choices=[1, 2, 3] + ) assert field.of_type == graphene.String diff --git a/graphene_django/rest_framework/tests/test_mutation.py b/graphene_django/rest_framework/tests/test_mutation.py index 35acab7..4dccc18 100644 --- a/graphene_django/rest_framework/tests/test_mutation.py +++ b/graphene_django/rest_framework/tests/test_mutation.py @@ -10,30 +10,33 @@ from ...types import DjangoObjectType from ..models import MyFakeModel from ..mutation import SerializerMutation + def mock_info(): - return ResolveInfo( - None, - None, - None, - None, - schema=None, - fragments=None, - root_value=None, - operation=None, - variable_values=None, - context=None - ) + return ResolveInfo( + None, + None, + None, + None, + schema=None, + fragments=None, + root_value=None, + operation=None, + variable_values=None, + context=None, + ) class MyModelSerializer(serializers.ModelSerializer): class Meta: model = MyFakeModel - fields = '__all__' + fields = "__all__" + class MyModelMutation(SerializerMutation): class Meta: serializer_class = MyModelSerializer + class MySerializer(serializers.Serializer): text = serializers.CharField() model = MyModelSerializer() @@ -44,10 +47,11 @@ class MySerializer(serializers.Serializer): def test_needs_serializer_class(): with raises(Exception) as exc: + class MyMutation(SerializerMutation): pass - assert str(exc.value) == 'serializer_class is required for the SerializerMutation' + assert str(exc.value) == "serializer_class is required for the SerializerMutation" def test_has_fields(): @@ -55,9 +59,9 @@ def test_has_fields(): class Meta: serializer_class = MySerializer - assert 'text' in MyMutation._meta.fields - assert 'model' in MyMutation._meta.fields - assert 'errors' in MyMutation._meta.fields + assert "text" in MyMutation._meta.fields + assert "model" in MyMutation._meta.fields + assert "errors" in MyMutation._meta.fields def test_has_input_fields(): @@ -65,25 +69,24 @@ def test_has_input_fields(): class Meta: serializer_class = MySerializer - assert 'text' in MyMutation.Input._meta.fields - assert 'model' in MyMutation.Input._meta.fields + assert "text" in MyMutation.Input._meta.fields + assert "model" in MyMutation.Input._meta.fields def test_exclude_fields(): class MyMutation(SerializerMutation): class Meta: serializer_class = MyModelSerializer - exclude_fields = ['created'] + exclude_fields = ["created"] - assert 'cool_name' in MyMutation._meta.fields - assert 'created' not in MyMutation._meta.fields - assert 'errors' in MyMutation._meta.fields - assert 'cool_name' in MyMutation.Input._meta.fields - assert 'created' not in MyMutation.Input._meta.fields + assert "cool_name" in MyMutation._meta.fields + assert "created" not in MyMutation._meta.fields + assert "errors" in MyMutation._meta.fields + assert "cool_name" in MyMutation.Input._meta.fields + assert "created" not in MyMutation.Input._meta.fields def test_nested_model(): - class MyFakeModelGrapheneType(DjangoObjectType): class Meta: model = MyFakeModel @@ -92,67 +95,64 @@ def test_nested_model(): class Meta: serializer_class = MySerializer - model_field = MyMutation._meta.fields['model'] + model_field = MyMutation._meta.fields["model"] assert isinstance(model_field, Field) assert model_field.type == MyFakeModelGrapheneType - model_input = MyMutation.Input._meta.fields['model'] + model_input = MyMutation.Input._meta.fields["model"] model_input_type = model_input._type.of_type assert issubclass(model_input_type, InputObjectType) - assert 'cool_name' in model_input_type._meta.fields - assert 'created' in model_input_type._meta.fields + assert "cool_name" in model_input_type._meta.fields + assert "created" in model_input_type._meta.fields def test_mutate_and_get_payload_success(): - class MyMutation(SerializerMutation): class Meta: serializer_class = MySerializer - result = MyMutation.mutate_and_get_payload(None, mock_info(), **{ - 'text': 'value', - 'model': { - 'cool_name': 'other_value' - } - }) + result = MyMutation.mutate_and_get_payload( + None, mock_info(), **{"text": "value", "model": {"cool_name": "other_value"}} + ) assert result.errors is None @mark.django_db def test_model_add_mutate_and_get_payload_success(): - result = MyModelMutation.mutate_and_get_payload(None, mock_info(), **{ - 'cool_name': 'Narf', - }) + result = MyModelMutation.mutate_and_get_payload( + None, mock_info(), **{"cool_name": "Narf"} + ) assert result.errors is None - assert result.cool_name == 'Narf' + assert result.cool_name == "Narf" assert isinstance(result.created, datetime.datetime) + @mark.django_db def test_model_update_mutate_and_get_payload_success(): instance = MyFakeModel.objects.create(cool_name="Narf") - result = MyModelMutation.mutate_and_get_payload(None, mock_info(), **{ - 'id': instance.id, - 'cool_name': 'New Narf', - }) + result = MyModelMutation.mutate_and_get_payload( + None, mock_info(), **{"id": instance.id, "cool_name": "New Narf"} + ) assert result.errors is None - assert result.cool_name == 'New Narf' + assert result.cool_name == "New Narf" + @mark.django_db def test_model_invalid_update_mutate_and_get_payload_success(): class InvalidModelMutation(SerializerMutation): class Meta: serializer_class = MyModelSerializer - model_operations = ['update'] + model_operations = ["update"] with raises(Exception) as exc: - result = InvalidModelMutation.mutate_and_get_payload(None, mock_info(), **{ - 'cool_name': 'Narf', - }) + result = InvalidModelMutation.mutate_and_get_payload( + None, mock_info(), **{"cool_name": "Narf"} + ) assert '"id" required' in str(exc.value) -def test_mutate_and_get_payload_error(): +def test_mutate_and_get_payload_error(): class MyMutation(SerializerMutation): class Meta: serializer_class = MySerializer @@ -161,16 +161,19 @@ def test_mutate_and_get_payload_error(): result = MyMutation.mutate_and_get_payload(None, mock_info(), **{}) assert len(result.errors) > 0 + def test_model_mutate_and_get_payload_error(): # missing required fields result = MyModelMutation.mutate_and_get_payload(None, mock_info(), **{}) assert len(result.errors) > 0 + def test_invalid_serializer_operations(): with raises(Exception) as exc: + class MyModelMutation(SerializerMutation): class Meta: serializer_class = MyModelSerializer - model_operations = ['Add'] + model_operations = ["Add"] - assert 'model_operations' in str(exc.value) + assert "model_operations" in str(exc.value) diff --git a/graphene_django/settings.py b/graphene_django/settings.py index 46d70ee..7cd750a 100644 --- a/graphene_django/settings.py +++ b/graphene_django/settings.py @@ -26,27 +26,22 @@ except ImportError: # Copied shamelessly from Django REST Framework DEFAULTS = { - 'SCHEMA': None, - 'SCHEMA_OUTPUT': 'schema.json', - 'SCHEMA_INDENT': None, - 'MIDDLEWARE': (), + "SCHEMA": None, + "SCHEMA_OUTPUT": "schema.json", + "SCHEMA_INDENT": None, + "MIDDLEWARE": (), # Set to True if the connection fields must have # either the first or last argument - 'RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST': False, + "RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST": False, # Max items returned in ConnectionFields / FilterConnectionFields - 'RELAY_CONNECTION_MAX_LIMIT': 100, + "RELAY_CONNECTION_MAX_LIMIT": 100, } if settings.DEBUG: - DEFAULTS['MIDDLEWARE'] += ( - 'graphene_django.debug.DjangoDebugMiddleware', - ) + DEFAULTS["MIDDLEWARE"] += ("graphene_django.debug.DjangoDebugMiddleware",) # List of settings that may be in string import notation. -IMPORT_STRINGS = ( - 'MIDDLEWARE', - 'SCHEMA', -) +IMPORT_STRINGS = ("MIDDLEWARE", "SCHEMA") def perform_import(val, setting_name): @@ -69,12 +64,17 @@ def import_from_string(val, setting_name): """ try: # Nod to tastypie's use of importlib. - parts = val.split('.') - module_path, class_name = '.'.join(parts[:-1]), parts[-1] + parts = val.split(".") + module_path, class_name = ".".join(parts[:-1]), parts[-1] module = importlib.import_module(module_path) return getattr(module, class_name) except (ImportError, AttributeError) as e: - msg = "Could not import '%s' for Graphene setting '%s'. %s: %s." % (val, setting_name, e.__class__.__name__, e) + msg = "Could not import '%s' for Graphene setting '%s'. %s: %s." % ( + val, + setting_name, + e.__class__.__name__, + e, + ) raise ImportError(msg) @@ -96,8 +96,8 @@ class GrapheneSettings(object): @property def user_settings(self): - if not hasattr(self, '_user_settings'): - self._user_settings = getattr(settings, 'GRAPHENE', {}) + if not hasattr(self, "_user_settings"): + self._user_settings = getattr(settings, "GRAPHENE", {}) return self._user_settings def __getattr__(self, attr): @@ -125,8 +125,8 @@ graphene_settings = GrapheneSettings(None, DEFAULTS, IMPORT_STRINGS) def reload_graphene_settings(*args, **kwargs): global graphene_settings - setting, value = kwargs['setting'], kwargs['value'] - if setting == 'GRAPHENE': + setting, value = kwargs["setting"], kwargs["value"] + if setting == "GRAPHENE": graphene_settings = GrapheneSettings(value, DEFAULTS, IMPORT_STRINGS) diff --git a/graphene_django/tests/models.py b/graphene_django/tests/models.py index eab70ec..45ab737 100644 --- a/graphene_django/tests/models.py +++ b/graphene_django/tests/models.py @@ -3,10 +3,7 @@ from __future__ import absolute_import from django.db import models from django.utils.translation import ugettext_lazy as _ -CHOICES = ( - (1, 'this'), - (2, _('that')) -) +CHOICES = ((1, "this"), (2, _("that"))) class Pet(models.Model): @@ -15,38 +12,43 @@ class Pet(models.Model): class FilmDetails(models.Model): location = models.CharField(max_length=30) - film = models.OneToOneField('Film', on_delete=models.CASCADE, related_name='details') + film = models.OneToOneField( + "Film", on_delete=models.CASCADE, related_name="details" + ) class Film(models.Model): - genre = models.CharField(max_length=2, help_text='Genre', choices=[ - ('do', 'Documentary'), - ('ot', 'Other') - ], default='ot') - reporters = models.ManyToManyField('Reporter', - related_name='films') + genre = models.CharField( + max_length=2, + help_text="Genre", + choices=[("do", "Documentary"), ("ot", "Other")], + default="ot", + ) + reporters = models.ManyToManyField("Reporter", related_name="films") + class DoeReporterManager(models.Manager): def get_queryset(self): return super(DoeReporterManager, self).get_queryset().filter(last_name="Doe") + class Reporter(models.Model): first_name = models.CharField(max_length=30) last_name = models.CharField(max_length=30) email = models.EmailField() - pets = models.ManyToManyField('self') + pets = models.ManyToManyField("self") a_choice = models.CharField(max_length=30, choices=CHOICES) objects = models.Manager() doe_objects = DoeReporterManager() reporter_type = models.IntegerField( - 'Reporter Type', + "Reporter Type", null=True, blank=True, - choices=[(1, u'Regular'), (2, u'CNN Reporter')] + choices=[(1, u"Regular"), (2, u"CNN Reporter")], ) - def __str__(self): # __unicode__ on Python 2 + def __str__(self): # __unicode__ on Python 2 return "%s %s" % (self.first_name, self.last_name) def __init__(self, *args, **kwargs): @@ -61,11 +63,13 @@ class Reporter(models.Model): if self.reporter_type == 2: # quick and dirty way without enums self.__class__ = CNNReporter + class CNNReporter(Reporter): """ This class is a proxy model for Reporter, used for testing proxy model support """ + class Meta: proxy = True @@ -74,17 +78,27 @@ class Article(models.Model): headline = models.CharField(max_length=100) pub_date = models.DateField() pub_date_time = models.DateTimeField() - reporter = models.ForeignKey(Reporter, on_delete=models.CASCADE, related_name='articles') - editor = models.ForeignKey(Reporter, on_delete=models.CASCADE, related_name='edited_articles_+') - lang = models.CharField(max_length=2, help_text='Language', choices=[ - ('es', 'Spanish'), - ('en', 'English') - ], default='es') - importance = models.IntegerField('Importance', null=True, blank=True, - choices=[(1, u'Very important'), (2, u'Not as important')]) + reporter = models.ForeignKey( + Reporter, on_delete=models.CASCADE, related_name="articles" + ) + editor = models.ForeignKey( + Reporter, on_delete=models.CASCADE, related_name="edited_articles_+" + ) + lang = models.CharField( + max_length=2, + help_text="Language", + choices=[("es", "Spanish"), ("en", "English")], + default="es", + ) + importance = models.IntegerField( + "Importance", + null=True, + blank=True, + choices=[(1, u"Very important"), (2, u"Not as important")], + ) - def __str__(self): # __unicode__ on Python 2 + def __str__(self): # __unicode__ on Python 2 return self.headline class Meta: - ordering = ('headline',) + ordering = ("headline",) diff --git a/graphene_django/tests/schema.py b/graphene_django/tests/schema.py index 3134604..d0d9e47 100644 --- a/graphene_django/tests/schema.py +++ b/graphene_django/tests/schema.py @@ -6,10 +6,9 @@ from .models import Article, Reporter class Character(DjangoObjectType): - class Meta: model = Reporter - interfaces = (relay.Node, ) + interfaces = (relay.Node,) def get_node(self, info, id): pass @@ -20,7 +19,7 @@ class Human(DjangoObjectType): class Meta: model = Article - interfaces = (relay.Node, ) + interfaces = (relay.Node,) def resolve_raises(self, info): raise Exception("This field should raise exception") diff --git a/graphene_django/tests/schema_view.py b/graphene_django/tests/schema_view.py index c750433..9b3bd1e 100644 --- a/graphene_django/tests/schema_view.py +++ b/graphene_django/tests/schema_view.py @@ -12,10 +12,10 @@ class QueryRoot(ObjectType): raise Exception("Throws!") def resolve_request(self, info): - return info.context.GET.get('q') + return info.context.GET.get("q") def resolve_test(self, info, who=None): - return 'Hello %s' % (who or 'World') + return "Hello %s" % (who or "World") class MutationRoot(ObjectType): diff --git a/graphene_django/tests/test_command.py b/graphene_django/tests/test_command.py index caf9f7a..ff6e6e1 100644 --- a/graphene_django/tests/test_command.py +++ b/graphene_django/tests/test_command.py @@ -3,8 +3,8 @@ from mock import patch from six import StringIO -@patch('graphene_django.management.commands.graphql_schema.Command.save_file') +@patch("graphene_django.management.commands.graphql_schema.Command.save_file") def test_generate_file_on_call_graphql_schema(savefile_mock, settings): out = StringIO() - management.call_command('graphql_schema', schema='', stdout=out) + management.call_command("graphql_schema", schema="", stdout=out) assert "Successfully dumped GraphQL schema to schema.json" in out.getvalue() diff --git a/graphene_django/tests/test_converter.py b/graphene_django/tests/test_converter.py index bcf4564..196f008 100644 --- a/graphene_django/tests/test_converter.py +++ b/graphene_django/tests/test_converter.py @@ -19,11 +19,11 @@ from .models import Article, Film, FilmDetails, Reporter def assert_conversion(django_field, graphene_field, *args, **kwargs): - field = django_field(help_text='Custom Help Text', null=True, *args, **kwargs) + field = django_field(help_text="Custom Help Text", null=True, *args, **kwargs) graphene_type = convert_django_field(field) assert isinstance(graphene_type, graphene_field) field = graphene_type.Field() - assert field.description == 'Custom Help Text' + assert field.description == "Custom Help Text" nonnull_field = django_field(null=False, *args, **kwargs) if not nonnull_field.null: nonnull_graphene_type = convert_django_field(nonnull_field) @@ -36,7 +36,8 @@ def assert_conversion(django_field, graphene_field, *args, **kwargs): def test_should_unknown_django_field_raise_exception(): with raises(Exception) as excinfo: convert_django_field(None) - assert 'Don\'t know how to convert the Django field' in str(excinfo.value) + assert "Don't know how to convert the Django field" in str(excinfo.value) + def test_should_date_time_convert_string(): assert_conversion(models.DateTimeField, DateTime) @@ -128,70 +129,69 @@ def test_should_nullboolean_convert_boolean(): def test_field_with_choices_convert_enum(): - field = models.CharField(help_text='Language', choices=( - ('es', 'Spanish'), - ('en', 'English') - )) + field = models.CharField( + help_text="Language", choices=(("es", "Spanish"), ("en", "English")) + ) class TranslatedModel(models.Model): language = field class Meta: - app_label = 'test' + app_label = "test" graphene_type = convert_django_field_with_choices(field) assert isinstance(graphene_type, graphene.Enum) - assert graphene_type._meta.name == 'TranslatedModelLanguage' - assert graphene_type._meta.enum.__members__['ES'].value == 'es' - assert graphene_type._meta.enum.__members__['ES'].description == 'Spanish' - assert graphene_type._meta.enum.__members__['EN'].value == 'en' - assert graphene_type._meta.enum.__members__['EN'].description == 'English' + assert graphene_type._meta.name == "TranslatedModelLanguage" + assert graphene_type._meta.enum.__members__["ES"].value == "es" + assert graphene_type._meta.enum.__members__["ES"].description == "Spanish" + assert graphene_type._meta.enum.__members__["EN"].value == "en" + assert graphene_type._meta.enum.__members__["EN"].description == "English" def test_field_with_grouped_choices(): - field = models.CharField(help_text='Language', choices=( - ('Europe', ( - ('es', 'Spanish'), - ('en', 'English'), - )), - )) + field = models.CharField( + help_text="Language", + choices=(("Europe", (("es", "Spanish"), ("en", "English"))),), + ) class GroupedChoicesModel(models.Model): language = field class Meta: - app_label = 'test' + app_label = "test" convert_django_field_with_choices(field) def test_field_with_choices_gettext(): - field = models.CharField(help_text='Language', choices=( - ('es', _('Spanish')), - ('en', _('English')) - )) + field = models.CharField( + help_text="Language", choices=(("es", _("Spanish")), ("en", _("English"))) + ) class TranslatedChoicesModel(models.Model): language = field class Meta: - app_label = 'test' + app_label = "test" convert_django_field_with_choices(field) def test_field_with_choices_collision(): - field = models.CharField(help_text='Timezone', choices=( - ('Etc/GMT+1+2', 'Fake choice to produce double collision'), - ('Etc/GMT+1', 'Greenwich Mean Time +1'), - ('Etc/GMT-1', 'Greenwich Mean Time -1'), - )) + field = models.CharField( + help_text="Timezone", + choices=( + ("Etc/GMT+1+2", "Fake choice to produce double collision"), + ("Etc/GMT+1", "Greenwich Mean Time +1"), + ("Etc/GMT-1", "Greenwich Mean Time -1"), + ), + ) class CollisionChoicesModel(models.Model): timezone = field class Meta: - app_label = 'test' + app_label = "test" convert_django_field_with_choices(field) @@ -208,11 +208,12 @@ def test_should_manytomany_convert_connectionorlist(): def test_should_manytomany_convert_connectionorlist_list(): class A(DjangoObjectType): - class Meta: model = Reporter - graphene_field = convert_django_field(Reporter._meta.local_many_to_many[0], A._meta.registry) + graphene_field = convert_django_field( + Reporter._meta.local_many_to_many[0], A._meta.registry + ) assert isinstance(graphene_field, graphene.Dynamic) dynamic_field = graphene_field.get_type() assert isinstance(dynamic_field, graphene.Field) @@ -222,12 +223,13 @@ def test_should_manytomany_convert_connectionorlist_list(): def test_should_manytomany_convert_connectionorlist_connection(): class A(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) - graphene_field = convert_django_field(Reporter._meta.local_many_to_many[0], A._meta.registry) + graphene_field = convert_django_field( + Reporter._meta.local_many_to_many[0], A._meta.registry + ) assert isinstance(graphene_field, graphene.Dynamic) dynamic_field = graphene_field.get_type() assert isinstance(dynamic_field, ConnectionField) @@ -236,11 +238,11 @@ def test_should_manytomany_convert_connectionorlist_connection(): def test_should_manytoone_convert_connectionorlist(): # Django 1.9 uses 'rel', <1.9 uses 'related - related = getattr(Reporter.articles, 'rel', None) or \ - getattr(Reporter.articles, 'related') + related = getattr(Reporter.articles, "rel", None) or getattr( + Reporter.articles, "related" + ) class A(DjangoObjectType): - class Meta: model = Article @@ -254,11 +256,9 @@ def test_should_manytoone_convert_connectionorlist(): def test_should_onetoone_reverse_convert_model(): # Django 1.9 uses 'rel', <1.9 uses 'related - related = getattr(Film.details, 'rel', None) or \ - getattr(Film.details, 'related') + related = getattr(Film.details, "rel", None) or getattr(Film.details, "related") class A(DjangoObjectType): - class Meta: model = FilmDetails @@ -269,41 +269,41 @@ def test_should_onetoone_reverse_convert_model(): assert dynamic_field.type == A -@pytest.mark.skipif(ArrayField is MissingType, - reason="ArrayField should exist") +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") def test_should_postgres_array_convert_list(): - field = assert_conversion(ArrayField, graphene.List, models.CharField(max_length=100)) + field = assert_conversion( + ArrayField, graphene.List, models.CharField(max_length=100) + ) assert isinstance(field.type, graphene.NonNull) assert isinstance(field.type.of_type, graphene.List) assert field.type.of_type.of_type == graphene.String -@pytest.mark.skipif(ArrayField is MissingType, - reason="ArrayField should exist") +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") def test_should_postgres_array_multiple_convert_list(): - field = assert_conversion(ArrayField, graphene.List, ArrayField(models.CharField(max_length=100))) + field = assert_conversion( + ArrayField, graphene.List, ArrayField(models.CharField(max_length=100)) + ) assert isinstance(field.type, graphene.NonNull) assert isinstance(field.type.of_type, graphene.List) assert isinstance(field.type.of_type.of_type, graphene.List) assert field.type.of_type.of_type.of_type == graphene.String -@pytest.mark.skipif(HStoreField is MissingType, - reason="HStoreField should exist") +@pytest.mark.skipif(HStoreField is MissingType, reason="HStoreField should exist") def test_should_postgres_hstore_convert_string(): assert_conversion(HStoreField, JSONString) -@pytest.mark.skipif(JSONField is MissingType, - reason="JSONField should exist") +@pytest.mark.skipif(JSONField is MissingType, reason="JSONField should exist") def test_should_postgres_json_convert_string(): assert_conversion(JSONField, JSONString) -@pytest.mark.skipif(RangeField is MissingType, - reason="RangeField should exist") +@pytest.mark.skipif(RangeField is MissingType, reason="RangeField should exist") def test_should_postgres_range_convert_list(): from django.contrib.postgres.fields import IntegerRangeField + field = assert_conversion(IntegerRangeField, graphene.List) assert isinstance(field.type, graphene.NonNull) assert isinstance(field.type.of_type, graphene.List) diff --git a/graphene_django/tests/test_forms.py b/graphene_django/tests/test_forms.py index b15e866..fa6628d 100644 --- a/graphene_django/tests/test_forms.py +++ b/graphene_django/tests/test_forms.py @@ -1,7 +1,7 @@ from django.core.exceptions import ValidationError from py.test import raises -from ..forms import GlobalIDFormField,GlobalIDMultipleChoiceField +from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField # 'TXlUeXBlOmFiYw==' -> 'MyType', 'abc' @@ -9,24 +9,24 @@ from ..forms import GlobalIDFormField,GlobalIDMultipleChoiceField def test_global_id_valid(): field = GlobalIDFormField() - field.clean('TXlUeXBlOmFiYw==') + field.clean("TXlUeXBlOmFiYw==") def test_global_id_invalid(): field = GlobalIDFormField() with raises(ValidationError): - field.clean('badvalue') + field.clean("badvalue") def test_global_id_multiple_valid(): field = GlobalIDMultipleChoiceField() - field.clean(['TXlUeXBlOmFiYw==', 'TXlUeXBlOmFiYw==']) + field.clean(["TXlUeXBlOmFiYw==", "TXlUeXBlOmFiYw=="]) def test_global_id_multiple_invalid(): field = GlobalIDMultipleChoiceField() with raises(ValidationError): - field.clean(['badvalue', 'another bad avue']) + field.clean(["badvalue", "another bad avue"]) def test_global_id_none(): diff --git a/graphene_django/tests/test_query.py b/graphene_django/tests/test_query.py index 09ca702..1716034 100644 --- a/graphene_django/tests/test_query.py +++ b/graphene_django/tests/test_query.py @@ -15,41 +15,34 @@ from ..compat import MissingType, JSONField from ..fields import DjangoConnectionField from ..types import DjangoObjectType from ..settings import graphene_settings -from .models import ( - Article, - CNNReporter, - Reporter, - Film, - FilmDetails, -) +from .models import Article, CNNReporter, Reporter, Film, FilmDetails pytestmark = pytest.mark.django_db def test_should_query_only_fields(): with raises(Exception): - class ReporterType(DjangoObjectType): + class ReporterType(DjangoObjectType): class Meta: model = Reporter - only_fields = ('articles', ) + only_fields = ("articles",) schema = graphene.Schema(query=ReporterType) - query = ''' + query = """ query ReporterQuery { articles } - ''' + """ result = schema.execute(query) assert not result.errors def test_should_query_simplelazy_objects(): class ReporterType(DjangoObjectType): - class Meta: model = Reporter - only_fields = ('id', ) + only_fields = ("id",) class Query(graphene.ObjectType): reporter = graphene.Field(ReporterType) @@ -58,25 +51,20 @@ def test_should_query_simplelazy_objects(): return SimpleLazyObject(lambda: Reporter(id=1)) schema = graphene.Schema(query=Query) - query = ''' + query = """ query { reporter { id } } - ''' + """ result = schema.execute(query) assert not result.errors - assert result.data == { - 'reporter': { - 'id': '1' - } - } + assert result.data == {"reporter": {"id": "1"}} def test_should_query_well(): class ReporterType(DjangoObjectType): - class Meta: model = Reporter @@ -84,9 +72,9 @@ def test_should_query_well(): reporter = graphene.Field(ReporterType) def resolve_reporter(self, info): - return Reporter(first_name='ABA', last_name='X') + return Reporter(first_name="ABA", last_name="X") - query = ''' + query = """ query ReporterQuery { reporter { firstName, @@ -94,33 +82,30 @@ def test_should_query_well(): email } } - ''' - expected = { - 'reporter': { - 'firstName': 'ABA', - 'lastName': 'X', - 'email': '' - } - } + """ + expected = {"reporter": {"firstName": "ABA", "lastName": "X", "email": ""}} schema = graphene.Schema(query=Query) result = schema.execute(query) assert not result.errors assert result.data == expected -@pytest.mark.skipif(JSONField is MissingType, - reason="RangeField should exist") +@pytest.mark.skipif(JSONField is MissingType, reason="RangeField should exist") def test_should_query_postgres_fields(): - from django.contrib.postgres.fields import IntegerRangeField, ArrayField, JSONField, HStoreField + from django.contrib.postgres.fields import ( + IntegerRangeField, + ArrayField, + JSONField, + HStoreField, + ) class Event(models.Model): - ages = IntegerRangeField(help_text='The age ranges') - data = JSONField(help_text='Data') + ages = IntegerRangeField(help_text="The age ranges") + data = JSONField(help_text="Data") store = HStoreField() tags = ArrayField(models.CharField(max_length=50)) class EventType(DjangoObjectType): - class Meta: model = Event @@ -130,13 +115,13 @@ def test_should_query_postgres_fields(): def resolve_event(self, info): return Event( ages=(0, 10), - data={'angry_babies': True}, - store={'h': 'store'}, - tags=['child', 'angry', 'babies'] + data={"angry_babies": True}, + store={"h": "store"}, + tags=["child", "angry", "babies"], ) schema = graphene.Schema(query=Query) - query = ''' + query = """ query myQuery { event { ages @@ -145,14 +130,14 @@ def test_should_query_postgres_fields(): store } } - ''' + """ expected = { - 'event': { - 'ages': [0, 10], - 'tags': ['child', 'angry', 'babies'], - 'data': '{"angry_babies": true}', - 'store': '{"h": "store"}', - }, + "event": { + "ages": [0, 10], + "tags": ["child", "angry", "babies"], + "data": '{"angry_babies": true}', + "store": '{"h": "store"}', + } } result = schema.execute(query) assert not result.errors @@ -164,27 +149,27 @@ def test_should_node(): # Node._meta.registry = get_global_registry() class ReporterNode(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) @classmethod def get_node(cls, info, id): - return Reporter(id=2, first_name='Cookie Monster') + return Reporter(id=2, first_name="Cookie Monster") def resolve_articles(self, info, **args): - return [Article(headline='Hi!')] + return [Article(headline="Hi!")] class ArticleNode(DjangoObjectType): - class Meta: model = Article - interfaces = (Node, ) + interfaces = (Node,) @classmethod def get_node(cls, info, id): - return Article(id=1, headline='Article node', pub_date=datetime.date(2002, 3, 11)) + return Article( + id=1, headline="Article node", pub_date=datetime.date(2002, 3, 11) + ) class Query(graphene.ObjectType): node = Node.Field() @@ -192,9 +177,9 @@ def test_should_node(): article = graphene.Field(ArticleNode) def resolve_reporter(self, info): - return Reporter(id=1, first_name='ABA', last_name='X') + return Reporter(id=1, first_name="ABA", last_name="X") - query = ''' + query = """ query ReporterQuery { reporter { id, @@ -220,26 +205,20 @@ def test_should_node(): } } } - ''' + """ expected = { - 'reporter': { - 'id': 'UmVwb3J0ZXJOb2RlOjE=', - 'firstName': 'ABA', - 'lastName': 'X', - 'email': '', - 'articles': { - 'edges': [{ - 'node': { - 'headline': 'Hi!' - } - }] - }, + "reporter": { + "id": "UmVwb3J0ZXJOb2RlOjE=", + "firstName": "ABA", + "lastName": "X", + "email": "", + "articles": {"edges": [{"node": {"headline": "Hi!"}}]}, + }, + "myArticle": { + "id": "QXJ0aWNsZU5vZGU6MQ==", + "headline": "Article node", + "pubDate": "2002-03-11", }, - 'myArticle': { - 'id': 'QXJ0aWNsZU5vZGU6MQ==', - 'headline': 'Article node', - 'pubDate': '2002-03-11', - } } schema = graphene.Schema(query=Query) result = schema.execute(query) @@ -249,11 +228,10 @@ def test_should_node(): def test_should_query_connectionfields(): class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) - only_fields = ('articles', ) + interfaces = (Node,) + only_fields = ("articles",) class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) @@ -262,7 +240,7 @@ def test_should_query_connectionfields(): return [Reporter(id=1)] schema = graphene.Schema(query=Query) - query = ''' + query = """ query ReporterConnectionQuery { allReporters { pageInfo { @@ -275,55 +253,48 @@ def test_should_query_connectionfields(): } } } - ''' + """ result = schema.execute(query) assert not result.errors assert result.data == { - 'allReporters': { - 'pageInfo': { - 'hasNextPage': False, - }, - 'edges': [{ - 'node': { - 'id': 'UmVwb3J0ZXJUeXBlOjE=' - } - }] + "allReporters": { + "pageInfo": {"hasNextPage": False}, + "edges": [{"node": {"id": "UmVwb3J0ZXJUeXBlOjE="}}], } } def test_should_keep_annotations(): - from django.db.models import ( - Count, - Avg, - ) + from django.db.models import Count, Avg class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) - only_fields = ('articles', ) + interfaces = (Node,) + only_fields = ("articles",) class ArticleType(DjangoObjectType): - class Meta: model = Article - interfaces = (Node, ) - filter_fields = ('lang', ) + interfaces = (Node,) + filter_fields = ("lang",) class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) all_articles = DjangoConnectionField(ArticleType) def resolve_all_reporters(self, info, **args): - return Reporter.objects.annotate(articles_c=Count('articles')).order_by('articles_c') + return Reporter.objects.annotate(articles_c=Count("articles")).order_by( + "articles_c" + ) def resolve_all_articles(self, info, **args): - return Article.objects.annotate(import_avg=Avg('importance')).order_by('import_avg') + return Article.objects.annotate(import_avg=Avg("importance")).order_by( + "import_avg" + ) schema = graphene.Schema(query=Query) - query = ''' + query = """ query ReporterConnectionQuery { allReporters { pageInfo { @@ -346,55 +317,51 @@ def test_should_keep_annotations(): } } } - ''' + """ result = schema.execute(query) assert not result.errors -@pytest.mark.skipif(not DJANGO_FILTER_INSTALLED, - reason="django-filter should be installed") +@pytest.mark.skipif( + not DJANGO_FILTER_INSTALLED, reason="django-filter should be installed" +) def test_should_query_node_filtering(): class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class ArticleType(DjangoObjectType): - class Meta: model = Article - interfaces = (Node, ) - filter_fields = ('lang', ) + interfaces = (Node,) + filter_fields = ("lang",) class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) r = Reporter.objects.create( - first_name='John', - last_name='Doe', - email='johndoe@example.com', - a_choice=1 + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 ) Article.objects.create( - headline='Article Node 1', + headline="Article Node 1", pub_date=datetime.date.today(), pub_date_time=datetime.datetime.now(), reporter=r, editor=r, - lang='es' + lang="es", ) Article.objects.create( - headline='Article Node 2', + headline="Article Node 2", pub_date=datetime.date.today(), pub_date_time=datetime.datetime.now(), reporter=r, editor=r, - lang='en' + lang="en", ) schema = graphene.Schema(query=Query) - query = ''' + query = """ query NodeFilteringQuery { allReporters { edges { @@ -411,22 +378,20 @@ def test_should_query_node_filtering(): } } } - ''' + """ expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'id': 'UmVwb3J0ZXJUeXBlOjE=', - 'articles': { - 'edges': [{ - 'node': { - 'id': 'QXJ0aWNsZVR5cGU6MQ==' - } - }] + "allReporters": { + "edges": [ + { + "node": { + "id": "UmVwb3J0ZXJUeXBlOjE=", + "articles": { + "edges": [{"node": {"id": "QXJ0aWNsZVR5cGU6MQ=="}}] + }, } } - }] + ] } } @@ -435,15 +400,15 @@ def test_should_query_node_filtering(): assert result.data == expected -@pytest.mark.skipif(not DJANGO_FILTER_INSTALLED, - reason="django-filter should be installed") +@pytest.mark.skipif( + not DJANGO_FILTER_INSTALLED, reason="django-filter should be installed" +) def test_should_query_node_filtering_with_distinct_queryset(): class FilmType(DjangoObjectType): - class Meta: model = Film - interfaces = (Node, ) - filter_fields = ('genre',) + interfaces = (Node,) + filter_fields = ("genre",) class Query(graphene.ObjectType): films = DjangoConnectionField(FilmType) @@ -452,17 +417,15 @@ def test_should_query_node_filtering_with_distinct_queryset(): # return Reporter.objects.filter(Q(films__film__location__contains="Berlin") | Q(a_choice=1)) def resolve_films(self, info, **args): - return Film.objects.filter(Q(details__location__contains="Berlin") | Q(genre__in=['ot'])).distinct() + return Film.objects.filter( + Q(details__location__contains="Berlin") | Q(genre__in=["ot"]) + ).distinct() - f = Film.objects.create( - ) - fd = FilmDetails.objects.create( - location="Berlin", - film=f - ) + f = Film.objects.create() + fd = FilmDetails.objects.create(location="Berlin", film=f) schema = graphene.Schema(query=Query) - query = ''' + query = """ query NodeFilteringQuery { films { edges { @@ -472,75 +435,63 @@ def test_should_query_node_filtering_with_distinct_queryset(): } } } - ''' + """ - expected = { - 'films': { - 'edges': [{ - 'node': { - 'genre': 'OT' - } - }] - } - } + expected = {"films": {"edges": [{"node": {"genre": "OT"}}]}} result = schema.execute(query) assert not result.errors assert result.data == expected -@pytest.mark.skipif(not DJANGO_FILTER_INSTALLED, - reason="django-filter should be installed") +@pytest.mark.skipif( + not DJANGO_FILTER_INSTALLED, reason="django-filter should be installed" +) def test_should_query_node_multiple_filtering(): class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class ArticleType(DjangoObjectType): - class Meta: model = Article - interfaces = (Node, ) - filter_fields = ('lang', 'headline') + interfaces = (Node,) + filter_fields = ("lang", "headline") class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) r = Reporter.objects.create( - first_name='John', - last_name='Doe', - email='johndoe@example.com', - a_choice=1 + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 ) Article.objects.create( - headline='Article Node 1', + headline="Article Node 1", pub_date=datetime.date.today(), pub_date_time=datetime.datetime.now(), reporter=r, editor=r, - lang='es' + lang="es", ) Article.objects.create( - headline='Article Node 2', + headline="Article Node 2", pub_date=datetime.date.today(), pub_date_time=datetime.datetime.now(), reporter=r, editor=r, - lang='es' + lang="es", ) Article.objects.create( - headline='Article Node 3', + headline="Article Node 3", pub_date=datetime.date.today(), pub_date_time=datetime.datetime.now(), reporter=r, editor=r, - lang='en' + lang="en", ) schema = graphene.Schema(query=Query) - query = ''' + query = """ query NodeFilteringQuery { allReporters { edges { @@ -557,22 +508,20 @@ def test_should_query_node_multiple_filtering(): } } } - ''' + """ expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'id': 'UmVwb3J0ZXJUeXBlOjE=', - 'articles': { - 'edges': [{ - 'node': { - 'id': 'QXJ0aWNsZVR5cGU6MQ==' - } - }] + "allReporters": { + "edges": [ + { + "node": { + "id": "UmVwb3J0ZXJUeXBlOjE=", + "articles": { + "edges": [{"node": {"id": "QXJ0aWNsZVR5cGU6MQ=="}}] + }, } } - }] + ] } } @@ -585,23 +534,19 @@ def test_should_enforce_first_or_last(): graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST = True class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) r = Reporter.objects.create( - first_name='John', - last_name='Doe', - email='johndoe@example.com', - a_choice=1 + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 ) schema = graphene.Schema(query=Query) - query = ''' + query = """ query NodeFilteringQuery { allReporters { edges { @@ -611,17 +556,15 @@ def test_should_enforce_first_or_last(): } } } - ''' + """ - expected = { - 'allReporters': None - } + expected = {"allReporters": None} result = schema.execute(query) assert len(result.errors) == 1 assert str(result.errors[0]) == ( - 'You must provide a `first` or `last` value to properly ' - 'paginate the `allReporters` connection.' + "You must provide a `first` or `last` value to properly " + "paginate the `allReporters` connection." ) assert result.data == expected @@ -630,23 +573,19 @@ def test_should_error_if_first_is_greater_than_max(): graphene_settings.RELAY_CONNECTION_MAX_LIMIT = 100 class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) r = Reporter.objects.create( - first_name='John', - last_name='Doe', - email='johndoe@example.com', - a_choice=1 + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 ) schema = graphene.Schema(query=Query) - query = ''' + query = """ query NodeFilteringQuery { allReporters(first: 101) { edges { @@ -656,17 +595,15 @@ def test_should_error_if_first_is_greater_than_max(): } } } - ''' + """ - expected = { - 'allReporters': None - } + expected = {"allReporters": None} result = schema.execute(query) assert len(result.errors) == 1 assert str(result.errors[0]) == ( - 'Requesting 101 records on the `allReporters` connection ' - 'exceeds the `first` limit of 100 records.' + "Requesting 101 records on the `allReporters` connection " + "exceeds the `first` limit of 100 records." ) assert result.data == expected @@ -677,23 +614,19 @@ def test_should_error_if_last_is_greater_than_max(): graphene_settings.RELAY_CONNECTION_MAX_LIMIT = 100 class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) r = Reporter.objects.create( - first_name='John', - last_name='Doe', - email='johndoe@example.com', - a_choice=1 + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 ) schema = graphene.Schema(query=Query) - query = ''' + query = """ query NodeFilteringQuery { allReporters(last: 101) { edges { @@ -703,17 +636,15 @@ def test_should_error_if_last_is_greater_than_max(): } } } - ''' + """ - expected = { - 'allReporters': None - } + expected = {"allReporters": None} result = schema.execute(query) assert len(result.errors) == 1 assert str(result.errors[0]) == ( - 'Requesting 101 records on the `allReporters` connection ' - 'exceeds the `last` limit of 100 records.' + "Requesting 101 records on the `allReporters` connection " + "exceeds the `last` limit of 100 records." ) assert result.data == expected @@ -724,10 +655,9 @@ def test_should_query_promise_connectionfields(): from promise import Promise class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) @@ -736,7 +666,7 @@ def test_should_query_promise_connectionfields(): return Promise.resolve([Reporter(id=1)]) schema = graphene.Schema(query=Query) - query = ''' + query = """ query ReporterPromiseConnectionQuery { allReporters(first: 1) { edges { @@ -746,36 +676,25 @@ def test_should_query_promise_connectionfields(): } } } - ''' + """ - expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'id': 'UmVwb3J0ZXJUeXBlOjE=' - } - }] - } - } + expected = {"allReporters": {"edges": [{"node": {"id": "UmVwb3J0ZXJUeXBlOjE="}}]}} result = schema.execute(query) assert not result.errors assert result.data == expected + def test_should_query_connectionfields_with_last(): r = Reporter.objects.create( - first_name='John', - last_name='Doe', - email='johndoe@example.com', - a_choice=1 + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 ) class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) @@ -784,7 +703,7 @@ def test_should_query_connectionfields_with_last(): return Reporter.objects.all() schema = graphene.Schema(query=Query) - query = ''' + query = """ query ReporterLastQuery { allReporters(last: 1) { edges { @@ -794,52 +713,38 @@ def test_should_query_connectionfields_with_last(): } } } - ''' + """ - expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'id': 'UmVwb3J0ZXJUeXBlOjE=' - } - }] - } - } + expected = {"allReporters": {"edges": [{"node": {"id": "UmVwb3J0ZXJUeXBlOjE="}}]}} result = schema.execute(query) assert not result.errors assert result.data == expected + def test_should_query_connectionfields_with_manager(): r = Reporter.objects.create( - first_name='John', - last_name='Doe', - email='johndoe@example.com', - a_choice=1 + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 ) r = Reporter.objects.create( - first_name='John', - last_name='NotDoe', - email='johndoe@example.com', - a_choice=1 + first_name="John", last_name="NotDoe", email="johndoe@example.com", a_choice=1 ) class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(graphene.ObjectType): - all_reporters = DjangoConnectionField(ReporterType, on='doe_objects') + all_reporters = DjangoConnectionField(ReporterType, on="doe_objects") def resolve_all_reporters(self, info, **args): return Reporter.objects.all() schema = graphene.Schema(query=Query) - query = ''' + query = """ query ReporterLastQuery { allReporters(first: 2) { edges { @@ -849,17 +754,9 @@ def test_should_query_connectionfields_with_manager(): } } } - ''' + """ - expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'id': 'UmVwb3J0ZXJUeXBlOjE=' - } - }] - } - } + expected = {"allReporters": {"edges": [{"node": {"id": "UmVwb3J0ZXJUeXBlOjE="}}]}} result = schema.execute(query) assert not result.errors @@ -872,24 +769,24 @@ def test_should_query_dataloader_fields(): def article_batch_load_fn(keys): queryset = Article.objects.filter(reporter_id__in=keys) - return Promise.resolve([ - [article for article in queryset if article.reporter_id == id] - for id in keys - ]) + return Promise.resolve( + [ + [article for article in queryset if article.reporter_id == id] + for id in keys + ] + ) article_loader = DataLoader(article_batch_load_fn) class ArticleType(DjangoObjectType): - class Meta: model = Article - interfaces = (Node, ) + interfaces = (Node,) class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) use_connection = True articles = DjangoConnectionField(ArticleType) @@ -901,31 +798,28 @@ def test_should_query_dataloader_fields(): all_reporters = DjangoConnectionField(ReporterType) r = Reporter.objects.create( - first_name='John', - last_name='Doe', - email='johndoe@example.com', - a_choice=1 + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 ) Article.objects.create( - headline='Article Node 1', + headline="Article Node 1", pub_date=datetime.date.today(), pub_date_time=datetime.datetime.now(), reporter=r, editor=r, - lang='es' + lang="es", ) Article.objects.create( - headline='Article Node 2', + headline="Article Node 2", pub_date=datetime.date.today(), pub_date_time=datetime.datetime.now(), reporter=r, editor=r, - lang='en' + lang="en", ) schema = graphene.Schema(query=Query) - query = ''' + query = """ query ReporterPromiseConnectionQuery { allReporters(first: 1) { edges { @@ -942,26 +836,23 @@ def test_should_query_dataloader_fields(): } } } - ''' + """ expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'id': 'UmVwb3J0ZXJUeXBlOjE=', - 'articles': { - 'edges': [{ - 'node': { - 'headline': 'Article Node 1', - } - }, { - 'node': { - 'headline': 'Article Node 2' - } - }] + "allReporters": { + "edges": [ + { + "node": { + "id": "UmVwb3J0ZXJUeXBlOjE=", + "articles": { + "edges": [ + {"node": {"headline": "Article Node 1"}}, + {"node": {"headline": "Article Node 2"}}, + ] + }, } } - }] + ] } } @@ -972,7 +863,7 @@ def test_should_query_dataloader_fields(): def test_should_handle_inherited_choices(): class BaseModel(models.Model): - choice_field = models.IntegerField(choices=((0, 'zero'), (1, 'one'))) + choice_field = models.IntegerField(choices=((0, "zero"), (1, "one"))) class ChildModel(BaseModel): class Meta: @@ -991,13 +882,13 @@ def test_should_handle_inherited_choices(): child = graphene.Field(ChildType) schema = graphene.Schema(query=Query) - query = ''' + query = """ query { child { choiceField } } - ''' + """ result = schema.execute(query) assert not result.errors @@ -1007,24 +898,21 @@ def test_proxy_model_support(): This test asserts that we can query for all Reporters, even if some are of a proxy model type at runtime. """ - class ReporterType(DjangoObjectType): + class ReporterType(DjangoObjectType): class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) use_connection = True reporter_1 = Reporter.objects.create( - first_name='John', - last_name='Doe', - email='johndoe@example.com', - a_choice=1 + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 ) reporter_2 = CNNReporter.objects.create( - first_name='Some', - last_name='Guy', - email='someguy@cnn.com', + first_name="Some", + last_name="Guy", + email="someguy@cnn.com", a_choice=1, reporter_type=2, # set this guy to be CNN ) @@ -1033,7 +921,7 @@ def test_proxy_model_support(): all_reporters = DjangoConnectionField(ReporterType) schema = graphene.Schema(query=Query) - query = ''' + query = """ query ProxyModelQuery { allReporters { edges { @@ -1043,20 +931,13 @@ def test_proxy_model_support(): } } } - ''' + """ expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'id': 'UmVwb3J0ZXJUeXBlOjE=', - }, - }, - { - 'node': { - 'id': 'UmVwb3J0ZXJUeXBlOjI=', - }, - } + "allReporters": { + "edges": [ + {"node": {"id": "UmVwb3J0ZXJUeXBlOjE="}}, + {"node": {"id": "UmVwb3J0ZXJUeXBlOjI="}}, ] } } @@ -1080,24 +961,21 @@ def test_proxy_model_fails(): represent the type, and it doesn't seem like there is a clear way to enforce this pattern across all projects """ - class CNNReporterType(DjangoObjectType): + class CNNReporterType(DjangoObjectType): class Meta: model = CNNReporter - interfaces = (Node, ) + interfaces = (Node,) use_connection = True reporter_1 = Reporter.objects.create( - first_name='John', - last_name='Doe', - email='johndoe@example.com', - a_choice=1 + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 ) reporter_2 = CNNReporter.objects.create( - first_name='Some', - last_name='Guy', - email='someguy@cnn.com', + first_name="Some", + last_name="Guy", + email="someguy@cnn.com", a_choice=1, reporter_type=2, # set this guy to be CNN ) @@ -1106,7 +984,7 @@ def test_proxy_model_fails(): all_reporters = DjangoConnectionField(CNNReporterType) schema = graphene.Schema(query=Query) - query = ''' + query = """ query ProxyModelQuery { allReporters { edges { @@ -1116,20 +994,13 @@ def test_proxy_model_fails(): } } } - ''' + """ expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'id': 'UmVwb3J0ZXJUeXBlOjE=', - }, - }, - { - 'node': { - 'id': 'UmVwb3J0ZXJUeXBlOjI=', - }, - } + "allReporters": { + "edges": [ + {"node": {"id": "UmVwb3J0ZXJUeXBlOjE="}}, + {"node": {"id": "UmVwb3J0ZXJUeXBlOjI="}}, ] } } diff --git a/graphene_django/tests/test_schema.py b/graphene_django/tests/test_schema.py index 904c043..452449b 100644 --- a/graphene_django/tests/test_schema.py +++ b/graphene_django/tests/test_schema.py @@ -7,48 +7,47 @@ from .models import Reporter def test_should_raise_if_no_model(): with raises(Exception) as excinfo: + class Character1(DjangoObjectType): pass - assert 'valid Django Model' in str(excinfo.value) + + assert "valid Django Model" in str(excinfo.value) def test_should_raise_if_model_is_invalid(): with raises(Exception) as excinfo: - class Character2(DjangoObjectType): + class Character2(DjangoObjectType): class Meta: model = 1 - assert 'valid Django Model' in str(excinfo.value) + + assert "valid Django Model" in str(excinfo.value) def test_should_map_fields_correctly(): class ReporterType2(DjangoObjectType): - class Meta: model = Reporter registry = Registry() + fields = list(ReporterType2._meta.fields.keys()) assert fields[:-2] == [ - 'id', - 'first_name', - 'last_name', - 'email', - 'pets', - 'a_choice', - 'reporter_type' + "id", + "first_name", + "last_name", + "email", + "pets", + "a_choice", + "reporter_type", ] - assert sorted(fields[-2:]) == [ - 'articles', - 'films', - ] + assert sorted(fields[-2:]) == ["articles", "films"] def test_should_map_only_few_fields(): class Reporter2(DjangoObjectType): - class Meta: model = Reporter - only_fields = ('id', 'email') + only_fields = ("id", "email") - assert list(Reporter2._meta.fields.keys()) == ['id', 'email'] + assert list(Reporter2._meta.fields.keys()) == ["id", "email"] diff --git a/graphene_django/tests/test_types.py b/graphene_django/tests/test_types.py index 18d8055..8a8643b 100644 --- a/graphene_django/tests/test_types.py +++ b/graphene_django/tests/test_types.py @@ -12,27 +12,30 @@ registry.reset_global_registry() class Reporter(DjangoObjectType): - '''Reporter description''' + """Reporter description""" + class Meta: model = ReporterModel class ArticleConnection(Connection): - '''Article Connection''' + """Article Connection""" + test = String() def resolve_test(): - return 'test' + return "test" class Meta: abstract = True class Article(DjangoObjectType): - '''Article description''' + """Article description""" + class Meta: model = ArticleModel - interfaces = (Node, ) + interfaces = (Node,) connection_class = ArticleConnection @@ -48,7 +51,7 @@ def test_django_interface(): assert issubclass(Node, Node) -@patch('graphene_django.tests.models.Article.objects.get', return_value=Article(id=1)) +@patch("graphene_django.tests.models.Article.objects.get", return_value=Article(id=1)) def test_django_get_node(get): article = Article.get_node(None, 1) get.assert_called_with(pk=1) @@ -58,18 +61,35 @@ def test_django_get_node(get): def test_django_objecttype_map_correct_fields(): fields = Reporter._meta.fields fields = list(fields.keys()) - assert fields[:-2] == ['id', 'first_name', 'last_name', 'email', 'pets', 'a_choice', 'reporter_type'] - assert sorted(fields[-2:]) == ['articles', 'films'] + assert fields[:-2] == [ + "id", + "first_name", + "last_name", + "email", + "pets", + "a_choice", + "reporter_type", + ] + assert sorted(fields[-2:]) == ["articles", "films"] def test_django_objecttype_with_node_have_correct_fields(): fields = Article._meta.fields - assert list(fields.keys()) == ['id', 'headline', 'pub_date', 'pub_date_time', 'reporter', 'editor', 'lang', 'importance'] + assert list(fields.keys()) == [ + "id", + "headline", + "pub_date", + "pub_date_time", + "reporter", + "editor", + "lang", + "importance", + ] def test_django_objecttype_with_custom_meta(): class ArticleTypeOptions(DjangoObjectTypeOptions): - '''Article Type Options''' + """Article Type Options""" class ArticleType(DjangoObjectType): class Meta: @@ -77,7 +97,7 @@ def test_django_objecttype_with_custom_meta(): @classmethod def __init_subclass_with_meta__(cls, **options): - options.setdefault('_meta', ArticleTypeOptions(cls)) + options.setdefault("_meta", ArticleTypeOptions(cls)) super(ArticleType, cls).__init_subclass_with_meta__(**options) class Article(ArticleType): @@ -180,6 +200,7 @@ def with_local_registry(func): else: registry.registry = old return retval + return inner @@ -188,11 +209,10 @@ def test_django_objecttype_only_fields(): class Reporter(DjangoObjectType): class Meta: model = ReporterModel - only_fields = ('id', 'email', 'films') - + only_fields = ("id", "email", "films") fields = list(Reporter._meta.fields.keys()) - assert fields == ['id', 'email', 'films'] + assert fields == ["id", "email", "films"] @with_local_registry @@ -200,8 +220,7 @@ def test_django_objecttype_exclude_fields(): class Reporter(DjangoObjectType): class Meta: model = ReporterModel - exclude_fields = ('email') - + exclude_fields = "email" fields = list(Reporter._meta.fields.keys()) - assert 'email' not in fields + assert "email" not in fields diff --git a/graphene_django/tests/test_views.py b/graphene_django/tests/test_views.py index dd3bfb2..db6cc4e 100644 --- a/graphene_django/tests/test_views.py +++ b/graphene_django/tests/test_views.py @@ -8,15 +8,15 @@ except ImportError: from urllib.parse import urlencode -def url_string(string='/graphql', **url_params): +def url_string(string="/graphql", **url_params): if url_params: - string += '?' + urlencode(url_params) + string += "?" + urlencode(url_params) return string def batch_url_string(**url_params): - return url_string('/graphql/batch', **url_params) + return url_string("/graphql/batch", **url_params) def response_json(response): @@ -28,441 +28,446 @@ jl = lambda **kwargs: json.dumps([kwargs]) def test_graphiql_is_enabled(client): - response = client.get(url_string(), HTTP_ACCEPT='text/html') + response = client.get(url_string(), HTTP_ACCEPT="text/html") assert response.status_code == 200 - assert response['Content-Type'].split(';')[0] == 'text/html' + assert response["Content-Type"].split(";")[0] == "text/html" + def test_qfactor_graphiql(client): - response = client.get(url_string(query='{test}'), HTTP_ACCEPT='application/json;q=0.8, text/html;q=0.9') + response = client.get( + url_string(query="{test}"), + HTTP_ACCEPT="application/json;q=0.8, text/html;q=0.9", + ) assert response.status_code == 200 - assert response['Content-Type'].split(';')[0] == 'text/html' + assert response["Content-Type"].split(";")[0] == "text/html" + def test_qfactor_json(client): - response = client.get(url_string(query='{test}'), HTTP_ACCEPT='text/html;q=0.8, application/json;q=0.9') + response = client.get( + url_string(query="{test}"), + HTTP_ACCEPT="text/html;q=0.8, application/json;q=0.9", + ) assert response.status_code == 200 - assert response['Content-Type'].split(';')[0] == 'application/json' - assert response_json(response) == { - 'data': {'test': "Hello World"} - } + assert response["Content-Type"].split(";")[0] == "application/json" + assert response_json(response) == {"data": {"test": "Hello World"}} def test_allows_get_with_query_param(client): - response = client.get(url_string(query='{test}')) + response = client.get(url_string(query="{test}")) assert response.status_code == 200 - assert response_json(response) == { - 'data': {'test': "Hello World"} - } + assert response_json(response) == {"data": {"test": "Hello World"}} def test_allows_get_with_variable_values(client): - response = client.get(url_string( - query='query helloWho($who: String){ test(who: $who) }', - variables=json.dumps({'who': "Dolly"}) - )) + response = client.get( + url_string( + query="query helloWho($who: String){ test(who: $who) }", + variables=json.dumps({"who": "Dolly"}), + ) + ) assert response.status_code == 200 - assert response_json(response) == { - 'data': {'test': "Hello Dolly"} - } + assert response_json(response) == {"data": {"test": "Hello Dolly"}} def test_allows_get_with_operation_name(client): - response = client.get(url_string( - query=''' + response = client.get( + url_string( + query=""" query helloYou { test(who: "You"), ...shared } query helloWorld { test(who: "World"), ...shared } query helloDolly { test(who: "Dolly"), ...shared } fragment shared on QueryRoot { shared: test(who: "Everyone") } - ''', - operationName='helloWorld' - )) + """, + operationName="helloWorld", + ) + ) assert response.status_code == 200 assert response_json(response) == { - 'data': { - 'test': 'Hello World', - 'shared': 'Hello Everyone' - } + "data": {"test": "Hello World", "shared": "Hello Everyone"} } def test_reports_validation_errors(client): - response = client.get(url_string( - query='{ test, unknownOne, unknownTwo }' - )) + response = client.get(url_string(query="{ test, unknownOne, unknownTwo }")) assert response.status_code == 400 assert response_json(response) == { - 'errors': [ + "errors": [ { - 'message': 'Cannot query field "unknownOne" on type "QueryRoot".', - 'locations': [{'line': 1, 'column': 9}] + "message": 'Cannot query field "unknownOne" on type "QueryRoot".', + "locations": [{"line": 1, "column": 9}], }, { - 'message': 'Cannot query field "unknownTwo" on type "QueryRoot".', - 'locations': [{'line': 1, 'column': 21}] - } + "message": 'Cannot query field "unknownTwo" on type "QueryRoot".', + "locations": [{"line": 1, "column": 21}], + }, ] } def test_errors_when_missing_operation_name(client): - response = client.get(url_string( - query=''' + response = client.get( + url_string( + query=""" query TestQuery { test } mutation TestMutation { writeTest { test } } - ''' - )) + """ + ) + ) assert response.status_code == 400 assert response_json(response) == { - 'errors': [ + "errors": [ { - 'message': 'Must provide operation name if query contains multiple operations.' + "message": "Must provide operation name if query contains multiple operations." } ] } def test_errors_when_sending_a_mutation_via_get(client): - response = client.get(url_string( - query=''' + response = client.get( + url_string( + query=""" mutation TestMutation { writeTest { test } } - ''' - )) + """ + ) + ) assert response.status_code == 405 assert response_json(response) == { - 'errors': [ - { - 'message': 'Can only perform a mutation operation from a POST request.' - } + "errors": [ + {"message": "Can only perform a mutation operation from a POST request."} ] } def test_errors_when_selecting_a_mutation_within_a_get(client): - response = client.get(url_string( - query=''' + response = client.get( + url_string( + query=""" query TestQuery { test } mutation TestMutation { writeTest { test } } - ''', - operationName='TestMutation' - )) + """, + operationName="TestMutation", + ) + ) assert response.status_code == 405 assert response_json(response) == { - 'errors': [ - { - 'message': 'Can only perform a mutation operation from a POST request.' - } + "errors": [ + {"message": "Can only perform a mutation operation from a POST request."} ] } def test_allows_mutation_to_exist_within_a_get(client): - response = client.get(url_string( - query=''' + response = client.get( + url_string( + query=""" query TestQuery { test } mutation TestMutation { writeTest { test } } - ''', - operationName='TestQuery' - )) + """, + operationName="TestQuery", + ) + ) assert response.status_code == 200 - assert response_json(response) == { - 'data': {'test': "Hello World"} - } + assert response_json(response) == {"data": {"test": "Hello World"}} def test_allows_post_with_json_encoding(client): - response = client.post(url_string(), j(query='{test}'), 'application/json') + response = client.post(url_string(), j(query="{test}"), "application/json") assert response.status_code == 200 - assert response_json(response) == { - 'data': {'test': "Hello World"} - } + assert response_json(response) == {"data": {"test": "Hello World"}} def test_batch_allows_post_with_json_encoding(client): - response = client.post(batch_url_string(), jl(id=1, query='{test}'), 'application/json') + response = client.post( + batch_url_string(), jl(id=1, query="{test}"), "application/json" + ) assert response.status_code == 200 - assert response_json(response) == [{ - 'id': 1, - 'data': {'test': "Hello World"}, - 'status': 200, - }] + assert response_json(response) == [ + {"id": 1, "data": {"test": "Hello World"}, "status": 200} + ] def test_batch_fails_if_is_empty(client): - response = client.post(batch_url_string(), '[]', 'application/json') + response = client.post(batch_url_string(), "[]", "application/json") assert response.status_code == 400 assert response_json(response) == { - 'errors': [{'message': 'Received an empty list in the batch request.'}] + "errors": [{"message": "Received an empty list in the batch request."}] } def test_allows_sending_a_mutation_via_post(client): - response = client.post(url_string(), j(query='mutation TestMutation { writeTest { test } }'), 'application/json') + response = client.post( + url_string(), + j(query="mutation TestMutation { writeTest { test } }"), + "application/json", + ) assert response.status_code == 200 - assert response_json(response) == { - 'data': {'writeTest': {'test': 'Hello World'}} - } + assert response_json(response) == {"data": {"writeTest": {"test": "Hello World"}}} def test_allows_post_with_url_encoding(client): - response = client.post(url_string(), urlencode(dict(query='{test}')), 'application/x-www-form-urlencoded') + response = client.post( + url_string(), + urlencode(dict(query="{test}")), + "application/x-www-form-urlencoded", + ) assert response.status_code == 200 - assert response_json(response) == { - 'data': {'test': "Hello World"} - } + assert response_json(response) == {"data": {"test": "Hello World"}} def test_supports_post_json_query_with_string_variables(client): - response = client.post(url_string(), j( - query='query helloWho($who: String){ test(who: $who) }', - variables=json.dumps({'who': "Dolly"}) - ), 'application/json') + response = client.post( + url_string(), + j( + query="query helloWho($who: String){ test(who: $who) }", + variables=json.dumps({"who": "Dolly"}), + ), + "application/json", + ) assert response.status_code == 200 - assert response_json(response) == { - 'data': {'test': "Hello Dolly"} - } - + assert response_json(response) == {"data": {"test": "Hello Dolly"}} def test_batch_supports_post_json_query_with_string_variables(client): - response = client.post(batch_url_string(), jl( - id=1, - query='query helloWho($who: String){ test(who: $who) }', - variables=json.dumps({'who': "Dolly"}) - ), 'application/json') + response = client.post( + batch_url_string(), + jl( + id=1, + query="query helloWho($who: String){ test(who: $who) }", + variables=json.dumps({"who": "Dolly"}), + ), + "application/json", + ) assert response.status_code == 200 - assert response_json(response) == [{ - 'id': 1, - 'data': {'test': "Hello Dolly"}, - 'status': 200, - }] + assert response_json(response) == [ + {"id": 1, "data": {"test": "Hello Dolly"}, "status": 200} + ] def test_supports_post_json_query_with_json_variables(client): - response = client.post(url_string(), j( - query='query helloWho($who: String){ test(who: $who) }', - variables={'who': "Dolly"} - ), 'application/json') + response = client.post( + url_string(), + j( + query="query helloWho($who: String){ test(who: $who) }", + variables={"who": "Dolly"}, + ), + "application/json", + ) assert response.status_code == 200 - assert response_json(response) == { - 'data': {'test': "Hello Dolly"} - } + assert response_json(response) == {"data": {"test": "Hello Dolly"}} def test_batch_supports_post_json_query_with_json_variables(client): - response = client.post(batch_url_string(), jl( - id=1, - query='query helloWho($who: String){ test(who: $who) }', - variables={'who': "Dolly"} - ), 'application/json') + response = client.post( + batch_url_string(), + jl( + id=1, + query="query helloWho($who: String){ test(who: $who) }", + variables={"who": "Dolly"}, + ), + "application/json", + ) assert response.status_code == 200 - assert response_json(response) == [{ - 'id': 1, - 'data': {'test': "Hello Dolly"}, - 'status': 200, - }] + assert response_json(response) == [ + {"id": 1, "data": {"test": "Hello Dolly"}, "status": 200} + ] def test_supports_post_url_encoded_query_with_string_variables(client): - response = client.post(url_string(), urlencode(dict( - query='query helloWho($who: String){ test(who: $who) }', - variables=json.dumps({'who': "Dolly"}) - )), 'application/x-www-form-urlencoded') + response = client.post( + url_string(), + urlencode( + dict( + query="query helloWho($who: String){ test(who: $who) }", + variables=json.dumps({"who": "Dolly"}), + ) + ), + "application/x-www-form-urlencoded", + ) assert response.status_code == 200 - assert response_json(response) == { - 'data': {'test': "Hello Dolly"} - } + assert response_json(response) == {"data": {"test": "Hello Dolly"}} def test_supports_post_json_quey_with_get_variable_values(client): - response = client.post(url_string( - variables=json.dumps({'who': "Dolly"}) - ), j( - query='query helloWho($who: String){ test(who: $who) }', - ), 'application/json') + response = client.post( + url_string(variables=json.dumps({"who": "Dolly"})), + j(query="query helloWho($who: String){ test(who: $who) }"), + "application/json", + ) assert response.status_code == 200 - assert response_json(response) == { - 'data': {'test': "Hello Dolly"} - } + assert response_json(response) == {"data": {"test": "Hello Dolly"}} def test_post_url_encoded_query_with_get_variable_values(client): - response = client.post(url_string( - variables=json.dumps({'who': "Dolly"}) - ), urlencode(dict( - query='query helloWho($who: String){ test(who: $who) }', - )), 'application/x-www-form-urlencoded') + response = client.post( + url_string(variables=json.dumps({"who": "Dolly"})), + urlencode(dict(query="query helloWho($who: String){ test(who: $who) }")), + "application/x-www-form-urlencoded", + ) assert response.status_code == 200 - assert response_json(response) == { - 'data': {'test': "Hello Dolly"} - } + assert response_json(response) == {"data": {"test": "Hello Dolly"}} def test_supports_post_raw_text_query_with_get_variable_values(client): - response = client.post(url_string( - variables=json.dumps({'who': "Dolly"}) - ), - 'query helloWho($who: String){ test(who: $who) }', - 'application/graphql' + response = client.post( + url_string(variables=json.dumps({"who": "Dolly"})), + "query helloWho($who: String){ test(who: $who) }", + "application/graphql", + ) + + assert response.status_code == 200 + assert response_json(response) == {"data": {"test": "Hello Dolly"}} + + +def test_allows_post_with_operation_name(client): + response = client.post( + url_string(), + j( + query=""" + query helloYou { test(who: "You"), ...shared } + query helloWorld { test(who: "World"), ...shared } + query helloDolly { test(who: "Dolly"), ...shared } + fragment shared on QueryRoot { + shared: test(who: "Everyone") + } + """, + operationName="helloWorld", + ), + "application/json", ) assert response.status_code == 200 assert response_json(response) == { - 'data': {'test': "Hello Dolly"} - } - - -def test_allows_post_with_operation_name(client): - response = client.post(url_string(), j( - query=''' - query helloYou { test(who: "You"), ...shared } - query helloWorld { test(who: "World"), ...shared } - query helloDolly { test(who: "Dolly"), ...shared } - fragment shared on QueryRoot { - shared: test(who: "Everyone") - } - ''', - operationName='helloWorld' - ), 'application/json') - - assert response.status_code == 200 - assert response_json(response) == { - 'data': { - 'test': 'Hello World', - 'shared': 'Hello Everyone' - } + "data": {"test": "Hello World", "shared": "Hello Everyone"} } def test_batch_allows_post_with_operation_name(client): - response = client.post(batch_url_string(), jl( - id=1, - query=''' + response = client.post( + batch_url_string(), + jl( + id=1, + query=""" query helloYou { test(who: "You"), ...shared } query helloWorld { test(who: "World"), ...shared } query helloDolly { test(who: "Dolly"), ...shared } fragment shared on QueryRoot { shared: test(who: "Everyone") } - ''', - operationName='helloWorld' - ), 'application/json') + """, + operationName="helloWorld", + ), + "application/json", + ) assert response.status_code == 200 - assert response_json(response) == [{ - 'id': 1, - 'data': { - 'test': 'Hello World', - 'shared': 'Hello Everyone' - }, - 'status': 200, - }] + assert response_json(response) == [ + { + "id": 1, + "data": {"test": "Hello World", "shared": "Hello Everyone"}, + "status": 200, + } + ] def test_allows_post_with_get_operation_name(client): - response = client.post(url_string( - operationName='helloWorld' - ), ''' + response = client.post( + url_string(operationName="helloWorld"), + """ query helloYou { test(who: "You"), ...shared } query helloWorld { test(who: "World"), ...shared } query helloDolly { test(who: "Dolly"), ...shared } fragment shared on QueryRoot { shared: test(who: "Everyone") } - ''', - 'application/graphql') + """, + "application/graphql", + ) assert response.status_code == 200 assert response_json(response) == { - 'data': { - 'test': 'Hello World', - 'shared': 'Hello Everyone' - } + "data": {"test": "Hello World", "shared": "Hello Everyone"} } -@pytest.mark.urls('graphene_django.tests.urls_inherited') +@pytest.mark.urls("graphene_django.tests.urls_inherited") def test_inherited_class_with_attributes_works(client): - inherited_url = '/graphql/inherited/' + inherited_url = "/graphql/inherited/" # Check schema and pretty attributes work - response = client.post(url_string(inherited_url, query='{test}')) + response = client.post(url_string(inherited_url, query="{test}")) assert response.content.decode() == ( - '{\n' - ' "data": {\n' - ' "test": "Hello World"\n' - ' }\n' - '}' + "{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}" ) # Check graphiql works - response = client.get(url_string(inherited_url), HTTP_ACCEPT='text/html') + response = client.get(url_string(inherited_url), HTTP_ACCEPT="text/html") assert response.status_code == 200 -@pytest.mark.urls('graphene_django.tests.urls_pretty') +@pytest.mark.urls("graphene_django.tests.urls_pretty") def test_supports_pretty_printing(client): - response = client.get(url_string(query='{test}')) + response = client.get(url_string(query="{test}")) assert response.content.decode() == ( - '{\n' - ' "data": {\n' - ' "test": "Hello World"\n' - ' }\n' - '}' + "{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}" ) def test_supports_pretty_printing_by_request(client): - response = client.get(url_string(query='{test}', pretty='1')) + response = client.get(url_string(query="{test}", pretty="1")) assert response.content.decode() == ( - '{\n' - ' "data": {\n' - ' "test": "Hello World"\n' - ' }\n' - '}' + "{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}" ) def test_handles_field_errors_caught_by_graphql(client): - response = client.get(url_string(query='{thrower}')) + response = client.get(url_string(query="{thrower}")) assert response.status_code == 200 assert response_json(response) == { - 'data': None, - 'errors': [{ - 'locations': [{'column': 2, 'line': 1}], - 'path': ['thrower'], - 'message': 'Throws!', - }] + "data": None, + "errors": [ + { + "locations": [{"column": 2, "line": 1}], + "path": ["thrower"], + "message": "Throws!", + } + ], } def test_handles_syntax_errors_caught_by_graphql(client): - response = client.get(url_string(query='syntaxerror')) + response = client.get(url_string(query="syntaxerror")) assert response.status_code == 400 assert response_json(response) == { - 'errors': [{'locations': [{'column': 1, 'line': 1}], - 'message': 'Syntax Error GraphQL (1:1) ' - 'Unexpected Name "syntaxerror"\n\n1: syntaxerror\n ^\n'}] + "errors": [ + { + "locations": [{"column": 1, "line": 1}], + "message": "Syntax Error GraphQL (1:1) " + 'Unexpected Name "syntaxerror"\n\n1: syntaxerror\n ^\n', + } + ] } @@ -471,25 +476,25 @@ def test_handles_errors_caused_by_a_lack_of_query(client): assert response.status_code == 400 assert response_json(response) == { - 'errors': [{'message': 'Must provide query string.'}] + "errors": [{"message": "Must provide query string."}] } def test_handles_not_expected_json_bodies(client): - response = client.post(url_string(), '[]', 'application/json') + response = client.post(url_string(), "[]", "application/json") assert response.status_code == 400 assert response_json(response) == { - 'errors': [{'message': 'The received data is not a valid JSON query.'}] + "errors": [{"message": "The received data is not a valid JSON query."}] } def test_handles_invalid_json_bodies(client): - response = client.post(url_string(), '[oh}', 'application/json') + response = client.post(url_string(), "[oh}", "application/json") assert response.status_code == 400 assert response_json(response) == { - 'errors': [{'message': 'POST body sent invalid JSON.'}] + "errors": [{"message": "POST body sent invalid JSON."}] } @@ -499,63 +504,57 @@ def test_handles_django_request_error(client, monkeypatch): monkeypatch.setattr("django.http.request.HttpRequest.read", mocked_read) - valid_json = json.dumps(dict(foo='bar')) - response = client.post(url_string(), valid_json, 'application/json') + valid_json = json.dumps(dict(foo="bar")) + response = client.post(url_string(), valid_json, "application/json") assert response.status_code == 400 - assert response_json(response) == { - 'errors': [{'message': 'foo-bar'}] - } + assert response_json(response) == {"errors": [{"message": "foo-bar"}]} def test_handles_incomplete_json_bodies(client): - response = client.post(url_string(), '{"query":', 'application/json') + response = client.post(url_string(), '{"query":', "application/json") assert response.status_code == 400 assert response_json(response) == { - 'errors': [{'message': 'POST body sent invalid JSON.'}] + "errors": [{"message": "POST body sent invalid JSON."}] } def test_handles_plain_post_text(client): - response = client.post(url_string( - variables=json.dumps({'who': "Dolly"}) - ), - 'query helloWho($who: String){ test(who: $who) }', - 'text/plain' + response = client.post( + url_string(variables=json.dumps({"who": "Dolly"})), + "query helloWho($who: String){ test(who: $who) }", + "text/plain", ) assert response.status_code == 400 assert response_json(response) == { - 'errors': [{'message': 'Must provide query string.'}] + "errors": [{"message": "Must provide query string."}] } def test_handles_poorly_formed_variables(client): - response = client.get(url_string( - query='query helloWho($who: String){ test(who: $who) }', - variables='who:You' - )) + response = client.get( + url_string( + query="query helloWho($who: String){ test(who: $who) }", variables="who:You" + ) + ) assert response.status_code == 400 assert response_json(response) == { - 'errors': [{'message': 'Variables are invalid JSON.'}] + "errors": [{"message": "Variables are invalid JSON."}] } def test_handles_unsupported_http_methods(client): - response = client.put(url_string(query='{test}')) + response = client.put(url_string(query="{test}")) assert response.status_code == 405 - assert response['Allow'] == 'GET, POST' + assert response["Allow"] == "GET, POST" assert response_json(response) == { - 'errors': [{'message': 'GraphQL only supports GET and POST requests.'}] + "errors": [{"message": "GraphQL only supports GET and POST requests."}] } def test_passes_request_into_context_request(client): - response = client.get(url_string(query='{request}', q='testing')) + response = client.get(url_string(query="{request}", q="testing")) assert response.status_code == 200 - assert response_json(response) == { - 'data': { - 'request': 'testing' - } - } + assert response_json(response) == {"data": {"request": "testing"}} diff --git a/graphene_django/tests/urls.py b/graphene_django/tests/urls.py index 8597baa..66b3fc4 100644 --- a/graphene_django/tests/urls.py +++ b/graphene_django/tests/urls.py @@ -3,6 +3,6 @@ from django.conf.urls import url from ..views import GraphQLView urlpatterns = [ - url(r'^graphql/batch', GraphQLView.as_view(batch=True)), - url(r'^graphql', GraphQLView.as_view(graphiql=True)), + url(r"^graphql/batch", GraphQLView.as_view(batch=True)), + url(r"^graphql", GraphQLView.as_view(graphiql=True)), ] diff --git a/graphene_django/tests/urls_inherited.py b/graphene_django/tests/urls_inherited.py index 55ec18f..6fa8019 100644 --- a/graphene_django/tests/urls_inherited.py +++ b/graphene_django/tests/urls_inherited.py @@ -3,12 +3,11 @@ from django.conf.urls import url from ..views import GraphQLView from .schema_view import schema + class CustomGraphQLView(GraphQLView): schema = schema graphiql = True pretty = True -urlpatterns = [ - url(r'^graphql/inherited/$', CustomGraphQLView.as_view()), -] +urlpatterns = [url(r"^graphql/inherited/$", CustomGraphQLView.as_view())] diff --git a/graphene_django/tests/urls_pretty.py b/graphene_django/tests/urls_pretty.py index dfe4e5b..1133c87 100644 --- a/graphene_django/tests/urls_pretty.py +++ b/graphene_django/tests/urls_pretty.py @@ -3,6 +3,4 @@ from django.conf.urls import url from ..views import GraphQLView from .schema_view import schema -urlpatterns = [ - url(r'^graphql', GraphQLView.as_view(schema=schema, pretty=True)), -] +urlpatterns = [url(r"^graphql", GraphQLView.as_view(schema=schema, pretty=True))] diff --git a/graphene_django/types.py b/graphene_django/types.py index d14e3ab..aa8b5a3 100644 --- a/graphene_django/types.py +++ b/graphene_django/types.py @@ -8,8 +8,7 @@ from graphene.types.utils import yank_fields_from_attrs from .converter import convert_django_field_with_choices from .registry import Registry, get_global_registry -from .utils import (DJANGO_FILTER_INSTALLED, get_model_fields, - is_valid_django_model) +from .utils import DJANGO_FILTER_INSTALLED, get_model_fields, is_valid_django_model def construct_fields(model, registry, only_fields, exclude_fields): @@ -21,7 +20,7 @@ def construct_fields(model, registry, only_fields, exclude_fields): # is_already_created = name in options.fields is_excluded = name in exclude_fields # or is_already_created # https://docs.djangoproject.com/en/1.10/ref/models/fields/#django.db.models.ForeignKey.related_query_name - is_no_backref = str(name).endswith('+') + is_no_backref = str(name).endswith("+") if is_not_in_only or is_excluded or is_no_backref: # We skip this field if we specify only_fields and is not # in there. Or when we exclude this field in exclude_fields. @@ -43,9 +42,21 @@ class DjangoObjectTypeOptions(ObjectTypeOptions): class DjangoObjectType(ObjectType): @classmethod - def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=False, - only_fields=(), exclude_fields=(), filter_fields=None, connection=None, - connection_class=None, use_connection=None, interfaces=(), _meta=None, **options): + def __init_subclass_with_meta__( + cls, + model=None, + registry=None, + skip_registry=False, + only_fields=(), + exclude_fields=(), + filter_fields=None, + connection=None, + connection_class=None, + use_connection=None, + interfaces=(), + _meta=None, + **options + ): assert is_valid_django_model(model), ( 'You need to pass a valid Django Model in {}.Meta, received "{}".' ).format(cls.__name__, model) @@ -54,7 +65,7 @@ class DjangoObjectType(ObjectType): registry = get_global_registry() assert isinstance(registry, Registry), ( - 'The attribute registry in {} needs to be an instance of ' + "The attribute registry in {} needs to be an instance of " 'Registry, received "{}".' ).format(cls.__name__, registry) @@ -62,12 +73,13 @@ class DjangoObjectType(ObjectType): raise Exception("Can only set filter_fields if Django-Filter is installed") django_fields = yank_fields_from_attrs( - construct_fields(model, registry, only_fields, exclude_fields), - _as=Field, + construct_fields(model, registry, only_fields, exclude_fields), _as=Field ) if use_connection is None and interfaces: - use_connection = any((issubclass(interface, Node) for interface in interfaces)) + use_connection = any( + (issubclass(interface, Node) for interface in interfaces) + ) if use_connection and not connection: # We create the connection automatically @@ -75,7 +87,8 @@ class DjangoObjectType(ObjectType): connection_class = Connection connection = connection_class.create_type( - '{}Connection'.format(cls.__name__), node=cls) + "{}Connection".format(cls.__name__), node=cls + ) if connection is not None: assert issubclass(connection, Connection), ( @@ -91,7 +104,9 @@ class DjangoObjectType(ObjectType): _meta.fields = django_fields _meta.connection = connection - super(DjangoObjectType, cls).__init_subclass_with_meta__(_meta=_meta, interfaces=interfaces, **options) + super(DjangoObjectType, cls).__init_subclass_with_meta__( + _meta=_meta, interfaces=interfaces, **options + ) if not skip_registry: registry.register(cls) @@ -107,9 +122,7 @@ class DjangoObjectType(ObjectType): if isinstance(root, cls): return True if not is_valid_django_model(type(root)): - raise Exception(( - 'Received incompatible instance "{}".' - ).format(root)) + raise Exception(('Received incompatible instance "{}".').format(root)) model = root._meta.model._meta.concrete_model return model == cls._meta.model diff --git a/graphene_django/utils.py b/graphene_django/utils.py index f8d83bf..560f604 100644 --- a/graphene_django/utils.py +++ b/graphene_django/utils.py @@ -13,6 +13,7 @@ class LazyList(object): try: import django_filters # noqa + DJANGO_FILTER_INSTALLED = True except ImportError: DJANGO_FILTER_INSTALLED = False @@ -25,8 +26,7 @@ def get_reverse_fields(model, local_field_names): continue # Django =>1.9 uses 'rel', django <1.9 uses 'related' - related = getattr(attr, 'rel', None) or \ - getattr(attr, 'related', None) + related = getattr(attr, "rel", None) or getattr(attr, "related", None) if isinstance(related, models.ManyToOneRel): yield (name, related) elif isinstance(related, models.ManyToManyRel) and not related.symmetrical: @@ -42,9 +42,9 @@ def maybe_queryset(value): def get_model_fields(model): local_fields = [ (field.name, field) - for field - in sorted(list(model._meta.fields) + - list(model._meta.local_many_to_many)) + for field in sorted( + list(model._meta.fields) + list(model._meta.local_many_to_many) + ) ] # Make sure we don't duplicate local fields with "reverse" version diff --git a/graphene_django/views.py b/graphene_django/views.py index fad49db..be7ccf9 100644 --- a/graphene_django/views.py +++ b/graphene_django/views.py @@ -20,7 +20,6 @@ from .settings import graphene_settings class HttpError(Exception): - def __init__(self, response, message=None, *args, **kwargs): self.response = response self.message = message = message or response.content.decode() @@ -29,18 +28,18 @@ class HttpError(Exception): def get_accepted_content_types(request): def qualify(x): - parts = x.split(';', 1) + parts = x.split(";", 1) if len(parts) == 2: - match = re.match(r'(^|;)q=(0(\.\d{,3})?|1(\.0{,3})?)(;|$)', - parts[1]) + match = re.match(r"(^|;)q=(0(\.\d{,3})?|1(\.0{,3})?)(;|$)", parts[1]) if match: return parts[0].strip(), float(match.group(2)) return parts[0].strip(), 1 - raw_content_types = request.META.get('HTTP_ACCEPT', '*/*').split(',') + raw_content_types = request.META.get("HTTP_ACCEPT", "*/*").split(",") qualified_content_types = map(qualify, raw_content_types) - return list(x[0] for x in sorted(qualified_content_types, - key=lambda x: x[1], reverse=True)) + return list( + x[0] for x in sorted(qualified_content_types, key=lambda x: x[1], reverse=True) + ) def instantiate_middleware(middlewares): @@ -52,8 +51,8 @@ def instantiate_middleware(middlewares): class GraphQLView(View): - graphiql_version = '0.11.10' - graphiql_template = 'graphene/graphiql.html' + graphiql_version = "0.11.10" + graphiql_template = "graphene/graphiql.html" schema = None graphiql = False @@ -64,8 +63,17 @@ class GraphQLView(View): pretty = False batch = False - def __init__(self, schema=None, executor=None, middleware=None, root_value=None, graphiql=False, pretty=False, - batch=False, backend=None): + def __init__( + self, + schema=None, + executor=None, + middleware=None, + root_value=None, + graphiql=False, + pretty=False, + batch=False, + backend=None, + ): if not schema: schema = graphene_settings.SCHEMA @@ -86,9 +94,9 @@ class GraphQLView(View): self.backend = backend assert isinstance( - self.schema, GraphQLSchema), 'A Schema is required to be provided to GraphQLView.' - assert not all((graphiql, batch) - ), 'Use either graphiql or batch processing' + self.schema, GraphQLSchema + ), "A Schema is required to be provided to GraphQLView." + assert not all((graphiql, batch)), "Use either graphiql or batch processing" # noinspection PyUnusedLocal def get_root_value(self, request): @@ -106,59 +114,59 @@ class GraphQLView(View): @method_decorator(ensure_csrf_cookie) def dispatch(self, request, *args, **kwargs): try: - if request.method.lower() not in ('get', 'post'): - raise HttpError(HttpResponseNotAllowed( - ['GET', 'POST'], 'GraphQL only supports GET and POST requests.')) + if request.method.lower() not in ("get", "post"): + raise HttpError( + HttpResponseNotAllowed( + ["GET", "POST"], "GraphQL only supports GET and POST requests." + ) + ) data = self.parse_body(request) - show_graphiql = self.graphiql and self.can_display_graphiql( - request, data) + show_graphiql = self.graphiql and self.can_display_graphiql(request, data) if self.batch: responses = [self.get_response(request, entry) for entry in data] - result = '[{}]'.format(','.join([response[0] for response in responses])) - status_code = responses and max(responses, key=lambda response: response[1])[1] or 200 + result = "[{}]".format( + ",".join([response[0] for response in responses]) + ) + status_code = ( + responses + and max(responses, key=lambda response: response[1])[1] + or 200 + ) else: - result, status_code = self.get_response( - request, data, show_graphiql) + result, status_code = self.get_response(request, data, show_graphiql) if show_graphiql: query, variables, operation_name, id = self.get_graphql_params( - request, data) + request, data + ) return self.render_graphiql( request, graphiql_version=self.graphiql_version, - query=query or '', - variables=json.dumps(variables) or '', - operation_name=operation_name or '', - result=result or '' + query=query or "", + variables=json.dumps(variables) or "", + operation_name=operation_name or "", + result=result or "", ) return HttpResponse( - status=status_code, - content=result, - content_type='application/json' + status=status_code, content=result, content_type="application/json" ) except HttpError as e: response = e.response - response['Content-Type'] = 'application/json' - response.content = self.json_encode(request, { - 'errors': [self.format_error(e)] - }) + response["Content-Type"] = "application/json" + response.content = self.json_encode( + request, {"errors": [self.format_error(e)]} + ) return response def get_response(self, request, data, show_graphiql=False): - query, variables, operation_name, id = self.get_graphql_params( - request, data) + query, variables, operation_name, id = self.get_graphql_params(request, data) execution_result = self.execute_graphql_request( - request, - data, - query, - variables, - operation_name, - show_graphiql + request, data, query, variables, operation_name, show_graphiql ) status_code = 200 @@ -166,17 +174,18 @@ class GraphQLView(View): response = {} if execution_result.errors: - response['errors'] = [self.format_error( - e) for e in execution_result.errors] + response["errors"] = [ + self.format_error(e) for e in execution_result.errors + ] if execution_result.invalid: status_code = 400 else: - response['data'] = execution_result.data + response["data"] = execution_result.data if self.batch: - response['id'] = id - response['status'] = status_code + response["id"] = id + response["status"] = status_code result = self.json_encode(request, response, pretty=show_graphiql) else: @@ -188,22 +197,21 @@ class GraphQLView(View): return render(request, self.graphiql_template, data) def json_encode(self, request, d, pretty=False): - if not (self.pretty or pretty) and not request.GET.get('pretty'): - return json.dumps(d, separators=(',', ':')) + if not (self.pretty or pretty) and not request.GET.get("pretty"): + return json.dumps(d, separators=(",", ":")) - return json.dumps(d, sort_keys=True, - indent=2, separators=(',', ': ')) + return json.dumps(d, sort_keys=True, indent=2, separators=(",", ": ")) def parse_body(self, request): content_type = self.get_content_type(request) - if content_type == 'application/graphql': - return {'query': request.body.decode()} + if content_type == "application/graphql": + return {"query": request.body.decode()} - elif content_type == 'application/json': + elif content_type == "application/json": # noinspection PyBroadException try: - body = request.body.decode('utf-8') + body = request.body.decode("utf-8") except Exception as e: raise HttpError(HttpResponseBadRequest(str(e))) @@ -211,33 +219,36 @@ class GraphQLView(View): request_json = json.loads(body) if self.batch: assert isinstance(request_json, list), ( - 'Batch requests should receive a list, but received {}.' + "Batch requests should receive a list, but received {}." ).format(repr(request_json)) - assert len(request_json) > 0, ( - 'Received an empty list in the batch request.' - ) + assert ( + len(request_json) > 0 + ), "Received an empty list in the batch request." else: - assert isinstance(request_json, dict), ( - 'The received data is not a valid JSON query.' - ) + assert isinstance( + request_json, dict + ), "The received data is not a valid JSON query." return request_json except AssertionError as e: raise HttpError(HttpResponseBadRequest(str(e))) except (TypeError, ValueError): - raise HttpError(HttpResponseBadRequest( - 'POST body sent invalid JSON.')) + raise HttpError(HttpResponseBadRequest("POST body sent invalid JSON.")) - elif content_type in ['application/x-www-form-urlencoded', 'multipart/form-data']: + elif content_type in [ + "application/x-www-form-urlencoded", + "multipart/form-data", + ]: return request.POST return {} - def execute_graphql_request(self, request, data, query, variables, operation_name, show_graphiql=False): + def execute_graphql_request( + self, request, data, query, variables, operation_name, show_graphiql=False + ): if not query: if show_graphiql: return None - raise HttpError(HttpResponseBadRequest( - 'Must provide query string.')) + raise HttpError(HttpResponseBadRequest("Must provide query string.")) try: backend = self.get_backend(request) @@ -245,23 +256,27 @@ class GraphQLView(View): except Exception as e: return ExecutionResult(errors=[e], invalid=True) - if request.method.lower() == 'get': + if request.method.lower() == "get": operation_type = document.get_operation_type(operation_name) - if operation_type and operation_type != 'query': + if operation_type and operation_type != "query": if show_graphiql: return None - raise HttpError(HttpResponseNotAllowed( - ['POST'], 'Can only perform a {} operation from a POST request.'.format( - operation_type) - )) + raise HttpError( + HttpResponseNotAllowed( + ["POST"], + "Can only perform a {} operation from a POST request.".format( + operation_type + ), + ) + ) try: extra_options = {} if self.executor: # We only include it optionally since # executor is not a valid argument in all backends - extra_options['executor'] = self.executor + extra_options["executor"] = self.executor return document.execute( root=self.get_root_value(request), @@ -276,7 +291,7 @@ class GraphQLView(View): @classmethod def can_display_graphiql(cls, request, data): - raw = 'raw' in request.GET or 'raw' in data + raw = "raw" in request.GET or "raw" in data return not raw and cls.request_wants_html(request) @classmethod @@ -285,26 +300,32 @@ class GraphQLView(View): accepted_length = len(accepted) # the list will be ordered in preferred first - so we have to make # sure the most preferred gets the highest number - html_priority = accepted_length - accepted.index('text/html') if 'text/html' in accepted else 0 - json_priority = accepted_length - accepted.index('application/json') if 'application/json' in accepted else 0 + html_priority = ( + accepted_length - accepted.index("text/html") + if "text/html" in accepted + else 0 + ) + json_priority = ( + accepted_length - accepted.index("application/json") + if "application/json" in accepted + else 0 + ) return html_priority > json_priority @staticmethod def get_graphql_params(request, data): - query = request.GET.get('query') or data.get('query') - variables = request.GET.get('variables') or data.get('variables') - id = request.GET.get('id') or data.get('id') + query = request.GET.get("query") or data.get("query") + variables = request.GET.get("variables") or data.get("variables") + id = request.GET.get("id") or data.get("id") if variables and isinstance(variables, six.text_type): try: variables = json.loads(variables) except Exception: - raise HttpError(HttpResponseBadRequest( - 'Variables are invalid JSON.')) + raise HttpError(HttpResponseBadRequest("Variables are invalid JSON.")) - operation_name = request.GET.get( - 'operationName') or data.get('operationName') + operation_name = request.GET.get("operationName") or data.get("operationName") if operation_name == "null": operation_name = None @@ -315,11 +336,10 @@ class GraphQLView(View): if isinstance(error, GraphQLError): return format_graphql_error(error) - return {'message': six.text_type(error)} + return {"message": six.text_type(error)} @staticmethod def get_content_type(request): meta = request.META - content_type = meta.get( - 'CONTENT_TYPE', meta.get('HTTP_CONTENT_TYPE', '')) - return content_type.split(';', 1)[0].lower() + content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", "")) + return content_type.split(";", 1)[0].lower()