mirror of
https://github.com/graphql-python/graphene-django.git
synced 2025-04-14 06:04:23 +03:00
Delete graphene_django directory
This commit is contained in:
parent
e7f7d8da07
commit
00e27d6a66
|
@ -1,11 +0,0 @@
|
|||
from .fields import DjangoConnectionField, DjangoListField
|
||||
from .types import DjangoObjectType
|
||||
|
||||
__version__ = "3.0.0b7"
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"DjangoObjectType",
|
||||
"DjangoListField",
|
||||
"DjangoConnectionField",
|
||||
]
|
|
@ -1,25 +0,0 @@
|
|||
class MissingType(object):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
try:
|
||||
# Postgres fields are only available in Django with psycopg2 installed
|
||||
# and we cannot have psycopg2 on PyPy
|
||||
from django.contrib.postgres.fields import (
|
||||
IntegerRangeField,
|
||||
ArrayField,
|
||||
HStoreField,
|
||||
JSONField as PGJSONField,
|
||||
RangeField,
|
||||
)
|
||||
except ImportError:
|
||||
IntegerRangeField, ArrayField, HStoreField, PGJSONField, RangeField = (
|
||||
MissingType,
|
||||
) * 5
|
||||
|
||||
try:
|
||||
# JSONField is only available from Django 3.1
|
||||
from django.db.models import JSONField
|
||||
except ImportError:
|
||||
JSONField = MissingType
|
|
@ -1,18 +0,0 @@
|
|||
import pytest
|
||||
|
||||
from graphene_django.settings import graphene_settings as gsettings
|
||||
|
||||
from .registry import reset_global_registry
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_registry_fixture(db):
|
||||
yield None
|
||||
reset_global_registry()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def graphene_settings():
|
||||
settings = dict(gsettings.__dict__)
|
||||
yield gsettings
|
||||
gsettings.__dict__ = settings
|
|
@ -1 +0,0 @@
|
|||
MUTATION_ERRORS_FLAG = "graphene_mutation_has_errors"
|
|
@ -1,356 +0,0 @@
|
|||
from collections import OrderedDict
|
||||
from functools import singledispatch, wraps
|
||||
|
||||
from django.db import models
|
||||
from django.utils.encoding import force_str
|
||||
from django.utils.functional import Promise
|
||||
from django.utils.module_loading import import_string
|
||||
|
||||
from graphene import (
|
||||
ID,
|
||||
UUID,
|
||||
Boolean,
|
||||
Date,
|
||||
DateTime,
|
||||
Dynamic,
|
||||
Enum,
|
||||
Field,
|
||||
Float,
|
||||
Int,
|
||||
List,
|
||||
NonNull,
|
||||
String,
|
||||
Time,
|
||||
Decimal,
|
||||
)
|
||||
from graphene.types.json import JSONString
|
||||
from graphene.utils.str_converters import to_camel_case
|
||||
from graphql import GraphQLError, assert_valid_name
|
||||
from graphql.pyutils import register_description
|
||||
|
||||
from .compat import ArrayField, HStoreField, JSONField, PGJSONField, RangeField
|
||||
from .fields import DjangoListField, DjangoConnectionField
|
||||
from .settings import graphene_settings
|
||||
from .utils.str_converters import to_const
|
||||
|
||||
|
||||
class BlankValueField(Field):
|
||||
def wrap_resolve(self, parent_resolver):
|
||||
resolver = self.resolver or parent_resolver
|
||||
|
||||
# create custom resolver
|
||||
def blank_field_wrapper(func):
|
||||
@wraps(func)
|
||||
def wrapped_resolver(*args, **kwargs):
|
||||
return_value = func(*args, **kwargs)
|
||||
if return_value == "":
|
||||
return None
|
||||
return return_value
|
||||
|
||||
return wrapped_resolver
|
||||
|
||||
return blank_field_wrapper(resolver)
|
||||
|
||||
|
||||
def convert_choice_name(name):
|
||||
name = to_const(force_str(name))
|
||||
try:
|
||||
assert_valid_name(name)
|
||||
except GraphQLError:
|
||||
name = "A_%s" % name
|
||||
return name
|
||||
|
||||
|
||||
def get_choices(choices):
|
||||
converted_names = []
|
||||
if isinstance(choices, OrderedDict):
|
||||
choices = choices.items()
|
||||
for value, help_text in choices:
|
||||
if isinstance(help_text, (tuple, list)):
|
||||
for choice in get_choices(help_text):
|
||||
yield choice
|
||||
else:
|
||||
name = convert_choice_name(value)
|
||||
while name in converted_names:
|
||||
name += "_" + str(len(converted_names))
|
||||
converted_names.append(name)
|
||||
description = str(
|
||||
help_text
|
||||
) # TODO: translatable description: https://github.com/graphql-python/graphql-core-next/issues/58
|
||||
yield name, value, description
|
||||
|
||||
|
||||
def convert_choices_to_named_enum_with_descriptions(name, choices):
|
||||
choices = list(get_choices(choices))
|
||||
named_choices = [(c[0], c[1]) for c in choices]
|
||||
named_choices_descriptions = {c[0]: c[2] for c in choices}
|
||||
|
||||
class EnumWithDescriptionsType(object):
|
||||
@property
|
||||
def description(self):
|
||||
return str(named_choices_descriptions[self.name])
|
||||
|
||||
return_type = Enum(name, list(named_choices), type=EnumWithDescriptionsType)
|
||||
return return_type
|
||||
|
||||
|
||||
def generate_enum_name(django_model_meta, field):
|
||||
if graphene_settings.DJANGO_CHOICE_FIELD_ENUM_CUSTOM_NAME:
|
||||
# Try and import custom function
|
||||
custom_func = import_string(
|
||||
graphene_settings.DJANGO_CHOICE_FIELD_ENUM_CUSTOM_NAME
|
||||
)
|
||||
name = custom_func(field)
|
||||
elif graphene_settings.DJANGO_CHOICE_FIELD_ENUM_V2_NAMING is True:
|
||||
name = to_camel_case("{}_{}".format(django_model_meta.object_name, field.name))
|
||||
else:
|
||||
name = "{app_label}{object_name}{field_name}Choices".format(
|
||||
app_label=to_camel_case(django_model_meta.app_label.title()),
|
||||
object_name=django_model_meta.object_name,
|
||||
field_name=to_camel_case(field.name.title()),
|
||||
)
|
||||
return name
|
||||
|
||||
|
||||
def convert_choice_field_to_enum(field, name=None):
|
||||
if name is None:
|
||||
name = generate_enum_name(field.model._meta, field)
|
||||
choices = field.choices
|
||||
return convert_choices_to_named_enum_with_descriptions(name, choices)
|
||||
|
||||
|
||||
def convert_django_field_with_choices(
|
||||
field, registry=None, convert_choices_to_enum=True
|
||||
):
|
||||
if registry is not None:
|
||||
converted = registry.get_converted_field(field)
|
||||
if converted:
|
||||
return converted
|
||||
choices = getattr(field, "choices", None)
|
||||
if choices and convert_choices_to_enum:
|
||||
EnumCls = convert_choice_field_to_enum(field)
|
||||
required = not (field.blank or field.null)
|
||||
|
||||
converted = EnumCls(
|
||||
description=get_django_field_description(field), required=required
|
||||
).mount_as(BlankValueField)
|
||||
else:
|
||||
converted = convert_django_field(field, registry)
|
||||
if registry is not None:
|
||||
registry.register_converted_field(field, converted)
|
||||
return converted
|
||||
|
||||
|
||||
def get_django_field_description(field):
|
||||
return str(field.help_text) if field.help_text else None
|
||||
|
||||
|
||||
@singledispatch
|
||||
def convert_django_field(field, registry=None):
|
||||
raise Exception(
|
||||
"Don't know how to convert the Django field %s (%s)" % (field, field.__class__)
|
||||
)
|
||||
|
||||
|
||||
@convert_django_field.register(models.CharField)
|
||||
@convert_django_field.register(models.TextField)
|
||||
@convert_django_field.register(models.EmailField)
|
||||
@convert_django_field.register(models.SlugField)
|
||||
@convert_django_field.register(models.URLField)
|
||||
@convert_django_field.register(models.GenericIPAddressField)
|
||||
@convert_django_field.register(models.FileField)
|
||||
@convert_django_field.register(models.FilePathField)
|
||||
def convert_field_to_string(field, registry=None):
|
||||
return String(
|
||||
description=get_django_field_description(field), required=not field.null
|
||||
)
|
||||
|
||||
|
||||
@convert_django_field.register(models.BigAutoField)
|
||||
@convert_django_field.register(models.AutoField)
|
||||
def convert_field_to_id(field, registry=None):
|
||||
return ID(description=get_django_field_description(field), required=not field.null)
|
||||
|
||||
|
||||
if hasattr(models, "SmallAutoField"):
|
||||
|
||||
@convert_django_field.register(models.SmallAutoField)
|
||||
def convert_field_small_to_id(field, registry=None):
|
||||
return convert_field_to_id(field, registry)
|
||||
|
||||
|
||||
@convert_django_field.register(models.UUIDField)
|
||||
def convert_field_to_uuid(field, registry=None):
|
||||
return UUID(
|
||||
description=get_django_field_description(field), required=not field.null
|
||||
)
|
||||
|
||||
|
||||
@convert_django_field.register(models.PositiveIntegerField)
|
||||
@convert_django_field.register(models.PositiveSmallIntegerField)
|
||||
@convert_django_field.register(models.SmallIntegerField)
|
||||
@convert_django_field.register(models.BigIntegerField)
|
||||
@convert_django_field.register(models.IntegerField)
|
||||
def convert_field_to_int(field, registry=None):
|
||||
return Int(description=get_django_field_description(field), required=not field.null)
|
||||
|
||||
|
||||
@convert_django_field.register(models.NullBooleanField)
|
||||
@convert_django_field.register(models.BooleanField)
|
||||
def convert_field_to_boolean(field, registry=None):
|
||||
return Boolean(
|
||||
description=get_django_field_description(field), required=not field.null
|
||||
)
|
||||
|
||||
|
||||
@convert_django_field.register(models.DecimalField)
|
||||
def convert_field_to_decimal(field, registry=None):
|
||||
return Decimal(description=field.help_text, required=not field.null)
|
||||
|
||||
|
||||
@convert_django_field.register(models.FloatField)
|
||||
@convert_django_field.register(models.DurationField)
|
||||
def convert_field_to_float(field, registry=None):
|
||||
return Float(
|
||||
description=get_django_field_description(field), required=not field.null
|
||||
)
|
||||
|
||||
|
||||
@convert_django_field.register(models.DateTimeField)
|
||||
def convert_datetime_to_string(field, registry=None):
|
||||
return DateTime(
|
||||
description=get_django_field_description(field), required=not field.null
|
||||
)
|
||||
|
||||
|
||||
@convert_django_field.register(models.DateField)
|
||||
def convert_date_to_string(field, registry=None):
|
||||
return Date(
|
||||
description=get_django_field_description(field), required=not field.null
|
||||
)
|
||||
|
||||
|
||||
@convert_django_field.register(models.TimeField)
|
||||
def convert_time_to_string(field, registry=None):
|
||||
return Time(
|
||||
description=get_django_field_description(field), required=not field.null
|
||||
)
|
||||
|
||||
|
||||
@convert_django_field.register(models.OneToOneRel)
|
||||
def convert_onetoone_field_to_djangomodel(field, registry=None):
|
||||
model = field.related_model
|
||||
|
||||
def dynamic_type():
|
||||
_type = registry.get_type_for_model(model)
|
||||
if not _type:
|
||||
return
|
||||
|
||||
return Field(_type, required=not field.null)
|
||||
|
||||
return Dynamic(dynamic_type)
|
||||
|
||||
|
||||
@convert_django_field.register(models.ManyToManyField)
|
||||
@convert_django_field.register(models.ManyToManyRel)
|
||||
@convert_django_field.register(models.ManyToOneRel)
|
||||
def convert_field_to_list_or_connection(field, registry=None):
|
||||
model = field.related_model
|
||||
|
||||
def dynamic_type():
|
||||
_type = registry.get_type_for_model(model)
|
||||
if not _type:
|
||||
return
|
||||
|
||||
if isinstance(field, models.ManyToManyField):
|
||||
description = get_django_field_description(field)
|
||||
else:
|
||||
description = get_django_field_description(field.field)
|
||||
|
||||
# If there is a connection, we should transform the field
|
||||
# into a DjangoConnectionField
|
||||
if _type._meta.connection:
|
||||
# Use a DjangoFilterConnectionField if there are
|
||||
# defined filter_fields or a filterset_class in the
|
||||
# DjangoObjectType Meta
|
||||
if _type._meta.filter_fields or _type._meta.filterset_class:
|
||||
from .filter.fields import DjangoFilterConnectionField
|
||||
|
||||
return DjangoFilterConnectionField(
|
||||
_type, required=True, description=description
|
||||
)
|
||||
|
||||
return DjangoConnectionField(_type, required=True, description=description)
|
||||
|
||||
return DjangoListField(
|
||||
_type,
|
||||
required=True, # A Set is always returned, never None.
|
||||
description=description,
|
||||
)
|
||||
|
||||
return Dynamic(dynamic_type)
|
||||
|
||||
|
||||
@convert_django_field.register(models.OneToOneField)
|
||||
@convert_django_field.register(models.ForeignKey)
|
||||
def convert_field_to_djangomodel(field, registry=None):
|
||||
model = field.related_model
|
||||
|
||||
def dynamic_type():
|
||||
_type = registry.get_type_for_model(model)
|
||||
if not _type:
|
||||
return
|
||||
|
||||
return Field(
|
||||
_type,
|
||||
description=get_django_field_description(field),
|
||||
required=not field.null,
|
||||
)
|
||||
|
||||
return Dynamic(dynamic_type)
|
||||
|
||||
|
||||
@convert_django_field.register(ArrayField)
|
||||
def convert_postgres_array_to_list(field, registry=None):
|
||||
inner_type = convert_django_field(field.base_field)
|
||||
if not isinstance(inner_type, (List, NonNull)):
|
||||
inner_type = (
|
||||
NonNull(type(inner_type))
|
||||
if inner_type.kwargs["required"]
|
||||
else type(inner_type)
|
||||
)
|
||||
return List(
|
||||
inner_type,
|
||||
description=get_django_field_description(field),
|
||||
required=not field.null,
|
||||
)
|
||||
|
||||
|
||||
@convert_django_field.register(HStoreField)
|
||||
@convert_django_field.register(PGJSONField)
|
||||
@convert_django_field.register(JSONField)
|
||||
def convert_pg_and_json_field_to_string(field, registry=None):
|
||||
return JSONString(
|
||||
description=get_django_field_description(field), required=not field.null
|
||||
)
|
||||
|
||||
|
||||
@convert_django_field.register(RangeField)
|
||||
def convert_postgres_range_to_string(field, registry=None):
|
||||
inner_type = convert_django_field(field.base_field)
|
||||
if not isinstance(inner_type, (List, NonNull)):
|
||||
inner_type = (
|
||||
NonNull(type(inner_type))
|
||||
if inner_type.kwargs["required"]
|
||||
else type(inner_type)
|
||||
)
|
||||
return List(
|
||||
inner_type,
|
||||
description=get_django_field_description(field),
|
||||
required=not field.null,
|
||||
)
|
||||
|
||||
|
||||
# Register Django lazy()-wrapped values as GraphQL description/help_text.
|
||||
# This is needed for using lazy translations, see https://github.com/graphql-python/graphql-core-next/issues/58.
|
||||
register_description(Promise)
|
|
@ -1,4 +0,0 @@
|
|||
from .middleware import DjangoDebugMiddleware
|
||||
from .types import DjangoDebug
|
||||
|
||||
__all__ = ["DjangoDebugMiddleware", "DjangoDebug"]
|
|
@ -1,17 +0,0 @@
|
|||
import traceback
|
||||
|
||||
from django.utils.encoding import force_str
|
||||
|
||||
from .types import DjangoDebugException
|
||||
|
||||
|
||||
def wrap_exception(exception):
|
||||
return DjangoDebugException(
|
||||
message=force_str(exception),
|
||||
exc_type=force_str(type(exception)),
|
||||
stack="".join(
|
||||
traceback.format_exception(
|
||||
etype=type(exception), value=exception, tb=exception.__traceback__
|
||||
)
|
||||
),
|
||||
)
|
|
@ -1,10 +0,0 @@
|
|||
from graphene import ObjectType, String
|
||||
|
||||
|
||||
class DjangoDebugException(ObjectType):
|
||||
class Meta:
|
||||
description = "Represents a single exception raised."
|
||||
|
||||
exc_type = String(required=True, description="The class of the exception")
|
||||
message = String(required=True, description="The message of the exception")
|
||||
stack = String(required=True, description="The stack trace")
|
|
@ -1,71 +0,0 @@
|
|||
from django.db import connections
|
||||
|
||||
from promise import Promise
|
||||
|
||||
from .sql.tracking import unwrap_cursor, wrap_cursor
|
||||
from .exception.formating import wrap_exception
|
||||
from .types import DjangoDebug
|
||||
|
||||
|
||||
class DjangoDebugContext(object):
|
||||
def __init__(self):
|
||||
self.debug_promise = None
|
||||
self.promises = []
|
||||
self.object = DjangoDebug(sql=[], exceptions=[])
|
||||
self.enable_instrumentation()
|
||||
|
||||
def get_debug_promise(self):
|
||||
if not self.debug_promise:
|
||||
self.debug_promise = Promise.all(self.promises)
|
||||
self.promises = []
|
||||
return self.debug_promise.then(self.on_resolve_all_promises).get()
|
||||
|
||||
def on_resolve_error(self, value):
|
||||
if hasattr(self, "object"):
|
||||
self.object.exceptions.append(wrap_exception(value))
|
||||
return Promise.reject(value)
|
||||
|
||||
def on_resolve_all_promises(self, values):
|
||||
if self.promises:
|
||||
self.debug_promise = None
|
||||
return self.get_debug_promise()
|
||||
self.disable_instrumentation()
|
||||
return self.object
|
||||
|
||||
def add_promise(self, promise):
|
||||
if self.debug_promise:
|
||||
self.promises.append(promise)
|
||||
|
||||
def enable_instrumentation(self):
|
||||
# This is thread-safe because database connections are thread-local.
|
||||
for connection in connections.all():
|
||||
wrap_cursor(connection, self)
|
||||
|
||||
def disable_instrumentation(self):
|
||||
for connection in connections.all():
|
||||
unwrap_cursor(connection)
|
||||
|
||||
|
||||
class DjangoDebugMiddleware(object):
|
||||
def resolve(self, next, root, info, **args):
|
||||
context = info.context
|
||||
django_debug = getattr(context, "django_debug", None)
|
||||
if not django_debug:
|
||||
if context is None:
|
||||
raise Exception("DjangoDebug cannot be executed in None contexts")
|
||||
try:
|
||||
context.django_debug = DjangoDebugContext()
|
||||
except Exception:
|
||||
raise Exception(
|
||||
"DjangoDebug need the context to be writable, context received: {}.".format(
|
||||
context.__class__.__name__
|
||||
)
|
||||
)
|
||||
if info.schema.get_type("DjangoDebug") == info.return_type:
|
||||
return context.django_debug.get_debug_promise()
|
||||
try:
|
||||
promise = next(root, info, **args)
|
||||
except Exception as e:
|
||||
return context.django_debug.on_resolve_error(e)
|
||||
context.django_debug.add_promise(promise)
|
||||
return promise
|
|
@ -1,169 +0,0 @@
|
|||
# Code obtained from django-debug-toolbar sql panel tracking
|
||||
from __future__ import absolute_import, unicode_literals
|
||||
|
||||
import json
|
||||
from threading import local
|
||||
from time import time
|
||||
|
||||
from django.utils.encoding import force_str
|
||||
|
||||
from .types import DjangoDebugSQL
|
||||
|
||||
|
||||
class SQLQueryTriggered(Exception):
|
||||
"""Thrown when template panel triggers a query"""
|
||||
|
||||
|
||||
class ThreadLocalState(local):
|
||||
def __init__(self):
|
||||
self.enabled = True
|
||||
|
||||
@property
|
||||
def Wrapper(self):
|
||||
if self.enabled:
|
||||
return NormalCursorWrapper
|
||||
return ExceptionCursorWrapper
|
||||
|
||||
def recording(self, v):
|
||||
self.enabled = v
|
||||
|
||||
|
||||
state = ThreadLocalState()
|
||||
recording = state.recording # export function
|
||||
|
||||
|
||||
def wrap_cursor(connection, panel):
|
||||
if not hasattr(connection, "_graphene_cursor"):
|
||||
connection._graphene_cursor = connection.cursor
|
||||
|
||||
def cursor():
|
||||
return state.Wrapper(connection._graphene_cursor(), connection, panel)
|
||||
|
||||
connection.cursor = cursor
|
||||
return cursor
|
||||
|
||||
|
||||
def unwrap_cursor(connection):
|
||||
if hasattr(connection, "_graphene_cursor"):
|
||||
previous_cursor = connection._graphene_cursor
|
||||
connection.cursor = previous_cursor
|
||||
del connection._graphene_cursor
|
||||
|
||||
|
||||
class ExceptionCursorWrapper(object):
|
||||
"""
|
||||
Wraps a cursor and raises an exception on any operation.
|
||||
Used in Templates panel.
|
||||
"""
|
||||
|
||||
def __init__(self, cursor, db, logger):
|
||||
pass
|
||||
|
||||
def __getattr__(self, attr):
|
||||
raise SQLQueryTriggered()
|
||||
|
||||
|
||||
class NormalCursorWrapper(object):
|
||||
"""
|
||||
Wraps a cursor and logs queries.
|
||||
"""
|
||||
|
||||
def __init__(self, cursor, db, logger):
|
||||
self.cursor = cursor
|
||||
# Instance of a BaseDatabaseWrapper subclass
|
||||
self.db = db
|
||||
# logger must implement a ``record`` method
|
||||
self.logger = logger
|
||||
|
||||
def _quote_expr(self, element):
|
||||
if isinstance(element, str):
|
||||
return "'%s'" % force_str(element).replace("'", "''")
|
||||
else:
|
||||
return repr(element)
|
||||
|
||||
def _quote_params(self, params):
|
||||
if not params:
|
||||
return params
|
||||
if isinstance(params, dict):
|
||||
return dict((key, self._quote_expr(value)) for key, value in params.items())
|
||||
return list(map(self._quote_expr, params))
|
||||
|
||||
def _decode(self, param):
|
||||
try:
|
||||
return force_str(param, strings_only=True)
|
||||
except UnicodeDecodeError:
|
||||
return "(encoded string)"
|
||||
|
||||
def _record(self, method, sql, params):
|
||||
start_time = time()
|
||||
try:
|
||||
return method(sql, params)
|
||||
finally:
|
||||
stop_time = time()
|
||||
duration = stop_time - start_time
|
||||
_params = ""
|
||||
try:
|
||||
_params = json.dumps(list(map(self._decode, params)))
|
||||
except Exception:
|
||||
pass # object not JSON serializable
|
||||
|
||||
alias = getattr(self.db, "alias", "default")
|
||||
conn = self.db.connection
|
||||
vendor = getattr(conn, "vendor", "unknown")
|
||||
|
||||
params = {
|
||||
"vendor": vendor,
|
||||
"alias": alias,
|
||||
"sql": self.db.ops.last_executed_query(
|
||||
self.cursor, sql, self._quote_params(params)
|
||||
),
|
||||
"duration": duration,
|
||||
"raw_sql": sql,
|
||||
"params": _params,
|
||||
"start_time": start_time,
|
||||
"stop_time": stop_time,
|
||||
"is_slow": duration > 10,
|
||||
"is_select": sql.lower().strip().startswith("select"),
|
||||
}
|
||||
|
||||
if vendor == "postgresql":
|
||||
# If an erroneous query was ran on the connection, it might
|
||||
# be in a state where checking isolation_level raises an
|
||||
# exception.
|
||||
try:
|
||||
iso_level = conn.isolation_level
|
||||
except conn.InternalError:
|
||||
iso_level = "unknown"
|
||||
params.update(
|
||||
{
|
||||
"trans_id": self.logger.get_transaction_id(alias),
|
||||
"trans_status": conn.get_transaction_status(),
|
||||
"iso_level": iso_level,
|
||||
"encoding": conn.encoding,
|
||||
}
|
||||
)
|
||||
|
||||
_sql = DjangoDebugSQL(**params)
|
||||
# We keep `sql` to maintain backwards compatibility
|
||||
self.logger.object.sql.append(_sql)
|
||||
|
||||
def callproc(self, procname, params=None):
|
||||
return self._record(self.cursor.callproc, procname, params)
|
||||
|
||||
def execute(self, sql, params=None):
|
||||
return self._record(self.cursor.execute, sql, params)
|
||||
|
||||
def executemany(self, sql, param_list):
|
||||
return self._record(self.cursor.executemany, sql, param_list)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.cursor, attr)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.cursor)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.close()
|
|
@ -1,41 +0,0 @@
|
|||
from graphene import Boolean, Float, ObjectType, String
|
||||
|
||||
|
||||
class DjangoDebugSQL(ObjectType):
|
||||
class Meta:
|
||||
description = "Represents a single database query made to a Django managed DB."
|
||||
|
||||
vendor = String(
|
||||
required=True,
|
||||
description=(
|
||||
"The type of database being used (e.g. postrgesql, mysql, sqlite)."
|
||||
),
|
||||
)
|
||||
alias = String(
|
||||
required=True, description="The Django database alias (e.g. 'default')."
|
||||
)
|
||||
sql = String(description="The actual SQL sent to this database.")
|
||||
duration = Float(
|
||||
required=True, description="Duration of this database query in seconds."
|
||||
)
|
||||
raw_sql = String(
|
||||
required=True, description="The raw SQL of this query, without params."
|
||||
)
|
||||
params = String(
|
||||
required=True, description="JSON encoded database query parameters."
|
||||
)
|
||||
start_time = Float(required=True, description="Start time of this database query.")
|
||||
stop_time = Float(required=True, description="Stop time of this database query.")
|
||||
is_slow = Boolean(
|
||||
required=True,
|
||||
description="Whether this database query took more than 10 seconds.",
|
||||
)
|
||||
is_select = Boolean(
|
||||
required=True, description="Whether this database query was a SELECT."
|
||||
)
|
||||
|
||||
# Postgres
|
||||
trans_id = String(description="Postgres transaction ID if available.")
|
||||
trans_status = String(description="Postgres transaction status if available.")
|
||||
iso_level = String(description="Postgres isolation level if available.")
|
||||
encoding = String(description="Postgres connection encoding if available.")
|
|
@ -1,313 +0,0 @@
|
|||
import graphene
|
||||
import pytest
|
||||
from graphene.relay import Node
|
||||
from graphene_django import DjangoConnectionField, DjangoObjectType
|
||||
|
||||
from ...tests.models import Reporter
|
||||
from ..middleware import DjangoDebugMiddleware
|
||||
from ..types import DjangoDebug
|
||||
|
||||
|
||||
class context(object):
|
||||
pass
|
||||
|
||||
|
||||
def test_should_query_field():
|
||||
r1 = Reporter(last_name="ABA")
|
||||
r1.save()
|
||||
r2 = Reporter(last_name="Griffin")
|
||||
r2.save()
|
||||
|
||||
class ReporterType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Reporter
|
||||
interfaces = (Node,)
|
||||
fields = "__all__"
|
||||
|
||||
class Query(graphene.ObjectType):
|
||||
reporter = graphene.Field(ReporterType)
|
||||
debug = graphene.Field(DjangoDebug, name="_debug")
|
||||
|
||||
def resolve_reporter(self, info, **args):
|
||||
return Reporter.objects.first()
|
||||
|
||||
query = """
|
||||
query ReporterQuery {
|
||||
reporter {
|
||||
lastName
|
||||
}
|
||||
_debug {
|
||||
sql {
|
||||
rawSql
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
expected = {
|
||||
"reporter": {"lastName": "ABA"},
|
||||
"_debug": {"sql": [{"rawSql": str(Reporter.objects.order_by("pk")[:1].query)}]},
|
||||
}
|
||||
schema = graphene.Schema(query=Query)
|
||||
result = schema.execute(
|
||||
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
|
||||
)
|
||||
assert not result.errors
|
||||
assert result.data == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_limit", [None, 100])
|
||||
def test_should_query_nested_field(graphene_settings, max_limit):
|
||||
graphene_settings.RELAY_CONNECTION_MAX_LIMIT = max_limit
|
||||
|
||||
r1 = Reporter(last_name="ABA")
|
||||
r1.save()
|
||||
r2 = Reporter(last_name="Griffin")
|
||||
r2.save()
|
||||
r2.pets.add(r1)
|
||||
r1.pets.add(r2)
|
||||
|
||||
class ReporterType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Reporter
|
||||
interfaces = (Node,)
|
||||
fields = "__all__"
|
||||
|
||||
class Query(graphene.ObjectType):
|
||||
reporter = graphene.Field(ReporterType)
|
||||
debug = graphene.Field(DjangoDebug, name="_debug")
|
||||
|
||||
def resolve_reporter(self, info, **args):
|
||||
return Reporter.objects.first()
|
||||
|
||||
query = """
|
||||
query ReporterQuery {
|
||||
reporter {
|
||||
lastName
|
||||
pets { edges { node {
|
||||
lastName
|
||||
pets { edges { node { lastName } } }
|
||||
} } }
|
||||
}
|
||||
_debug {
|
||||
sql {
|
||||
rawSql
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
expected = {
|
||||
"reporter": {
|
||||
"lastName": "ABA",
|
||||
"pets": {
|
||||
"edges": [
|
||||
{
|
||||
"node": {
|
||||
"lastName": "Griffin",
|
||||
"pets": {"edges": [{"node": {"lastName": "ABA"}}]},
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
}
|
||||
schema = graphene.Schema(query=Query)
|
||||
result = schema.execute(
|
||||
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
|
||||
)
|
||||
assert not result.errors
|
||||
query = str(Reporter.objects.order_by("pk")[:1].query)
|
||||
assert result.data["_debug"]["sql"][0]["rawSql"] == query
|
||||
assert "COUNT" in result.data["_debug"]["sql"][1]["rawSql"]
|
||||
assert "tests_reporter_pets" in result.data["_debug"]["sql"][2]["rawSql"]
|
||||
assert "COUNT" in result.data["_debug"]["sql"][3]["rawSql"]
|
||||
assert "tests_reporter_pets" in result.data["_debug"]["sql"][4]["rawSql"]
|
||||
assert len(result.data["_debug"]["sql"]) == 5
|
||||
|
||||
assert result.data["reporter"] == expected["reporter"]
|
||||
|
||||
|
||||
def test_should_query_list():
|
||||
r1 = Reporter(last_name="ABA")
|
||||
r1.save()
|
||||
r2 = Reporter(last_name="Griffin")
|
||||
r2.save()
|
||||
|
||||
class ReporterType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Reporter
|
||||
interfaces = (Node,)
|
||||
fields = "__all__"
|
||||
|
||||
class Query(graphene.ObjectType):
|
||||
all_reporters = graphene.List(ReporterType)
|
||||
debug = graphene.Field(DjangoDebug, name="_debug")
|
||||
|
||||
def resolve_all_reporters(self, info, **args):
|
||||
return Reporter.objects.all()
|
||||
|
||||
query = """
|
||||
query ReporterQuery {
|
||||
allReporters {
|
||||
lastName
|
||||
}
|
||||
_debug {
|
||||
sql {
|
||||
rawSql
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
expected = {
|
||||
"allReporters": [{"lastName": "ABA"}, {"lastName": "Griffin"}],
|
||||
"_debug": {"sql": [{"rawSql": str(Reporter.objects.all().query)}]},
|
||||
}
|
||||
schema = graphene.Schema(query=Query)
|
||||
result = schema.execute(
|
||||
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
|
||||
)
|
||||
assert not result.errors
|
||||
assert result.data == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_limit", [None, 100])
|
||||
def test_should_query_connection(graphene_settings, max_limit):
|
||||
graphene_settings.RELAY_CONNECTION_MAX_LIMIT = max_limit
|
||||
|
||||
r1 = Reporter(last_name="ABA")
|
||||
r1.save()
|
||||
r2 = Reporter(last_name="Griffin")
|
||||
r2.save()
|
||||
|
||||
class ReporterType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Reporter
|
||||
interfaces = (Node,)
|
||||
fields = "__all__"
|
||||
|
||||
class Query(graphene.ObjectType):
|
||||
all_reporters = DjangoConnectionField(ReporterType)
|
||||
debug = graphene.Field(DjangoDebug, name="_debug")
|
||||
|
||||
def resolve_all_reporters(self, info, **args):
|
||||
return Reporter.objects.all()
|
||||
|
||||
query = """
|
||||
query ReporterQuery {
|
||||
allReporters(first:1) {
|
||||
edges {
|
||||
node {
|
||||
lastName
|
||||
}
|
||||
}
|
||||
}
|
||||
_debug {
|
||||
sql {
|
||||
rawSql
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
expected = {"allReporters": {"edges": [{"node": {"lastName": "ABA"}}]}}
|
||||
schema = graphene.Schema(query=Query)
|
||||
result = schema.execute(
|
||||
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
|
||||
)
|
||||
assert not result.errors
|
||||
assert result.data["allReporters"] == expected["allReporters"]
|
||||
assert len(result.data["_debug"]["sql"]) == 2
|
||||
assert "COUNT" in result.data["_debug"]["sql"][0]["rawSql"]
|
||||
query = str(Reporter.objects.all()[:1].query)
|
||||
assert result.data["_debug"]["sql"][1]["rawSql"] == query
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_limit", [None, 100])
|
||||
def test_should_query_connectionfilter(graphene_settings, max_limit):
|
||||
graphene_settings.RELAY_CONNECTION_MAX_LIMIT = max_limit
|
||||
|
||||
from ...filter import DjangoFilterConnectionField
|
||||
|
||||
r1 = Reporter(last_name="ABA")
|
||||
r1.save()
|
||||
r2 = Reporter(last_name="Griffin")
|
||||
r2.save()
|
||||
|
||||
class ReporterType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Reporter
|
||||
interfaces = (Node,)
|
||||
fields = "__all__"
|
||||
|
||||
class Query(graphene.ObjectType):
|
||||
all_reporters = DjangoFilterConnectionField(ReporterType, fields=["last_name"])
|
||||
s = graphene.String(resolver=lambda *_: "S")
|
||||
debug = graphene.Field(DjangoDebug, name="_debug")
|
||||
|
||||
def resolve_all_reporters(self, info, **args):
|
||||
return Reporter.objects.all()
|
||||
|
||||
query = """
|
||||
query ReporterQuery {
|
||||
allReporters(first:1) {
|
||||
edges {
|
||||
node {
|
||||
lastName
|
||||
}
|
||||
}
|
||||
}
|
||||
_debug {
|
||||
sql {
|
||||
rawSql
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
expected = {"allReporters": {"edges": [{"node": {"lastName": "ABA"}}]}}
|
||||
schema = graphene.Schema(query=Query)
|
||||
result = schema.execute(
|
||||
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
|
||||
)
|
||||
assert not result.errors
|
||||
assert result.data["allReporters"] == expected["allReporters"]
|
||||
assert len(result.data["_debug"]["sql"]) == 2
|
||||
assert "COUNT" in result.data["_debug"]["sql"][0]["rawSql"]
|
||||
query = str(Reporter.objects.all()[:1].query)
|
||||
assert result.data["_debug"]["sql"][1]["rawSql"] == query
|
||||
|
||||
|
||||
def test_should_query_stack_trace():
|
||||
class ReporterType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Reporter
|
||||
interfaces = (Node,)
|
||||
fields = "__all__"
|
||||
|
||||
class Query(graphene.ObjectType):
|
||||
reporter = graphene.Field(ReporterType)
|
||||
debug = graphene.Field(DjangoDebug, name="_debug")
|
||||
|
||||
def resolve_reporter(self, info, **args):
|
||||
raise Exception("caught stack trace")
|
||||
|
||||
query = """
|
||||
query ReporterQuery {
|
||||
reporter {
|
||||
lastName
|
||||
}
|
||||
_debug {
|
||||
exceptions {
|
||||
message
|
||||
stack
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
schema = graphene.Schema(query=Query)
|
||||
result = schema.execute(
|
||||
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
|
||||
)
|
||||
assert result.errors
|
||||
assert len(result.data["_debug"]["exceptions"])
|
||||
debug_exception = result.data["_debug"]["exceptions"][0]
|
||||
assert debug_exception["stack"].count("\n") > 1
|
||||
assert "test_query.py" in debug_exception["stack"]
|
||||
assert debug_exception["message"] == "caught stack trace"
|
|
@ -1,14 +0,0 @@
|
|||
from graphene import List, ObjectType
|
||||
|
||||
from .sql.types import DjangoDebugSQL
|
||||
from .exception.types import DjangoDebugException
|
||||
|
||||
|
||||
class DjangoDebug(ObjectType):
|
||||
class Meta:
|
||||
description = "Debugging information for the current query."
|
||||
|
||||
sql = List(DjangoDebugSQL, description="Executed SQL queries for this API query.")
|
||||
exceptions = List(
|
||||
DjangoDebugException, description="Raise exceptions for this API query."
|
||||
)
|
|
@ -1,249 +0,0 @@
|
|||
from functools import partial
|
||||
|
||||
from django.db.models.query import QuerySet
|
||||
from graphql_relay.connection.arrayconnection import (
|
||||
connection_from_array_slice,
|
||||
cursor_to_offset,
|
||||
get_offset_with_default,
|
||||
offset_to_cursor,
|
||||
)
|
||||
from promise import Promise
|
||||
|
||||
from graphene import Int, NonNull
|
||||
from graphene.relay import ConnectionField
|
||||
from graphene.relay.connection import connection_adapter, page_info_adapter
|
||||
from graphene.types import Field, List
|
||||
|
||||
from .settings import graphene_settings
|
||||
from .utils import maybe_queryset
|
||||
|
||||
|
||||
class DjangoListField(Field):
|
||||
def __init__(self, _type, *args, **kwargs):
|
||||
from .types import DjangoObjectType
|
||||
|
||||
if isinstance(_type, NonNull):
|
||||
_type = _type.of_type
|
||||
|
||||
# Django would never return a Set of None vvvvvvv
|
||||
super(DjangoListField, self).__init__(List(NonNull(_type)), *args, **kwargs)
|
||||
|
||||
assert issubclass(
|
||||
self._underlying_type, DjangoObjectType
|
||||
), "DjangoListField only accepts DjangoObjectType types"
|
||||
|
||||
@property
|
||||
def _underlying_type(self):
|
||||
_type = self._type
|
||||
while hasattr(_type, "of_type"):
|
||||
_type = _type.of_type
|
||||
return _type
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
return self._underlying_type._meta.model
|
||||
|
||||
def get_manager(self):
|
||||
return self.model._default_manager
|
||||
|
||||
@staticmethod
|
||||
def list_resolver(
|
||||
django_object_type, resolver, default_manager, root, info, **args
|
||||
):
|
||||
queryset = maybe_queryset(resolver(root, info, **args))
|
||||
if queryset is None:
|
||||
queryset = maybe_queryset(default_manager)
|
||||
|
||||
if isinstance(queryset, QuerySet):
|
||||
# Pass queryset to the DjangoObjectType get_queryset method
|
||||
queryset = maybe_queryset(django_object_type.get_queryset(queryset, info))
|
||||
|
||||
return queryset
|
||||
|
||||
def wrap_resolve(self, parent_resolver):
|
||||
resolver = super(DjangoListField, self).wrap_resolve(parent_resolver)
|
||||
_type = self.type
|
||||
if isinstance(_type, NonNull):
|
||||
_type = _type.of_type
|
||||
django_object_type = _type.of_type.of_type
|
||||
return partial(
|
||||
self.list_resolver, django_object_type, resolver, self.get_manager(),
|
||||
)
|
||||
|
||||
|
||||
class DjangoConnectionField(ConnectionField):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.on = kwargs.pop("on", False)
|
||||
self.max_limit = kwargs.pop(
|
||||
"max_limit", graphene_settings.RELAY_CONNECTION_MAX_LIMIT
|
||||
)
|
||||
self.enforce_first_or_last = kwargs.pop(
|
||||
"enforce_first_or_last",
|
||||
graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST,
|
||||
)
|
||||
kwargs.setdefault("offset", Int())
|
||||
super(DjangoConnectionField, self).__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
from .types import DjangoObjectType
|
||||
|
||||
_type = super(ConnectionField, self).type
|
||||
non_null = False
|
||||
if isinstance(_type, NonNull):
|
||||
_type = _type.of_type
|
||||
non_null = True
|
||||
assert issubclass(
|
||||
_type, DjangoObjectType
|
||||
), "DjangoConnectionField only accepts DjangoObjectType types"
|
||||
assert _type._meta.connection, "The type {} doesn't have a connection".format(
|
||||
_type.__name__
|
||||
)
|
||||
connection_type = _type._meta.connection
|
||||
if non_null:
|
||||
return NonNull(connection_type)
|
||||
return connection_type
|
||||
|
||||
@property
|
||||
def connection_type(self):
|
||||
type = self.type
|
||||
if isinstance(type, NonNull):
|
||||
return type.of_type
|
||||
return type
|
||||
|
||||
@property
|
||||
def node_type(self):
|
||||
return self.connection_type._meta.node
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
return self.node_type._meta.model
|
||||
|
||||
def get_manager(self):
|
||||
if self.on:
|
||||
return getattr(self.model, self.on)
|
||||
else:
|
||||
return self.model._default_manager
|
||||
|
||||
@classmethod
|
||||
def resolve_queryset(cls, connection, queryset, info, args):
|
||||
# queryset is the resolved iterable from ObjectType
|
||||
return connection._meta.node.get_queryset(queryset, info)
|
||||
|
||||
@classmethod
|
||||
def resolve_connection(cls, connection, args, iterable, max_limit=None):
|
||||
# Remove the offset parameter and convert it to an after cursor.
|
||||
offset = args.pop("offset", None)
|
||||
after = args.get("after")
|
||||
if offset:
|
||||
if after:
|
||||
offset += cursor_to_offset(after) + 1
|
||||
# input offset starts at 1 while the graphene offset starts at 0
|
||||
args["after"] = offset_to_cursor(offset - 1)
|
||||
|
||||
iterable = maybe_queryset(iterable)
|
||||
|
||||
if isinstance(iterable, QuerySet):
|
||||
list_length = iterable.count()
|
||||
else:
|
||||
list_length = len(iterable)
|
||||
list_slice_length = (
|
||||
min(max_limit, list_length) if max_limit is not None else list_length
|
||||
)
|
||||
|
||||
# If after is higher than list_length, connection_from_list_slice
|
||||
# would try to do a negative slicing which makes django throw an
|
||||
# AssertionError
|
||||
after = min(get_offset_with_default(args.get("after"), -1) + 1, list_length)
|
||||
|
||||
if max_limit is not None and args.get("first", None) is None:
|
||||
if args.get("last", None) is not None:
|
||||
after = list_length - args["last"]
|
||||
else:
|
||||
args["first"] = max_limit
|
||||
|
||||
connection = connection_from_array_slice(
|
||||
iterable[after:],
|
||||
args,
|
||||
slice_start=after,
|
||||
array_length=list_length,
|
||||
array_slice_length=list_slice_length,
|
||||
connection_type=partial(connection_adapter, connection),
|
||||
edge_type=connection.Edge,
|
||||
page_info_type=page_info_adapter,
|
||||
)
|
||||
connection.iterable = iterable
|
||||
connection.length = list_length
|
||||
return connection
|
||||
|
||||
@classmethod
|
||||
def connection_resolver(
|
||||
cls,
|
||||
resolver,
|
||||
connection,
|
||||
default_manager,
|
||||
queryset_resolver,
|
||||
max_limit,
|
||||
enforce_first_or_last,
|
||||
root,
|
||||
info,
|
||||
**args
|
||||
):
|
||||
first = args.get("first")
|
||||
last = args.get("last")
|
||||
offset = args.get("offset")
|
||||
before = args.get("before")
|
||||
|
||||
if enforce_first_or_last:
|
||||
assert first or last, (
|
||||
"You must provide a `first` or `last` value to properly paginate the `{}` connection."
|
||||
).format(info.field_name)
|
||||
|
||||
if max_limit:
|
||||
if first:
|
||||
assert first <= max_limit, (
|
||||
"Requesting {} records on the `{}` connection exceeds the `first` limit of {} records."
|
||||
).format(first, info.field_name, max_limit)
|
||||
args["first"] = min(first, max_limit)
|
||||
|
||||
if last:
|
||||
assert last <= max_limit, (
|
||||
"Requesting {} records on the `{}` connection exceeds the `last` limit of {} records."
|
||||
).format(last, info.field_name, max_limit)
|
||||
args["last"] = min(last, max_limit)
|
||||
|
||||
if offset is not None:
|
||||
assert before is None, (
|
||||
"You can't provide a `before` value at the same time as an `offset` value to properly paginate the `{}` connection."
|
||||
).format(info.field_name)
|
||||
|
||||
# eventually leads to DjangoObjectType's get_queryset (accepts queryset)
|
||||
# or a resolve_foo (does not accept queryset)
|
||||
iterable = resolver(root, info, **args)
|
||||
if iterable is None:
|
||||
iterable = default_manager
|
||||
# thus the iterable gets refiltered by resolve_queryset
|
||||
# but iterable might be promise
|
||||
iterable = queryset_resolver(connection, iterable, info, args)
|
||||
on_resolve = partial(
|
||||
cls.resolve_connection, connection, args, max_limit=max_limit
|
||||
)
|
||||
|
||||
if Promise.is_thenable(iterable):
|
||||
return Promise.resolve(iterable).then(on_resolve)
|
||||
|
||||
return on_resolve(iterable)
|
||||
|
||||
def wrap_resolve(self, parent_resolver):
|
||||
return partial(
|
||||
self.connection_resolver,
|
||||
parent_resolver,
|
||||
self.connection_type,
|
||||
self.get_manager(),
|
||||
self.get_queryset_resolver(),
|
||||
self.max_limit,
|
||||
self.enforce_first_or_last,
|
||||
)
|
||||
|
||||
def get_queryset_resolver(self):
|
||||
return self.resolve_queryset
|
|
@ -1,29 +0,0 @@
|
|||
import warnings
|
||||
from ..utils import DJANGO_FILTER_INSTALLED
|
||||
|
||||
if not DJANGO_FILTER_INSTALLED:
|
||||
warnings.warn(
|
||||
"Use of django filtering requires the django-filter package "
|
||||
"be installed. You can do so using `pip install django-filter`",
|
||||
ImportWarning,
|
||||
)
|
||||
else:
|
||||
from .fields import DjangoFilterConnectionField
|
||||
from .filters import (
|
||||
ArrayFilter,
|
||||
GlobalIDFilter,
|
||||
GlobalIDMultipleChoiceFilter,
|
||||
ListFilter,
|
||||
RangeFilter,
|
||||
TypedFilter,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DjangoFilterConnectionField",
|
||||
"GlobalIDFilter",
|
||||
"GlobalIDMultipleChoiceFilter",
|
||||
"ArrayFilter",
|
||||
"ListFilter",
|
||||
"RangeFilter",
|
||||
"TypedFilter",
|
||||
]
|
|
@ -1,109 +0,0 @@
|
|||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
|
||||
from django.core.exceptions import ValidationError
|
||||
|
||||
from graphene.types.enum import EnumType
|
||||
from graphene.types.argument import to_arguments
|
||||
from graphene.utils.str_converters import to_snake_case
|
||||
|
||||
from ..fields import DjangoConnectionField
|
||||
from .utils import get_filtering_args_from_filterset, get_filterset_class
|
||||
|
||||
|
||||
def convert_enum(data):
|
||||
"""
|
||||
Check if the data is a enum option (or potentially nested list of enum option)
|
||||
and convert it to its value.
|
||||
|
||||
This method is used to pre-process the data for the filters as they can take an
|
||||
graphene.Enum as argument, but filters (from django_filters) expect a simple value.
|
||||
"""
|
||||
if isinstance(data, list):
|
||||
return [convert_enum(item) for item in data]
|
||||
if isinstance(type(data), EnumType):
|
||||
return data.value
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
class DjangoFilterConnectionField(DjangoConnectionField):
|
||||
def __init__(
|
||||
self,
|
||||
type,
|
||||
fields=None,
|
||||
order_by=None,
|
||||
extra_filter_meta=None,
|
||||
filterset_class=None,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
self._fields = fields
|
||||
self._provided_filterset_class = filterset_class
|
||||
self._filterset_class = None
|
||||
self._filtering_args = None
|
||||
self._extra_filter_meta = extra_filter_meta
|
||||
self._base_args = None
|
||||
super(DjangoFilterConnectionField, self).__init__(type, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def args(self):
|
||||
return to_arguments(self._base_args or OrderedDict(), self.filtering_args)
|
||||
|
||||
@args.setter
|
||||
def args(self, args):
|
||||
self._base_args = args
|
||||
|
||||
@property
|
||||
def filterset_class(self):
|
||||
if not self._filterset_class:
|
||||
fields = self._fields or self.node_type._meta.filter_fields
|
||||
meta = dict(model=self.model, fields=fields)
|
||||
if self._extra_filter_meta:
|
||||
meta.update(self._extra_filter_meta)
|
||||
|
||||
filterset_class = (
|
||||
self._provided_filterset_class or self.node_type._meta.filterset_class
|
||||
)
|
||||
self._filterset_class = get_filterset_class(filterset_class, **meta)
|
||||
|
||||
return self._filterset_class
|
||||
|
||||
@property
|
||||
def filtering_args(self):
|
||||
if not self._filtering_args:
|
||||
self._filtering_args = get_filtering_args_from_filterset(
|
||||
self.filterset_class, self.node_type
|
||||
)
|
||||
return self._filtering_args
|
||||
|
||||
@classmethod
|
||||
def resolve_queryset(
|
||||
cls, connection, iterable, info, args, filtering_args, filterset_class
|
||||
):
|
||||
def filter_kwargs():
|
||||
kwargs = {}
|
||||
for k, v in args.items():
|
||||
if k in filtering_args:
|
||||
if k == "order_by" and v is not None:
|
||||
v = to_snake_case(v)
|
||||
kwargs[k] = convert_enum(v)
|
||||
return kwargs
|
||||
|
||||
qs = super(DjangoFilterConnectionField, cls).resolve_queryset(
|
||||
connection, iterable, info, args
|
||||
)
|
||||
|
||||
filterset = filterset_class(
|
||||
data=filter_kwargs(), queryset=qs, request=info.context
|
||||
)
|
||||
if filterset.is_valid():
|
||||
return filterset.qs
|
||||
raise ValidationError(filterset.form.errors.as_json())
|
||||
|
||||
def get_queryset_resolver(self):
|
||||
return partial(
|
||||
self.resolve_queryset,
|
||||
filterset_class=self.filterset_class,
|
||||
filtering_args=self.filtering_args,
|
||||
)
|
|
@ -1,25 +0,0 @@
|
|||
import warnings
|
||||
from ...utils import DJANGO_FILTER_INSTALLED
|
||||
|
||||
if not DJANGO_FILTER_INSTALLED:
|
||||
warnings.warn(
|
||||
"Use of django filtering requires the django-filter package "
|
||||
"be installed. You can do so using `pip install django-filter`",
|
||||
ImportWarning,
|
||||
)
|
||||
else:
|
||||
from .array_filter import ArrayFilter
|
||||
from .global_id_filter import GlobalIDFilter, GlobalIDMultipleChoiceFilter
|
||||
from .list_filter import ListFilter
|
||||
from .range_filter import RangeFilter
|
||||
from .typed_filter import TypedFilter
|
||||
|
||||
__all__ = [
|
||||
"DjangoFilterConnectionField",
|
||||
"GlobalIDFilter",
|
||||
"GlobalIDMultipleChoiceFilter",
|
||||
"ArrayFilter",
|
||||
"ListFilter",
|
||||
"RangeFilter",
|
||||
"TypedFilter",
|
||||
]
|
|
@ -1,27 +0,0 @@
|
|||
from django_filters.constants import EMPTY_VALUES
|
||||
|
||||
from .typed_filter import TypedFilter
|
||||
|
||||
|
||||
class ArrayFilter(TypedFilter):
|
||||
"""
|
||||
Filter made for PostgreSQL ArrayField.
|
||||
"""
|
||||
|
||||
def filter(self, qs, value):
|
||||
"""
|
||||
Override the default filter class to check first whether the list is
|
||||
empty or not.
|
||||
This needs to be done as in this case we expect to get the filter applied with
|
||||
an empty list since it's a valid value but django_filter consider an empty list
|
||||
to be an empty input value (see `EMPTY_VALUES`) meaning that
|
||||
the filter does not need to be applied (hence returning the original
|
||||
queryset).
|
||||
"""
|
||||
if value in EMPTY_VALUES and value != []:
|
||||
return qs
|
||||
if self.distinct:
|
||||
qs = qs.distinct()
|
||||
lookup = "%s__%s" % (self.field_name, self.lookup_expr)
|
||||
qs = self.get_method(qs)(**{lookup: value})
|
||||
return qs
|
|
@ -1,28 +0,0 @@
|
|||
from django_filters import Filter, MultipleChoiceFilter
|
||||
|
||||
from graphql_relay.node.node import from_global_id
|
||||
|
||||
from ...forms import GlobalIDFormField, GlobalIDMultipleChoiceField
|
||||
|
||||
|
||||
class GlobalIDFilter(Filter):
|
||||
"""
|
||||
Filter for Relay global ID.
|
||||
"""
|
||||
|
||||
field_class = GlobalIDFormField
|
||||
|
||||
def filter(self, qs, value):
|
||||
""" Convert the filter value to a primary key before filtering """
|
||||
_id = None
|
||||
if value is not None:
|
||||
_, _id = from_global_id(value)
|
||||
return super(GlobalIDFilter, self).filter(qs, _id)
|
||||
|
||||
|
||||
class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
|
||||
field_class = GlobalIDMultipleChoiceField
|
||||
|
||||
def filter(self, qs, value):
|
||||
gids = [from_global_id(v)[1] for v in value]
|
||||
return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids)
|
|
@ -1,26 +0,0 @@
|
|||
from .typed_filter import TypedFilter
|
||||
|
||||
|
||||
class ListFilter(TypedFilter):
|
||||
"""
|
||||
Filter that takes a list of value as input.
|
||||
It is for example used for `__in` filters.
|
||||
"""
|
||||
|
||||
def filter(self, qs, value):
|
||||
"""
|
||||
Override the default filter class to check first whether the list is
|
||||
empty or not.
|
||||
This needs to be done as in this case we expect to get an empty output
|
||||
(if not an exclude filter) but django_filter consider an empty list
|
||||
to be an empty input value (see `EMPTY_VALUES`) meaning that
|
||||
the filter does not need to be applied (hence returning the original
|
||||
queryset).
|
||||
"""
|
||||
if value is not None and len(value) == 0:
|
||||
if self.exclude:
|
||||
return qs
|
||||
else:
|
||||
return qs.none()
|
||||
else:
|
||||
return super(ListFilter, self).filter(qs, value)
|
|
@ -1,24 +0,0 @@
|
|||
from django.core.exceptions import ValidationError
|
||||
from django.forms import Field
|
||||
|
||||
from .typed_filter import TypedFilter
|
||||
|
||||
|
||||
def validate_range(value):
|
||||
"""
|
||||
Validator for range filter input: the list of value must be of length 2.
|
||||
Note that validators are only run if the value is not empty.
|
||||
"""
|
||||
if len(value) != 2:
|
||||
raise ValidationError(
|
||||
"Invalid range specified: it needs to contain 2 values.", code="invalid"
|
||||
)
|
||||
|
||||
|
||||
class RangeField(Field):
|
||||
default_validators = [validate_range]
|
||||
empty_values = [None]
|
||||
|
||||
|
||||
class RangeFilter(TypedFilter):
|
||||
field_class = RangeField
|
|
@ -1,27 +0,0 @@
|
|||
from django_filters import Filter
|
||||
|
||||
from graphene.types.utils import get_type
|
||||
|
||||
|
||||
class TypedFilter(Filter):
|
||||
"""
|
||||
Filter class for which the input GraphQL type can explicitly be provided.
|
||||
If it is not provided, when building the schema, it will try to guess
|
||||
it from the field.
|
||||
"""
|
||||
|
||||
def __init__(self, input_type=None, *args, **kwargs):
|
||||
self._input_type = input_type
|
||||
super(TypedFilter, self).__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def input_type(self):
|
||||
input_type = get_type(self._input_type)
|
||||
if input_type is not None:
|
||||
if not callable(getattr(input_type, "get_type", None)):
|
||||
raise ValueError(
|
||||
"Wrong `input_type` for {}: it only accepts graphene types, got {}".format(
|
||||
self.__class__.__name__, input_type
|
||||
)
|
||||
)
|
||||
return input_type
|
|
@ -1,51 +0,0 @@
|
|||
import itertools
|
||||
|
||||
from django.db import models
|
||||
from django_filters.filterset import BaseFilterSet, FilterSet
|
||||
from django_filters.filterset import FILTER_FOR_DBFIELD_DEFAULTS
|
||||
|
||||
from .filters import GlobalIDFilter, GlobalIDMultipleChoiceFilter
|
||||
|
||||
|
||||
GRAPHENE_FILTER_SET_OVERRIDES = {
|
||||
models.AutoField: {"filter_class": GlobalIDFilter},
|
||||
models.OneToOneField: {"filter_class": GlobalIDFilter},
|
||||
models.ForeignKey: {"filter_class": GlobalIDFilter},
|
||||
models.ManyToManyField: {"filter_class": GlobalIDMultipleChoiceFilter},
|
||||
models.ManyToOneRel: {"filter_class": GlobalIDMultipleChoiceFilter},
|
||||
models.ManyToManyRel: {"filter_class": GlobalIDMultipleChoiceFilter},
|
||||
}
|
||||
|
||||
|
||||
class GrapheneFilterSetMixin(BaseFilterSet):
|
||||
""" A django_filters.filterset.BaseFilterSet with default filter overrides
|
||||
to handle global IDs """
|
||||
|
||||
FILTER_DEFAULTS = dict(
|
||||
itertools.chain(
|
||||
FILTER_FOR_DBFIELD_DEFAULTS.items(), GRAPHENE_FILTER_SET_OVERRIDES.items()
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def setup_filterset(filterset_class):
|
||||
""" Wrap a provided filterset in Graphene-specific functionality
|
||||
"""
|
||||
return type(
|
||||
"Graphene{}".format(filterset_class.__name__),
|
||||
(filterset_class, GrapheneFilterSetMixin),
|
||||
{},
|
||||
)
|
||||
|
||||
|
||||
def custom_filterset_factory(model, filterset_base_class=FilterSet, **meta):
|
||||
""" Create a filterset for the given model using the provided meta data
|
||||
"""
|
||||
meta.update({"model": model})
|
||||
meta_class = type(str("Meta"), (object,), meta)
|
||||
filterset = type(
|
||||
str("%sFilterSet" % model._meta.object_name),
|
||||
(filterset_base_class, GrapheneFilterSetMixin),
|
||||
{"Meta": meta_class},
|
||||
)
|
||||
return filterset
|
|
@ -1,151 +0,0 @@
|
|||
from mock import MagicMock
|
||||
import pytest
|
||||
|
||||
from django.db import models
|
||||
from django.db.models.query import QuerySet
|
||||
from django_filters import filters
|
||||
from django_filters import FilterSet
|
||||
import graphene
|
||||
from graphene.relay import Node
|
||||
from graphene_django import DjangoObjectType
|
||||
from graphene_django.utils import DJANGO_FILTER_INSTALLED
|
||||
from graphene_django.filter import ArrayFilter, ListFilter
|
||||
|
||||
from ...compat import ArrayField
|
||||
|
||||
pytestmark = []
|
||||
|
||||
if DJANGO_FILTER_INSTALLED:
|
||||
from graphene_django.filter import DjangoFilterConnectionField
|
||||
else:
|
||||
pytestmark.append(
|
||||
pytest.mark.skipif(
|
||||
True, reason="django_filters not installed or not compatible"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
STORE = {"events": []}
|
||||
|
||||
|
||||
class Event(models.Model):
|
||||
name = models.CharField(max_length=50)
|
||||
tags = ArrayField(models.CharField(max_length=50))
|
||||
tag_ids = ArrayField(models.IntegerField())
|
||||
random_field = ArrayField(models.BooleanField())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def EventFilterSet():
|
||||
class EventFilterSet(FilterSet):
|
||||
class Meta:
|
||||
model = Event
|
||||
fields = {
|
||||
"name": ["exact", "contains"],
|
||||
}
|
||||
|
||||
# Those are actually usable with our Query fixture bellow
|
||||
tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains")
|
||||
tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap")
|
||||
tags = ArrayFilter(field_name="tags", lookup_expr="exact")
|
||||
|
||||
# Those are actually not usable and only to check type declarations
|
||||
tags_ids__contains = ArrayFilter(field_name="tag_ids", lookup_expr="contains")
|
||||
tags_ids__overlap = ArrayFilter(field_name="tag_ids", lookup_expr="overlap")
|
||||
tags_ids = ArrayFilter(field_name="tag_ids", lookup_expr="exact")
|
||||
random_field__contains = ArrayFilter(
|
||||
field_name="random_field", lookup_expr="contains"
|
||||
)
|
||||
random_field__overlap = ArrayFilter(
|
||||
field_name="random_field", lookup_expr="overlap"
|
||||
)
|
||||
random_field = ArrayFilter(field_name="random_field", lookup_expr="exact")
|
||||
|
||||
return EventFilterSet
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def EventType(EventFilterSet):
|
||||
class EventType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Event
|
||||
interfaces = (Node,)
|
||||
fields = "__all__"
|
||||
filterset_class = EventFilterSet
|
||||
|
||||
return EventType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def Query(EventType):
|
||||
"""
|
||||
Note that we have to use a custom resolver to replicate the arrayfield filter behavior as
|
||||
we are running unit tests in sqlite which does not have ArrayFields.
|
||||
"""
|
||||
|
||||
class Query(graphene.ObjectType):
|
||||
events = DjangoFilterConnectionField(EventType)
|
||||
|
||||
def resolve_events(self, info, **kwargs):
|
||||
|
||||
events = [
|
||||
Event(name="Live Show", tags=["concert", "music", "rock"],),
|
||||
Event(name="Musical", tags=["movie", "music"],),
|
||||
Event(name="Ballet", tags=["concert", "dance"],),
|
||||
Event(name="Speech", tags=[],),
|
||||
]
|
||||
|
||||
STORE["events"] = events
|
||||
|
||||
m_queryset = MagicMock(spec=QuerySet)
|
||||
m_queryset.model = Event
|
||||
|
||||
def filter_events(**kwargs):
|
||||
if "tags__contains" in kwargs:
|
||||
STORE["events"] = list(
|
||||
filter(
|
||||
lambda e: set(kwargs["tags__contains"]).issubset(
|
||||
set(e.tags)
|
||||
),
|
||||
STORE["events"],
|
||||
)
|
||||
)
|
||||
if "tags__overlap" in kwargs:
|
||||
STORE["events"] = list(
|
||||
filter(
|
||||
lambda e: not set(kwargs["tags__overlap"]).isdisjoint(
|
||||
set(e.tags)
|
||||
),
|
||||
STORE["events"],
|
||||
)
|
||||
)
|
||||
if "tags__exact" in kwargs:
|
||||
STORE["events"] = list(
|
||||
filter(
|
||||
lambda e: set(kwargs["tags__exact"]) == set(e.tags),
|
||||
STORE["events"],
|
||||
)
|
||||
)
|
||||
|
||||
def mock_queryset_filter(*args, **kwargs):
|
||||
filter_events(**kwargs)
|
||||
return m_queryset
|
||||
|
||||
def mock_queryset_none(*args, **kwargs):
|
||||
STORE["events"] = []
|
||||
return m_queryset
|
||||
|
||||
def mock_queryset_count(*args, **kwargs):
|
||||
return len(STORE["events"])
|
||||
|
||||
m_queryset.all.return_value = m_queryset
|
||||
m_queryset.filter.side_effect = mock_queryset_filter
|
||||
m_queryset.none.side_effect = mock_queryset_none
|
||||
m_queryset.count.side_effect = mock_queryset_count
|
||||
m_queryset.__getitem__.side_effect = lambda index: STORE[
|
||||
"events"
|
||||
].__getitem__(index)
|
||||
|
||||
return m_queryset
|
||||
|
||||
return Query
|
|
@ -1,30 +0,0 @@
|
|||
import django_filters
|
||||
from django_filters import OrderingFilter
|
||||
|
||||
from graphene_django.tests.models import Article, Pet, Reporter
|
||||
|
||||
|
||||
class ArticleFilter(django_filters.FilterSet):
|
||||
class Meta:
|
||||
model = Article
|
||||
fields = {
|
||||
"headline": ["exact", "icontains"],
|
||||
"pub_date": ["gt", "lt", "exact"],
|
||||
"reporter": ["exact", "in"],
|
||||
}
|
||||
|
||||
order_by = OrderingFilter(fields=("pub_date",))
|
||||
|
||||
|
||||
class ReporterFilter(django_filters.FilterSet):
|
||||
class Meta:
|
||||
model = Reporter
|
||||
fields = ["first_name", "last_name", "email", "pets"]
|
||||
|
||||
order_by = OrderingFilter(fields=("first_name",))
|
||||
|
||||
|
||||
class PetFilter(django_filters.FilterSet):
|
||||
class Meta:
|
||||
model = Pet
|
||||
fields = ["name"]
|
|
@ -1,87 +0,0 @@
|
|||
import pytest
|
||||
|
||||
from graphene import Schema
|
||||
|
||||
from ...compat import ArrayField, MissingType
|
||||
|
||||
|
||||
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
|
||||
def test_array_field_contains_multiple(Query):
|
||||
"""
|
||||
Test contains filter on a array field of string.
|
||||
"""
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
events (tags_Contains: ["concert", "music"]) {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["events"]["edges"] == [
|
||||
{"node": {"name": "Live Show"}},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
|
||||
def test_array_field_contains_one(Query):
|
||||
"""
|
||||
Test contains filter on a array field of string.
|
||||
"""
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
events (tags_Contains: ["music"]) {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["events"]["edges"] == [
|
||||
{"node": {"name": "Live Show"}},
|
||||
{"node": {"name": "Musical"}},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
|
||||
def test_array_field_contains_empty_list(Query):
|
||||
"""
|
||||
Test contains filter on a array field of string.
|
||||
"""
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
events (tags_Contains: []) {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["events"]["edges"] == [
|
||||
{"node": {"name": "Live Show"}},
|
||||
{"node": {"name": "Musical"}},
|
||||
{"node": {"name": "Ballet"}},
|
||||
{"node": {"name": "Speech"}},
|
||||
]
|
|
@ -1,130 +0,0 @@
|
|||
import pytest
|
||||
|
||||
from graphene import Schema
|
||||
|
||||
from ...compat import ArrayField, MissingType
|
||||
|
||||
|
||||
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
|
||||
def test_array_field_exact_no_match(Query):
|
||||
"""
|
||||
Test exact filter on a array field of string.
|
||||
"""
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
events (tags: ["concert", "music"]) {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["events"]["edges"] == []
|
||||
|
||||
|
||||
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
|
||||
def test_array_field_exact_match(Query):
|
||||
"""
|
||||
Test exact filter on a array field of string.
|
||||
"""
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
events (tags: ["movie", "music"]) {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["events"]["edges"] == [
|
||||
{"node": {"name": "Musical"}},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
|
||||
def test_array_field_exact_empty_list(Query):
|
||||
"""
|
||||
Test exact filter on a array field of string.
|
||||
"""
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
events (tags: []) {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["events"]["edges"] == [
|
||||
{"node": {"name": "Speech"}},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
|
||||
def test_array_field_filter_schema_type(Query):
|
||||
"""
|
||||
Check that the type in the filter is an array field like on the object type.
|
||||
"""
|
||||
schema = Schema(query=Query)
|
||||
schema_str = str(schema)
|
||||
|
||||
assert (
|
||||
'''type EventType implements Node {
|
||||
"""The ID of the object"""
|
||||
id: ID!
|
||||
name: String!
|
||||
tags: [String!]!
|
||||
tagIds: [Int!]!
|
||||
randomField: [Boolean!]!
|
||||
}'''
|
||||
in schema_str
|
||||
)
|
||||
|
||||
filters = {
|
||||
"offset": "Int",
|
||||
"before": "String",
|
||||
"after": "String",
|
||||
"first": "Int",
|
||||
"last": "Int",
|
||||
"name": "String",
|
||||
"name_Contains": "String",
|
||||
"tags_Contains": "[String!]",
|
||||
"tags_Overlap": "[String!]",
|
||||
"tags": "[String!]",
|
||||
"tagsIds_Contains": "[Int!]",
|
||||
"tagsIds_Overlap": "[Int!]",
|
||||
"tagsIds": "[Int!]",
|
||||
"randomField_Contains": "[Boolean!]",
|
||||
"randomField_Overlap": "[Boolean!]",
|
||||
"randomField": "[Boolean!]",
|
||||
}
|
||||
filters_str = ", ".join(
|
||||
[
|
||||
f"{filter_field}: {gql_type} = null"
|
||||
for filter_field, gql_type in filters.items()
|
||||
]
|
||||
)
|
||||
assert (
|
||||
f"type Query {{\n events({filters_str}): EventTypeConnection\n}}" in schema_str
|
||||
)
|
|
@ -1,84 +0,0 @@
|
|||
import pytest
|
||||
|
||||
from graphene import Schema
|
||||
|
||||
from ...compat import ArrayField, MissingType
|
||||
|
||||
|
||||
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
|
||||
def test_array_field_overlap_multiple(Query):
|
||||
"""
|
||||
Test overlap filter on a array field of string.
|
||||
"""
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
events (tags_Overlap: ["concert", "music"]) {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["events"]["edges"] == [
|
||||
{"node": {"name": "Live Show"}},
|
||||
{"node": {"name": "Musical"}},
|
||||
{"node": {"name": "Ballet"}},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
|
||||
def test_array_field_overlap_one(Query):
|
||||
"""
|
||||
Test overlap filter on a array field of string.
|
||||
"""
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
events (tags_Overlap: ["music"]) {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["events"]["edges"] == [
|
||||
{"node": {"name": "Live Show"}},
|
||||
{"node": {"name": "Musical"}},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
|
||||
def test_array_field_overlap_empty_list(Query):
|
||||
"""
|
||||
Test overlap filter on a array field of string.
|
||||
"""
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
events (tags_Overlap: []) {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["events"]["edges"] == []
|
|
@ -1,160 +0,0 @@
|
|||
import pytest
|
||||
|
||||
import graphene
|
||||
from graphene.relay import Node
|
||||
|
||||
from graphene_django import DjangoObjectType, DjangoConnectionField
|
||||
from graphene_django.tests.models import Article, Reporter
|
||||
from graphene_django.utils import DJANGO_FILTER_INSTALLED
|
||||
|
||||
pytestmark = []
|
||||
|
||||
if DJANGO_FILTER_INSTALLED:
|
||||
from graphene_django.filter import DjangoFilterConnectionField
|
||||
else:
|
||||
pytestmark.append(
|
||||
pytest.mark.skipif(
|
||||
True, reason="django_filters not installed or not compatible"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def schema():
|
||||
class ReporterType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Reporter
|
||||
interfaces = (Node,)
|
||||
fields = "__all__"
|
||||
|
||||
class ArticleType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Article
|
||||
interfaces = (Node,)
|
||||
fields = "__all__"
|
||||
filter_fields = {
|
||||
"lang": ["exact", "in"],
|
||||
"reporter__a_choice": ["exact", "in"],
|
||||
}
|
||||
|
||||
class Query(graphene.ObjectType):
|
||||
all_reporters = DjangoConnectionField(ReporterType)
|
||||
all_articles = DjangoFilterConnectionField(ArticleType)
|
||||
|
||||
schema = graphene.Schema(query=Query)
|
||||
return schema
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reporter_article_data():
|
||||
john = Reporter.objects.create(
|
||||
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
|
||||
)
|
||||
jane = Reporter.objects.create(
|
||||
first_name="Jane", last_name="Doe", email="janedoe@example.com", a_choice=2
|
||||
)
|
||||
Article.objects.create(
|
||||
headline="Article Node 1", reporter=john, editor=john, lang="es",
|
||||
)
|
||||
Article.objects.create(
|
||||
headline="Article Node 2", reporter=john, editor=john, lang="en",
|
||||
)
|
||||
Article.objects.create(
|
||||
headline="Article Node 3", reporter=jane, editor=jane, lang="en",
|
||||
)
|
||||
|
||||
|
||||
def test_filter_enum_on_connection(schema, reporter_article_data):
|
||||
"""
|
||||
Check that we can filter with enums on a connection.
|
||||
"""
|
||||
query = """
|
||||
query {
|
||||
allArticles(lang: ES) {
|
||||
edges {
|
||||
node {
|
||||
headline
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
expected = {"allArticles": {"edges": [{"node": {"headline": "Article Node 1"}},]}}
|
||||
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data == expected
|
||||
|
||||
|
||||
def test_filter_on_foreign_key_enum_field(schema, reporter_article_data):
|
||||
"""
|
||||
Check that we can filter with enums on a field from a foreign key.
|
||||
"""
|
||||
query = """
|
||||
query {
|
||||
allArticles(reporter_AChoice: A_1) {
|
||||
edges {
|
||||
node {
|
||||
headline
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
expected = {
|
||||
"allArticles": {
|
||||
"edges": [
|
||||
{"node": {"headline": "Article Node 1"}},
|
||||
{"node": {"headline": "Article Node 2"}},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data == expected
|
||||
|
||||
|
||||
def test_filter_enum_field_schema_type(schema):
|
||||
"""
|
||||
Check that the type in the filter is an enum like on the object type.
|
||||
"""
|
||||
schema_str = str(schema)
|
||||
|
||||
assert (
|
||||
'''type ArticleType implements Node {
|
||||
"""The ID of the object"""
|
||||
id: ID!
|
||||
headline: String!
|
||||
pubDate: Date!
|
||||
pubDateTime: DateTime!
|
||||
reporter: ReporterType!
|
||||
editor: ReporterType!
|
||||
|
||||
"""Language"""
|
||||
lang: TestsArticleLangChoices!
|
||||
importance: TestsArticleImportanceChoices
|
||||
}'''
|
||||
in schema_str
|
||||
)
|
||||
|
||||
filters = {
|
||||
"offset": "Int",
|
||||
"before": "String",
|
||||
"after": "String",
|
||||
"first": "Int",
|
||||
"last": "Int",
|
||||
"lang": "TestsArticleLangChoices",
|
||||
"lang_In": "[TestsArticleLangChoices]",
|
||||
"reporter_AChoice": "TestsReporterAChoiceChoices",
|
||||
"reporter_AChoice_In": "[TestsReporterAChoiceChoices]",
|
||||
}
|
||||
filters_str = ", ".join(
|
||||
[
|
||||
f"{filter_field}: {gql_type} = null"
|
||||
for filter_field, gql_type in filters.items()
|
||||
]
|
||||
)
|
||||
assert f" allArticles({filters_str}): ArticleTypeConnection\n" in schema_str
|
File diff suppressed because it is too large
Load Diff
|
@ -1,448 +0,0 @@
|
|||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from django_filters import FilterSet
|
||||
from django_filters import rest_framework as filters
|
||||
from graphene import ObjectType, Schema
|
||||
from graphene.relay import Node
|
||||
from graphene_django import DjangoObjectType
|
||||
from graphene_django.tests.models import Pet, Person, Reporter, Article, Film
|
||||
from graphene_django.filter.tests.filters import ArticleFilter
|
||||
from graphene_django.utils import DJANGO_FILTER_INSTALLED
|
||||
|
||||
pytestmark = []
|
||||
|
||||
if DJANGO_FILTER_INSTALLED:
|
||||
from graphene_django.filter import DjangoFilterConnectionField
|
||||
else:
|
||||
pytestmark.append(
|
||||
pytest.mark.skipif(
|
||||
True, reason="django_filters not installed or not compatible"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def query():
|
||||
class PetNode(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Pet
|
||||
interfaces = (Node,)
|
||||
fields = "__all__"
|
||||
filter_fields = {
|
||||
"id": ["exact", "in"],
|
||||
"name": ["exact", "in"],
|
||||
"age": ["exact", "in", "range"],
|
||||
}
|
||||
|
||||
class ReporterNode(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Reporter
|
||||
interfaces = (Node,)
|
||||
fields = "__all__"
|
||||
# choice filter using enum
|
||||
filter_fields = {"reporter_type": ["exact", "in"]}
|
||||
|
||||
class ArticleNode(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Article
|
||||
interfaces = (Node,)
|
||||
fields = "__all__"
|
||||
filterset_class = ArticleFilter
|
||||
|
||||
class FilmNode(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Film
|
||||
interfaces = (Node,)
|
||||
fields = "__all__"
|
||||
# choice filter not using enum
|
||||
filter_fields = {
|
||||
"genre": ["exact", "in"],
|
||||
}
|
||||
convert_choices_to_enum = False
|
||||
|
||||
class PersonFilterSet(FilterSet):
|
||||
class Meta:
|
||||
model = Person
|
||||
fields = {"name": ["in"]}
|
||||
|
||||
names = filters.BaseInFilter(method="filter_names")
|
||||
|
||||
def filter_names(self, qs, name, value):
|
||||
"""
|
||||
This custom filter take a string as input with comma separated values.
|
||||
Note that the value here is already a list as it has been transformed by the BaseInFilter class.
|
||||
"""
|
||||
return qs.filter(name__in=value)
|
||||
|
||||
class PersonNode(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Person
|
||||
interfaces = (Node,)
|
||||
filterset_class = PersonFilterSet
|
||||
fields = "__all__"
|
||||
|
||||
class Query(ObjectType):
|
||||
pets = DjangoFilterConnectionField(PetNode)
|
||||
people = DjangoFilterConnectionField(PersonNode)
|
||||
articles = DjangoFilterConnectionField(ArticleNode)
|
||||
films = DjangoFilterConnectionField(FilmNode)
|
||||
reporters = DjangoFilterConnectionField(ReporterNode)
|
||||
|
||||
return Query
|
||||
|
||||
|
||||
def test_string_in_filter(query):
|
||||
"""
|
||||
Test in filter on a string field.
|
||||
"""
|
||||
Pet.objects.create(name="Brutus", age=12)
|
||||
Pet.objects.create(name="Mimi", age=3)
|
||||
Pet.objects.create(name="Jojo, the rabbit", age=3)
|
||||
|
||||
schema = Schema(query=query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
pets (name_In: ["Brutus", "Jojo, the rabbit"]) {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["pets"]["edges"] == [
|
||||
{"node": {"name": "Brutus"}},
|
||||
{"node": {"name": "Jojo, the rabbit"}},
|
||||
]
|
||||
|
||||
|
||||
def test_string_in_filter_with_otjer_filter(query):
|
||||
"""
|
||||
Test in filter on a string field which has also a custom filter doing a similar operation.
|
||||
"""
|
||||
Person.objects.create(name="John")
|
||||
Person.objects.create(name="Michael")
|
||||
Person.objects.create(name="Angela")
|
||||
|
||||
schema = Schema(query=query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
people (name_In: ["John", "Michael"]) {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["people"]["edges"] == [
|
||||
{"node": {"name": "John"}},
|
||||
{"node": {"name": "Michael"}},
|
||||
]
|
||||
|
||||
|
||||
def test_string_in_filter_with_declared_filter(query):
|
||||
"""
|
||||
Test in filter on a string field with a custom filterset class.
|
||||
"""
|
||||
Person.objects.create(name="John")
|
||||
Person.objects.create(name="Michael")
|
||||
Person.objects.create(name="Angela")
|
||||
|
||||
schema = Schema(query=query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
people (names: "John,Michael") {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["people"]["edges"] == [
|
||||
{"node": {"name": "John"}},
|
||||
{"node": {"name": "Michael"}},
|
||||
]
|
||||
|
||||
|
||||
def test_int_in_filter(query):
|
||||
"""
|
||||
Test in filter on an integer field.
|
||||
"""
|
||||
Pet.objects.create(name="Brutus", age=12)
|
||||
Pet.objects.create(name="Mimi", age=3)
|
||||
Pet.objects.create(name="Jojo, the rabbit", age=3)
|
||||
|
||||
schema = Schema(query=query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
pets (age_In: [3]) {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["pets"]["edges"] == [
|
||||
{"node": {"name": "Mimi"}},
|
||||
{"node": {"name": "Jojo, the rabbit"}},
|
||||
]
|
||||
|
||||
query = """
|
||||
query {
|
||||
pets (age_In: [3, 12]) {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["pets"]["edges"] == [
|
||||
{"node": {"name": "Brutus"}},
|
||||
{"node": {"name": "Mimi"}},
|
||||
{"node": {"name": "Jojo, the rabbit"}},
|
||||
]
|
||||
|
||||
|
||||
def test_in_filter_with_empty_list(query):
|
||||
"""
|
||||
Check that using a in filter with an empty list provided as input returns no objects.
|
||||
"""
|
||||
Pet.objects.create(name="Brutus", age=12)
|
||||
Pet.objects.create(name="Mimi", age=8)
|
||||
Pet.objects.create(name="Picotin", age=5)
|
||||
|
||||
schema = Schema(query=query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
pets (name_In: []) {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert len(result.data["pets"]["edges"]) == 0
|
||||
|
||||
|
||||
def test_choice_in_filter_without_enum(query):
|
||||
"""
|
||||
Test in filter o an choice field not using an enum (Film.genre).
|
||||
"""
|
||||
|
||||
john_doe = Reporter.objects.create(
|
||||
first_name="John", last_name="Doe", email="john@doe.com"
|
||||
)
|
||||
jean_bon = Reporter.objects.create(
|
||||
first_name="Jean", last_name="Bon", email="jean@bon.com"
|
||||
)
|
||||
documentary_film = Film.objects.create(genre="do")
|
||||
documentary_film.reporters.add(john_doe)
|
||||
action_film = Film.objects.create(genre="ac")
|
||||
action_film.reporters.add(john_doe)
|
||||
other_film = Film.objects.create(genre="ot")
|
||||
other_film.reporters.add(john_doe)
|
||||
other_film.reporters.add(jean_bon)
|
||||
|
||||
schema = Schema(query=query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
films (genre_In: ["do", "ac"]) {
|
||||
edges {
|
||||
node {
|
||||
genre
|
||||
reporters {
|
||||
edges {
|
||||
node {
|
||||
lastName
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["films"]["edges"] == [
|
||||
{
|
||||
"node": {
|
||||
"genre": "do",
|
||||
"reporters": {"edges": [{"node": {"lastName": "Doe"}}]},
|
||||
}
|
||||
},
|
||||
{
|
||||
"node": {
|
||||
"genre": "ac",
|
||||
"reporters": {"edges": [{"node": {"lastName": "Doe"}}]},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_fk_id_in_filter(query):
|
||||
"""
|
||||
Test in filter on an foreign key relationship.
|
||||
"""
|
||||
john_doe = Reporter.objects.create(
|
||||
first_name="John", last_name="Doe", email="john@doe.com"
|
||||
)
|
||||
jean_bon = Reporter.objects.create(
|
||||
first_name="Jean", last_name="Bon", email="jean@bon.com"
|
||||
)
|
||||
sara_croche = Reporter.objects.create(
|
||||
first_name="Sara", last_name="Croche", email="sara@croche.com"
|
||||
)
|
||||
Article.objects.create(
|
||||
headline="A",
|
||||
pub_date=datetime.now(),
|
||||
pub_date_time=datetime.now(),
|
||||
reporter=john_doe,
|
||||
editor=john_doe,
|
||||
)
|
||||
Article.objects.create(
|
||||
headline="B",
|
||||
pub_date=datetime.now(),
|
||||
pub_date_time=datetime.now(),
|
||||
reporter=jean_bon,
|
||||
editor=jean_bon,
|
||||
)
|
||||
Article.objects.create(
|
||||
headline="C",
|
||||
pub_date=datetime.now(),
|
||||
pub_date_time=datetime.now(),
|
||||
reporter=sara_croche,
|
||||
editor=sara_croche,
|
||||
)
|
||||
|
||||
schema = Schema(query=query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
articles (reporter_In: [%s, %s]) {
|
||||
edges {
|
||||
node {
|
||||
headline
|
||||
reporter {
|
||||
lastName
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
""" % (
|
||||
john_doe.id,
|
||||
jean_bon.id,
|
||||
)
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["articles"]["edges"] == [
|
||||
{"node": {"headline": "A", "reporter": {"lastName": "Doe"}}},
|
||||
{"node": {"headline": "B", "reporter": {"lastName": "Bon"}}},
|
||||
]
|
||||
|
||||
|
||||
def test_enum_in_filter(query):
|
||||
"""
|
||||
Test in filter on a choice field using an enum (Reporter.reporter_type).
|
||||
"""
|
||||
|
||||
Reporter.objects.create(
|
||||
first_name="John", last_name="Doe", email="john@doe.com", reporter_type=1
|
||||
)
|
||||
Reporter.objects.create(
|
||||
first_name="Jean", last_name="Bon", email="jean@bon.com", reporter_type=2
|
||||
)
|
||||
Reporter.objects.create(
|
||||
first_name="Jane", last_name="Doe", email="jane@doe.com", reporter_type=2
|
||||
)
|
||||
Reporter.objects.create(
|
||||
first_name="Jack", last_name="Black", email="jack@black.com", reporter_type=None
|
||||
)
|
||||
|
||||
schema = Schema(query=query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
reporters (reporterType_In: [A_1]) {
|
||||
edges {
|
||||
node {
|
||||
email
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["reporters"]["edges"] == [
|
||||
{"node": {"email": "john@doe.com"}},
|
||||
]
|
||||
|
||||
query = """
|
||||
query {
|
||||
reporters (reporterType_In: [A_2]) {
|
||||
edges {
|
||||
node {
|
||||
email
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["reporters"]["edges"] == [
|
||||
{"node": {"email": "jean@bon.com"}},
|
||||
{"node": {"email": "jane@doe.com"}},
|
||||
]
|
||||
|
||||
query = """
|
||||
query {
|
||||
reporters (reporterType_In: [A_2, A_1]) {
|
||||
edges {
|
||||
node {
|
||||
email
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["reporters"]["edges"] == [
|
||||
{"node": {"email": "john@doe.com"}},
|
||||
{"node": {"email": "jean@bon.com"}},
|
||||
{"node": {"email": "jane@doe.com"}},
|
||||
]
|
|
@ -1,115 +0,0 @@
|
|||
import json
|
||||
import pytest
|
||||
|
||||
from django_filters import FilterSet
|
||||
from django_filters import rest_framework as filters
|
||||
from graphene import ObjectType, Schema
|
||||
from graphene.relay import Node
|
||||
from graphene_django import DjangoObjectType
|
||||
from graphene_django.tests.models import Pet
|
||||
from graphene_django.utils import DJANGO_FILTER_INSTALLED
|
||||
|
||||
pytestmark = []
|
||||
|
||||
if DJANGO_FILTER_INSTALLED:
|
||||
from graphene_django.filter import DjangoFilterConnectionField
|
||||
else:
|
||||
pytestmark.append(
|
||||
pytest.mark.skipif(
|
||||
True, reason="django_filters not installed or not compatible"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class PetNode(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Pet
|
||||
interfaces = (Node,)
|
||||
fields = "__all__"
|
||||
filter_fields = {
|
||||
"name": ["exact", "in"],
|
||||
"age": ["exact", "in", "range"],
|
||||
}
|
||||
|
||||
|
||||
class Query(ObjectType):
|
||||
pets = DjangoFilterConnectionField(PetNode)
|
||||
|
||||
|
||||
def test_int_range_filter():
|
||||
"""
|
||||
Test range filter on an integer field.
|
||||
"""
|
||||
Pet.objects.create(name="Brutus", age=12)
|
||||
Pet.objects.create(name="Mimi", age=8)
|
||||
Pet.objects.create(name="Jojo, the rabbit", age=3)
|
||||
Pet.objects.create(name="Picotin", age=5)
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
pets (age_Range: [4, 9]) {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["pets"]["edges"] == [
|
||||
{"node": {"name": "Mimi"}},
|
||||
{"node": {"name": "Picotin"}},
|
||||
]
|
||||
|
||||
|
||||
def test_range_filter_with_invalid_input():
|
||||
"""
|
||||
Test range filter used with invalid inputs raise an error.
|
||||
"""
|
||||
Pet.objects.create(name="Brutus", age=12)
|
||||
Pet.objects.create(name="Mimi", age=8)
|
||||
Pet.objects.create(name="Jojo, the rabbit", age=3)
|
||||
Pet.objects.create(name="Picotin", age=5)
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query ($rangeValue: [Int]) {
|
||||
pets (age_Range: $rangeValue) {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
expected_error = json.dumps(
|
||||
{
|
||||
"age__range": [
|
||||
{
|
||||
"message": "Invalid range specified: it needs to contain 2 values.",
|
||||
"code": "invalid",
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
# Empty list
|
||||
result = schema.execute(query, variables={"rangeValue": []})
|
||||
assert len(result.errors) == 1
|
||||
assert result.errors[0].message == expected_error
|
||||
|
||||
# Only one item in the list
|
||||
result = schema.execute(query, variables={"rangeValue": [1]})
|
||||
assert len(result.errors) == 1
|
||||
assert result.errors[0].message == expected_error
|
||||
|
||||
# More than 2 items in the list
|
||||
result = schema.execute(query, variables={"rangeValue": [1, 2, 3]})
|
||||
assert len(result.errors) == 1
|
||||
assert result.errors[0].message == expected_error
|
|
@ -1,157 +0,0 @@
|
|||
import pytest
|
||||
|
||||
from django_filters import FilterSet
|
||||
|
||||
import graphene
|
||||
from graphene.relay import Node
|
||||
|
||||
from graphene_django import DjangoObjectType
|
||||
from graphene_django.tests.models import Article, Reporter
|
||||
from graphene_django.utils import DJANGO_FILTER_INSTALLED
|
||||
|
||||
pytestmark = []
|
||||
|
||||
if DJANGO_FILTER_INSTALLED:
|
||||
from graphene_django.filter import (
|
||||
DjangoFilterConnectionField,
|
||||
TypedFilter,
|
||||
ListFilter,
|
||||
)
|
||||
else:
|
||||
pytestmark.append(
|
||||
pytest.mark.skipif(
|
||||
True, reason="django_filters not installed or not compatible"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def schema():
|
||||
class ArticleFilterSet(FilterSet):
|
||||
class Meta:
|
||||
model = Article
|
||||
fields = {
|
||||
"lang": ["exact", "in"],
|
||||
}
|
||||
|
||||
lang__contains = TypedFilter(
|
||||
field_name="lang", lookup_expr="icontains", input_type=graphene.String
|
||||
)
|
||||
lang__in_str = ListFilter(
|
||||
field_name="lang",
|
||||
lookup_expr="in",
|
||||
input_type=graphene.List(graphene.String),
|
||||
)
|
||||
first_n = TypedFilter(input_type=graphene.Int, method="first_n_filter")
|
||||
only_first = TypedFilter(
|
||||
input_type=graphene.Boolean, method="only_first_filter"
|
||||
)
|
||||
|
||||
def first_n_filter(self, queryset, _name, value):
|
||||
return queryset[:value]
|
||||
|
||||
def only_first_filter(self, queryset, _name, value):
|
||||
if value:
|
||||
return queryset[:1]
|
||||
else:
|
||||
return queryset
|
||||
|
||||
class ArticleType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Article
|
||||
interfaces = (Node,)
|
||||
fields = "__all__"
|
||||
filterset_class = ArticleFilterSet
|
||||
|
||||
class Query(graphene.ObjectType):
|
||||
articles = DjangoFilterConnectionField(ArticleType)
|
||||
|
||||
schema = graphene.Schema(query=Query)
|
||||
return schema
|
||||
|
||||
|
||||
def test_typed_filter_schema(schema):
|
||||
"""
|
||||
Check that the type provided in the filter is reflected in the schema.
|
||||
"""
|
||||
|
||||
schema_str = str(schema)
|
||||
|
||||
filters = {
|
||||
"offset": "Int",
|
||||
"before": "String",
|
||||
"after": "String",
|
||||
"first": "Int",
|
||||
"last": "Int",
|
||||
"lang": "TestsArticleLangChoices",
|
||||
"lang_In": "[TestsArticleLangChoices]",
|
||||
"lang_Contains": "String",
|
||||
"lang_InStr": "[String]",
|
||||
"firstN": "Int",
|
||||
"onlyFirst": "Boolean",
|
||||
}
|
||||
|
||||
all_articles_filters = (
|
||||
schema_str.split(" articles(")[1]
|
||||
.split("): ArticleTypeConnection\n")[0]
|
||||
.split(", ")
|
||||
)
|
||||
|
||||
for filter_field, gql_type in filters.items():
|
||||
assert "{}: {} = null".format(filter_field, gql_type) in all_articles_filters
|
||||
|
||||
|
||||
def test_typed_filters_work(schema):
|
||||
reporter = Reporter.objects.create(first_name="John", last_name="Doe", email="")
|
||||
Article.objects.create(
|
||||
headline="A", reporter=reporter, editor=reporter, lang="es",
|
||||
)
|
||||
Article.objects.create(
|
||||
headline="B", reporter=reporter, editor=reporter, lang="es",
|
||||
)
|
||||
Article.objects.create(
|
||||
headline="C", reporter=reporter, editor=reporter, lang="en",
|
||||
)
|
||||
|
||||
query = "query { articles (lang_In: [ES]) { edges { node { headline } } } }"
|
||||
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["articles"]["edges"] == [
|
||||
{"node": {"headline": "A"}},
|
||||
{"node": {"headline": "B"}},
|
||||
]
|
||||
|
||||
query = 'query { articles (lang_InStr: ["es"]) { edges { node { headline } } } }'
|
||||
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["articles"]["edges"] == [
|
||||
{"node": {"headline": "A"}},
|
||||
{"node": {"headline": "B"}},
|
||||
]
|
||||
|
||||
query = 'query { articles (lang_Contains: "n") { edges { node { headline } } } }'
|
||||
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["articles"]["edges"] == [
|
||||
{"node": {"headline": "C"}},
|
||||
]
|
||||
|
||||
query = "query { articles (firstN: 2) { edges { node { headline } } } }"
|
||||
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["articles"]["edges"] == [
|
||||
{"node": {"headline": "A"}},
|
||||
{"node": {"headline": "B"}},
|
||||
]
|
||||
|
||||
query = "query { articles (onlyFirst: true) { edges { node { headline } } } }"
|
||||
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["articles"]["edges"] == [
|
||||
{"node": {"headline": "A"}},
|
||||
]
|
|
@ -1,155 +0,0 @@
|
|||
import graphene
|
||||
from django import forms
|
||||
from django_filters.utils import get_model_field, get_field_parts
|
||||
from django_filters.filters import Filter, BaseCSVFilter
|
||||
from .filters import ArrayFilter, ListFilter, RangeFilter, TypedFilter
|
||||
from .filterset import custom_filterset_factory, setup_filterset
|
||||
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
|
||||
|
||||
|
||||
def get_field_type(registry, model, field_name):
|
||||
"""
|
||||
Try to get a model field corresponding Graphql type from the DjangoObjectType.
|
||||
"""
|
||||
object_type = registry.get_type_for_model(model)
|
||||
if object_type:
|
||||
object_type_field = object_type._meta.fields.get(field_name)
|
||||
if object_type_field:
|
||||
field_type = object_type_field.type
|
||||
if isinstance(field_type, graphene.NonNull):
|
||||
field_type = field_type.of_type
|
||||
return field_type
|
||||
return None
|
||||
|
||||
|
||||
def get_filtering_args_from_filterset(filterset_class, type):
|
||||
"""
|
||||
Inspect a FilterSet and produce the arguments to pass to a Graphene Field.
|
||||
These arguments will be available to filter against in the GraphQL API.
|
||||
"""
|
||||
from ..forms.converter import convert_form_field
|
||||
|
||||
args = {}
|
||||
model = filterset_class._meta.model
|
||||
registry = type._meta.registry
|
||||
for name, filter_field in filterset_class.base_filters.items():
|
||||
filter_type = filter_field.lookup_expr
|
||||
required = filter_field.extra.get("required", False)
|
||||
field_type = None
|
||||
form_field = None
|
||||
|
||||
if (
|
||||
isinstance(filter_field, TypedFilter)
|
||||
and filter_field.input_type is not None
|
||||
):
|
||||
# First check if the filter input type has been explicitely given
|
||||
field_type = filter_field.input_type
|
||||
else:
|
||||
if name not in filterset_class.declared_filters or isinstance(
|
||||
filter_field, TypedFilter
|
||||
):
|
||||
# Get the filter field for filters that are no explicitly declared.
|
||||
if filter_type == "isnull":
|
||||
field = graphene.Boolean(required=required)
|
||||
else:
|
||||
model_field = get_model_field(model, filter_field.field_name)
|
||||
|
||||
# Get the form field either from:
|
||||
# 1. the formfield corresponding to the model field
|
||||
# 2. the field defined on filter
|
||||
if hasattr(model_field, "formfield"):
|
||||
form_field = model_field.formfield(required=required)
|
||||
if not form_field:
|
||||
form_field = filter_field.field
|
||||
|
||||
# First try to get the matching field type from the GraphQL DjangoObjectType
|
||||
if model_field:
|
||||
if (
|
||||
isinstance(form_field, forms.ModelChoiceField)
|
||||
or isinstance(form_field, forms.ModelMultipleChoiceField)
|
||||
or isinstance(form_field, GlobalIDMultipleChoiceField)
|
||||
or isinstance(form_field, GlobalIDFormField)
|
||||
):
|
||||
# Foreign key have dynamic types and filtering on a foreign key actually means filtering on its ID.
|
||||
field_type = get_field_type(
|
||||
registry, model_field.related_model, "id"
|
||||
)
|
||||
else:
|
||||
field_type = get_field_type(
|
||||
registry, model_field.model, model_field.name
|
||||
)
|
||||
|
||||
if not field_type:
|
||||
# Fallback on converting the form field either because:
|
||||
# - it's an explicitly declared filters
|
||||
# - we did not manage to get the type from the model type
|
||||
form_field = form_field or filter_field.field
|
||||
field_type = convert_form_field(form_field).get_type()
|
||||
|
||||
if isinstance(filter_field, ListFilter) or isinstance(
|
||||
filter_field, RangeFilter
|
||||
):
|
||||
# Replace InFilter/RangeFilter filters (`in`, `range`) argument type to be a list of
|
||||
# the same type as the field. See comments in `replace_csv_filters` method for more details.
|
||||
field_type = graphene.List(field_type)
|
||||
|
||||
args[name] = graphene.Argument(
|
||||
field_type, description=filter_field.label, required=required,
|
||||
)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def get_filterset_class(filterset_class, **meta):
|
||||
"""
|
||||
Get the class to be used as the FilterSet.
|
||||
"""
|
||||
if filterset_class:
|
||||
# If were given a FilterSet class, then set it up.
|
||||
graphene_filterset_class = setup_filterset(filterset_class)
|
||||
else:
|
||||
# Otherwise create one.
|
||||
graphene_filterset_class = custom_filterset_factory(**meta)
|
||||
|
||||
replace_csv_filters(graphene_filterset_class)
|
||||
return graphene_filterset_class
|
||||
|
||||
|
||||
def replace_csv_filters(filterset_class):
|
||||
"""
|
||||
Replace the "in" and "range" filters (that are not explicitly declared)
|
||||
to not be BaseCSVFilter (BaseInFilter, BaseRangeFilter) objects anymore
|
||||
but our custom InFilter/RangeFilter filter class that use the input
|
||||
value as filter argument on the queryset.
|
||||
|
||||
This is because those BaseCSVFilter are expecting a string as input with
|
||||
comma separated values.
|
||||
But with GraphQl we can actually have a list as input and have a proper
|
||||
type verification of each value in the list.
|
||||
|
||||
See issue https://github.com/graphql-python/graphene-django/issues/1068.
|
||||
"""
|
||||
for name, filter_field in list(filterset_class.base_filters.items()):
|
||||
# Do not touch any declared filters
|
||||
if name in filterset_class.declared_filters:
|
||||
continue
|
||||
|
||||
filter_type = filter_field.lookup_expr
|
||||
if filter_type == "in":
|
||||
filterset_class.base_filters[name] = ListFilter(
|
||||
field_name=filter_field.field_name,
|
||||
lookup_expr=filter_field.lookup_expr,
|
||||
label=filter_field.label,
|
||||
method=filter_field.method,
|
||||
exclude=filter_field.exclude,
|
||||
**filter_field.extra
|
||||
)
|
||||
elif filter_type == "range":
|
||||
filterset_class.base_filters[name] = RangeFilter(
|
||||
field_name=filter_field.field_name,
|
||||
lookup_expr=filter_field.lookup_expr,
|
||||
label=filter_field.label,
|
||||
method=filter_field.method,
|
||||
exclude=filter_field.exclude,
|
||||
**filter_field.extra
|
||||
)
|
|
@ -1 +0,0 @@
|
|||
from .forms import GlobalIDFormField, GlobalIDMultipleChoiceField # noqa
|
|
@ -1,99 +0,0 @@
|
|||
from functools import singledispatch
|
||||
|
||||
from django import forms
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
|
||||
from graphene import ID, Boolean, Float, Int, List, String, UUID, Date, DateTime, Time
|
||||
|
||||
from .forms import GlobalIDFormField, GlobalIDMultipleChoiceField
|
||||
|
||||
|
||||
def get_form_field_description(field):
|
||||
return str(field.help_text) if field.help_text else None
|
||||
|
||||
|
||||
@singledispatch
|
||||
def convert_form_field(field):
|
||||
raise ImproperlyConfigured(
|
||||
"Don't know how to convert the Django form field %s (%s) "
|
||||
"to Graphene type" % (field, field.__class__)
|
||||
)
|
||||
|
||||
|
||||
@convert_form_field.register(forms.fields.BaseTemporalField)
|
||||
@convert_form_field.register(forms.CharField)
|
||||
@convert_form_field.register(forms.EmailField)
|
||||
@convert_form_field.register(forms.SlugField)
|
||||
@convert_form_field.register(forms.URLField)
|
||||
@convert_form_field.register(forms.ChoiceField)
|
||||
@convert_form_field.register(forms.RegexField)
|
||||
@convert_form_field.register(forms.Field)
|
||||
def convert_form_field_to_string(field):
|
||||
return String(
|
||||
description=get_form_field_description(field), required=field.required
|
||||
)
|
||||
|
||||
|
||||
@convert_form_field.register(forms.UUIDField)
|
||||
def convert_form_field_to_uuid(field):
|
||||
return UUID(description=get_form_field_description(field), required=field.required)
|
||||
|
||||
|
||||
@convert_form_field.register(forms.IntegerField)
|
||||
@convert_form_field.register(forms.NumberInput)
|
||||
def convert_form_field_to_int(field):
|
||||
return Int(description=get_form_field_description(field), required=field.required)
|
||||
|
||||
|
||||
@convert_form_field.register(forms.BooleanField)
|
||||
def convert_form_field_to_boolean(field):
|
||||
return Boolean(
|
||||
description=get_form_field_description(field), required=field.required
|
||||
)
|
||||
|
||||
|
||||
@convert_form_field.register(forms.NullBooleanField)
|
||||
def convert_form_field_to_nullboolean(field):
|
||||
return Boolean(description=get_form_field_description(field))
|
||||
|
||||
|
||||
@convert_form_field.register(forms.DecimalField)
|
||||
@convert_form_field.register(forms.FloatField)
|
||||
def convert_form_field_to_float(field):
|
||||
return Float(description=get_form_field_description(field), required=field.required)
|
||||
|
||||
|
||||
@convert_form_field.register(forms.MultipleChoiceField)
|
||||
def convert_form_field_to_string_list(field):
|
||||
return List(
|
||||
String, description=get_form_field_description(field), required=field.required
|
||||
)
|
||||
|
||||
|
||||
@convert_form_field.register(forms.ModelMultipleChoiceField)
|
||||
@convert_form_field.register(GlobalIDMultipleChoiceField)
|
||||
def convert_form_field_to_id_list(field):
|
||||
return List(ID, required=field.required)
|
||||
|
||||
|
||||
@convert_form_field.register(forms.DateField)
|
||||
def convert_form_field_to_date(field):
|
||||
return Date(description=get_form_field_description(field), required=field.required)
|
||||
|
||||
|
||||
@convert_form_field.register(forms.DateTimeField)
|
||||
def convert_form_field_to_datetime(field):
|
||||
return DateTime(
|
||||
description=get_form_field_description(field), required=field.required
|
||||
)
|
||||
|
||||
|
||||
@convert_form_field.register(forms.TimeField)
|
||||
def convert_form_field_to_time(field):
|
||||
return Time(description=get_form_field_description(field), required=field.required)
|
||||
|
||||
|
||||
@convert_form_field.register(forms.ModelChoiceField)
|
||||
@convert_form_field.register(GlobalIDFormField)
|
||||
def convert_form_field_to_id(field):
|
||||
return ID(required=field.required)
|
|
@ -1,40 +0,0 @@
|
|||
import binascii
|
||||
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.forms import CharField, Field, MultipleChoiceField
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from graphql_relay import from_global_id
|
||||
|
||||
|
||||
class GlobalIDFormField(Field):
|
||||
default_error_messages = {"invalid": _("Invalid ID specified.")}
|
||||
|
||||
def clean(self, value):
|
||||
if not value and not self.required:
|
||||
return None
|
||||
|
||||
try:
|
||||
_type, _id = from_global_id(value)
|
||||
except (TypeError, ValueError, UnicodeDecodeError, binascii.Error):
|
||||
raise ValidationError(self.error_messages["invalid"])
|
||||
|
||||
try:
|
||||
CharField().clean(_id)
|
||||
CharField().clean(_type)
|
||||
except ValidationError:
|
||||
raise ValidationError(self.error_messages["invalid"])
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class GlobalIDMultipleChoiceField(MultipleChoiceField):
|
||||
default_error_messages = {
|
||||
"invalid_choice": _("One of the specified IDs was invalid (%(value)s)."),
|
||||
"invalid_list": _("Enter a list of values."),
|
||||
}
|
||||
|
||||
def valid_value(self, value):
|
||||
# Clean will raise a validation error if there is a problem
|
||||
GlobalIDFormField().clean(value)
|
||||
return True
|
|
@ -1,192 +0,0 @@
|
|||
# from django import forms
|
||||
from collections import OrderedDict
|
||||
|
||||
import graphene
|
||||
from graphene import Field, InputField
|
||||
from graphene.relay.mutation import ClientIDMutation
|
||||
from graphene.types.mutation import MutationOptions
|
||||
|
||||
# from graphene.types.inputobjecttype import (
|
||||
# InputObjectTypeOptions,
|
||||
# InputObjectType,
|
||||
# )
|
||||
from graphene.types.utils import yank_fields_from_attrs
|
||||
from graphene_django.constants import MUTATION_ERRORS_FLAG
|
||||
from graphene_django.registry import get_global_registry
|
||||
|
||||
from ..types import ErrorType
|
||||
from .converter import convert_form_field
|
||||
|
||||
|
||||
def fields_for_form(form, only_fields, exclude_fields):
|
||||
fields = OrderedDict()
|
||||
for name, field in form.fields.items():
|
||||
is_not_in_only = only_fields and name not in only_fields
|
||||
is_excluded = (
|
||||
name
|
||||
in exclude_fields # or
|
||||
# name in already_created_fields
|
||||
)
|
||||
|
||||
if is_not_in_only or is_excluded:
|
||||
continue
|
||||
|
||||
fields[name] = convert_form_field(field)
|
||||
return fields
|
||||
|
||||
|
||||
class BaseDjangoFormMutation(ClientIDMutation):
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
@classmethod
|
||||
def mutate_and_get_payload(cls, root, info, **input):
|
||||
form = cls.get_form(root, info, **input)
|
||||
|
||||
if form.is_valid():
|
||||
return cls.perform_mutate(form, info)
|
||||
else:
|
||||
errors = ErrorType.from_errors(form.errors)
|
||||
_set_errors_flag_to_context(info)
|
||||
|
||||
return cls(errors=errors, **form.data)
|
||||
|
||||
@classmethod
|
||||
def get_form(cls, root, info, **input):
|
||||
form_kwargs = cls.get_form_kwargs(root, info, **input)
|
||||
return cls._meta.form_class(**form_kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_form_kwargs(cls, root, info, **input):
|
||||
kwargs = {"data": input}
|
||||
|
||||
pk = input.pop("id", None)
|
||||
if pk:
|
||||
instance = cls._meta.model._default_manager.get(pk=pk)
|
||||
kwargs["instance"] = instance
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
class DjangoFormMutationOptions(MutationOptions):
|
||||
form_class = None
|
||||
|
||||
|
||||
class DjangoFormMutation(BaseDjangoFormMutation):
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
errors = graphene.List(ErrorType)
|
||||
|
||||
@classmethod
|
||||
def __init_subclass_with_meta__(
|
||||
cls, form_class=None, only_fields=(), exclude_fields=(), **options
|
||||
):
|
||||
|
||||
if not form_class:
|
||||
raise Exception("form_class is required for DjangoFormMutation")
|
||||
|
||||
form = form_class()
|
||||
input_fields = fields_for_form(form, only_fields, exclude_fields)
|
||||
output_fields = fields_for_form(form, only_fields, exclude_fields)
|
||||
|
||||
_meta = DjangoFormMutationOptions(cls)
|
||||
_meta.form_class = form_class
|
||||
_meta.fields = yank_fields_from_attrs(output_fields, _as=Field)
|
||||
|
||||
input_fields = yank_fields_from_attrs(input_fields, _as=InputField)
|
||||
super(DjangoFormMutation, cls).__init_subclass_with_meta__(
|
||||
_meta=_meta, input_fields=input_fields, **options
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def perform_mutate(cls, form, info):
|
||||
if hasattr(form, "save"):
|
||||
# `save` method won't exist on plain Django forms, but this mutation can
|
||||
# in theory be used with `ModelForm`s as well and we do want to save them.
|
||||
form.save()
|
||||
return cls(errors=[], **form.cleaned_data)
|
||||
|
||||
|
||||
class DjangoModelDjangoFormMutationOptions(DjangoFormMutationOptions):
|
||||
model = None
|
||||
return_field_name = None
|
||||
|
||||
|
||||
class DjangoModelFormMutation(BaseDjangoFormMutation):
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
errors = graphene.List(ErrorType)
|
||||
|
||||
@classmethod
|
||||
def __init_subclass_with_meta__(
|
||||
cls,
|
||||
form_class=None,
|
||||
model=None,
|
||||
return_field_name=None,
|
||||
only_fields=(),
|
||||
exclude_fields=(),
|
||||
**options
|
||||
):
|
||||
|
||||
if not form_class:
|
||||
raise Exception("form_class is required for DjangoModelFormMutation")
|
||||
|
||||
if not model:
|
||||
model = form_class._meta.model
|
||||
|
||||
if not model:
|
||||
raise Exception("model is required for DjangoModelFormMutation")
|
||||
|
||||
form = form_class()
|
||||
input_fields = fields_for_form(form, only_fields, exclude_fields)
|
||||
if "id" not in exclude_fields:
|
||||
input_fields["id"] = graphene.ID()
|
||||
|
||||
registry = get_global_registry()
|
||||
model_type = registry.get_type_for_model(model)
|
||||
if not model_type:
|
||||
raise Exception("No type registered for model: {}".format(model.__name__))
|
||||
|
||||
if not return_field_name:
|
||||
model_name = model.__name__
|
||||
return_field_name = model_name[:1].lower() + model_name[1:]
|
||||
|
||||
output_fields = OrderedDict()
|
||||
output_fields[return_field_name] = graphene.Field(model_type)
|
||||
|
||||
_meta = DjangoModelDjangoFormMutationOptions(cls)
|
||||
_meta.form_class = form_class
|
||||
_meta.model = model
|
||||
_meta.return_field_name = return_field_name
|
||||
_meta.fields = yank_fields_from_attrs(output_fields, _as=Field)
|
||||
|
||||
input_fields = yank_fields_from_attrs(input_fields, _as=InputField)
|
||||
super(DjangoModelFormMutation, cls).__init_subclass_with_meta__(
|
||||
_meta=_meta, input_fields=input_fields, **options
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def mutate_and_get_payload(cls, root, info, **input):
|
||||
form = cls.get_form(root, info, **input)
|
||||
|
||||
if form.is_valid():
|
||||
return cls.perform_mutate(form, info)
|
||||
else:
|
||||
errors = ErrorType.from_errors(form.errors)
|
||||
_set_errors_flag_to_context(info)
|
||||
|
||||
return cls(errors=errors)
|
||||
|
||||
@classmethod
|
||||
def perform_mutate(cls, form, info):
|
||||
obj = form.save()
|
||||
kwargs = {cls._meta.return_field_name: obj}
|
||||
return cls(errors=[], **kwargs)
|
||||
|
||||
|
||||
def _set_errors_flag_to_context(info):
|
||||
# This is not ideal but necessary to keep the response errors empty
|
||||
if info and info.context:
|
||||
setattr(info.context, MUTATION_ERRORS_FLAG, True)
|
|
@ -1,121 +0,0 @@
|
|||
from django import forms
|
||||
from py.test import raises
|
||||
|
||||
import graphene
|
||||
from graphene import (
|
||||
String,
|
||||
Int,
|
||||
Boolean,
|
||||
Float,
|
||||
ID,
|
||||
UUID,
|
||||
List,
|
||||
NonNull,
|
||||
DateTime,
|
||||
Date,
|
||||
Time,
|
||||
)
|
||||
|
||||
from ..converter import convert_form_field
|
||||
|
||||
|
||||
def assert_conversion(django_field, graphene_field, *args):
|
||||
field = django_field(*args, help_text="Custom Help Text")
|
||||
graphene_type = convert_form_field(field)
|
||||
assert isinstance(graphene_type, graphene_field)
|
||||
field = graphene_type.Field()
|
||||
assert field.description == "Custom Help Text"
|
||||
return field
|
||||
|
||||
|
||||
def test_should_unknown_django_field_raise_exception():
|
||||
with raises(Exception) as excinfo:
|
||||
convert_form_field(None)
|
||||
assert "Don't know how to convert the Django form field" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_should_date_convert_date():
|
||||
assert_conversion(forms.DateField, Date)
|
||||
|
||||
|
||||
def test_should_time_convert_time():
|
||||
assert_conversion(forms.TimeField, Time)
|
||||
|
||||
|
||||
def test_should_date_time_convert_date_time():
|
||||
assert_conversion(forms.DateTimeField, DateTime)
|
||||
|
||||
|
||||
def test_should_char_convert_string():
|
||||
assert_conversion(forms.CharField, String)
|
||||
|
||||
|
||||
def test_should_email_convert_string():
|
||||
assert_conversion(forms.EmailField, String)
|
||||
|
||||
|
||||
def test_should_slug_convert_string():
|
||||
assert_conversion(forms.SlugField, String)
|
||||
|
||||
|
||||
def test_should_url_convert_string():
|
||||
assert_conversion(forms.URLField, String)
|
||||
|
||||
|
||||
def test_should_choice_convert_string():
|
||||
assert_conversion(forms.ChoiceField, String)
|
||||
|
||||
|
||||
def test_should_base_field_convert_string():
|
||||
assert_conversion(forms.Field, String)
|
||||
|
||||
|
||||
def test_should_regex_convert_string():
|
||||
assert_conversion(forms.RegexField, String, "[0-9]+")
|
||||
|
||||
|
||||
def test_should_uuid_convert_string():
|
||||
if hasattr(forms, "UUIDField"):
|
||||
assert_conversion(forms.UUIDField, UUID)
|
||||
|
||||
|
||||
def test_should_integer_convert_int():
|
||||
assert_conversion(forms.IntegerField, Int)
|
||||
|
||||
|
||||
def test_should_boolean_convert_boolean():
|
||||
field = assert_conversion(forms.BooleanField, Boolean)
|
||||
assert isinstance(field.type, NonNull)
|
||||
|
||||
|
||||
def test_should_nullboolean_convert_boolean():
|
||||
field = assert_conversion(forms.NullBooleanField, Boolean)
|
||||
assert not isinstance(field.type, NonNull)
|
||||
|
||||
|
||||
def test_should_float_convert_float():
|
||||
assert_conversion(forms.FloatField, Float)
|
||||
|
||||
|
||||
def test_should_decimal_convert_float():
|
||||
assert_conversion(forms.DecimalField, Float)
|
||||
|
||||
|
||||
def test_should_multiple_choice_convert_list():
|
||||
field = forms.MultipleChoiceField()
|
||||
graphene_type = convert_form_field(field)
|
||||
assert isinstance(graphene_type, List)
|
||||
assert graphene_type.of_type == String
|
||||
|
||||
|
||||
def test_should_model_multiple_choice_convert_connectionorlist():
|
||||
field = forms.ModelMultipleChoiceField(queryset=None)
|
||||
graphene_type = convert_form_field(field)
|
||||
assert isinstance(graphene_type, List)
|
||||
assert graphene_type.of_type == ID
|
||||
|
||||
|
||||
def test_should_manytoone_convert_connectionorlist():
|
||||
field = forms.ModelChoiceField(queryset=None)
|
||||
graphene_type = convert_form_field(field)
|
||||
assert isinstance(graphene_type, ID)
|
|
@ -1,385 +0,0 @@
|
|||
import pytest
|
||||
from django import forms
|
||||
from django.core.exceptions import ValidationError
|
||||
from py.test import raises
|
||||
|
||||
from graphene import Field, ObjectType, Schema, String
|
||||
from graphene_django import DjangoObjectType
|
||||
from graphene_django.tests.forms import PetForm
|
||||
from graphene_django.tests.models import Pet
|
||||
from graphene_django.tests.mutations import PetMutation
|
||||
|
||||
from ..mutation import DjangoFormMutation, DjangoModelFormMutation
|
||||
|
||||
|
||||
class MyForm(forms.Form):
|
||||
text = forms.CharField()
|
||||
|
||||
def clean_text(self):
|
||||
text = self.cleaned_data["text"]
|
||||
if text == "INVALID_INPUT":
|
||||
raise ValidationError("Invalid input")
|
||||
return text
|
||||
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
|
||||
def test_needs_form_class():
|
||||
with raises(Exception) as exc:
|
||||
|
||||
class MyMutation(DjangoFormMutation):
|
||||
pass
|
||||
|
||||
assert exc.value.args[0] == "form_class is required for DjangoFormMutation"
|
||||
|
||||
|
||||
def test_has_output_fields():
|
||||
class MyMutation(DjangoFormMutation):
|
||||
class Meta:
|
||||
form_class = MyForm
|
||||
|
||||
assert "errors" in MyMutation._meta.fields
|
||||
|
||||
|
||||
def test_has_input_fields():
|
||||
class MyMutation(DjangoFormMutation):
|
||||
class Meta:
|
||||
form_class = MyForm
|
||||
|
||||
assert "text" in MyMutation.Input._meta.fields
|
||||
|
||||
|
||||
def test_mutation_error_camelcased(graphene_settings):
|
||||
class ExtraPetForm(PetForm):
|
||||
test_field = forms.CharField(required=True)
|
||||
|
||||
class PetType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Pet
|
||||
fields = "__all__"
|
||||
|
||||
class PetMutation(DjangoModelFormMutation):
|
||||
pet = Field(PetType)
|
||||
|
||||
class Meta:
|
||||
form_class = ExtraPetForm
|
||||
|
||||
result = PetMutation.mutate_and_get_payload(None, None)
|
||||
assert {f.field for f in result.errors} == {"name", "age", "testField"}
|
||||
graphene_settings.CAMELCASE_ERRORS = False
|
||||
result = PetMutation.mutate_and_get_payload(None, None)
|
||||
assert {f.field for f in result.errors} == {"name", "age", "test_field"}
|
||||
|
||||
|
||||
class MockQuery(ObjectType):
|
||||
a = String()
|
||||
|
||||
|
||||
def test_form_invalid_form():
|
||||
class MyMutation(DjangoFormMutation):
|
||||
class Meta:
|
||||
form_class = MyForm
|
||||
|
||||
class Mutation(ObjectType):
|
||||
my_mutation = MyMutation.Field()
|
||||
|
||||
schema = Schema(query=MockQuery, mutation=Mutation)
|
||||
|
||||
result = schema.execute(
|
||||
""" mutation MyMutation {
|
||||
myMutation(input: { text: "INVALID_INPUT" }) {
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
text
|
||||
}
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
assert result.errors is None
|
||||
assert result.data["myMutation"]["errors"] == [
|
||||
{"field": "text", "messages": ["Invalid input"]}
|
||||
]
|
||||
|
||||
|
||||
def test_form_valid_input():
|
||||
class MyMutation(DjangoFormMutation):
|
||||
class Meta:
|
||||
form_class = MyForm
|
||||
|
||||
class Mutation(ObjectType):
|
||||
my_mutation = MyMutation.Field()
|
||||
|
||||
schema = Schema(query=MockQuery, mutation=Mutation)
|
||||
|
||||
result = schema.execute(
|
||||
""" mutation MyMutation {
|
||||
myMutation(input: { text: "VALID_INPUT" }) {
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
text
|
||||
}
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
assert result.errors is None
|
||||
assert result.data["myMutation"]["errors"] == []
|
||||
assert result.data["myMutation"]["text"] == "VALID_INPUT"
|
||||
|
||||
|
||||
def test_default_meta_fields():
|
||||
assert PetMutation._meta.model is Pet
|
||||
assert PetMutation._meta.return_field_name == "pet"
|
||||
assert "pet" in PetMutation._meta.fields
|
||||
|
||||
|
||||
def test_default_input_meta_fields():
|
||||
assert PetMutation._meta.model is Pet
|
||||
assert PetMutation._meta.return_field_name == "pet"
|
||||
assert "name" in PetMutation.Input._meta.fields
|
||||
assert "client_mutation_id" in PetMutation.Input._meta.fields
|
||||
assert "id" in PetMutation.Input._meta.fields
|
||||
|
||||
|
||||
def test_exclude_fields_input_meta_fields():
|
||||
class PetType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Pet
|
||||
fields = "__all__"
|
||||
|
||||
class PetMutation(DjangoModelFormMutation):
|
||||
pet = Field(PetType)
|
||||
|
||||
class Meta:
|
||||
form_class = PetForm
|
||||
exclude_fields = ["id"]
|
||||
|
||||
assert PetMutation._meta.model is Pet
|
||||
assert PetMutation._meta.return_field_name == "pet"
|
||||
assert "name" in PetMutation.Input._meta.fields
|
||||
assert "age" in PetMutation.Input._meta.fields
|
||||
assert "client_mutation_id" in PetMutation.Input._meta.fields
|
||||
assert "id" not in PetMutation.Input._meta.fields
|
||||
|
||||
|
||||
def test_custom_return_field_name():
|
||||
class PetType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Pet
|
||||
fields = "__all__"
|
||||
|
||||
class PetMutation(DjangoModelFormMutation):
|
||||
pet = Field(PetType)
|
||||
|
||||
class Meta:
|
||||
form_class = PetForm
|
||||
model = Pet
|
||||
return_field_name = "animal"
|
||||
|
||||
assert PetMutation._meta.model is Pet
|
||||
assert PetMutation._meta.return_field_name == "animal"
|
||||
assert "animal" in PetMutation._meta.fields
|
||||
|
||||
|
||||
def test_model_form_mutation_mutate_existing():
|
||||
class Mutation(ObjectType):
|
||||
pet_mutation = PetMutation.Field()
|
||||
|
||||
schema = Schema(query=MockQuery, mutation=Mutation)
|
||||
|
||||
pet = Pet.objects.create(name="Axel", age=10)
|
||||
|
||||
result = schema.execute(
|
||||
""" mutation PetMutation($pk: ID!) {
|
||||
petMutation(input: { id: $pk, name: "Mia", age: 10 }) {
|
||||
pet {
|
||||
name
|
||||
age
|
||||
}
|
||||
}
|
||||
}
|
||||
""",
|
||||
variable_values={"pk": pet.pk},
|
||||
)
|
||||
|
||||
assert result.errors is None
|
||||
assert result.data["petMutation"]["pet"] == {"name": "Mia", "age": 10}
|
||||
|
||||
assert Pet.objects.count() == 1
|
||||
pet.refresh_from_db()
|
||||
assert pet.name == "Mia"
|
||||
|
||||
|
||||
def test_model_form_mutation_creates_new():
|
||||
class Mutation(ObjectType):
|
||||
pet_mutation = PetMutation.Field()
|
||||
|
||||
schema = Schema(query=MockQuery, mutation=Mutation)
|
||||
|
||||
result = schema.execute(
|
||||
""" mutation PetMutation {
|
||||
petMutation(input: { name: "Mia", age: 10 }) {
|
||||
pet {
|
||||
name
|
||||
age
|
||||
}
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
)
|
||||
assert result.errors is None
|
||||
assert result.data["petMutation"]["pet"] == {"name": "Mia", "age": 10}
|
||||
|
||||
assert Pet.objects.count() == 1
|
||||
pet = Pet.objects.get()
|
||||
assert pet.name == "Mia"
|
||||
assert pet.age == 10
|
||||
|
||||
|
||||
def test_model_form_mutation_invalid_input():
|
||||
class Mutation(ObjectType):
|
||||
pet_mutation = PetMutation.Field()
|
||||
|
||||
schema = Schema(query=MockQuery, mutation=Mutation)
|
||||
|
||||
result = schema.execute(
|
||||
""" mutation PetMutation {
|
||||
petMutation(input: { name: "Mia", age: 99 }) {
|
||||
pet {
|
||||
name
|
||||
age
|
||||
}
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
)
|
||||
assert result.errors is None
|
||||
assert result.data["petMutation"]["pet"] is None
|
||||
assert result.data["petMutation"]["errors"] == [
|
||||
{"field": "age", "messages": ["Too old"]}
|
||||
]
|
||||
|
||||
assert Pet.objects.count() == 0
|
||||
|
||||
|
||||
def test_model_form_mutation_mutate_invalid_form():
|
||||
result = PetMutation.mutate_and_get_payload(None, None)
|
||||
|
||||
# A pet was not created
|
||||
Pet.objects.count() == 0
|
||||
|
||||
fields_w_error = [e.field for e in result.errors]
|
||||
assert len(result.errors) == 2
|
||||
assert result.errors[0].messages == ["This field is required."]
|
||||
assert result.errors[1].messages == ["This field is required."]
|
||||
assert "age" in fields_w_error
|
||||
assert "name" in fields_w_error
|
||||
|
||||
|
||||
def test_model_form_mutation_multiple_creation_valid():
|
||||
class Mutation(ObjectType):
|
||||
pet_mutation = PetMutation.Field()
|
||||
|
||||
schema = Schema(query=MockQuery, mutation=Mutation)
|
||||
|
||||
result = schema.execute(
|
||||
"""
|
||||
mutation PetMutations {
|
||||
petMutation1: petMutation(input: { name: "Mia", age: 10 }) {
|
||||
pet {
|
||||
name
|
||||
age
|
||||
}
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
petMutation2: petMutation(input: { name: "Enzo", age: 0 }) {
|
||||
pet {
|
||||
name
|
||||
age
|
||||
}
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
)
|
||||
assert result.errors is None
|
||||
assert result.data["petMutation1"]["pet"] == {"name": "Mia", "age": 10}
|
||||
assert result.data["petMutation2"]["pet"] == {"name": "Enzo", "age": 0}
|
||||
|
||||
assert Pet.objects.count() == 2
|
||||
|
||||
pet1 = Pet.objects.first()
|
||||
assert pet1.name == "Mia"
|
||||
assert pet1.age == 10
|
||||
|
||||
pet2 = Pet.objects.last()
|
||||
assert pet2.name == "Enzo"
|
||||
assert pet2.age == 0
|
||||
|
||||
|
||||
def test_model_form_mutation_multiple_creation_invalid():
|
||||
class Mutation(ObjectType):
|
||||
pet_mutation = PetMutation.Field()
|
||||
|
||||
schema = Schema(query=MockQuery, mutation=Mutation)
|
||||
|
||||
result = schema.execute(
|
||||
"""
|
||||
mutation PetMutations {
|
||||
petMutation1: petMutation(input: { name: "Mia", age: 99 }) {
|
||||
pet {
|
||||
name
|
||||
age
|
||||
}
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
petMutation2: petMutation(input: { name: "Enzo", age: 0 }) {
|
||||
pet {
|
||||
name
|
||||
age
|
||||
}
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
)
|
||||
assert result.errors is None
|
||||
|
||||
assert result.data["petMutation1"]["pet"] is None
|
||||
assert result.data["petMutation1"]["errors"] == [
|
||||
{"field": "age", "messages": ["Too old"]}
|
||||
]
|
||||
|
||||
assert result.data["petMutation2"]["pet"] == {"name": "Enzo", "age": 0}
|
||||
|
||||
assert Pet.objects.count() == 1
|
||||
|
||||
pet = Pet.objects.get()
|
||||
assert pet.name == "Enzo"
|
||||
assert pet.age == 0
|
|
@ -1 +0,0 @@
|
|||
from ..types import ErrorType # noqa Import ErrorType for backwards compatability
|
|
@ -1,115 +0,0 @@
|
|||
import os
|
||||
import importlib
|
||||
import json
|
||||
import functools
|
||||
|
||||
from django.core.management.base import BaseCommand, CommandError
|
||||
from django.utils import autoreload
|
||||
|
||||
from graphql import print_schema
|
||||
from graphene_django.settings import graphene_settings
|
||||
|
||||
|
||||
class CommandArguments(BaseCommand):
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
"--schema",
|
||||
type=str,
|
||||
dest="schema",
|
||||
default=graphene_settings.SCHEMA,
|
||||
help="Django app containing schema to dump, e.g. myproject.core.schema.schema",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--out",
|
||||
type=str,
|
||||
dest="out",
|
||||
default=graphene_settings.SCHEMA_OUTPUT,
|
||||
help="Output file, --out=- prints to stdout (default: schema.json)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--indent",
|
||||
type=int,
|
||||
dest="indent",
|
||||
default=graphene_settings.SCHEMA_INDENT,
|
||||
help="Output file indent (default: None)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--watch",
|
||||
dest="watch",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Updates the schema on file changes (default: False)",
|
||||
)
|
||||
|
||||
|
||||
class Command(CommandArguments):
|
||||
help = "Dump Graphene schema as a JSON or GraphQL file"
|
||||
can_import_settings = True
|
||||
requires_system_checks = False
|
||||
|
||||
def save_json_file(self, out, schema_dict, indent):
|
||||
with open(out, "w") as outfile:
|
||||
json.dump(schema_dict, outfile, indent=indent, sort_keys=True)
|
||||
|
||||
def save_graphql_file(self, out, schema):
|
||||
with open(out, "w", encoding="utf-8") as outfile:
|
||||
outfile.write(print_schema(schema.graphql_schema))
|
||||
|
||||
def get_schema(self, schema, out, indent):
|
||||
schema_dict = {"data": schema.introspect()}
|
||||
if out == "-" or out == "-.json":
|
||||
self.stdout.write(json.dumps(schema_dict, indent=indent, sort_keys=True))
|
||||
elif out == "-.graphql":
|
||||
self.stdout.write(print_schema(schema))
|
||||
else:
|
||||
# Determine format
|
||||
_, file_extension = os.path.splitext(out)
|
||||
|
||||
if file_extension == ".graphql":
|
||||
self.save_graphql_file(out, schema)
|
||||
elif file_extension == ".json":
|
||||
self.save_json_file(out, schema_dict, indent)
|
||||
else:
|
||||
raise CommandError(
|
||||
'Unrecognised file format "{}"'.format(file_extension)
|
||||
)
|
||||
|
||||
style = getattr(self, "style", None)
|
||||
success = getattr(style, "SUCCESS", lambda x: x)
|
||||
|
||||
self.stdout.write(
|
||||
success("Successfully dumped GraphQL schema to {}".format(out))
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
options_schema = options.get("schema")
|
||||
|
||||
if options_schema and type(options_schema) is str:
|
||||
module_str, schema_name = options_schema.rsplit(".", 1)
|
||||
mod = importlib.import_module(module_str)
|
||||
schema = getattr(mod, schema_name)
|
||||
|
||||
elif options_schema:
|
||||
schema = options_schema
|
||||
|
||||
else:
|
||||
schema = graphene_settings.SCHEMA
|
||||
|
||||
out = options.get("out") or graphene_settings.SCHEMA_OUTPUT
|
||||
|
||||
if not schema:
|
||||
raise CommandError(
|
||||
"Specify schema on GRAPHENE.SCHEMA setting or by using --schema"
|
||||
)
|
||||
|
||||
indent = options.get("indent")
|
||||
watch = options.get("watch")
|
||||
if watch:
|
||||
autoreload.run_with_reloader(
|
||||
functools.partial(self.get_schema, schema, out, indent)
|
||||
)
|
||||
else:
|
||||
self.get_schema(schema, out, indent)
|
|
@ -1,43 +0,0 @@
|
|||
class Registry(object):
|
||||
def __init__(self):
|
||||
self._registry = {}
|
||||
self._field_registry = {}
|
||||
|
||||
def register(self, cls):
|
||||
from .types import DjangoObjectType
|
||||
|
||||
assert issubclass(
|
||||
cls, DjangoObjectType
|
||||
), 'Only DjangoObjectTypes can be registered, received "{}"'.format(
|
||||
cls.__name__
|
||||
)
|
||||
assert cls._meta.registry == self, "Registry for a Model have to match."
|
||||
# assert self.get_type_for_model(cls._meta.model) == cls, (
|
||||
# 'Multiple DjangoObjectTypes registered for "{}"'.format(cls._meta.model)
|
||||
# )
|
||||
if not getattr(cls._meta, "skip_registry", False):
|
||||
self._registry[cls._meta.model] = cls
|
||||
|
||||
def get_type_for_model(self, model):
|
||||
return self._registry.get(model)
|
||||
|
||||
def register_converted_field(self, field, converted):
|
||||
self._field_registry[field] = converted
|
||||
|
||||
def get_converted_field(self, field):
|
||||
return self._field_registry.get(field)
|
||||
|
||||
|
||||
registry = None
|
||||
|
||||
|
||||
def get_global_registry():
|
||||
global registry
|
||||
if not registry:
|
||||
registry = Registry()
|
||||
return registry
|
||||
|
||||
|
||||
def reset_global_registry():
|
||||
global registry
|
||||
registry = None
|
|
@ -1,16 +0,0 @@
|
|||
from django.db import models
|
||||
|
||||
|
||||
class MyFakeModel(models.Model):
|
||||
cool_name = models.CharField(max_length=50)
|
||||
created = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
|
||||
class MyFakeModelWithPassword(models.Model):
|
||||
cool_name = models.CharField(max_length=50)
|
||||
password = models.CharField(max_length=50)
|
||||
|
||||
|
||||
class MyFakeModelWithDate(models.Model):
|
||||
cool_name = models.CharField(max_length=50)
|
||||
last_edited = models.DateField()
|
|
@ -1,175 +0,0 @@
|
|||
from collections import OrderedDict
|
||||
|
||||
from django.shortcuts import get_object_or_404
|
||||
from rest_framework import serializers
|
||||
|
||||
import graphene
|
||||
from graphene.relay.mutation import ClientIDMutation
|
||||
from graphene.types import Field, InputField
|
||||
from graphene.types.mutation import MutationOptions
|
||||
from graphene.types.objecttype import yank_fields_from_attrs
|
||||
|
||||
from ..types import ErrorType
|
||||
from .serializer_converter import convert_serializer_field
|
||||
|
||||
|
||||
class SerializerMutationOptions(MutationOptions):
|
||||
lookup_field = None
|
||||
model_class = None
|
||||
model_operations = ["create", "update"]
|
||||
serializer_class = None
|
||||
|
||||
|
||||
def fields_for_serializer(
|
||||
serializer,
|
||||
only_fields,
|
||||
exclude_fields,
|
||||
is_input=False,
|
||||
convert_choices_to_enum=True,
|
||||
lookup_field=None,
|
||||
):
|
||||
fields = OrderedDict()
|
||||
for name, field in serializer.fields.items():
|
||||
is_not_in_only = only_fields and name not in only_fields
|
||||
is_excluded = any(
|
||||
[
|
||||
name in exclude_fields,
|
||||
field.write_only
|
||||
and not is_input, # don't show write_only fields in Query
|
||||
field.read_only
|
||||
and is_input
|
||||
and lookup_field != name, # don't show read_only fields in Input
|
||||
]
|
||||
)
|
||||
|
||||
if is_not_in_only or is_excluded:
|
||||
continue
|
||||
|
||||
fields[name] = convert_serializer_field(
|
||||
field, is_input=is_input, convert_choices_to_enum=convert_choices_to_enum
|
||||
)
|
||||
return fields
|
||||
|
||||
|
||||
class SerializerMutation(ClientIDMutation):
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
errors = graphene.List(
|
||||
ErrorType, description="May contain more than one error for same field."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __init_subclass_with_meta__(
|
||||
cls,
|
||||
lookup_field=None,
|
||||
serializer_class=None,
|
||||
model_class=None,
|
||||
model_operations=("create", "update"),
|
||||
only_fields=(),
|
||||
exclude_fields=(),
|
||||
convert_choices_to_enum=True,
|
||||
_meta=None,
|
||||
**options
|
||||
):
|
||||
|
||||
if not serializer_class:
|
||||
raise Exception("serializer_class is required for the SerializerMutation")
|
||||
|
||||
if "update" not in model_operations and "create" not in model_operations:
|
||||
raise Exception('model_operations must contain "create" and/or "update"')
|
||||
|
||||
serializer = serializer_class()
|
||||
if model_class is None:
|
||||
serializer_meta = getattr(serializer_class, "Meta", None)
|
||||
if serializer_meta:
|
||||
model_class = getattr(serializer_meta, "model", None)
|
||||
|
||||
if lookup_field is None and model_class:
|
||||
lookup_field = model_class._meta.pk.name
|
||||
|
||||
input_fields = fields_for_serializer(
|
||||
serializer,
|
||||
only_fields,
|
||||
exclude_fields,
|
||||
is_input=True,
|
||||
convert_choices_to_enum=convert_choices_to_enum,
|
||||
lookup_field=lookup_field,
|
||||
)
|
||||
output_fields = fields_for_serializer(
|
||||
serializer,
|
||||
only_fields,
|
||||
exclude_fields,
|
||||
is_input=False,
|
||||
convert_choices_to_enum=convert_choices_to_enum,
|
||||
lookup_field=lookup_field,
|
||||
)
|
||||
|
||||
if not _meta:
|
||||
_meta = SerializerMutationOptions(cls)
|
||||
_meta.lookup_field = lookup_field
|
||||
_meta.model_operations = model_operations
|
||||
_meta.serializer_class = serializer_class
|
||||
_meta.model_class = model_class
|
||||
_meta.fields = yank_fields_from_attrs(output_fields, _as=Field)
|
||||
|
||||
input_fields = yank_fields_from_attrs(input_fields, _as=InputField)
|
||||
super(SerializerMutation, cls).__init_subclass_with_meta__(
|
||||
_meta=_meta, input_fields=input_fields, **options
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_serializer_kwargs(cls, root, info, **input):
|
||||
lookup_field = cls._meta.lookup_field
|
||||
model_class = cls._meta.model_class
|
||||
|
||||
if model_class:
|
||||
if "update" in cls._meta.model_operations and lookup_field in input:
|
||||
instance = get_object_or_404(
|
||||
model_class, **{lookup_field: input[lookup_field]}
|
||||
)
|
||||
partial = True
|
||||
elif "create" in cls._meta.model_operations:
|
||||
instance = None
|
||||
partial = False
|
||||
else:
|
||||
raise Exception(
|
||||
'Invalid update operation. Input parameter "{}" required.'.format(
|
||||
lookup_field
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"instance": instance,
|
||||
"data": input,
|
||||
"context": {"request": info.context},
|
||||
"partial": partial,
|
||||
}
|
||||
|
||||
return {"data": input, "context": {"request": info.context}}
|
||||
|
||||
@classmethod
|
||||
def mutate_and_get_payload(cls, root, info, **input):
|
||||
kwargs = cls.get_serializer_kwargs(root, info, **input)
|
||||
serializer = cls._meta.serializer_class(**kwargs)
|
||||
|
||||
if serializer.is_valid():
|
||||
return cls.perform_mutate(serializer, info)
|
||||
else:
|
||||
errors = ErrorType.from_errors(serializer.errors)
|
||||
|
||||
return cls(errors=errors)
|
||||
|
||||
@classmethod
|
||||
def perform_mutate(cls, serializer, info):
|
||||
obj = serializer.save()
|
||||
|
||||
kwargs = {}
|
||||
for f, field in serializer.fields.items():
|
||||
if not field.write_only:
|
||||
if isinstance(field, serializers.SerializerMethodField):
|
||||
kwargs[f] = field.to_representation(obj)
|
||||
else:
|
||||
kwargs[f] = field.get_attribute(obj)
|
||||
|
||||
return cls(errors=None, **kwargs)
|
|
@ -1,159 +0,0 @@
|
|||
from functools import singledispatch
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from rest_framework import serializers
|
||||
|
||||
import graphene
|
||||
|
||||
from ..registry import get_global_registry
|
||||
from ..converter import convert_choices_to_named_enum_with_descriptions
|
||||
from .types import DictType
|
||||
|
||||
|
||||
@singledispatch
|
||||
def get_graphene_type_from_serializer_field(field):
|
||||
raise ImproperlyConfigured(
|
||||
"Don't know how to convert the serializer field %s (%s) "
|
||||
"to Graphene type" % (field, field.__class__)
|
||||
)
|
||||
|
||||
|
||||
def convert_serializer_field(field, is_input=True, convert_choices_to_enum=True):
|
||||
"""
|
||||
Converts a django rest frameworks field to a graphql field
|
||||
and marks the field as required if we are creating an input type
|
||||
and the field itself is required
|
||||
"""
|
||||
|
||||
if isinstance(field, serializers.ChoiceField) and not convert_choices_to_enum:
|
||||
graphql_type = graphene.String
|
||||
else:
|
||||
graphql_type = get_graphene_type_from_serializer_field(field)
|
||||
|
||||
args = []
|
||||
kwargs = {"description": field.help_text, "required": is_input and field.required}
|
||||
|
||||
# if it is a tuple or a list it means that we are returning
|
||||
# the graphql type and the child type
|
||||
if isinstance(graphql_type, (list, tuple)):
|
||||
kwargs["of_type"] = graphql_type[1]
|
||||
graphql_type = graphql_type[0]
|
||||
|
||||
if isinstance(field, serializers.ModelSerializer):
|
||||
if is_input:
|
||||
graphql_type = convert_serializer_to_input_type(field.__class__)
|
||||
else:
|
||||
global_registry = get_global_registry()
|
||||
field_model = field.Meta.model
|
||||
args = [global_registry.get_type_for_model(field_model)]
|
||||
elif isinstance(field, serializers.ListSerializer):
|
||||
field = field.child
|
||||
if is_input:
|
||||
kwargs["of_type"] = convert_serializer_to_input_type(field.__class__)
|
||||
else:
|
||||
del kwargs["of_type"]
|
||||
global_registry = get_global_registry()
|
||||
field_model = field.Meta.model
|
||||
args = [global_registry.get_type_for_model(field_model)]
|
||||
|
||||
return graphql_type(*args, **kwargs)
|
||||
|
||||
|
||||
def convert_serializer_to_input_type(serializer_class):
|
||||
cached_type = convert_serializer_to_input_type.cache.get(
|
||||
serializer_class.__name__, None
|
||||
)
|
||||
if cached_type:
|
||||
return cached_type
|
||||
serializer = serializer_class()
|
||||
|
||||
items = {
|
||||
name: convert_serializer_field(field)
|
||||
for name, field in serializer.fields.items()
|
||||
}
|
||||
ret_type = type(
|
||||
"{}Input".format(serializer.__class__.__name__),
|
||||
(graphene.InputObjectType,),
|
||||
items,
|
||||
)
|
||||
convert_serializer_to_input_type.cache[serializer_class.__name__] = ret_type
|
||||
return ret_type
|
||||
|
||||
|
||||
convert_serializer_to_input_type.cache = {}
|
||||
|
||||
|
||||
@get_graphene_type_from_serializer_field.register(serializers.Field)
|
||||
def convert_serializer_field_to_string(field):
|
||||
return graphene.String
|
||||
|
||||
|
||||
@get_graphene_type_from_serializer_field.register(serializers.ModelSerializer)
|
||||
def convert_serializer_to_field(field):
|
||||
return graphene.Field
|
||||
|
||||
|
||||
@get_graphene_type_from_serializer_field.register(serializers.ListSerializer)
|
||||
def convert_list_serializer_to_field(field):
|
||||
child_type = get_graphene_type_from_serializer_field(field.child)
|
||||
return (graphene.List, child_type)
|
||||
|
||||
|
||||
@get_graphene_type_from_serializer_field.register(serializers.IntegerField)
|
||||
def convert_serializer_field_to_int(field):
|
||||
return graphene.Int
|
||||
|
||||
|
||||
@get_graphene_type_from_serializer_field.register(serializers.BooleanField)
|
||||
def convert_serializer_field_to_bool(field):
|
||||
return graphene.Boolean
|
||||
|
||||
|
||||
@get_graphene_type_from_serializer_field.register(serializers.FloatField)
|
||||
@get_graphene_type_from_serializer_field.register(serializers.DecimalField)
|
||||
def convert_serializer_field_to_float(field):
|
||||
return graphene.Float
|
||||
|
||||
|
||||
@get_graphene_type_from_serializer_field.register(serializers.DateTimeField)
|
||||
def convert_serializer_field_to_datetime_time(field):
|
||||
return graphene.types.datetime.DateTime
|
||||
|
||||
|
||||
@get_graphene_type_from_serializer_field.register(serializers.DateField)
|
||||
def convert_serializer_field_to_date_time(field):
|
||||
return graphene.types.datetime.Date
|
||||
|
||||
|
||||
@get_graphene_type_from_serializer_field.register(serializers.TimeField)
|
||||
def convert_serializer_field_to_time(field):
|
||||
return graphene.types.datetime.Time
|
||||
|
||||
|
||||
@get_graphene_type_from_serializer_field.register(serializers.ListField)
|
||||
def convert_serializer_field_to_list(field, is_input=True):
|
||||
child_type = get_graphene_type_from_serializer_field(field.child)
|
||||
return (graphene.List, child_type)
|
||||
|
||||
|
||||
@get_graphene_type_from_serializer_field.register(serializers.DictField)
|
||||
def convert_serializer_field_to_dict(field):
|
||||
return DictType
|
||||
|
||||
|
||||
@get_graphene_type_from_serializer_field.register(serializers.JSONField)
|
||||
def convert_serializer_field_to_jsonstring(field):
|
||||
return graphene.types.json.JSONString
|
||||
|
||||
|
||||
@get_graphene_type_from_serializer_field.register(serializers.MultipleChoiceField)
|
||||
def convert_serializer_field_to_list_of_enum(field):
|
||||
child_type = convert_serializer_field_to_enum(field)
|
||||
return (graphene.List, child_type)
|
||||
|
||||
|
||||
@get_graphene_type_from_serializer_field.register(serializers.ChoiceField)
|
||||
def convert_serializer_field_to_enum(field):
|
||||
# enums require a name
|
||||
name = field.field_name or field.source or "Choices"
|
||||
return convert_choices_to_named_enum_with_descriptions(name, field.choices)
|
|
@ -1,220 +0,0 @@
|
|||
import copy
|
||||
|
||||
import graphene
|
||||
from django.db import models
|
||||
from graphene import InputObjectType
|
||||
from py.test import raises
|
||||
from rest_framework import serializers
|
||||
|
||||
from ..serializer_converter import convert_serializer_field
|
||||
from ..types import DictType
|
||||
|
||||
|
||||
def _get_type(
|
||||
rest_framework_field, is_input=True, convert_choices_to_enum=True, **kwargs
|
||||
):
|
||||
# prevents the following error:
|
||||
# AssertionError: The `source` argument is not meaningful when applied to a `child=` field.
|
||||
# Remove `source=` from the field declaration.
|
||||
# since we are reusing the same child in when testing the required attribute
|
||||
|
||||
if "child" in kwargs:
|
||||
kwargs["child"] = copy.deepcopy(kwargs["child"])
|
||||
|
||||
field = rest_framework_field(**kwargs)
|
||||
|
||||
return convert_serializer_field(
|
||||
field, is_input=is_input, convert_choices_to_enum=convert_choices_to_enum
|
||||
)
|
||||
|
||||
|
||||
def assert_conversion(rest_framework_field, graphene_field, **kwargs):
|
||||
graphene_type = _get_type(
|
||||
rest_framework_field, help_text="Custom Help Text", **kwargs
|
||||
)
|
||||
assert isinstance(graphene_type, graphene_field)
|
||||
|
||||
graphene_type_required = _get_type(
|
||||
rest_framework_field, help_text="Custom Help Text", required=True, **kwargs
|
||||
)
|
||||
assert isinstance(graphene_type_required, graphene_field)
|
||||
|
||||
return graphene_type
|
||||
|
||||
|
||||
def test_should_unknown_rest_framework_field_raise_exception():
|
||||
with raises(Exception) as excinfo:
|
||||
convert_serializer_field(None)
|
||||
assert "Don't know how to convert the serializer field" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_should_char_convert_string():
|
||||
assert_conversion(serializers.CharField, graphene.String)
|
||||
|
||||
|
||||
def test_should_email_convert_string():
|
||||
assert_conversion(serializers.EmailField, graphene.String)
|
||||
|
||||
|
||||
def test_should_slug_convert_string():
|
||||
assert_conversion(serializers.SlugField, graphene.String)
|
||||
|
||||
|
||||
def test_should_url_convert_string():
|
||||
assert_conversion(serializers.URLField, graphene.String)
|
||||
|
||||
|
||||
def test_should_choice_convert_enum():
|
||||
field = assert_conversion(
|
||||
serializers.ChoiceField,
|
||||
graphene.Enum,
|
||||
choices=[("h", "Hello"), ("w", "World")],
|
||||
source="word",
|
||||
)
|
||||
assert field._meta.enum.__members__["H"].value == "h"
|
||||
assert field._meta.enum.__members__["H"].description == "Hello"
|
||||
assert field._meta.enum.__members__["W"].value == "w"
|
||||
assert field._meta.enum.__members__["W"].description == "World"
|
||||
|
||||
|
||||
def test_should_choice_convert_string_if_enum_disabled():
|
||||
assert_conversion(
|
||||
serializers.ChoiceField,
|
||||
graphene.String,
|
||||
choices=[("h", "Hello"), ("w", "World")],
|
||||
source="word",
|
||||
convert_choices_to_enum=False,
|
||||
)
|
||||
|
||||
|
||||
def test_should_base_field_convert_string():
|
||||
assert_conversion(serializers.Field, graphene.String)
|
||||
|
||||
|
||||
def test_should_regex_convert_string():
|
||||
assert_conversion(serializers.RegexField, graphene.String, regex="[0-9]+")
|
||||
|
||||
|
||||
def test_should_uuid_convert_string():
|
||||
if hasattr(serializers, "UUIDField"):
|
||||
assert_conversion(serializers.UUIDField, graphene.String)
|
||||
|
||||
|
||||
def test_should_model_convert_field():
|
||||
class MyModelSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = None
|
||||
fields = "__all__"
|
||||
|
||||
assert_conversion(MyModelSerializer, graphene.Field, is_input=False)
|
||||
|
||||
|
||||
def test_should_date_time_convert_datetime():
|
||||
assert_conversion(serializers.DateTimeField, graphene.types.datetime.DateTime)
|
||||
|
||||
|
||||
def test_should_date_convert_date():
|
||||
assert_conversion(serializers.DateField, graphene.types.datetime.Date)
|
||||
|
||||
|
||||
def test_should_time_convert_time():
|
||||
assert_conversion(serializers.TimeField, graphene.types.datetime.Time)
|
||||
|
||||
|
||||
def test_should_integer_convert_int():
|
||||
assert_conversion(serializers.IntegerField, graphene.Int)
|
||||
|
||||
|
||||
def test_should_boolean_convert_boolean():
|
||||
assert_conversion(serializers.BooleanField, graphene.Boolean)
|
||||
|
||||
|
||||
def test_should_float_convert_float():
|
||||
assert_conversion(serializers.FloatField, graphene.Float)
|
||||
|
||||
|
||||
def test_should_decimal_convert_float():
|
||||
assert_conversion(
|
||||
serializers.DecimalField, graphene.Float, max_digits=4, decimal_places=2
|
||||
)
|
||||
|
||||
|
||||
def test_should_list_convert_to_list():
|
||||
class StringListField(serializers.ListField):
|
||||
child = serializers.CharField()
|
||||
|
||||
field_a = assert_conversion(
|
||||
serializers.ListField,
|
||||
graphene.List,
|
||||
child=serializers.IntegerField(min_value=0, max_value=100),
|
||||
)
|
||||
|
||||
assert field_a.of_type == graphene.Int
|
||||
|
||||
field_b = assert_conversion(StringListField, graphene.List)
|
||||
|
||||
assert field_b.of_type == graphene.String
|
||||
|
||||
|
||||
def test_should_list_serializer_convert_to_list():
|
||||
class FooModel(models.Model):
|
||||
pass
|
||||
|
||||
class ChildSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = FooModel
|
||||
fields = "__all__"
|
||||
|
||||
class ParentSerializer(serializers.ModelSerializer):
|
||||
child = ChildSerializer(many=True)
|
||||
|
||||
class Meta:
|
||||
model = FooModel
|
||||
fields = "__all__"
|
||||
|
||||
converted_type = convert_serializer_field(
|
||||
ParentSerializer().get_fields()["child"], is_input=True
|
||||
)
|
||||
assert isinstance(converted_type, graphene.List)
|
||||
|
||||
converted_type = convert_serializer_field(
|
||||
ParentSerializer().get_fields()["child"], is_input=False
|
||||
)
|
||||
assert isinstance(converted_type, graphene.List)
|
||||
assert converted_type.of_type is None
|
||||
|
||||
|
||||
def test_should_dict_convert_dict():
|
||||
assert_conversion(serializers.DictField, DictType)
|
||||
|
||||
|
||||
def test_should_duration_convert_string():
|
||||
assert_conversion(serializers.DurationField, graphene.String)
|
||||
|
||||
|
||||
def test_should_file_convert_string():
|
||||
assert_conversion(serializers.FileField, graphene.String)
|
||||
|
||||
|
||||
def test_should_filepath_convert_string():
|
||||
assert_conversion(serializers.FilePathField, graphene.Enum, path="/")
|
||||
|
||||
|
||||
def test_should_ip_convert_string():
|
||||
assert_conversion(serializers.IPAddressField, graphene.String)
|
||||
|
||||
|
||||
def test_should_image_convert_string():
|
||||
assert_conversion(serializers.ImageField, graphene.String)
|
||||
|
||||
|
||||
def test_should_json_convert_jsonstring():
|
||||
assert_conversion(serializers.JSONField, graphene.types.json.JSONString)
|
||||
|
||||
|
||||
def test_should_multiplechoicefield_convert_to_list_of_enum():
|
||||
field = assert_conversion(
|
||||
serializers.MultipleChoiceField, graphene.List, choices=[1, 2, 3]
|
||||
)
|
||||
|
||||
assert issubclass(field.of_type, graphene.Enum)
|
|
@ -1,66 +0,0 @@
|
|||
from django.db import models
|
||||
from rest_framework import serializers
|
||||
|
||||
import graphene
|
||||
from graphene import Schema
|
||||
from graphene_django import DjangoObjectType
|
||||
from graphene_django.rest_framework.mutation import SerializerMutation
|
||||
|
||||
|
||||
class MyFakeChildModel(models.Model):
|
||||
name = models.CharField(max_length=50)
|
||||
created = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
|
||||
class MyFakeParentModel(models.Model):
|
||||
name = models.CharField(max_length=50)
|
||||
created = models.DateTimeField(auto_now_add=True)
|
||||
child1 = models.OneToOneField(
|
||||
MyFakeChildModel, related_name="parent1", on_delete=models.CASCADE
|
||||
)
|
||||
child2 = models.OneToOneField(
|
||||
MyFakeChildModel, related_name="parent2", on_delete=models.CASCADE
|
||||
)
|
||||
|
||||
|
||||
class ParentType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = MyFakeParentModel
|
||||
interfaces = (graphene.relay.Node,)
|
||||
fields = "__all__"
|
||||
|
||||
|
||||
class ChildType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = MyFakeChildModel
|
||||
interfaces = (graphene.relay.Node,)
|
||||
fields = "__all__"
|
||||
|
||||
|
||||
class MyModelChildSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = MyFakeChildModel
|
||||
fields = "__all__"
|
||||
|
||||
|
||||
class MyModelParentSerializer(serializers.ModelSerializer):
|
||||
child1 = MyModelChildSerializer()
|
||||
child2 = MyModelChildSerializer()
|
||||
|
||||
class Meta:
|
||||
model = MyFakeParentModel
|
||||
fields = "__all__"
|
||||
|
||||
|
||||
class MyParentModelMutation(SerializerMutation):
|
||||
class Meta:
|
||||
serializer_class = MyModelParentSerializer
|
||||
|
||||
|
||||
class Mutation(graphene.ObjectType):
|
||||
createParentWithChild = MyParentModelMutation.Field()
|
||||
|
||||
|
||||
def test_create_schema():
|
||||
schema = Schema(mutation=Mutation, types=[ParentType, ChildType])
|
||||
assert schema
|
|
@ -1,286 +0,0 @@
|
|||
import datetime
|
||||
|
||||
from py.test import raises
|
||||
from rest_framework import serializers
|
||||
|
||||
from graphene import Field, ResolveInfo
|
||||
from graphene.types.inputobjecttype import InputObjectType
|
||||
|
||||
from ...types import DjangoObjectType
|
||||
from ..models import MyFakeModel, MyFakeModelWithDate, MyFakeModelWithPassword
|
||||
from ..mutation import SerializerMutation
|
||||
|
||||
|
||||
def mock_info():
|
||||
return ResolveInfo(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
path=None,
|
||||
schema=None,
|
||||
fragments=None,
|
||||
root_value=None,
|
||||
operation=None,
|
||||
variable_values=None,
|
||||
context=None,
|
||||
is_awaitable=None,
|
||||
)
|
||||
|
||||
|
||||
class MyModelSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = MyFakeModel
|
||||
fields = "__all__"
|
||||
|
||||
|
||||
class MyModelSerializerWithMethod(serializers.ModelSerializer):
|
||||
days_since_last_edit = serializers.SerializerMethodField()
|
||||
|
||||
class Meta:
|
||||
model = MyFakeModelWithDate
|
||||
fields = "__all__"
|
||||
|
||||
def get_days_since_last_edit(self, obj):
|
||||
now = datetime.date(2020, 1, 8)
|
||||
return (now - obj.last_edited).days
|
||||
|
||||
|
||||
class MyModelMutation(SerializerMutation):
|
||||
class Meta:
|
||||
serializer_class = MyModelSerializer
|
||||
|
||||
|
||||
class MySerializer(serializers.Serializer):
|
||||
text = serializers.CharField()
|
||||
model = MyModelSerializer()
|
||||
|
||||
def create(self, validated_data):
|
||||
return validated_data
|
||||
|
||||
|
||||
def test_needs_serializer_class():
|
||||
with raises(Exception) as exc:
|
||||
|
||||
class MyMutation(SerializerMutation):
|
||||
pass
|
||||
|
||||
assert str(exc.value) == "serializer_class is required for the SerializerMutation"
|
||||
|
||||
|
||||
def test_has_fields():
|
||||
class MyMutation(SerializerMutation):
|
||||
class Meta:
|
||||
serializer_class = MySerializer
|
||||
|
||||
assert "text" in MyMutation._meta.fields
|
||||
assert "model" in MyMutation._meta.fields
|
||||
assert "errors" in MyMutation._meta.fields
|
||||
|
||||
|
||||
def test_has_input_fields():
|
||||
class MyMutation(SerializerMutation):
|
||||
class Meta:
|
||||
serializer_class = MySerializer
|
||||
|
||||
assert "text" in MyMutation.Input._meta.fields
|
||||
assert "model" in MyMutation.Input._meta.fields
|
||||
|
||||
|
||||
def test_exclude_fields():
|
||||
class MyMutation(SerializerMutation):
|
||||
class Meta:
|
||||
serializer_class = MyModelSerializer
|
||||
exclude_fields = ["created"]
|
||||
|
||||
assert "cool_name" in MyMutation._meta.fields
|
||||
assert "created" not in MyMutation._meta.fields
|
||||
assert "errors" in MyMutation._meta.fields
|
||||
assert "cool_name" in MyMutation.Input._meta.fields
|
||||
assert "created" not in MyMutation.Input._meta.fields
|
||||
|
||||
|
||||
def test_write_only_field():
|
||||
class WriteOnlyFieldModelSerializer(serializers.ModelSerializer):
|
||||
password = serializers.CharField(write_only=True)
|
||||
|
||||
class Meta:
|
||||
model = MyFakeModelWithPassword
|
||||
fields = ["cool_name", "password"]
|
||||
|
||||
class MyMutation(SerializerMutation):
|
||||
class Meta:
|
||||
serializer_class = WriteOnlyFieldModelSerializer
|
||||
|
||||
result = MyMutation.mutate_and_get_payload(
|
||||
None, mock_info(), **{"cool_name": "New Narf", "password": "admin"}
|
||||
)
|
||||
|
||||
assert hasattr(result, "cool_name")
|
||||
assert not hasattr(
|
||||
result, "password"
|
||||
), "'password' is write_only field and shouldn't be visible"
|
||||
|
||||
|
||||
def test_write_only_field_using_extra_kwargs():
|
||||
class WriteOnlyFieldModelSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = MyFakeModelWithPassword
|
||||
fields = ["cool_name", "password"]
|
||||
extra_kwargs = {"password": {"write_only": True}}
|
||||
|
||||
class MyMutation(SerializerMutation):
|
||||
class Meta:
|
||||
serializer_class = WriteOnlyFieldModelSerializer
|
||||
|
||||
result = MyMutation.mutate_and_get_payload(
|
||||
None, mock_info(), **{"cool_name": "New Narf", "password": "admin"}
|
||||
)
|
||||
|
||||
assert hasattr(result, "cool_name")
|
||||
assert not hasattr(
|
||||
result, "password"
|
||||
), "'password' is write_only field and shouldn't be visible"
|
||||
|
||||
|
||||
def test_read_only_fields():
|
||||
class ReadOnlyFieldModelSerializer(serializers.ModelSerializer):
|
||||
id = serializers.CharField(read_only=True)
|
||||
cool_name = serializers.CharField(read_only=True)
|
||||
|
||||
class Meta:
|
||||
model = MyFakeModelWithPassword
|
||||
lookup_field = "id"
|
||||
fields = ["id", "cool_name", "password"]
|
||||
|
||||
class MyMutation(SerializerMutation):
|
||||
class Meta:
|
||||
serializer_class = ReadOnlyFieldModelSerializer
|
||||
|
||||
assert "password" in MyMutation.Input._meta.fields
|
||||
assert "id" in MyMutation.Input._meta.fields
|
||||
assert (
|
||||
"cool_name" not in MyMutation.Input._meta.fields
|
||||
), "'cool_name' is read_only field and shouldn't be on arguments"
|
||||
|
||||
|
||||
def test_nested_model():
|
||||
class MyFakeModelGrapheneType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = MyFakeModel
|
||||
fields = "__all__"
|
||||
|
||||
class MyMutation(SerializerMutation):
|
||||
class Meta:
|
||||
serializer_class = MySerializer
|
||||
|
||||
model_field = MyMutation._meta.fields["model"]
|
||||
assert isinstance(model_field, Field)
|
||||
assert model_field.type == MyFakeModelGrapheneType
|
||||
|
||||
model_input = MyMutation.Input._meta.fields["model"]
|
||||
model_input_type = model_input._type.of_type
|
||||
assert issubclass(model_input_type, InputObjectType)
|
||||
assert "cool_name" in model_input_type._meta.fields
|
||||
assert "created" in model_input_type._meta.fields
|
||||
|
||||
|
||||
def test_mutate_and_get_payload_success():
|
||||
class MyMutation(SerializerMutation):
|
||||
class Meta:
|
||||
serializer_class = MySerializer
|
||||
|
||||
result = MyMutation.mutate_and_get_payload(
|
||||
None, mock_info(), **{"text": "value", "model": {"cool_name": "other_value"}}
|
||||
)
|
||||
assert result.errors is None
|
||||
|
||||
|
||||
def test_model_add_mutate_and_get_payload_success():
|
||||
result = MyModelMutation.mutate_and_get_payload(
|
||||
None, mock_info(), **{"cool_name": "Narf"}
|
||||
)
|
||||
assert result.errors is None
|
||||
assert result.cool_name == "Narf"
|
||||
assert isinstance(result.created, datetime.datetime)
|
||||
|
||||
|
||||
def test_model_update_mutate_and_get_payload_success():
|
||||
instance = MyFakeModel.objects.create(cool_name="Narf")
|
||||
result = MyModelMutation.mutate_and_get_payload(
|
||||
None, mock_info(), **{"id": instance.id, "cool_name": "New Narf"}
|
||||
)
|
||||
assert result.errors is None
|
||||
assert result.cool_name == "New Narf"
|
||||
|
||||
|
||||
def test_model_partial_update_mutate_and_get_payload_success():
|
||||
instance = MyFakeModel.objects.create(cool_name="Narf")
|
||||
result = MyModelMutation.mutate_and_get_payload(
|
||||
None, mock_info(), **{"id": instance.id}
|
||||
)
|
||||
assert result.errors is None
|
||||
assert result.cool_name == "Narf"
|
||||
|
||||
|
||||
def test_model_invalid_update_mutate_and_get_payload_success():
|
||||
class InvalidModelMutation(SerializerMutation):
|
||||
class Meta:
|
||||
serializer_class = MyModelSerializer
|
||||
model_operations = ["update"]
|
||||
|
||||
with raises(Exception) as exc:
|
||||
result = InvalidModelMutation.mutate_and_get_payload(
|
||||
None, mock_info(), **{"cool_name": "Narf"}
|
||||
)
|
||||
|
||||
assert '"id" required' in str(exc.value)
|
||||
|
||||
|
||||
def test_perform_mutate_success():
|
||||
class MyMethodMutation(SerializerMutation):
|
||||
class Meta:
|
||||
serializer_class = MyModelSerializerWithMethod
|
||||
|
||||
result = MyMethodMutation.mutate_and_get_payload(
|
||||
None,
|
||||
mock_info(),
|
||||
**{"cool_name": "Narf", "last_edited": datetime.date(2020, 1, 4)}
|
||||
)
|
||||
|
||||
assert result.errors is None
|
||||
assert result.cool_name == "Narf"
|
||||
assert result.days_since_last_edit == 4
|
||||
|
||||
|
||||
def test_mutate_and_get_payload_error():
|
||||
class MyMutation(SerializerMutation):
|
||||
class Meta:
|
||||
serializer_class = MySerializer
|
||||
|
||||
# missing required fields
|
||||
result = MyMutation.mutate_and_get_payload(None, mock_info(), **{})
|
||||
assert len(result.errors) > 0
|
||||
|
||||
|
||||
def test_model_mutate_and_get_payload_error():
|
||||
# missing required fields
|
||||
result = MyModelMutation.mutate_and_get_payload(None, mock_info(), **{})
|
||||
assert len(result.errors) > 0
|
||||
|
||||
|
||||
def test_mutation_error_camelcased(graphene_settings):
|
||||
graphene_settings.CAMELCASE_ERRORS = True
|
||||
result = MyModelMutation.mutate_and_get_payload(None, mock_info(), **{})
|
||||
assert result.errors[0].field == "coolName"
|
||||
|
||||
|
||||
def test_invalid_serializer_operations():
|
||||
with raises(Exception) as exc:
|
||||
|
||||
class MyModelMutation(SerializerMutation):
|
||||
class Meta:
|
||||
serializer_class = MyModelSerializer
|
||||
model_operations = ["Add"]
|
||||
|
||||
assert "model_operations" in str(exc.value)
|
|
@ -1,7 +0,0 @@
|
|||
import graphene
|
||||
from graphene.types.unmountedtype import UnmountedType
|
||||
|
||||
|
||||
class DictType(UnmountedType):
|
||||
key = graphene.String()
|
||||
value = graphene.String()
|
|
@ -1,140 +0,0 @@
|
|||
"""
|
||||
Settings for Graphene are all namespaced in the GRAPHENE setting.
|
||||
For example your project's `settings.py` file might look like this:
|
||||
GRAPHENE = {
|
||||
'SCHEMA': 'my_app.schema.schema'
|
||||
'MIDDLEWARE': (
|
||||
'graphene_django.debug.DjangoDebugMiddleware',
|
||||
)
|
||||
}
|
||||
This module provides the `graphene_settings` object, that is used to access
|
||||
Graphene settings, checking for user settings first, then falling
|
||||
back to the defaults.
|
||||
"""
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from django.conf import settings
|
||||
from django.test.signals import setting_changed
|
||||
|
||||
import importlib # Available in Python 3.1+
|
||||
|
||||
|
||||
# Copied shamelessly from Django REST Framework
|
||||
|
||||
DEFAULTS = {
|
||||
"SCHEMA": None,
|
||||
"SCHEMA_OUTPUT": "schema.json",
|
||||
"SCHEMA_INDENT": 2,
|
||||
"MIDDLEWARE": (),
|
||||
# Set to True if the connection fields must have
|
||||
# either the first or last argument
|
||||
"RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST": False,
|
||||
# Max items returned in ConnectionFields / FilterConnectionFields
|
||||
"RELAY_CONNECTION_MAX_LIMIT": 100,
|
||||
"CAMELCASE_ERRORS": True,
|
||||
# Set to True to enable v2 naming convention for choice field Enum's
|
||||
"DJANGO_CHOICE_FIELD_ENUM_V2_NAMING": False,
|
||||
"DJANGO_CHOICE_FIELD_ENUM_CUSTOM_NAME": None,
|
||||
# Use a separate path for handling subscriptions.
|
||||
"SUBSCRIPTION_PATH": None,
|
||||
# By default GraphiQL headers editor tab is enabled, set to False to hide it
|
||||
# This sets headerEditorEnabled GraphiQL option, for details go to
|
||||
# https://github.com/graphql/graphiql/tree/main/packages/graphiql#options
|
||||
"GRAPHIQL_HEADER_EDITOR_ENABLED": True,
|
||||
"ATOMIC_MUTATIONS": False,
|
||||
}
|
||||
|
||||
if settings.DEBUG:
|
||||
DEFAULTS["MIDDLEWARE"] += ("graphene_django.debug.DjangoDebugMiddleware",)
|
||||
|
||||
# List of settings that may be in string import notation.
|
||||
IMPORT_STRINGS = ("MIDDLEWARE", "SCHEMA")
|
||||
|
||||
|
||||
def perform_import(val, setting_name):
|
||||
"""
|
||||
If the given setting is a string import notation,
|
||||
then perform the necessary import or imports.
|
||||
"""
|
||||
if val is None:
|
||||
return None
|
||||
elif isinstance(val, str):
|
||||
return import_from_string(val, setting_name)
|
||||
elif isinstance(val, (list, tuple)):
|
||||
return [import_from_string(item, setting_name) for item in val]
|
||||
return val
|
||||
|
||||
|
||||
def import_from_string(val, setting_name):
|
||||
"""
|
||||
Attempt to import a class from a string representation.
|
||||
"""
|
||||
try:
|
||||
# Nod to tastypie's use of importlib.
|
||||
parts = val.split(".")
|
||||
module_path, class_name = ".".join(parts[:-1]), parts[-1]
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
except (ImportError, AttributeError) as e:
|
||||
msg = "Could not import '%s' for Graphene setting '%s'. %s: %s." % (
|
||||
val,
|
||||
setting_name,
|
||||
e.__class__.__name__,
|
||||
e,
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
|
||||
class GrapheneSettings(object):
|
||||
"""
|
||||
A settings object, that allows API settings to be accessed as properties.
|
||||
For example:
|
||||
from graphene_django.settings import settings
|
||||
print(settings.SCHEMA)
|
||||
Any setting with string import paths will be automatically resolved
|
||||
and return the class, rather than the string literal.
|
||||
"""
|
||||
|
||||
def __init__(self, user_settings=None, defaults=None, import_strings=None):
|
||||
if user_settings:
|
||||
self._user_settings = user_settings
|
||||
self.defaults = defaults or DEFAULTS
|
||||
self.import_strings = import_strings or IMPORT_STRINGS
|
||||
|
||||
@property
|
||||
def user_settings(self):
|
||||
if not hasattr(self, "_user_settings"):
|
||||
self._user_settings = getattr(settings, "GRAPHENE", {})
|
||||
return self._user_settings
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr not in self.defaults:
|
||||
raise AttributeError("Invalid Graphene setting: '%s'" % attr)
|
||||
|
||||
try:
|
||||
# Check if present in user settings
|
||||
val = self.user_settings[attr]
|
||||
except KeyError:
|
||||
# Fall back to defaults
|
||||
val = self.defaults[attr]
|
||||
|
||||
# Coerce import strings into classes
|
||||
if attr in self.import_strings:
|
||||
val = perform_import(val, attr)
|
||||
|
||||
# Cache the result
|
||||
setattr(self, attr, val)
|
||||
return val
|
||||
|
||||
|
||||
graphene_settings = GrapheneSettings(None, DEFAULTS, IMPORT_STRINGS)
|
||||
|
||||
|
||||
def reload_graphene_settings(*args, **kwargs):
|
||||
global graphene_settings
|
||||
setting, value = kwargs["setting"], kwargs["value"]
|
||||
if setting == "GRAPHENE":
|
||||
graphene_settings = GrapheneSettings(value, DEFAULTS, IMPORT_STRINGS)
|
||||
|
||||
|
||||
setting_changed.connect(reload_graphene_settings)
|
|
@ -1,203 +0,0 @@
|
|||
(function (
|
||||
document,
|
||||
|
||||
GRAPHENE_SETTINGS,
|
||||
GraphiQL,
|
||||
React,
|
||||
ReactDOM,
|
||||
SubscriptionsTransportWs,
|
||||
fetch,
|
||||
history,
|
||||
location,
|
||||
) {
|
||||
// Parse the cookie value for a CSRF token
|
||||
var csrftoken;
|
||||
var cookies = ("; " + document.cookie).split("; csrftoken=");
|
||||
if (cookies.length == 2) {
|
||||
csrftoken = cookies.pop().split(";").shift();
|
||||
} else {
|
||||
csrftoken = document.querySelector("[name=csrfmiddlewaretoken]").value;
|
||||
}
|
||||
|
||||
// Collect the URL parameters
|
||||
var parameters = {};
|
||||
location.hash
|
||||
.substr(1)
|
||||
.split("&")
|
||||
.forEach(function (entry) {
|
||||
var eq = entry.indexOf("=");
|
||||
if (eq >= 0) {
|
||||
parameters[decodeURIComponent(entry.slice(0, eq))] = decodeURIComponent(
|
||||
entry.slice(eq + 1),
|
||||
);
|
||||
}
|
||||
});
|
||||
// Produce a Location fragment string from a parameter object.
|
||||
function locationQuery(params) {
|
||||
return (
|
||||
"#" +
|
||||
Object.keys(params)
|
||||
.map(function (key) {
|
||||
return (
|
||||
encodeURIComponent(key) + "=" + encodeURIComponent(params[key])
|
||||
);
|
||||
})
|
||||
.join("&")
|
||||
);
|
||||
}
|
||||
// Derive a fetch URL from the current URL, sans the GraphQL parameters.
|
||||
var graphqlParamNames = {
|
||||
query: true,
|
||||
variables: true,
|
||||
operationName: true,
|
||||
};
|
||||
var otherParams = {};
|
||||
for (var k in parameters) {
|
||||
if (parameters.hasOwnProperty(k) && graphqlParamNames[k] !== true) {
|
||||
otherParams[k] = parameters[k];
|
||||
}
|
||||
}
|
||||
|
||||
var fetchURL = locationQuery(otherParams);
|
||||
|
||||
// Defines a GraphQL fetcher using the fetch API.
|
||||
function httpClient(graphQLParams, opts) {
|
||||
if (typeof opts === 'undefined') {
|
||||
opts = {};
|
||||
}
|
||||
var headers = opts.headers || {};
|
||||
headers['Accept'] = headers['Accept'] || 'application/json';
|
||||
headers['Content-Type'] = headers['Content-Type'] || 'application/json';
|
||||
if (csrftoken) {
|
||||
headers['X-CSRFToken'] = csrftoken
|
||||
}
|
||||
return fetch(fetchURL, {
|
||||
method: "post",
|
||||
headers: headers,
|
||||
body: JSON.stringify(graphQLParams),
|
||||
credentials: "include",
|
||||
})
|
||||
.then(function (response) {
|
||||
return response.text();
|
||||
})
|
||||
.then(function (responseBody) {
|
||||
try {
|
||||
return JSON.parse(responseBody);
|
||||
} catch (error) {
|
||||
return responseBody;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Derive the subscription URL. If the SUBSCRIPTION_URL setting is specified, uses that value. Otherwise
|
||||
// assumes the current window location with an appropriate websocket protocol.
|
||||
var subscribeURL =
|
||||
location.origin.replace(/^http/, "ws") +
|
||||
(GRAPHENE_SETTINGS.subscriptionPath || location.pathname);
|
||||
|
||||
// Create a subscription client.
|
||||
var subscriptionClient = new SubscriptionsTransportWs.SubscriptionClient(
|
||||
subscribeURL,
|
||||
{
|
||||
// Reconnect after any interruptions.
|
||||
reconnect: true,
|
||||
// Delay socket initialization until the first subscription is started.
|
||||
lazy: true,
|
||||
},
|
||||
);
|
||||
|
||||
// Keep a reference to the currently-active subscription, if available.
|
||||
var activeSubscription = null;
|
||||
|
||||
// Define a GraphQL fetcher that can intelligently route queries based on the operation type.
|
||||
function graphQLFetcher(graphQLParams, opts) {
|
||||
var operationType = getOperationType(graphQLParams);
|
||||
|
||||
// If we're about to execute a new operation, and we have an active subscription,
|
||||
// unsubscribe before continuing.
|
||||
if (activeSubscription) {
|
||||
activeSubscription.unsubscribe();
|
||||
activeSubscription = null;
|
||||
}
|
||||
|
||||
if (operationType === "subscription") {
|
||||
return {
|
||||
subscribe: function (observer) {
|
||||
activeSubscription = subscriptionClient;
|
||||
return subscriptionClient.request(graphQLParams, opts).subscribe(observer);
|
||||
},
|
||||
};
|
||||
} else {
|
||||
return httpClient(graphQLParams, opts);
|
||||
}
|
||||
}
|
||||
|
||||
// Determine the type of operation being executed for a given set of GraphQL parameters.
|
||||
function getOperationType(graphQLParams) {
|
||||
// Run a regex against the query to determine the operation type (query, mutation, subscription).
|
||||
var operationRegex = new RegExp(
|
||||
// Look for lines that start with an operation keyword, ignoring whitespace.
|
||||
"^\\s*(query|mutation|subscription)\\s*" +
|
||||
// The operation keyword should be followed by whitespace and the operationName in the GraphQL parameters (if available).
|
||||
(graphQLParams.operationName ? ("\\s+" + graphQLParams.operationName) : "") +
|
||||
// The line should eventually encounter an opening curly brace.
|
||||
"[^\\{]*\\{",
|
||||
// Enable multiline matching.
|
||||
"m",
|
||||
);
|
||||
var match = operationRegex.exec(graphQLParams.query);
|
||||
if (!match) {
|
||||
return "query";
|
||||
}
|
||||
|
||||
return match[1];
|
||||
}
|
||||
|
||||
// When the query and variables string is edited, update the URL bar so
|
||||
// that it can be easily shared.
|
||||
function onEditQuery(newQuery) {
|
||||
parameters.query = newQuery;
|
||||
updateURL();
|
||||
}
|
||||
function onEditVariables(newVariables) {
|
||||
parameters.variables = newVariables;
|
||||
updateURL();
|
||||
}
|
||||
function onEditOperationName(newOperationName) {
|
||||
parameters.operationName = newOperationName;
|
||||
updateURL();
|
||||
}
|
||||
function updateURL() {
|
||||
history.replaceState(null, null, locationQuery(parameters));
|
||||
}
|
||||
var options = {
|
||||
fetcher: graphQLFetcher,
|
||||
onEditQuery: onEditQuery,
|
||||
onEditVariables: onEditVariables,
|
||||
onEditOperationName: onEditOperationName,
|
||||
headerEditorEnabled: GRAPHENE_SETTINGS.graphiqlHeaderEditorEnabled,
|
||||
query: parameters.query,
|
||||
};
|
||||
if (parameters.variables) {
|
||||
options.variables = parameters.variables;
|
||||
}
|
||||
if (parameters.operation_name) {
|
||||
options.operationName = parameters.operation_name;
|
||||
}
|
||||
// Render <GraphiQL /> into the body.
|
||||
ReactDOM.render(
|
||||
React.createElement(GraphiQL, options),
|
||||
document.getElementById("editor"),
|
||||
);
|
||||
})(
|
||||
document,
|
||||
|
||||
window.GRAPHENE_SETTINGS,
|
||||
window.GraphiQL,
|
||||
window.React,
|
||||
window.ReactDOM,
|
||||
window.SubscriptionsTransportWs,
|
||||
window.fetch,
|
||||
window.history,
|
||||
window.location,
|
||||
);
|
|
@ -1,53 +0,0 @@
|
|||
<!--
|
||||
The request to this GraphQL server provided the header "Accept: text/html"
|
||||
and as a result has been presented GraphiQL - an in-browser IDE for
|
||||
exploring GraphQL.
|
||||
If you wish to receive JSON, provide the header "Accept: application/json" or
|
||||
add "&raw" to the end of the URL within a browser.
|
||||
-->
|
||||
{% load static %}
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<style>
|
||||
html, body, #editor {
|
||||
height: 100%;
|
||||
margin: 0;
|
||||
overflow: hidden;
|
||||
width: 100%;
|
||||
}
|
||||
</style>
|
||||
<link href="https://cdn.jsdelivr.net/npm/graphiql@{{graphiql_version}}/graphiql.min.css"
|
||||
integrity="{{graphiql_css_sri}}"
|
||||
rel="stylesheet"
|
||||
crossorigin="anonymous" />
|
||||
<script src="https://cdn.jsdelivr.net/npm/whatwg-fetch@{{whatwg_fetch_version}}/dist/fetch.umd.js"
|
||||
integrity="{{whatwg_fetch_sri}}"
|
||||
crossorigin="anonymous"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/react@{{react_version}}/umd/react.production.min.js"
|
||||
integrity="{{react_sri}}"
|
||||
crossorigin="anonymous"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/react-dom@{{react_version}}/umd/react-dom.production.min.js"
|
||||
integrity="{{react_dom_sri}}"
|
||||
crossorigin="anonymous"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/graphiql@{{graphiql_version}}/graphiql.min.js"
|
||||
integrity="{{graphiql_sri}}"
|
||||
crossorigin="anonymous"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/subscriptions-transport-ws@{{subscriptions_transport_ws_version}}/browser/client.js"
|
||||
integrity="{{subscriptions_transport_ws_sri}}"
|
||||
crossorigin="anonymous"></script>
|
||||
</head>
|
||||
<body>
|
||||
<div id="editor"></div>
|
||||
{% csrf_token %}
|
||||
<script type="application/javascript">
|
||||
window.GRAPHENE_SETTINGS = {
|
||||
{% if subscription_path %}
|
||||
subscriptionPath: "{{subscription_path}}",
|
||||
{% endif %}
|
||||
graphiqlHeaderEditorEnabled: {{ graphiql_header_editor_enabled|yesno:"true,false" }},
|
||||
};
|
||||
</script>
|
||||
<script src="{% static 'graphene_django/graphiql.js' %}"></script>
|
||||
</body>
|
||||
</html>
|
|
@ -1,16 +0,0 @@
|
|||
from django import forms
|
||||
from django.core.exceptions import ValidationError
|
||||
|
||||
from .models import Pet
|
||||
|
||||
|
||||
class PetForm(forms.ModelForm):
|
||||
class Meta:
|
||||
model = Pet
|
||||
fields = "__all__"
|
||||
|
||||
def clean_age(self):
|
||||
age = self.cleaned_data["age"]
|
||||
if age >= 99:
|
||||
raise ValidationError("Too old")
|
||||
return age
|
|
@ -1,44 +0,0 @@
|
|||
# https://github.com/graphql-python/graphene-django/issues/520
|
||||
|
||||
import datetime
|
||||
|
||||
from django import forms
|
||||
|
||||
import graphene
|
||||
|
||||
from graphene import Field, ResolveInfo
|
||||
from graphene.types.inputobjecttype import InputObjectType
|
||||
from py.test import raises
|
||||
from py.test import mark
|
||||
from rest_framework import serializers
|
||||
|
||||
from ...types import DjangoObjectType
|
||||
from ...rest_framework.models import MyFakeModel
|
||||
from ...rest_framework.mutation import SerializerMutation
|
||||
from ...forms.mutation import DjangoFormMutation
|
||||
|
||||
|
||||
class MyModelSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = MyFakeModel
|
||||
fields = "__all__"
|
||||
|
||||
|
||||
class MyForm(forms.Form):
|
||||
text = forms.CharField()
|
||||
|
||||
|
||||
def test_can_use_form_and_serializer_mutations():
|
||||
class MyMutation(SerializerMutation):
|
||||
class Meta:
|
||||
serializer_class = MyModelSerializer
|
||||
|
||||
class MyFormMutation(DjangoFormMutation):
|
||||
class Meta:
|
||||
form_class = MyForm
|
||||
|
||||
class Mutation(graphene.ObjectType):
|
||||
my_mutation = MyMutation.Field()
|
||||
my_form_mutation = MyFormMutation.Field()
|
||||
|
||||
graphene.Schema(mutation=Mutation)
|
|
@ -1,119 +0,0 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
from django.db import models
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
CHOICES = ((1, "this"), (2, _("that")))
|
||||
|
||||
|
||||
class Person(models.Model):
|
||||
name = models.CharField(max_length=30)
|
||||
|
||||
|
||||
class Pet(models.Model):
|
||||
name = models.CharField(max_length=30)
|
||||
age = models.PositiveIntegerField()
|
||||
|
||||
|
||||
class FilmDetails(models.Model):
|
||||
location = models.CharField(max_length=30)
|
||||
film = models.OneToOneField(
|
||||
"Film", on_delete=models.CASCADE, related_name="details"
|
||||
)
|
||||
|
||||
|
||||
class Film(models.Model):
|
||||
genre = models.CharField(
|
||||
max_length=2,
|
||||
help_text="Genre",
|
||||
choices=[("do", "Documentary"), ("ac", "Action"), ("ot", "Other")],
|
||||
default="ot",
|
||||
)
|
||||
reporters = models.ManyToManyField("Reporter", related_name="films")
|
||||
|
||||
|
||||
class DoeReporterManager(models.Manager):
|
||||
def get_queryset(self):
|
||||
return super(DoeReporterManager, self).get_queryset().filter(last_name="Doe")
|
||||
|
||||
|
||||
class Reporter(models.Model):
|
||||
first_name = models.CharField(max_length=30)
|
||||
last_name = models.CharField(max_length=30)
|
||||
email = models.EmailField()
|
||||
pets = models.ManyToManyField("self")
|
||||
a_choice = models.CharField(max_length=30, choices=CHOICES, blank=True)
|
||||
objects = models.Manager()
|
||||
doe_objects = DoeReporterManager()
|
||||
|
||||
reporter_type = models.IntegerField(
|
||||
"Reporter Type",
|
||||
null=True,
|
||||
blank=True,
|
||||
choices=[(1, "Regular"), (2, "CNN Reporter")],
|
||||
)
|
||||
|
||||
def __str__(self): # __unicode__ on Python 2
|
||||
return "%s %s" % (self.first_name, self.last_name)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
Override the init method so that during runtime, Django
|
||||
can know that this object can be a CNNReporter by casting
|
||||
it to the proxy model. Otherwise, as far as Django knows,
|
||||
when a CNNReporter is pulled from the database, it is still
|
||||
of type Reporter. This was added to test proxy model support.
|
||||
"""
|
||||
super(Reporter, self).__init__(*args, **kwargs)
|
||||
if self.reporter_type == 2: # quick and dirty way without enums
|
||||
self.__class__ = CNNReporter
|
||||
|
||||
def some_method(self):
|
||||
return 123
|
||||
|
||||
|
||||
class CNNReporterManager(models.Manager):
|
||||
def get_queryset(self):
|
||||
return super(CNNReporterManager, self).get_queryset().filter(reporter_type=2)
|
||||
|
||||
|
||||
class CNNReporter(Reporter):
|
||||
"""
|
||||
This class is a proxy model for Reporter, used for testing
|
||||
proxy model support
|
||||
"""
|
||||
|
||||
class Meta:
|
||||
proxy = True
|
||||
|
||||
objects = CNNReporterManager()
|
||||
|
||||
|
||||
class Article(models.Model):
|
||||
headline = models.CharField(max_length=100)
|
||||
pub_date = models.DateField(auto_now_add=True)
|
||||
pub_date_time = models.DateTimeField(auto_now_add=True)
|
||||
reporter = models.ForeignKey(
|
||||
Reporter, on_delete=models.CASCADE, related_name="articles"
|
||||
)
|
||||
editor = models.ForeignKey(
|
||||
Reporter, on_delete=models.CASCADE, related_name="edited_articles_+"
|
||||
)
|
||||
lang = models.CharField(
|
||||
max_length=2,
|
||||
help_text="Language",
|
||||
choices=[("es", "Spanish"), ("en", "English")],
|
||||
default="es",
|
||||
)
|
||||
importance = models.IntegerField(
|
||||
"Importance",
|
||||
null=True,
|
||||
blank=True,
|
||||
choices=[(1, "Very important"), (2, "Not as important")],
|
||||
)
|
||||
|
||||
def __str__(self): # __unicode__ on Python 2
|
||||
return self.headline
|
||||
|
||||
class Meta:
|
||||
ordering = ("headline",)
|
|
@ -1,18 +0,0 @@
|
|||
from graphene import Field
|
||||
|
||||
from graphene_django.forms.mutation import DjangoFormMutation, DjangoModelFormMutation
|
||||
|
||||
from .forms import PetForm
|
||||
from .types import PetType
|
||||
|
||||
|
||||
class PetFormMutation(DjangoFormMutation):
|
||||
class Meta:
|
||||
form_class = PetForm
|
||||
|
||||
|
||||
class PetMutation(DjangoModelFormMutation):
|
||||
pet = Field(PetType)
|
||||
|
||||
class Meta:
|
||||
form_class = PetForm
|
|
@ -1,40 +0,0 @@
|
|||
import graphene
|
||||
from graphene import Schema, relay
|
||||
|
||||
from ..types import DjangoObjectType
|
||||
from .models import Article, Reporter
|
||||
|
||||
|
||||
class Character(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Reporter
|
||||
interfaces = (relay.Node,)
|
||||
fields = "__all__"
|
||||
|
||||
def get_node(self, info, id):
|
||||
pass
|
||||
|
||||
|
||||
class Human(DjangoObjectType):
|
||||
raises = graphene.String()
|
||||
|
||||
class Meta:
|
||||
model = Article
|
||||
interfaces = (relay.Node,)
|
||||
fields = "__all__"
|
||||
|
||||
def resolve_raises(self, info):
|
||||
raise Exception("This field should raise exception")
|
||||
|
||||
def get_node(self, info, id):
|
||||
pass
|
||||
|
||||
|
||||
class Query(graphene.ObjectType):
|
||||
human = graphene.Field(Human)
|
||||
|
||||
def resolve_human(self, info):
|
||||
return Human()
|
||||
|
||||
|
||||
schema = Schema(query=Query)
|
|
@ -1,32 +0,0 @@
|
|||
import graphene
|
||||
from graphene import ObjectType, Schema
|
||||
|
||||
from .mutations import PetFormMutation, PetMutation
|
||||
|
||||
|
||||
class QueryRoot(ObjectType):
|
||||
|
||||
thrower = graphene.String(required=True)
|
||||
request = graphene.String(required=True)
|
||||
test = graphene.String(who=graphene.String())
|
||||
|
||||
def resolve_thrower(self, info):
|
||||
raise Exception("Throws!")
|
||||
|
||||
def resolve_request(self, info):
|
||||
return info.context.GET.get("q")
|
||||
|
||||
def resolve_test(self, info, who=None):
|
||||
return "Hello %s" % (who or "World")
|
||||
|
||||
|
||||
class MutationRoot(ObjectType):
|
||||
pet_form_mutation = PetFormMutation.Field()
|
||||
pet_mutation = PetMutation.Field()
|
||||
write_test = graphene.Field(QueryRoot)
|
||||
|
||||
def resolve_write_test(self, info):
|
||||
return QueryRoot()
|
||||
|
||||
|
||||
schema = Schema(query=QueryRoot, mutation=MutationRoot)
|
|
@ -1,58 +0,0 @@
|
|||
from textwrap import dedent
|
||||
|
||||
from django.core import management
|
||||
from io import StringIO
|
||||
from mock import mock_open, patch
|
||||
|
||||
from graphene import ObjectType, Schema, String
|
||||
|
||||
|
||||
@patch("graphene_django.management.commands.graphql_schema.Command.save_json_file")
|
||||
def test_generate_json_file_on_call_graphql_schema(savefile_mock):
|
||||
out = StringIO()
|
||||
management.call_command("graphql_schema", schema="", stdout=out)
|
||||
assert "Successfully dumped GraphQL schema to schema.json" in out.getvalue()
|
||||
|
||||
|
||||
@patch("json.dump")
|
||||
def test_json_files_are_canonical(dump_mock):
|
||||
open_mock = mock_open()
|
||||
with patch("graphene_django.management.commands.graphql_schema.open", open_mock):
|
||||
management.call_command("graphql_schema", schema="")
|
||||
|
||||
open_mock.assert_called_once()
|
||||
|
||||
dump_mock.assert_called_once()
|
||||
assert dump_mock.call_args[1][
|
||||
"sort_keys"
|
||||
], "json.mock() should be used to sort the output"
|
||||
assert (
|
||||
dump_mock.call_args[1]["indent"] > 0
|
||||
), "output should be pretty-printed by default"
|
||||
|
||||
|
||||
def test_generate_graphql_file_on_call_graphql_schema():
|
||||
class Query(ObjectType):
|
||||
hi = String()
|
||||
|
||||
mock_schema = Schema(query=Query)
|
||||
|
||||
open_mock = mock_open()
|
||||
with patch("graphene_django.management.commands.graphql_schema.open", open_mock):
|
||||
management.call_command(
|
||||
"graphql_schema", schema=mock_schema, out="schema.graphql"
|
||||
)
|
||||
|
||||
open_mock.assert_called_once()
|
||||
|
||||
handle = open_mock()
|
||||
assert handle.write.called_once()
|
||||
|
||||
schema_output = handle.write.call_args[0][0]
|
||||
assert schema_output == dedent(
|
||||
"""\
|
||||
type Query {
|
||||
hi: String
|
||||
}
|
||||
"""
|
||||
)
|
|
@ -1,468 +0,0 @@
|
|||
from collections import namedtuple
|
||||
|
||||
import pytest
|
||||
from django.db import models
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from py.test import raises
|
||||
|
||||
import graphene
|
||||
from graphene import NonNull
|
||||
from graphene.relay import ConnectionField, Node
|
||||
from graphene.types.datetime import Date, DateTime, Time
|
||||
from graphene.types.json import JSONString
|
||||
|
||||
from ..compat import (
|
||||
ArrayField,
|
||||
HStoreField,
|
||||
JSONField,
|
||||
PGJSONField,
|
||||
MissingType,
|
||||
RangeField,
|
||||
)
|
||||
from ..converter import (
|
||||
convert_django_field,
|
||||
convert_django_field_with_choices,
|
||||
generate_enum_name,
|
||||
)
|
||||
from ..registry import Registry
|
||||
from ..types import DjangoObjectType
|
||||
from .models import Article, Film, FilmDetails, Reporter
|
||||
|
||||
# from graphene.core.types.custom_scalars import DateTime, Time, JSONString
|
||||
|
||||
|
||||
def assert_conversion(django_field, graphene_field, *args, **kwargs):
|
||||
_kwargs = kwargs.copy()
|
||||
if "null" not in kwargs:
|
||||
_kwargs["null"] = True
|
||||
field = django_field(help_text="Custom Help Text", *args, **_kwargs)
|
||||
graphene_type = convert_django_field(field)
|
||||
assert isinstance(graphene_type, graphene_field)
|
||||
field = graphene_type.Field()
|
||||
assert field.description == "Custom Help Text"
|
||||
|
||||
_kwargs = kwargs.copy()
|
||||
if "null" not in kwargs:
|
||||
_kwargs["null"] = False
|
||||
nonnull_field = django_field(*args, **_kwargs)
|
||||
if not nonnull_field.null:
|
||||
nonnull_graphene_type = convert_django_field(nonnull_field)
|
||||
nonnull_field = nonnull_graphene_type.Field()
|
||||
assert isinstance(nonnull_field.type, graphene.NonNull)
|
||||
return nonnull_field
|
||||
return field
|
||||
|
||||
|
||||
def test_should_unknown_django_field_raise_exception():
|
||||
with raises(Exception) as excinfo:
|
||||
convert_django_field(None)
|
||||
assert "Don't know how to convert the Django field" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_should_date_time_convert_string():
|
||||
assert_conversion(models.DateTimeField, DateTime)
|
||||
|
||||
|
||||
def test_should_date_convert_string():
|
||||
assert_conversion(models.DateField, Date)
|
||||
|
||||
|
||||
def test_should_time_convert_string():
|
||||
assert_conversion(models.TimeField, Time)
|
||||
|
||||
|
||||
def test_should_char_convert_string():
|
||||
assert_conversion(models.CharField, graphene.String)
|
||||
|
||||
|
||||
def test_should_text_convert_string():
|
||||
assert_conversion(models.TextField, graphene.String)
|
||||
|
||||
|
||||
def test_should_email_convert_string():
|
||||
assert_conversion(models.EmailField, graphene.String)
|
||||
|
||||
|
||||
def test_should_slug_convert_string():
|
||||
assert_conversion(models.SlugField, graphene.String)
|
||||
|
||||
|
||||
def test_should_url_convert_string():
|
||||
assert_conversion(models.URLField, graphene.String)
|
||||
|
||||
|
||||
def test_should_ipaddress_convert_string():
|
||||
assert_conversion(models.GenericIPAddressField, graphene.String)
|
||||
|
||||
|
||||
def test_should_file_convert_string():
|
||||
assert_conversion(models.FileField, graphene.String)
|
||||
|
||||
|
||||
def test_should_image_convert_string():
|
||||
assert_conversion(models.ImageField, graphene.String)
|
||||
|
||||
|
||||
def test_should_file_path_field_convert_string():
|
||||
assert_conversion(models.FilePathField, graphene.String)
|
||||
|
||||
|
||||
def test_should_auto_convert_id():
|
||||
assert_conversion(models.AutoField, graphene.ID, primary_key=True)
|
||||
|
||||
|
||||
def test_should_big_auto_convert_id():
|
||||
assert_conversion(models.BigAutoField, graphene.ID, primary_key=True)
|
||||
|
||||
|
||||
def test_should_small_auto_convert_id():
|
||||
if hasattr(models, "SmallAutoField"):
|
||||
assert_conversion(models.SmallAutoField, graphene.ID, primary_key=True)
|
||||
|
||||
|
||||
def test_should_uuid_convert_id():
|
||||
assert_conversion(models.UUIDField, graphene.UUID)
|
||||
|
||||
|
||||
def test_should_auto_convert_duration():
|
||||
assert_conversion(models.DurationField, graphene.Float)
|
||||
|
||||
|
||||
def test_should_positive_integer_convert_int():
|
||||
assert_conversion(models.PositiveIntegerField, graphene.Int)
|
||||
|
||||
|
||||
def test_should_positive_small_convert_int():
|
||||
assert_conversion(models.PositiveSmallIntegerField, graphene.Int)
|
||||
|
||||
|
||||
def test_should_small_integer_convert_int():
|
||||
assert_conversion(models.SmallIntegerField, graphene.Int)
|
||||
|
||||
|
||||
def test_should_big_integer_convert_int():
|
||||
assert_conversion(models.BigIntegerField, graphene.Int)
|
||||
|
||||
|
||||
def test_should_integer_convert_int():
|
||||
assert_conversion(models.IntegerField, graphene.Int)
|
||||
|
||||
|
||||
def test_should_boolean_convert_boolean():
|
||||
assert_conversion(models.BooleanField, graphene.Boolean, null=True)
|
||||
|
||||
|
||||
def test_should_boolean_convert_non_null_boolean():
|
||||
field = assert_conversion(models.BooleanField, graphene.Boolean, null=False)
|
||||
assert isinstance(field.type, graphene.NonNull)
|
||||
assert field.type.of_type == graphene.Boolean
|
||||
|
||||
|
||||
def test_should_nullboolean_convert_boolean():
|
||||
assert_conversion(models.NullBooleanField, graphene.Boolean)
|
||||
|
||||
|
||||
def test_field_with_choices_convert_enum():
|
||||
field = models.CharField(
|
||||
help_text="Language", choices=(("es", "Spanish"), ("en", "English"))
|
||||
)
|
||||
|
||||
class TranslatedModel(models.Model):
|
||||
language = field
|
||||
|
||||
class Meta:
|
||||
app_label = "test"
|
||||
|
||||
graphene_type = convert_django_field_with_choices(field).type.of_type
|
||||
assert graphene_type._meta.name == "TestTranslatedModelLanguageChoices"
|
||||
assert graphene_type._meta.enum.__members__["ES"].value == "es"
|
||||
assert graphene_type._meta.enum.__members__["ES"].description == "Spanish"
|
||||
assert graphene_type._meta.enum.__members__["EN"].value == "en"
|
||||
assert graphene_type._meta.enum.__members__["EN"].description == "English"
|
||||
|
||||
|
||||
def test_field_with_grouped_choices():
|
||||
field = models.CharField(
|
||||
help_text="Language",
|
||||
choices=(("Europe", (("es", "Spanish"), ("en", "English"))),),
|
||||
)
|
||||
|
||||
class GroupedChoicesModel(models.Model):
|
||||
language = field
|
||||
|
||||
class Meta:
|
||||
app_label = "test"
|
||||
|
||||
convert_django_field_with_choices(field)
|
||||
|
||||
|
||||
def test_field_with_choices_gettext():
|
||||
field = models.CharField(
|
||||
help_text="Language", choices=(("es", _("Spanish")), ("en", _("English")))
|
||||
)
|
||||
|
||||
class TranslatedChoicesModel(models.Model):
|
||||
language = field
|
||||
|
||||
class Meta:
|
||||
app_label = "test"
|
||||
|
||||
convert_django_field_with_choices(field)
|
||||
|
||||
|
||||
def test_field_with_choices_collision():
|
||||
field = models.CharField(
|
||||
help_text="Timezone",
|
||||
choices=(
|
||||
("Etc/GMT+1+2", "Fake choice to produce double collision"),
|
||||
("Etc/GMT+1", "Greenwich Mean Time +1"),
|
||||
("Etc/GMT-1", "Greenwich Mean Time -1"),
|
||||
),
|
||||
)
|
||||
|
||||
class CollisionChoicesModel(models.Model):
|
||||
timezone = field
|
||||
|
||||
class Meta:
|
||||
app_label = "test"
|
||||
|
||||
convert_django_field_with_choices(field)
|
||||
|
||||
|
||||
def test_field_with_choices_convert_enum_false():
|
||||
field = models.CharField(
|
||||
help_text="Language", choices=(("es", "Spanish"), ("en", "English"))
|
||||
)
|
||||
|
||||
class TranslatedModel(models.Model):
|
||||
language = field
|
||||
|
||||
class Meta:
|
||||
app_label = "test"
|
||||
|
||||
graphene_type = convert_django_field_with_choices(
|
||||
field, convert_choices_to_enum=False
|
||||
)
|
||||
assert isinstance(graphene_type, graphene.String)
|
||||
|
||||
|
||||
def test_should_float_convert_float():
|
||||
assert_conversion(models.FloatField, graphene.Float)
|
||||
|
||||
|
||||
def test_should_float_convert_decimal():
|
||||
assert_conversion(models.DecimalField, graphene.Decimal)
|
||||
|
||||
|
||||
def test_should_manytomany_convert_connectionorlist():
|
||||
registry = Registry()
|
||||
dynamic_field = convert_django_field(Reporter._meta.local_many_to_many[0], registry)
|
||||
assert not dynamic_field.get_type()
|
||||
|
||||
|
||||
def test_should_manytomany_convert_connectionorlist_list():
|
||||
class A(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Reporter
|
||||
fields = "__all__"
|
||||
|
||||
graphene_field = convert_django_field(
|
||||
Reporter._meta.local_many_to_many[0], A._meta.registry
|
||||
)
|
||||
assert isinstance(graphene_field, graphene.Dynamic)
|
||||
dynamic_field = graphene_field.get_type()
|
||||
assert isinstance(dynamic_field, graphene.Field)
|
||||
# A NonNull List of NonNull A ([A!]!)
|
||||
# https://github.com/graphql-python/graphene-django/issues/448
|
||||
assert isinstance(dynamic_field.type, NonNull)
|
||||
assert isinstance(dynamic_field.type.of_type, graphene.List)
|
||||
assert isinstance(dynamic_field.type.of_type.of_type, NonNull)
|
||||
assert dynamic_field.type.of_type.of_type.of_type == A
|
||||
|
||||
|
||||
def test_should_manytomany_convert_connectionorlist_connection():
|
||||
class A(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Reporter
|
||||
interfaces = (Node,)
|
||||
fields = "__all__"
|
||||
|
||||
graphene_field = convert_django_field(
|
||||
Reporter._meta.local_many_to_many[0], A._meta.registry
|
||||
)
|
||||
assert isinstance(graphene_field, graphene.Dynamic)
|
||||
dynamic_field = graphene_field.get_type()
|
||||
assert isinstance(dynamic_field, ConnectionField)
|
||||
assert dynamic_field.type.of_type == A._meta.connection
|
||||
|
||||
|
||||
def test_should_manytoone_convert_connectionorlist():
|
||||
class A(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Article
|
||||
fields = "__all__"
|
||||
|
||||
graphene_field = convert_django_field(Reporter.articles.rel, A._meta.registry)
|
||||
assert isinstance(graphene_field, graphene.Dynamic)
|
||||
dynamic_field = graphene_field.get_type()
|
||||
assert isinstance(dynamic_field, graphene.Field)
|
||||
# a NonNull List of NonNull A ([A!]!)
|
||||
assert isinstance(dynamic_field.type, NonNull)
|
||||
assert isinstance(dynamic_field.type.of_type, graphene.List)
|
||||
assert isinstance(dynamic_field.type.of_type.of_type, NonNull)
|
||||
assert dynamic_field.type.of_type.of_type.of_type == A
|
||||
|
||||
|
||||
def test_should_onetoone_reverse_convert_model():
|
||||
class A(DjangoObjectType):
|
||||
class Meta:
|
||||
model = FilmDetails
|
||||
fields = "__all__"
|
||||
|
||||
graphene_field = convert_django_field(Film.details.related, A._meta.registry)
|
||||
assert isinstance(graphene_field, graphene.Dynamic)
|
||||
dynamic_field = graphene_field.get_type()
|
||||
assert isinstance(dynamic_field, graphene.Field)
|
||||
assert dynamic_field.type == A
|
||||
|
||||
|
||||
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
|
||||
def test_should_postgres_array_convert_list():
|
||||
field = assert_conversion(
|
||||
ArrayField, graphene.List, models.CharField(max_length=100)
|
||||
)
|
||||
assert isinstance(field.type, graphene.NonNull)
|
||||
assert isinstance(field.type.of_type, graphene.List)
|
||||
assert isinstance(field.type.of_type.of_type, graphene.NonNull)
|
||||
assert field.type.of_type.of_type.of_type == graphene.String
|
||||
|
||||
field = assert_conversion(
|
||||
ArrayField, graphene.List, models.CharField(max_length=100, null=True)
|
||||
)
|
||||
assert isinstance(field.type, graphene.NonNull)
|
||||
assert isinstance(field.type.of_type, graphene.List)
|
||||
assert field.type.of_type.of_type == graphene.String
|
||||
|
||||
|
||||
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
|
||||
def test_should_postgres_array_multiple_convert_list():
|
||||
field = assert_conversion(
|
||||
ArrayField, graphene.List, ArrayField(models.CharField(max_length=100))
|
||||
)
|
||||
assert isinstance(field.type, graphene.NonNull)
|
||||
assert isinstance(field.type.of_type, graphene.List)
|
||||
assert isinstance(field.type.of_type.of_type, graphene.List)
|
||||
assert isinstance(field.type.of_type.of_type.of_type, graphene.NonNull)
|
||||
assert field.type.of_type.of_type.of_type.of_type == graphene.String
|
||||
|
||||
field = assert_conversion(
|
||||
ArrayField,
|
||||
graphene.List,
|
||||
ArrayField(models.CharField(max_length=100, null=True)),
|
||||
)
|
||||
assert isinstance(field.type, graphene.NonNull)
|
||||
assert isinstance(field.type.of_type, graphene.List)
|
||||
assert isinstance(field.type.of_type.of_type, graphene.List)
|
||||
assert field.type.of_type.of_type.of_type == graphene.String
|
||||
|
||||
|
||||
@pytest.mark.skipif(HStoreField is MissingType, reason="HStoreField should exist")
|
||||
def test_should_postgres_hstore_convert_string():
|
||||
assert_conversion(HStoreField, JSONString)
|
||||
|
||||
|
||||
@pytest.mark.skipif(PGJSONField is MissingType, reason="PGJSONField should exist")
|
||||
def test_should_postgres_json_convert_string():
|
||||
assert_conversion(PGJSONField, JSONString)
|
||||
|
||||
|
||||
@pytest.mark.skipif(JSONField is MissingType, reason="JSONField should exist")
|
||||
def test_should_json_convert_string():
|
||||
assert_conversion(JSONField, JSONString)
|
||||
|
||||
|
||||
@pytest.mark.skipif(RangeField is MissingType, reason="RangeField should exist")
|
||||
def test_should_postgres_range_convert_list():
|
||||
from django.contrib.postgres.fields import IntegerRangeField
|
||||
|
||||
field = assert_conversion(IntegerRangeField, graphene.List)
|
||||
assert isinstance(field.type, graphene.NonNull)
|
||||
assert isinstance(field.type.of_type, graphene.List)
|
||||
assert isinstance(field.type.of_type.of_type, graphene.NonNull)
|
||||
assert field.type.of_type.of_type.of_type == graphene.Int
|
||||
|
||||
|
||||
def test_generate_enum_name():
|
||||
MockDjangoModelMeta = namedtuple("DjangoMeta", ["app_label", "object_name"])
|
||||
|
||||
# Simple case
|
||||
field = graphene.Field(graphene.String, name="type")
|
||||
model_meta = MockDjangoModelMeta(app_label="users", object_name="User")
|
||||
assert generate_enum_name(model_meta, field) == "UsersUserTypeChoices"
|
||||
|
||||
# More complicated multiple work case
|
||||
field = graphene.Field(graphene.String, name="fizz_buzz")
|
||||
model_meta = MockDjangoModelMeta(
|
||||
app_label="some_long_app_name", object_name="SomeObject"
|
||||
)
|
||||
assert (
|
||||
generate_enum_name(model_meta, field)
|
||||
== "SomeLongAppNameSomeObjectFizzBuzzChoices"
|
||||
)
|
||||
|
||||
|
||||
def test_generate_v2_enum_name(graphene_settings):
|
||||
MockDjangoModelMeta = namedtuple("DjangoMeta", ["app_label", "object_name"])
|
||||
graphene_settings.DJANGO_CHOICE_FIELD_ENUM_V2_NAMING = True
|
||||
|
||||
# Simple case
|
||||
field = graphene.Field(graphene.String, name="type")
|
||||
model_meta = MockDjangoModelMeta(app_label="users", object_name="User")
|
||||
assert generate_enum_name(model_meta, field) == "UserType"
|
||||
|
||||
# More complicated multiple work case
|
||||
field = graphene.Field(graphene.String, name="fizz_buzz")
|
||||
model_meta = MockDjangoModelMeta(
|
||||
app_label="some_long_app_name", object_name="SomeObject"
|
||||
)
|
||||
assert generate_enum_name(model_meta, field) == "SomeObjectFizzBuzz"
|
||||
|
||||
|
||||
def test_choice_enum_blank_value():
|
||||
"""Test that choice fields with blank values work"""
|
||||
|
||||
class ReporterType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Reporter
|
||||
fields = (
|
||||
"first_name",
|
||||
"a_choice",
|
||||
)
|
||||
|
||||
class Query(graphene.ObjectType):
|
||||
reporter = graphene.Field(ReporterType)
|
||||
|
||||
def resolve_reporter(root, info):
|
||||
return Reporter.objects.first()
|
||||
|
||||
schema = graphene.Schema(query=Query)
|
||||
|
||||
# Create model with empty choice option
|
||||
Reporter.objects.create(
|
||||
first_name="Bridget", last_name="Jones", email="bridget@example.com"
|
||||
)
|
||||
|
||||
result = schema.execute(
|
||||
"""
|
||||
query {
|
||||
reporter {
|
||||
firstName
|
||||
aChoice
|
||||
}
|
||||
}
|
||||
"""
|
||||
)
|
||||
assert not result.errors
|
||||
assert result.data == {
|
||||
"reporter": {"firstName": "Bridget", "aChoice": None},
|
||||
}
|
|
@ -1,502 +0,0 @@
|
|||
import datetime
|
||||
from django.db.models import Count
|
||||
|
||||
import pytest
|
||||
|
||||
from graphene import List, NonNull, ObjectType, Schema, String
|
||||
|
||||
from ..fields import DjangoListField
|
||||
from ..types import DjangoObjectType
|
||||
from .models import Article as ArticleModel
|
||||
from .models import Reporter as ReporterModel
|
||||
|
||||
|
||||
class TestDjangoListField:
|
||||
def test_only_django_object_types(self):
|
||||
class TestType(ObjectType):
|
||||
foo = String()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
list_field = DjangoListField(TestType)
|
||||
|
||||
def test_only_import_paths(self):
|
||||
list_field = DjangoListField("graphene_django.tests.schema.Human")
|
||||
from .schema import Human
|
||||
|
||||
assert list_field._type.of_type.of_type is Human
|
||||
|
||||
def test_non_null_type(self):
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
fields = ("first_name",)
|
||||
|
||||
list_field = DjangoListField(NonNull(Reporter))
|
||||
|
||||
assert isinstance(list_field.type, List)
|
||||
assert isinstance(list_field.type.of_type, NonNull)
|
||||
assert list_field.type.of_type.of_type is Reporter
|
||||
|
||||
def test_get_django_model(self):
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
fields = ("first_name",)
|
||||
|
||||
list_field = DjangoListField(Reporter)
|
||||
assert list_field.model is ReporterModel
|
||||
|
||||
def test_list_field_default_queryset(self):
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
fields = ("first_name",)
|
||||
|
||||
class Query(ObjectType):
|
||||
reporters = DjangoListField(Reporter)
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
reporters {
|
||||
firstName
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
ReporterModel.objects.create(first_name="Tara", last_name="West")
|
||||
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||
|
||||
result = schema.execute(query)
|
||||
|
||||
assert not result.errors
|
||||
assert result.data == {
|
||||
"reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}]
|
||||
}
|
||||
|
||||
def test_list_field_queryset_is_not_cached(self):
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
fields = ("first_name",)
|
||||
|
||||
class Query(ObjectType):
|
||||
reporters = DjangoListField(Reporter)
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
reporters {
|
||||
firstName
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data == {"reporters": []}
|
||||
|
||||
ReporterModel.objects.create(first_name="Tara", last_name="West")
|
||||
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||
|
||||
result = schema.execute(query)
|
||||
|
||||
assert not result.errors
|
||||
assert result.data == {
|
||||
"reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}]
|
||||
}
|
||||
|
||||
def test_override_resolver(self):
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
fields = ("first_name",)
|
||||
|
||||
class Query(ObjectType):
|
||||
reporters = DjangoListField(Reporter)
|
||||
|
||||
def resolve_reporters(_, info):
|
||||
return ReporterModel.objects.filter(first_name="Tara")
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
reporters {
|
||||
firstName
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
ReporterModel.objects.create(first_name="Tara", last_name="West")
|
||||
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||
|
||||
result = schema.execute(query)
|
||||
|
||||
assert not result.errors
|
||||
assert result.data == {"reporters": [{"firstName": "Tara"}]}
|
||||
|
||||
def test_nested_list_field(self):
|
||||
class Article(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ArticleModel
|
||||
fields = ("headline",)
|
||||
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
fields = ("first_name", "articles")
|
||||
|
||||
class Query(ObjectType):
|
||||
reporters = DjangoListField(Reporter)
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
reporters {
|
||||
firstName
|
||||
articles {
|
||||
headline
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
|
||||
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||
|
||||
ArticleModel.objects.create(
|
||||
headline="Amazing news",
|
||||
reporter=r1,
|
||||
pub_date=datetime.date.today(),
|
||||
pub_date_time=datetime.datetime.now(),
|
||||
editor=r1,
|
||||
)
|
||||
ArticleModel.objects.create(
|
||||
headline="Not so good news",
|
||||
reporter=r1,
|
||||
pub_date=datetime.date.today(),
|
||||
pub_date_time=datetime.datetime.now(),
|
||||
editor=r1,
|
||||
)
|
||||
|
||||
result = schema.execute(query)
|
||||
|
||||
assert not result.errors
|
||||
assert result.data == {
|
||||
"reporters": [
|
||||
{
|
||||
"firstName": "Tara",
|
||||
"articles": [
|
||||
{"headline": "Amazing news"},
|
||||
{"headline": "Not so good news"},
|
||||
],
|
||||
},
|
||||
{"firstName": "Debra", "articles": []},
|
||||
]
|
||||
}
|
||||
|
||||
def test_override_resolver_nested_list_field(self):
|
||||
class Article(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ArticleModel
|
||||
fields = ("headline",)
|
||||
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
fields = ("first_name", "articles")
|
||||
|
||||
def resolve_articles(reporter, info):
|
||||
return reporter.articles.filter(headline__contains="Amazing")
|
||||
|
||||
class Query(ObjectType):
|
||||
reporters = DjangoListField(Reporter)
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
reporters {
|
||||
firstName
|
||||
articles {
|
||||
headline
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
|
||||
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||
|
||||
ArticleModel.objects.create(
|
||||
headline="Amazing news",
|
||||
reporter=r1,
|
||||
pub_date=datetime.date.today(),
|
||||
pub_date_time=datetime.datetime.now(),
|
||||
editor=r1,
|
||||
)
|
||||
ArticleModel.objects.create(
|
||||
headline="Not so good news",
|
||||
reporter=r1,
|
||||
pub_date=datetime.date.today(),
|
||||
pub_date_time=datetime.datetime.now(),
|
||||
editor=r1,
|
||||
)
|
||||
|
||||
result = schema.execute(query)
|
||||
|
||||
assert not result.errors
|
||||
assert result.data == {
|
||||
"reporters": [
|
||||
{"firstName": "Tara", "articles": [{"headline": "Amazing news"}]},
|
||||
{"firstName": "Debra", "articles": []},
|
||||
]
|
||||
}
|
||||
|
||||
def test_get_queryset_filter(self):
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
fields = ("first_name", "articles")
|
||||
|
||||
@classmethod
|
||||
def get_queryset(cls, queryset, info):
|
||||
# Only get reporters with at least 1 article
|
||||
return queryset.annotate(article_count=Count("articles")).filter(
|
||||
article_count__gt=0
|
||||
)
|
||||
|
||||
class Query(ObjectType):
|
||||
reporters = DjangoListField(Reporter)
|
||||
|
||||
def resolve_reporters(_, info):
|
||||
return ReporterModel.objects.all()
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
reporters {
|
||||
firstName
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
|
||||
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||
|
||||
ArticleModel.objects.create(
|
||||
headline="Amazing news",
|
||||
reporter=r1,
|
||||
pub_date=datetime.date.today(),
|
||||
pub_date_time=datetime.datetime.now(),
|
||||
editor=r1,
|
||||
)
|
||||
|
||||
result = schema.execute(query)
|
||||
|
||||
assert not result.errors
|
||||
assert result.data == {"reporters": [{"firstName": "Tara"}]}
|
||||
|
||||
def test_resolve_list(self):
|
||||
"""Resolving a plain list should work (and not call get_queryset)"""
|
||||
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
fields = ("first_name", "articles")
|
||||
|
||||
@classmethod
|
||||
def get_queryset(cls, queryset, info):
|
||||
# Only get reporters with at least 1 article
|
||||
return queryset.annotate(article_count=Count("articles")).filter(
|
||||
article_count__gt=0
|
||||
)
|
||||
|
||||
class Query(ObjectType):
|
||||
reporters = DjangoListField(Reporter)
|
||||
|
||||
def resolve_reporters(_, info):
|
||||
return [ReporterModel.objects.get(first_name="Debra")]
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
reporters {
|
||||
firstName
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
|
||||
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||
|
||||
ArticleModel.objects.create(
|
||||
headline="Amazing news",
|
||||
reporter=r1,
|
||||
pub_date=datetime.date.today(),
|
||||
pub_date_time=datetime.datetime.now(),
|
||||
editor=r1,
|
||||
)
|
||||
|
||||
result = schema.execute(query)
|
||||
|
||||
assert not result.errors
|
||||
assert result.data == {"reporters": [{"firstName": "Debra"}]}
|
||||
|
||||
def test_get_queryset_foreign_key(self):
|
||||
class Article(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ArticleModel
|
||||
fields = ("headline",)
|
||||
|
||||
@classmethod
|
||||
def get_queryset(cls, queryset, info):
|
||||
# Rose tinted glasses
|
||||
return queryset.exclude(headline__contains="Not so good")
|
||||
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
fields = ("first_name", "articles")
|
||||
|
||||
class Query(ObjectType):
|
||||
reporters = DjangoListField(Reporter)
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
reporters {
|
||||
firstName
|
||||
articles {
|
||||
headline
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
|
||||
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||
|
||||
ArticleModel.objects.create(
|
||||
headline="Amazing news",
|
||||
reporter=r1,
|
||||
pub_date=datetime.date.today(),
|
||||
pub_date_time=datetime.datetime.now(),
|
||||
editor=r1,
|
||||
)
|
||||
ArticleModel.objects.create(
|
||||
headline="Not so good news",
|
||||
reporter=r1,
|
||||
pub_date=datetime.date.today(),
|
||||
pub_date_time=datetime.datetime.now(),
|
||||
editor=r1,
|
||||
)
|
||||
|
||||
result = schema.execute(query)
|
||||
|
||||
assert not result.errors
|
||||
assert result.data == {
|
||||
"reporters": [
|
||||
{"firstName": "Tara", "articles": [{"headline": "Amazing news"}]},
|
||||
{"firstName": "Debra", "articles": []},
|
||||
]
|
||||
}
|
||||
|
||||
def test_resolve_list_external_resolver(self):
|
||||
"""Resolving a plain list from external resolver should work (and not call get_queryset)"""
|
||||
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
fields = ("first_name", "articles")
|
||||
|
||||
@classmethod
|
||||
def get_queryset(cls, queryset, info):
|
||||
# Only get reporters with at least 1 article
|
||||
return queryset.annotate(article_count=Count("articles")).filter(
|
||||
article_count__gt=0
|
||||
)
|
||||
|
||||
def resolve_reporters(_, info):
|
||||
return [ReporterModel.objects.get(first_name="Debra")]
|
||||
|
||||
class Query(ObjectType):
|
||||
reporters = DjangoListField(Reporter, resolver=resolve_reporters)
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
reporters {
|
||||
firstName
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
|
||||
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||
|
||||
ArticleModel.objects.create(
|
||||
headline="Amazing news",
|
||||
reporter=r1,
|
||||
pub_date=datetime.date.today(),
|
||||
pub_date_time=datetime.datetime.now(),
|
||||
editor=r1,
|
||||
)
|
||||
|
||||
result = schema.execute(query)
|
||||
|
||||
assert not result.errors
|
||||
assert result.data == {"reporters": [{"firstName": "Debra"}]}
|
||||
|
||||
def test_get_queryset_filter_external_resolver(self):
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
fields = ("first_name", "articles")
|
||||
|
||||
@classmethod
|
||||
def get_queryset(cls, queryset, info):
|
||||
# Only get reporters with at least 1 article
|
||||
return queryset.annotate(article_count=Count("articles")).filter(
|
||||
article_count__gt=0
|
||||
)
|
||||
|
||||
def resolve_reporters(_, info):
|
||||
return ReporterModel.objects.all()
|
||||
|
||||
class Query(ObjectType):
|
||||
reporters = DjangoListField(Reporter, resolver=resolve_reporters)
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
reporters {
|
||||
firstName
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
|
||||
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||
|
||||
ArticleModel.objects.create(
|
||||
headline="Amazing news",
|
||||
reporter=r1,
|
||||
pub_date=datetime.date.today(),
|
||||
pub_date_time=datetime.datetime.now(),
|
||||
editor=r1,
|
||||
)
|
||||
|
||||
result = schema.execute(query)
|
||||
|
||||
assert not result.errors
|
||||
assert result.data == {"reporters": [{"firstName": "Tara"}]}
|
|
@ -1,40 +0,0 @@
|
|||
from django.core.exceptions import ValidationError
|
||||
from py.test import raises
|
||||
|
||||
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
|
||||
|
||||
|
||||
# 'TXlUeXBlOmFiYw==' -> 'MyType', 'abc'
|
||||
|
||||
|
||||
def test_global_id_valid():
|
||||
field = GlobalIDFormField()
|
||||
field.clean("TXlUeXBlOmFiYw==")
|
||||
|
||||
|
||||
def test_global_id_invalid():
|
||||
field = GlobalIDFormField()
|
||||
with raises(ValidationError):
|
||||
field.clean("badvalue")
|
||||
|
||||
|
||||
def test_global_id_multiple_valid():
|
||||
field = GlobalIDMultipleChoiceField()
|
||||
field.clean(["TXlUeXBlOmFiYw==", "TXlUeXBlOmFiYw=="])
|
||||
|
||||
|
||||
def test_global_id_multiple_invalid():
|
||||
field = GlobalIDMultipleChoiceField()
|
||||
with raises(ValidationError):
|
||||
field.clean(["badvalue", "another bad avue"])
|
||||
|
||||
|
||||
def test_global_id_none():
|
||||
field = GlobalIDFormField()
|
||||
with raises(ValidationError):
|
||||
field.clean(None)
|
||||
|
||||
|
||||
def test_global_id_none_optional():
|
||||
field = GlobalIDFormField(required=False)
|
||||
field.clean(None)
|
File diff suppressed because it is too large
Load Diff
|
@ -1,55 +0,0 @@
|
|||
from py.test import raises
|
||||
|
||||
from ..registry import Registry
|
||||
from ..types import DjangoObjectType
|
||||
from .models import Reporter
|
||||
|
||||
|
||||
def test_should_raise_if_no_model():
|
||||
with raises(Exception) as excinfo:
|
||||
|
||||
class Character1(DjangoObjectType):
|
||||
fields = "__all__"
|
||||
|
||||
assert "valid Django Model" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_should_raise_if_model_is_invalid():
|
||||
with raises(Exception) as excinfo:
|
||||
|
||||
class Character2(DjangoObjectType):
|
||||
class Meta:
|
||||
model = 1
|
||||
fields = "__all__"
|
||||
|
||||
assert "valid Django Model" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_should_map_fields_correctly():
|
||||
class ReporterType2(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Reporter
|
||||
registry = Registry()
|
||||
fields = "__all__"
|
||||
|
||||
fields = list(ReporterType2._meta.fields.keys())
|
||||
assert fields[:-2] == [
|
||||
"id",
|
||||
"first_name",
|
||||
"last_name",
|
||||
"email",
|
||||
"pets",
|
||||
"a_choice",
|
||||
"reporter_type",
|
||||
]
|
||||
|
||||
assert sorted(fields[-2:]) == ["articles", "films"]
|
||||
|
||||
|
||||
def test_should_map_only_few_fields():
|
||||
class Reporter2(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Reporter
|
||||
fields = ("id", "email")
|
||||
|
||||
assert list(Reporter2._meta.fields.keys()) == ["id", "email"]
|
|
@ -1,691 +0,0 @@
|
|||
from collections import OrderedDict, defaultdict
|
||||
from textwrap import dedent
|
||||
|
||||
import pytest
|
||||
from django.db import models
|
||||
from mock import patch
|
||||
|
||||
from graphene import Connection, Field, Interface, ObjectType, Schema, String
|
||||
from graphene.relay import Node
|
||||
|
||||
from .. import registry
|
||||
from ..filter import DjangoFilterConnectionField
|
||||
from ..types import DjangoObjectType, DjangoObjectTypeOptions
|
||||
from .models import Article as ArticleModel
|
||||
from .models import Reporter as ReporterModel
|
||||
|
||||
|
||||
class Reporter(DjangoObjectType):
|
||||
"""Reporter description"""
|
||||
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
fields = "__all__"
|
||||
|
||||
|
||||
class ArticleConnection(Connection):
|
||||
"""Article Connection"""
|
||||
|
||||
test = String()
|
||||
|
||||
def resolve_test():
|
||||
return "test"
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
|
||||
class Article(DjangoObjectType):
|
||||
"""Article description"""
|
||||
|
||||
class Meta:
|
||||
model = ArticleModel
|
||||
interfaces = (Node,)
|
||||
connection_class = ArticleConnection
|
||||
fields = "__all__"
|
||||
|
||||
|
||||
class RootQuery(ObjectType):
|
||||
node = Node.Field()
|
||||
|
||||
|
||||
schema = Schema(query=RootQuery, types=[Article, Reporter])
|
||||
|
||||
|
||||
def test_django_interface():
|
||||
assert issubclass(Node, Interface)
|
||||
assert issubclass(Node, Node)
|
||||
|
||||
|
||||
@patch("graphene_django.tests.models.Article.objects.get", return_value=Article(id=1))
|
||||
def test_django_get_node(get):
|
||||
article = Article.get_node(None, 1)
|
||||
get.assert_called_with(pk=1)
|
||||
assert article.id == 1
|
||||
|
||||
|
||||
def test_django_objecttype_map_correct_fields():
|
||||
fields = Reporter._meta.fields
|
||||
fields = list(fields.keys())
|
||||
assert fields[:-2] == [
|
||||
"id",
|
||||
"first_name",
|
||||
"last_name",
|
||||
"email",
|
||||
"pets",
|
||||
"a_choice",
|
||||
"reporter_type",
|
||||
]
|
||||
assert sorted(fields[-2:]) == ["articles", "films"]
|
||||
|
||||
|
||||
def test_django_objecttype_with_node_have_correct_fields():
|
||||
fields = Article._meta.fields
|
||||
assert list(fields.keys()) == [
|
||||
"id",
|
||||
"headline",
|
||||
"pub_date",
|
||||
"pub_date_time",
|
||||
"reporter",
|
||||
"editor",
|
||||
"lang",
|
||||
"importance",
|
||||
]
|
||||
|
||||
|
||||
def test_django_objecttype_with_custom_meta():
|
||||
class ArticleTypeOptions(DjangoObjectTypeOptions):
|
||||
"""Article Type Options"""
|
||||
|
||||
class ArticleType(DjangoObjectType):
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
@classmethod
|
||||
def __init_subclass_with_meta__(cls, **options):
|
||||
options.setdefault("_meta", ArticleTypeOptions(cls))
|
||||
super(ArticleType, cls).__init_subclass_with_meta__(**options)
|
||||
|
||||
class Article(ArticleType):
|
||||
class Meta:
|
||||
model = ArticleModel
|
||||
fields = "__all__"
|
||||
|
||||
assert isinstance(Article._meta, ArticleTypeOptions)
|
||||
|
||||
|
||||
def test_schema_representation():
|
||||
expected = dedent(
|
||||
"""\
|
||||
schema {
|
||||
query: RootQuery
|
||||
}
|
||||
|
||||
\"""Article description\"""
|
||||
type Article implements Node {
|
||||
\"""The ID of the object\"""
|
||||
id: ID!
|
||||
headline: String!
|
||||
pubDate: Date!
|
||||
pubDateTime: DateTime!
|
||||
reporter: Reporter!
|
||||
editor: Reporter!
|
||||
|
||||
\"""Language\"""
|
||||
lang: TestsArticleLangChoices!
|
||||
importance: TestsArticleImportanceChoices
|
||||
}
|
||||
|
||||
\"""An object with an ID\"""
|
||||
interface Node {
|
||||
\"""The ID of the object\"""
|
||||
id: ID!
|
||||
}
|
||||
|
||||
\"""
|
||||
The `Date` scalar type represents a Date
|
||||
value as specified by
|
||||
[iso8601](https://en.wikipedia.org/wiki/ISO_8601).
|
||||
\"""
|
||||
scalar Date
|
||||
|
||||
\"""
|
||||
The `DateTime` scalar type represents a DateTime
|
||||
value as specified by
|
||||
[iso8601](https://en.wikipedia.org/wiki/ISO_8601).
|
||||
\"""
|
||||
scalar DateTime
|
||||
|
||||
\"""An enumeration.\"""
|
||||
enum TestsArticleLangChoices {
|
||||
\"""Spanish\"""
|
||||
ES
|
||||
|
||||
\"""English\"""
|
||||
EN
|
||||
}
|
||||
|
||||
\"""An enumeration.\"""
|
||||
enum TestsArticleImportanceChoices {
|
||||
\"""Very important\"""
|
||||
A_1
|
||||
|
||||
\"""Not as important\"""
|
||||
A_2
|
||||
}
|
||||
|
||||
\"""Reporter description\"""
|
||||
type Reporter {
|
||||
id: ID!
|
||||
firstName: String!
|
||||
lastName: String!
|
||||
email: String!
|
||||
pets: [Reporter!]!
|
||||
aChoice: TestsReporterAChoiceChoices
|
||||
reporterType: TestsReporterReporterTypeChoices
|
||||
articles(offset: Int = null, before: String = null, after: String = null, first: Int = null, last: Int = null): ArticleConnection!
|
||||
}
|
||||
|
||||
\"""An enumeration.\"""
|
||||
enum TestsReporterAChoiceChoices {
|
||||
\"""this\"""
|
||||
A_1
|
||||
|
||||
\"""that\"""
|
||||
A_2
|
||||
}
|
||||
|
||||
\"""An enumeration.\"""
|
||||
enum TestsReporterReporterTypeChoices {
|
||||
\"""Regular\"""
|
||||
A_1
|
||||
|
||||
\"""CNN Reporter\"""
|
||||
A_2
|
||||
}
|
||||
|
||||
type ArticleConnection {
|
||||
\"""Pagination data for this connection.\"""
|
||||
pageInfo: PageInfo!
|
||||
|
||||
\"""Contains the nodes in this connection.\"""
|
||||
edges: [ArticleEdge]!
|
||||
test: String
|
||||
}
|
||||
|
||||
\"""
|
||||
The Relay compliant `PageInfo` type, containing data necessary to paginate this connection.
|
||||
\"""
|
||||
type PageInfo {
|
||||
\"""When paginating forwards, are there more items?\"""
|
||||
hasNextPage: Boolean!
|
||||
|
||||
\"""When paginating backwards, are there more items?\"""
|
||||
hasPreviousPage: Boolean!
|
||||
|
||||
\"""When paginating backwards, the cursor to continue.\"""
|
||||
startCursor: String
|
||||
|
||||
\"""When paginating forwards, the cursor to continue.\"""
|
||||
endCursor: String
|
||||
}
|
||||
|
||||
\"""A Relay edge containing a `Article` and its cursor.\"""
|
||||
type ArticleEdge {
|
||||
\"""The item at the end of the edge\"""
|
||||
node: Article
|
||||
|
||||
\"""A cursor for use in pagination\"""
|
||||
cursor: String!
|
||||
}
|
||||
|
||||
type RootQuery {
|
||||
node(
|
||||
\"""The ID of the object\"""
|
||||
id: ID!
|
||||
): Node
|
||||
}
|
||||
"""
|
||||
)
|
||||
assert str(schema) == expected
|
||||
|
||||
|
||||
def with_local_registry(func):
|
||||
def inner(*args, **kwargs):
|
||||
old = registry.get_global_registry()
|
||||
try:
|
||||
retval = func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
registry.registry = old
|
||||
raise e
|
||||
else:
|
||||
registry.registry = old
|
||||
return retval
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
@with_local_registry
|
||||
def test_django_objecttype_only_fields():
|
||||
with pytest.warns(DeprecationWarning):
|
||||
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
only_fields = ("id", "email", "films")
|
||||
|
||||
fields = list(Reporter._meta.fields.keys())
|
||||
assert fields == ["id", "email", "films"]
|
||||
|
||||
|
||||
@with_local_registry
|
||||
def test_django_objecttype_fields():
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
fields = ("id", "email", "films")
|
||||
|
||||
fields = list(Reporter._meta.fields.keys())
|
||||
assert fields == ["id", "email", "films"]
|
||||
|
||||
|
||||
@with_local_registry
|
||||
def test_django_objecttype_fields_empty():
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
fields = ()
|
||||
|
||||
fields = list(Reporter._meta.fields.keys())
|
||||
assert fields == []
|
||||
|
||||
|
||||
@with_local_registry
|
||||
def test_django_objecttype_only_fields_and_fields():
|
||||
with pytest.raises(Exception):
|
||||
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
only_fields = ("id", "email", "films")
|
||||
fields = ("id", "email", "films")
|
||||
|
||||
|
||||
@with_local_registry
|
||||
def test_django_objecttype_all_fields():
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
fields = "__all__"
|
||||
|
||||
fields = list(Reporter._meta.fields.keys())
|
||||
assert len(fields) == len(ReporterModel._meta.get_fields())
|
||||
|
||||
|
||||
@with_local_registry
|
||||
def test_django_objecttype_exclude_fields():
|
||||
with pytest.warns(DeprecationWarning):
|
||||
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
exclude_fields = ["email"]
|
||||
|
||||
fields = list(Reporter._meta.fields.keys())
|
||||
assert "email" not in fields
|
||||
|
||||
|
||||
@with_local_registry
|
||||
def test_django_objecttype_exclude():
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
exclude = ["email"]
|
||||
|
||||
fields = list(Reporter._meta.fields.keys())
|
||||
assert "email" not in fields
|
||||
|
||||
|
||||
@with_local_registry
|
||||
def test_django_objecttype_exclude_fields_and_exclude():
|
||||
with pytest.raises(Exception):
|
||||
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
exclude = ["email"]
|
||||
exclude_fields = ["email"]
|
||||
|
||||
|
||||
@with_local_registry
|
||||
def test_django_objecttype_exclude_and_only():
|
||||
with pytest.raises(AssertionError):
|
||||
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
exclude = ["email"]
|
||||
fields = ["id"]
|
||||
|
||||
|
||||
@with_local_registry
|
||||
def test_django_objecttype_fields_exclude_type_checking():
|
||||
with pytest.raises(TypeError):
|
||||
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
fields = "foo"
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
|
||||
class Reporter2(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
exclude = "foo"
|
||||
|
||||
|
||||
@with_local_registry
|
||||
def test_django_objecttype_fields_exist_on_model():
|
||||
with pytest.warns(UserWarning, match=r"Field name .* doesn't exist"):
|
||||
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
fields = ["first_name", "foo", "email"]
|
||||
|
||||
with pytest.warns(
|
||||
UserWarning,
|
||||
match=r"Field name .* matches an attribute on Django model .* but it's not a model field",
|
||||
) as record:
|
||||
|
||||
class Reporter2(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
fields = ["first_name", "some_method", "email"]
|
||||
|
||||
# Don't warn if selecting a custom field
|
||||
with pytest.warns(None) as record:
|
||||
|
||||
class Reporter3(DjangoObjectType):
|
||||
custom_field = String()
|
||||
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
fields = ["first_name", "custom_field", "email"]
|
||||
|
||||
assert len(record) == 0
|
||||
|
||||
|
||||
@with_local_registry
|
||||
def test_django_objecttype_exclude_fields_exist_on_model():
|
||||
with pytest.warns(
|
||||
UserWarning,
|
||||
match=r"Django model .* does not have a field or attribute named .*",
|
||||
):
|
||||
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
exclude = ["foo"]
|
||||
|
||||
# Don't warn if selecting a custom field
|
||||
with pytest.warns(
|
||||
UserWarning,
|
||||
match=r"Excluding the custom field .* on DjangoObjectType .* has no effect.",
|
||||
):
|
||||
|
||||
class Reporter3(DjangoObjectType):
|
||||
custom_field = String()
|
||||
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
exclude = ["custom_field"]
|
||||
|
||||
# Don't warn on exclude fields
|
||||
with pytest.warns(None) as record:
|
||||
|
||||
class Reporter4(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
exclude = ["email", "first_name"]
|
||||
|
||||
assert len(record) == 0
|
||||
|
||||
|
||||
@with_local_registry
|
||||
def test_django_objecttype_neither_fields_nor_exclude():
|
||||
with pytest.warns(
|
||||
DeprecationWarning,
|
||||
match=r"Creating a DjangoObjectType without either the `fields` "
|
||||
"or the `exclude` option is deprecated.",
|
||||
):
|
||||
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
|
||||
with pytest.warns(None) as record:
|
||||
|
||||
class Reporter2(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
fields = ["email"]
|
||||
|
||||
assert len(record) == 0
|
||||
|
||||
with pytest.warns(None) as record:
|
||||
|
||||
class Reporter3(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
exclude = ["email"]
|
||||
|
||||
assert len(record) == 0
|
||||
|
||||
|
||||
def custom_enum_name(field):
|
||||
return "CustomEnum{}".format(field.name.title())
|
||||
|
||||
|
||||
class TestDjangoObjectType:
|
||||
@pytest.fixture
|
||||
def PetModel(self):
|
||||
class PetModel(models.Model):
|
||||
kind = models.CharField(choices=(("cat", "Cat"), ("dog", "Dog")))
|
||||
cuteness = models.IntegerField(
|
||||
choices=((1, "Kind of cute"), (2, "Pretty cute"), (3, "OMG SO CUTE!!!"))
|
||||
)
|
||||
|
||||
yield PetModel
|
||||
|
||||
# Clear Django model cache so we don't get warnings when creating the
|
||||
# model multiple times
|
||||
PetModel._meta.apps.all_models = defaultdict(OrderedDict)
|
||||
|
||||
def test_django_objecttype_convert_choices_enum_false(self, PetModel):
|
||||
class Pet(DjangoObjectType):
|
||||
class Meta:
|
||||
model = PetModel
|
||||
convert_choices_to_enum = False
|
||||
fields = "__all__"
|
||||
|
||||
class Query(ObjectType):
|
||||
pet = Field(Pet)
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
assert str(schema) == dedent(
|
||||
"""\
|
||||
type Query {
|
||||
pet: Pet
|
||||
}
|
||||
|
||||
type Pet {
|
||||
id: ID!
|
||||
kind: String!
|
||||
cuteness: Int!
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
def test_django_objecttype_convert_choices_enum_list(self, PetModel):
|
||||
class Pet(DjangoObjectType):
|
||||
class Meta:
|
||||
model = PetModel
|
||||
convert_choices_to_enum = ["kind"]
|
||||
fields = "__all__"
|
||||
|
||||
class Query(ObjectType):
|
||||
pet = Field(Pet)
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
assert str(schema) == dedent(
|
||||
"""\
|
||||
type Query {
|
||||
pet: Pet
|
||||
}
|
||||
|
||||
type Pet {
|
||||
id: ID!
|
||||
kind: TestsPetModelKindChoices!
|
||||
cuteness: Int!
|
||||
}
|
||||
|
||||
\"""An enumeration.\"""
|
||||
enum TestsPetModelKindChoices {
|
||||
\"""Cat\"""
|
||||
CAT
|
||||
|
||||
\"""Dog\"""
|
||||
DOG
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
def test_django_objecttype_convert_choices_enum_empty_list(self, PetModel):
|
||||
class Pet(DjangoObjectType):
|
||||
class Meta:
|
||||
model = PetModel
|
||||
convert_choices_to_enum = []
|
||||
fields = "__all__"
|
||||
|
||||
class Query(ObjectType):
|
||||
pet = Field(Pet)
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
assert str(schema) == dedent(
|
||||
"""\
|
||||
type Query {
|
||||
pet: Pet
|
||||
}
|
||||
|
||||
type Pet {
|
||||
id: ID!
|
||||
kind: String!
|
||||
cuteness: Int!
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
def test_django_objecttype_convert_choices_enum_naming_collisions(
|
||||
self, PetModel, graphene_settings
|
||||
):
|
||||
class PetModelKind(DjangoObjectType):
|
||||
class Meta:
|
||||
model = PetModel
|
||||
fields = ["id", "kind"]
|
||||
|
||||
class Query(ObjectType):
|
||||
pet = Field(PetModelKind)
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
assert str(schema) == dedent(
|
||||
"""\
|
||||
type Query {
|
||||
pet: PetModelKind
|
||||
}
|
||||
|
||||
type PetModelKind {
|
||||
id: ID!
|
||||
kind: TestsPetModelKindChoices!
|
||||
}
|
||||
|
||||
\"""An enumeration.\"""
|
||||
enum TestsPetModelKindChoices {
|
||||
\"""Cat\"""
|
||||
CAT
|
||||
|
||||
\"""Dog\"""
|
||||
DOG
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
def test_django_objecttype_choices_custom_enum_name(
|
||||
self, PetModel, graphene_settings
|
||||
):
|
||||
graphene_settings.DJANGO_CHOICE_FIELD_ENUM_CUSTOM_NAME = (
|
||||
"graphene_django.tests.test_types.custom_enum_name"
|
||||
)
|
||||
|
||||
class PetModelKind(DjangoObjectType):
|
||||
class Meta:
|
||||
model = PetModel
|
||||
fields = ["id", "kind"]
|
||||
|
||||
class Query(ObjectType):
|
||||
pet = Field(PetModelKind)
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
assert str(schema) == dedent(
|
||||
"""\
|
||||
type Query {
|
||||
pet: PetModelKind
|
||||
}
|
||||
|
||||
type PetModelKind {
|
||||
id: ID!
|
||||
kind: CustomEnumKind!
|
||||
}
|
||||
|
||||
\"""An enumeration.\"""
|
||||
enum CustomEnumKind {
|
||||
\"""Cat\"""
|
||||
CAT
|
||||
|
||||
\"""Dog\"""
|
||||
DOG
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
@with_local_registry
|
||||
def test_django_objecttype_name_connection_propagation():
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
name = "CustomReporterName"
|
||||
fields = "__all__"
|
||||
filter_fields = ["email"]
|
||||
interfaces = (Node,)
|
||||
|
||||
class Query(ObjectType):
|
||||
reporter = Node.Field(Reporter)
|
||||
reporters = DjangoFilterConnectionField(Reporter)
|
||||
|
||||
assert Reporter._meta.name == "CustomReporterName"
|
||||
schema = str(Schema(query=Query))
|
||||
|
||||
assert "type CustomReporterName implements Node {" in schema
|
||||
assert "type CustomReporterNameConnection {" in schema
|
||||
assert "type CustomReporterNameEdge {" in schema
|
||||
|
||||
assert "type Reporter implements Node {" not in schema
|
||||
assert "type ReporterConnection {" not in schema
|
||||
assert "type ReporterEdge {" not in schema
|
|
@ -1,88 +0,0 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
from django.utils.translation import gettext_lazy
|
||||
from mock import patch
|
||||
|
||||
from ..utils import camelize, get_model_fields, GraphQLTestCase
|
||||
from .models import Film, Reporter
|
||||
from ..utils.testing import graphql_query
|
||||
|
||||
|
||||
def test_get_model_fields_no_duplication():
|
||||
reporter_fields = get_model_fields(Reporter)
|
||||
reporter_name_set = set([field[0] for field in reporter_fields])
|
||||
assert len(reporter_fields) == len(reporter_name_set)
|
||||
|
||||
film_fields = get_model_fields(Film)
|
||||
film_name_set = set([field[0] for field in film_fields])
|
||||
assert len(film_fields) == len(film_name_set)
|
||||
|
||||
|
||||
def test_camelize():
|
||||
assert camelize({}) == {}
|
||||
assert camelize("value_a") == "value_a"
|
||||
assert camelize({"value_a": "value_b"}) == {"valueA": "value_b"}
|
||||
assert camelize({"value_a": ["value_b"]}) == {"valueA": ["value_b"]}
|
||||
assert camelize({"value_a": ["value_b"]}) == {"valueA": ["value_b"]}
|
||||
assert camelize({"nested_field": {"value_a": ["error"], "value_b": ["error"]}}) == {
|
||||
"nestedField": {"valueA": ["error"], "valueB": ["error"]}
|
||||
}
|
||||
assert camelize({"value_a": gettext_lazy("value_b")}) == {"valueA": "value_b"}
|
||||
assert camelize({"value_a": [gettext_lazy("value_b")]}) == {"valueA": ["value_b"]}
|
||||
assert camelize(gettext_lazy("value_a")) == "value_a"
|
||||
assert camelize({gettext_lazy("value_a"): gettext_lazy("value_b")}) == {
|
||||
"valueA": "value_b"
|
||||
}
|
||||
assert camelize({0: {"field_a": ["errors"]}}) == {0: {"fieldA": ["errors"]}}
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("graphene_django.utils.testing.Client.post")
|
||||
def test_graphql_test_case_operation_name(post_mock):
|
||||
"""
|
||||
Test that `GraphQLTestCase.query()`'s `operation_name` argument produces an `operationName` field.
|
||||
"""
|
||||
|
||||
class TestClass(GraphQLTestCase):
|
||||
GRAPHQL_SCHEMA = True
|
||||
|
||||
def runTest(self):
|
||||
pass
|
||||
|
||||
tc = TestClass()
|
||||
tc._pre_setup()
|
||||
tc.setUpClass()
|
||||
tc.query("query { }", operation_name="QueryName")
|
||||
body = json.loads(post_mock.call_args.args[1])
|
||||
# `operationName` field from https://graphql.org/learn/serving-over-http/#post-request
|
||||
assert (
|
||||
"operationName",
|
||||
"QueryName",
|
||||
) in body.items(), "Field 'operationName' is not present in the final request."
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("graphene_django.utils.testing.Client.post")
|
||||
def test_graphql_query_case_operation_name(post_mock):
|
||||
graphql_query("query { }", operation_name="QueryName")
|
||||
body = json.loads(post_mock.call_args.args[1])
|
||||
# `operationName` field from https://graphql.org/learn/serving-over-http/#post-request
|
||||
assert (
|
||||
"operationName",
|
||||
"QueryName",
|
||||
) in body.items(), "Field 'operationName' is not present in the final request."
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client_query(client):
|
||||
def func(*args, **kwargs):
|
||||
return graphql_query(*args, client=client, **kwargs)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def test_pytest_fixture_usage(client_query):
|
||||
response = client_query("query { test }")
|
||||
content = json.loads(response.content)
|
||||
assert content == {"data": {"test": "Hello World"}}
|
|
@ -1,834 +0,0 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from mock import patch
|
||||
|
||||
from django.db import connection
|
||||
|
||||
from graphene_django.settings import graphene_settings
|
||||
|
||||
from .models import Pet
|
||||
|
||||
try:
|
||||
from urllib import urlencode
|
||||
except ImportError:
|
||||
from urllib.parse import urlencode
|
||||
|
||||
|
||||
def url_string(string="/graphql", **url_params):
|
||||
if url_params:
|
||||
string += "?" + urlencode(url_params)
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def batch_url_string(**url_params):
|
||||
return url_string("/graphql/batch", **url_params)
|
||||
|
||||
|
||||
def response_json(response):
|
||||
return json.loads(response.content.decode())
|
||||
|
||||
|
||||
j = lambda **kwargs: json.dumps(kwargs)
|
||||
jl = lambda **kwargs: json.dumps([kwargs])
|
||||
|
||||
|
||||
def test_graphiql_is_enabled(client):
|
||||
response = client.get(url_string(), HTTP_ACCEPT="text/html")
|
||||
assert response.status_code == 200
|
||||
assert response["Content-Type"].split(";")[0] == "text/html"
|
||||
|
||||
|
||||
def test_qfactor_graphiql(client):
|
||||
response = client.get(
|
||||
url_string(query="{test}"),
|
||||
HTTP_ACCEPT="application/json;q=0.8, text/html;q=0.9",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response["Content-Type"].split(";")[0] == "text/html"
|
||||
|
||||
|
||||
def test_qfactor_json(client):
|
||||
response = client.get(
|
||||
url_string(query="{test}"),
|
||||
HTTP_ACCEPT="text/html;q=0.8, application/json;q=0.9",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response["Content-Type"].split(";")[0] == "application/json"
|
||||
assert response_json(response) == {"data": {"test": "Hello World"}}
|
||||
|
||||
|
||||
def test_allows_get_with_query_param(client):
|
||||
response = client.get(url_string(query="{test}"))
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response_json(response) == {"data": {"test": "Hello World"}}
|
||||
|
||||
|
||||
def test_allows_get_with_variable_values(client):
|
||||
response = client.get(
|
||||
url_string(
|
||||
query="query helloWho($who: String){ test(who: $who) }",
|
||||
variables=json.dumps({"who": "Dolly"}),
|
||||
)
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
|
||||
|
||||
|
||||
def test_allows_get_with_operation_name(client):
|
||||
response = client.get(
|
||||
url_string(
|
||||
query="""
|
||||
query helloYou { test(who: "You"), ...shared }
|
||||
query helloWorld { test(who: "World"), ...shared }
|
||||
query helloDolly { test(who: "Dolly"), ...shared }
|
||||
fragment shared on QueryRoot {
|
||||
shared: test(who: "Everyone")
|
||||
}
|
||||
""",
|
||||
operationName="helloWorld",
|
||||
)
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response_json(response) == {
|
||||
"data": {"test": "Hello World", "shared": "Hello Everyone"}
|
||||
}
|
||||
|
||||
|
||||
def test_reports_validation_errors(client):
|
||||
response = client.get(url_string(query="{ test, unknownOne, unknownTwo }"))
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response_json(response) == {
|
||||
"errors": [
|
||||
{
|
||||
"message": "Cannot query field 'unknownOne' on type 'QueryRoot'.",
|
||||
"locations": [{"line": 1, "column": 9}],
|
||||
"path": None,
|
||||
},
|
||||
{
|
||||
"message": "Cannot query field 'unknownTwo' on type 'QueryRoot'.",
|
||||
"locations": [{"line": 1, "column": 21}],
|
||||
"path": None,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def test_errors_when_missing_operation_name(client):
|
||||
response = client.get(
|
||||
url_string(
|
||||
query="""
|
||||
query TestQuery { test }
|
||||
mutation TestMutation { writeTest { test } }
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response_json(response) == {
|
||||
"errors": [
|
||||
{
|
||||
"message": "Must provide operation name if query contains multiple operations.",
|
||||
"locations": None,
|
||||
"path": None,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def test_errors_when_sending_a_mutation_via_get(client):
|
||||
response = client.get(
|
||||
url_string(
|
||||
query="""
|
||||
mutation TestMutation { writeTest { test } }
|
||||
"""
|
||||
)
|
||||
)
|
||||
assert response.status_code == 405
|
||||
assert response_json(response) == {
|
||||
"errors": [
|
||||
{"message": "Can only perform a mutation operation from a POST request."}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def test_errors_when_selecting_a_mutation_within_a_get(client):
|
||||
response = client.get(
|
||||
url_string(
|
||||
query="""
|
||||
query TestQuery { test }
|
||||
mutation TestMutation { writeTest { test } }
|
||||
""",
|
||||
operationName="TestMutation",
|
||||
)
|
||||
)
|
||||
|
||||
assert response.status_code == 405
|
||||
assert response_json(response) == {
|
||||
"errors": [
|
||||
{"message": "Can only perform a mutation operation from a POST request."}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def test_allows_mutation_to_exist_within_a_get(client):
|
||||
response = client.get(
|
||||
url_string(
|
||||
query="""
|
||||
query TestQuery { test }
|
||||
mutation TestMutation { writeTest { test } }
|
||||
""",
|
||||
operationName="TestQuery",
|
||||
)
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response_json(response) == {"data": {"test": "Hello World"}}
|
||||
|
||||
|
||||
def test_allows_post_with_json_encoding(client):
|
||||
response = client.post(url_string(), j(query="{test}"), "application/json")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response_json(response) == {"data": {"test": "Hello World"}}
|
||||
|
||||
|
||||
def test_batch_allows_post_with_json_encoding(client):
|
||||
response = client.post(
|
||||
batch_url_string(), jl(id=1, query="{test}"), "application/json"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response_json(response) == [
|
||||
{"id": 1, "data": {"test": "Hello World"}, "status": 200}
|
||||
]
|
||||
|
||||
|
||||
def test_batch_fails_if_is_empty(client):
|
||||
response = client.post(batch_url_string(), "[]", "application/json")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response_json(response) == {
|
||||
"errors": [{"message": "Received an empty list in the batch request."}]
|
||||
}
|
||||
|
||||
|
||||
def test_allows_sending_a_mutation_via_post(client):
|
||||
response = client.post(
|
||||
url_string(),
|
||||
j(query="mutation TestMutation { writeTest { test } }"),
|
||||
"application/json",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response_json(response) == {"data": {"writeTest": {"test": "Hello World"}}}
|
||||
|
||||
|
||||
def test_allows_post_with_url_encoding(client):
|
||||
response = client.post(
|
||||
url_string(),
|
||||
urlencode(dict(query="{test}")),
|
||||
"application/x-www-form-urlencoded",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response_json(response) == {"data": {"test": "Hello World"}}
|
||||
|
||||
|
||||
def test_supports_post_json_query_with_string_variables(client):
|
||||
response = client.post(
|
||||
url_string(),
|
||||
j(
|
||||
query="query helloWho($who: String){ test(who: $who) }",
|
||||
variables=json.dumps({"who": "Dolly"}),
|
||||
),
|
||||
"application/json",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
|
||||
|
||||
|
||||
def test_batch_supports_post_json_query_with_string_variables(client):
|
||||
response = client.post(
|
||||
batch_url_string(),
|
||||
jl(
|
||||
id=1,
|
||||
query="query helloWho($who: String){ test(who: $who) }",
|
||||
variables=json.dumps({"who": "Dolly"}),
|
||||
),
|
||||
"application/json",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response_json(response) == [
|
||||
{"id": 1, "data": {"test": "Hello Dolly"}, "status": 200}
|
||||
]
|
||||
|
||||
|
||||
def test_supports_post_json_query_with_json_variables(client):
|
||||
response = client.post(
|
||||
url_string(),
|
||||
j(
|
||||
query="query helloWho($who: String){ test(who: $who) }",
|
||||
variables={"who": "Dolly"},
|
||||
),
|
||||
"application/json",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
|
||||
|
||||
|
||||
def test_batch_supports_post_json_query_with_json_variables(client):
|
||||
response = client.post(
|
||||
batch_url_string(),
|
||||
jl(
|
||||
id=1,
|
||||
query="query helloWho($who: String){ test(who: $who) }",
|
||||
variables={"who": "Dolly"},
|
||||
),
|
||||
"application/json",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response_json(response) == [
|
||||
{"id": 1, "data": {"test": "Hello Dolly"}, "status": 200}
|
||||
]
|
||||
|
||||
|
||||
def test_supports_post_url_encoded_query_with_string_variables(client):
|
||||
response = client.post(
|
||||
url_string(),
|
||||
urlencode(
|
||||
dict(
|
||||
query="query helloWho($who: String){ test(who: $who) }",
|
||||
variables=json.dumps({"who": "Dolly"}),
|
||||
)
|
||||
),
|
||||
"application/x-www-form-urlencoded",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
|
||||
|
||||
|
||||
def test_supports_post_json_quey_with_get_variable_values(client):
|
||||
response = client.post(
|
||||
url_string(variables=json.dumps({"who": "Dolly"})),
|
||||
j(query="query helloWho($who: String){ test(who: $who) }"),
|
||||
"application/json",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
|
||||
|
||||
|
||||
def test_post_url_encoded_query_with_get_variable_values(client):
|
||||
response = client.post(
|
||||
url_string(variables=json.dumps({"who": "Dolly"})),
|
||||
urlencode(dict(query="query helloWho($who: String){ test(who: $who) }")),
|
||||
"application/x-www-form-urlencoded",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
|
||||
|
||||
|
||||
def test_supports_post_raw_text_query_with_get_variable_values(client):
|
||||
response = client.post(
|
||||
url_string(variables=json.dumps({"who": "Dolly"})),
|
||||
"query helloWho($who: String){ test(who: $who) }",
|
||||
"application/graphql",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
|
||||
|
||||
|
||||
def test_allows_post_with_operation_name(client):
|
||||
response = client.post(
|
||||
url_string(),
|
||||
j(
|
||||
query="""
|
||||
query helloYou { test(who: "You"), ...shared }
|
||||
query helloWorld { test(who: "World"), ...shared }
|
||||
query helloDolly { test(who: "Dolly"), ...shared }
|
||||
fragment shared on QueryRoot {
|
||||
shared: test(who: "Everyone")
|
||||
}
|
||||
""",
|
||||
operationName="helloWorld",
|
||||
),
|
||||
"application/json",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response_json(response) == {
|
||||
"data": {"test": "Hello World", "shared": "Hello Everyone"}
|
||||
}
|
||||
|
||||
|
||||
def test_batch_allows_post_with_operation_name(client):
|
||||
response = client.post(
|
||||
batch_url_string(),
|
||||
jl(
|
||||
id=1,
|
||||
query="""
|
||||
query helloYou { test(who: "You"), ...shared }
|
||||
query helloWorld { test(who: "World"), ...shared }
|
||||
query helloDolly { test(who: "Dolly"), ...shared }
|
||||
fragment shared on QueryRoot {
|
||||
shared: test(who: "Everyone")
|
||||
}
|
||||
""",
|
||||
operationName="helloWorld",
|
||||
),
|
||||
"application/json",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response_json(response) == [
|
||||
{
|
||||
"id": 1,
|
||||
"data": {"test": "Hello World", "shared": "Hello Everyone"},
|
||||
"status": 200,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_allows_post_with_get_operation_name(client):
|
||||
response = client.post(
|
||||
url_string(operationName="helloWorld"),
|
||||
"""
|
||||
query helloYou { test(who: "You"), ...shared }
|
||||
query helloWorld { test(who: "World"), ...shared }
|
||||
query helloDolly { test(who: "Dolly"), ...shared }
|
||||
fragment shared on QueryRoot {
|
||||
shared: test(who: "Everyone")
|
||||
}
|
||||
""",
|
||||
"application/graphql",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response_json(response) == {
|
||||
"data": {"test": "Hello World", "shared": "Hello Everyone"}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.urls("graphene_django.tests.urls_inherited")
|
||||
def test_inherited_class_with_attributes_works(client):
|
||||
inherited_url = "/graphql/inherited/"
|
||||
# Check schema and pretty attributes work
|
||||
response = client.post(url_string(inherited_url, query="{test}"))
|
||||
assert response.content.decode() == (
|
||||
"{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}"
|
||||
)
|
||||
|
||||
# Check graphiql works
|
||||
response = client.get(url_string(inherited_url), HTTP_ACCEPT="text/html")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.urls("graphene_django.tests.urls_pretty")
|
||||
def test_supports_pretty_printing(client):
|
||||
response = client.get(url_string(query="{test}"))
|
||||
|
||||
assert response.content.decode() == (
|
||||
"{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}"
|
||||
)
|
||||
|
||||
|
||||
def test_supports_pretty_printing_by_request(client):
|
||||
response = client.get(url_string(query="{test}", pretty="1"))
|
||||
|
||||
assert response.content.decode() == (
|
||||
"{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}"
|
||||
)
|
||||
|
||||
|
||||
def test_handles_field_errors_caught_by_graphql(client):
|
||||
response = client.get(url_string(query="{thrower}"))
|
||||
assert response.status_code == 200
|
||||
assert response_json(response) == {
|
||||
"data": None,
|
||||
"errors": [
|
||||
{
|
||||
"locations": [{"column": 2, "line": 1}],
|
||||
"path": ["thrower"],
|
||||
"message": "Throws!",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_handles_syntax_errors_caught_by_graphql(client):
|
||||
response = client.get(url_string(query="syntaxerror"))
|
||||
assert response.status_code == 400
|
||||
assert response_json(response) == {
|
||||
"errors": [
|
||||
{
|
||||
"locations": [{"column": 1, "line": 1}],
|
||||
"message": "Syntax Error: Unexpected Name 'syntaxerror'.",
|
||||
"path": None,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def test_handles_errors_caused_by_a_lack_of_query(client):
|
||||
response = client.get(url_string())
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response_json(response) == {
|
||||
"errors": [{"message": "Must provide query string."}]
|
||||
}
|
||||
|
||||
|
||||
def test_handles_not_expected_json_bodies(client):
|
||||
response = client.post(url_string(), "[]", "application/json")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response_json(response) == {
|
||||
"errors": [{"message": "The received data is not a valid JSON query."}]
|
||||
}
|
||||
|
||||
|
||||
def test_handles_invalid_json_bodies(client):
|
||||
response = client.post(url_string(), "[oh}", "application/json")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response_json(response) == {
|
||||
"errors": [{"message": "POST body sent invalid JSON."}]
|
||||
}
|
||||
|
||||
|
||||
def test_handles_django_request_error(client, monkeypatch):
|
||||
def mocked_read(*args):
|
||||
raise IOError("foo-bar")
|
||||
|
||||
monkeypatch.setattr("django.http.request.HttpRequest.read", mocked_read)
|
||||
|
||||
valid_json = json.dumps(dict(foo="bar"))
|
||||
response = client.post(url_string(), valid_json, "application/json")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response_json(response) == {"errors": [{"message": "foo-bar"}]}
|
||||
|
||||
|
||||
def test_handles_incomplete_json_bodies(client):
|
||||
response = client.post(url_string(), '{"query":', "application/json")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response_json(response) == {
|
||||
"errors": [{"message": "POST body sent invalid JSON."}]
|
||||
}
|
||||
|
||||
|
||||
def test_handles_plain_post_text(client):
|
||||
response = client.post(
|
||||
url_string(variables=json.dumps({"who": "Dolly"})),
|
||||
"query helloWho($who: String){ test(who: $who) }",
|
||||
"text/plain",
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert response_json(response) == {
|
||||
"errors": [{"message": "Must provide query string."}]
|
||||
}
|
||||
|
||||
|
||||
def test_handles_poorly_formed_variables(client):
|
||||
response = client.get(
|
||||
url_string(
|
||||
query="query helloWho($who: String){ test(who: $who) }", variables="who:You"
|
||||
)
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert response_json(response) == {
|
||||
"errors": [{"message": "Variables are invalid JSON."}]
|
||||
}
|
||||
|
||||
|
||||
def test_handles_unsupported_http_methods(client):
|
||||
response = client.put(url_string(query="{test}"))
|
||||
assert response.status_code == 405
|
||||
assert response["Allow"] == "GET, POST"
|
||||
assert response_json(response) == {
|
||||
"errors": [{"message": "GraphQL only supports GET and POST requests."}]
|
||||
}
|
||||
|
||||
|
||||
def test_passes_request_into_context_request(client):
|
||||
response = client.get(url_string(query="{request}", q="testing"))
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response_json(response) == {"data": {"request": "testing"}}
|
||||
|
||||
|
||||
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", False)
|
||||
@patch.dict(
|
||||
connection.settings_dict, {"ATOMIC_MUTATIONS": False, "ATOMIC_REQUESTS": True}
|
||||
)
|
||||
def test_form_mutation_multiple_creation_invalid_atomic_request(client):
|
||||
query = """
|
||||
mutation PetMutations {
|
||||
petFormMutation1: petFormMutation(input: { name: "Mia", age: 99 }) {
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
petFormMutation2: petFormMutation(input: { name: "Enzo", age: 0 }) {
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
response = client.post(url_string(query=query))
|
||||
content = response_json(response)
|
||||
|
||||
assert "errors" not in content
|
||||
|
||||
assert content["data"]["petFormMutation1"]["errors"] == [
|
||||
{"field": "age", "messages": ["Too old"]}
|
||||
]
|
||||
|
||||
assert content["data"]["petFormMutation2"]["errors"] == []
|
||||
|
||||
assert Pet.objects.count() == 0
|
||||
|
||||
|
||||
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", False)
|
||||
@patch.dict(
|
||||
connection.settings_dict, {"ATOMIC_MUTATIONS": True, "ATOMIC_REQUESTS": False}
|
||||
)
|
||||
def test_form_mutation_multiple_creation_invalid_atomic_mutation_1(client):
|
||||
query = """
|
||||
mutation PetMutations {
|
||||
petFormMutation1: petFormMutation(input: { name: "Mia", age: 99 }) {
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
petFormMutation2: petFormMutation(input: { name: "Enzo", age: 0 }) {
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
response = client.post(url_string(query=query))
|
||||
content = response_json(response)
|
||||
|
||||
assert "errors" not in content
|
||||
|
||||
assert content["data"]["petFormMutation1"]["errors"] == [
|
||||
{"field": "age", "messages": ["Too old"]}
|
||||
]
|
||||
|
||||
assert content["data"]["petFormMutation2"]["errors"] == []
|
||||
|
||||
assert Pet.objects.count() == 0
|
||||
|
||||
|
||||
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", True)
|
||||
@patch.dict(
|
||||
connection.settings_dict, {"ATOMIC_MUTATIONS": False, "ATOMIC_REQUESTS": False}
|
||||
)
|
||||
def test_form_mutation_multiple_creation_invalid_atomic_mutation_2(client):
|
||||
query = """
|
||||
mutation PetMutations {
|
||||
petFormMutation1: petFormMutation(input: { name: "Mia", age: 99 }) {
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
petFormMutation2: petFormMutation(input: { name: "Enzo", age: 0 }) {
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
response = client.post(url_string(query=query))
|
||||
content = response_json(response)
|
||||
|
||||
assert "errors" not in content
|
||||
|
||||
assert content["data"]["petFormMutation1"]["errors"] == [
|
||||
{"field": "age", "messages": ["Too old"]}
|
||||
]
|
||||
|
||||
assert content["data"]["petFormMutation2"]["errors"] == []
|
||||
|
||||
assert Pet.objects.count() == 0
|
||||
|
||||
|
||||
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", False)
|
||||
@patch.dict(
|
||||
connection.settings_dict, {"ATOMIC_MUTATIONS": False, "ATOMIC_REQUESTS": False}
|
||||
)
|
||||
def test_form_mutation_multiple_creation_invalid_non_atomic(client):
|
||||
query = """
|
||||
mutation PetMutations {
|
||||
petFormMutation1: petFormMutation(input: { name: "Mia", age: 99 }) {
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
petFormMutation2: petFormMutation(input: { name: "Enzo", age: 0 }) {
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
response = client.post(url_string(query=query))
|
||||
content = response_json(response)
|
||||
|
||||
assert "errors" not in content
|
||||
|
||||
assert content["data"]["petFormMutation1"]["errors"] == [
|
||||
{"field": "age", "messages": ["Too old"]}
|
||||
]
|
||||
|
||||
assert content["data"]["petFormMutation2"]["errors"] == []
|
||||
|
||||
assert Pet.objects.count() == 1
|
||||
|
||||
pet = Pet.objects.get()
|
||||
assert pet.name == "Enzo"
|
||||
assert pet.age == 0
|
||||
|
||||
|
||||
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", False)
|
||||
@patch.dict(
|
||||
connection.settings_dict, {"ATOMIC_MUTATIONS": False, "ATOMIC_REQUESTS": True}
|
||||
)
|
||||
def test_model_form_mutation_multiple_creation_invalid_atomic_request(client):
|
||||
query = """
|
||||
mutation PetMutations {
|
||||
petMutation1: petMutation(input: { name: "Mia", age: 99 }) {
|
||||
pet {
|
||||
name
|
||||
age
|
||||
}
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
petMutation2: petMutation(input: { name: "Enzo", age: 0 }) {
|
||||
pet {
|
||||
name
|
||||
age
|
||||
}
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
response = client.post(url_string(query=query))
|
||||
content = response_json(response)
|
||||
|
||||
assert "errors" not in content
|
||||
|
||||
assert content["data"]["petMutation1"]["pet"] is None
|
||||
assert content["data"]["petMutation1"]["errors"] == [
|
||||
{"field": "age", "messages": ["Too old"]}
|
||||
]
|
||||
|
||||
assert content["data"]["petMutation2"]["pet"] == {"name": "Enzo", "age": 0}
|
||||
|
||||
assert Pet.objects.count() == 0
|
||||
|
||||
|
||||
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", False)
|
||||
@patch.dict(
|
||||
connection.settings_dict, {"ATOMIC_MUTATIONS": False, "ATOMIC_REQUESTS": False}
|
||||
)
|
||||
def test_model_form_mutation_multiple_creation_invalid_non_atomic(client):
|
||||
query = """
|
||||
mutation PetMutations {
|
||||
petMutation1: petMutation(input: { name: "Mia", age: 99 }) {
|
||||
pet {
|
||||
name
|
||||
age
|
||||
}
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
petMutation2: petMutation(input: { name: "Enzo", age: 0 }) {
|
||||
pet {
|
||||
name
|
||||
age
|
||||
}
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
response = client.post(url_string(query=query))
|
||||
content = response_json(response)
|
||||
|
||||
assert "errors" not in content
|
||||
|
||||
assert content["data"]["petMutation1"]["pet"] is None
|
||||
assert content["data"]["petMutation1"]["errors"] == [
|
||||
{"field": "age", "messages": ["Too old"]}
|
||||
]
|
||||
|
||||
assert content["data"]["petMutation2"]["pet"] == {"name": "Enzo", "age": 0}
|
||||
|
||||
assert Pet.objects.count() == 1
|
||||
|
||||
pet = Pet.objects.get()
|
||||
assert pet.name == "Enzo"
|
||||
assert pet.age == 0
|
||||
|
||||
|
||||
@patch("graphene_django.utils.utils.transaction.set_rollback")
|
||||
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", False)
|
||||
@patch.dict(
|
||||
connection.settings_dict, {"ATOMIC_MUTATIONS": False, "ATOMIC_REQUESTS": True}
|
||||
)
|
||||
def test_query_errors_atomic_request(set_rollback_mock, client):
|
||||
client.get(url_string(query="force error"))
|
||||
set_rollback_mock.assert_called_once_with(True)
|
||||
|
||||
|
||||
@patch("graphene_django.utils.utils.transaction.set_rollback")
|
||||
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", False)
|
||||
@patch.dict(
|
||||
connection.settings_dict, {"ATOMIC_MUTATIONS": False, "ATOMIC_REQUESTS": False}
|
||||
)
|
||||
def test_query_errors_non_atomic(set_rollback_mock, client):
|
||||
client.get(url_string(query="force error"))
|
||||
set_rollback_mock.assert_not_called()
|
|
@ -1,9 +0,0 @@
|
|||
from graphene_django.types import DjangoObjectType
|
||||
|
||||
from .models import Pet
|
||||
|
||||
|
||||
class PetType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Pet
|
||||
fields = "__all__"
|
|
@ -1,8 +0,0 @@
|
|||
from django.conf.urls import url
|
||||
|
||||
from ..views import GraphQLView
|
||||
|
||||
urlpatterns = [
|
||||
url(r"^graphql/batch", GraphQLView.as_view(batch=True)),
|
||||
url(r"^graphql", GraphQLView.as_view(graphiql=True)),
|
||||
]
|
|
@ -1,13 +0,0 @@
|
|||
from django.conf.urls import url
|
||||
|
||||
from ..views import GraphQLView
|
||||
from .schema_view import schema
|
||||
|
||||
|
||||
class CustomGraphQLView(GraphQLView):
|
||||
schema = schema
|
||||
graphiql = True
|
||||
pretty = True
|
||||
|
||||
|
||||
urlpatterns = [url(r"^graphql/inherited/$", CustomGraphQLView.as_view())]
|
|
@ -1,6 +0,0 @@
|
|||
from django.conf.urls import url
|
||||
|
||||
from ..views import GraphQLView
|
||||
from .schema_view import schema
|
||||
|
||||
urlpatterns = [url(r"^graphql", GraphQLView.as_view(schema=schema, pretty=True))]
|
|
@ -1,305 +0,0 @@
|
|||
import warnings
|
||||
from collections import OrderedDict
|
||||
from typing import Type
|
||||
|
||||
import graphene
|
||||
from django.db.models import Model
|
||||
from graphene.relay import Connection, Node
|
||||
from graphene.types.objecttype import ObjectType, ObjectTypeOptions
|
||||
from graphene.types.utils import yank_fields_from_attrs
|
||||
|
||||
from .converter import convert_django_field_with_choices
|
||||
from .registry import Registry, get_global_registry
|
||||
from .settings import graphene_settings
|
||||
from .utils import (
|
||||
DJANGO_FILTER_INSTALLED,
|
||||
camelize,
|
||||
get_model_fields,
|
||||
is_valid_django_model,
|
||||
)
|
||||
|
||||
ALL_FIELDS = "__all__"
|
||||
|
||||
|
||||
def construct_fields(
|
||||
model, registry, only_fields, exclude_fields, convert_choices_to_enum
|
||||
):
|
||||
_model_fields = get_model_fields(model)
|
||||
|
||||
fields = OrderedDict()
|
||||
for name, field in _model_fields:
|
||||
is_not_in_only = (
|
||||
only_fields is not None
|
||||
and only_fields != ALL_FIELDS
|
||||
and name not in only_fields
|
||||
)
|
||||
# is_already_created = name in options.fields
|
||||
is_excluded = (
|
||||
exclude_fields is not None and name in exclude_fields
|
||||
) # or is_already_created
|
||||
# https://docs.djangoproject.com/en/1.10/ref/models/fields/#django.db.models.ForeignKey.related_query_name
|
||||
is_no_backref = str(name).endswith("+")
|
||||
if is_not_in_only or is_excluded or is_no_backref:
|
||||
# We skip this field if we specify only_fields and is not
|
||||
# in there. Or when we exclude this field in exclude_fields.
|
||||
# Or when there is no back reference.
|
||||
continue
|
||||
|
||||
_convert_choices_to_enum = convert_choices_to_enum
|
||||
if not isinstance(_convert_choices_to_enum, bool):
|
||||
# then `convert_choices_to_enum` is a list of field names to convert
|
||||
if name in _convert_choices_to_enum:
|
||||
_convert_choices_to_enum = True
|
||||
else:
|
||||
_convert_choices_to_enum = False
|
||||
|
||||
converted = convert_django_field_with_choices(
|
||||
field, registry, convert_choices_to_enum=_convert_choices_to_enum
|
||||
)
|
||||
fields[name] = converted
|
||||
|
||||
return fields
|
||||
|
||||
|
||||
def validate_fields(type_, model, fields, only_fields, exclude_fields):
|
||||
# Validate the given fields against the model's fields and custom fields
|
||||
all_field_names = set(fields.keys())
|
||||
only_fields = only_fields if only_fields is not ALL_FIELDS else ()
|
||||
for name in only_fields or ():
|
||||
if name in all_field_names:
|
||||
continue
|
||||
|
||||
if hasattr(model, name):
|
||||
warnings.warn(
|
||||
(
|
||||
'Field name "{field_name}" matches an attribute on Django model "{app_label}.{object_name}" '
|
||||
"but it's not a model field so Graphene cannot determine what type it should be. "
|
||||
'Either define the type of the field on DjangoObjectType "{type_}" or remove it from the "fields" list.'
|
||||
).format(
|
||||
field_name=name,
|
||||
app_label=model._meta.app_label,
|
||||
object_name=model._meta.object_name,
|
||||
type_=type_,
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
warnings.warn(
|
||||
(
|
||||
'Field name "{field_name}" doesn\'t exist on Django model "{app_label}.{object_name}". '
|
||||
'Consider removing the field from the "fields" list of DjangoObjectType "{type_}" because it has no effect.'
|
||||
).format(
|
||||
field_name=name,
|
||||
app_label=model._meta.app_label,
|
||||
object_name=model._meta.object_name,
|
||||
type_=type_,
|
||||
)
|
||||
)
|
||||
|
||||
# Validate exclude fields
|
||||
for name in exclude_fields or ():
|
||||
if name in all_field_names:
|
||||
# Field is a custom field
|
||||
warnings.warn(
|
||||
(
|
||||
'Excluding the custom field "{field_name}" on DjangoObjectType "{type_}" has no effect. '
|
||||
'Either remove the custom field or remove the field from the "exclude" list.'
|
||||
).format(field_name=name, type_=type_)
|
||||
)
|
||||
else:
|
||||
if not hasattr(model, name):
|
||||
warnings.warn(
|
||||
(
|
||||
'Django model "{app_label}.{object_name}" does not have a field or attribute named "{field_name}". '
|
||||
'Consider removing the field from the "exclude" list of DjangoObjectType "{type_}" because it has no effect'
|
||||
).format(
|
||||
field_name=name,
|
||||
app_label=model._meta.app_label,
|
||||
object_name=model._meta.object_name,
|
||||
type_=type_,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class DjangoObjectTypeOptions(ObjectTypeOptions):
|
||||
model = None # type: Model
|
||||
registry = None # type: Registry
|
||||
connection = None # type: Type[Connection]
|
||||
|
||||
filter_fields = ()
|
||||
filterset_class = None
|
||||
|
||||
|
||||
class DjangoObjectType(ObjectType):
|
||||
@classmethod
|
||||
def __init_subclass_with_meta__(
|
||||
cls,
|
||||
model=None,
|
||||
registry=None,
|
||||
skip_registry=False,
|
||||
only_fields=None, # deprecated in favour of `fields`
|
||||
fields=None,
|
||||
exclude_fields=None, # deprecated in favour of `exclude`
|
||||
exclude=None,
|
||||
filter_fields=None,
|
||||
filterset_class=None,
|
||||
connection=None,
|
||||
connection_class=None,
|
||||
use_connection=None,
|
||||
interfaces=(),
|
||||
convert_choices_to_enum=True,
|
||||
_meta=None,
|
||||
**options
|
||||
):
|
||||
assert is_valid_django_model(model), (
|
||||
'You need to pass a valid Django Model in {}.Meta, received "{}".'
|
||||
).format(cls.__name__, model)
|
||||
|
||||
if not registry:
|
||||
registry = get_global_registry()
|
||||
|
||||
assert isinstance(registry, Registry), (
|
||||
"The attribute registry in {} needs to be an instance of "
|
||||
'Registry, received "{}".'
|
||||
).format(cls.__name__, registry)
|
||||
|
||||
if filter_fields and filterset_class:
|
||||
raise Exception("Can't set both filter_fields and filterset_class")
|
||||
|
||||
if not DJANGO_FILTER_INSTALLED and (filter_fields or filterset_class):
|
||||
raise Exception(
|
||||
(
|
||||
"Can only set filter_fields or filterset_class if "
|
||||
"Django-Filter is installed"
|
||||
)
|
||||
)
|
||||
|
||||
assert not (fields and exclude), (
|
||||
"Cannot set both 'fields' and 'exclude' options on "
|
||||
"DjangoObjectType {class_name}.".format(class_name=cls.__name__)
|
||||
)
|
||||
|
||||
# Alias only_fields -> fields
|
||||
if only_fields and fields:
|
||||
raise Exception("Can't set both only_fields and fields")
|
||||
if only_fields:
|
||||
warnings.warn(
|
||||
"Defining `only_fields` is deprecated in favour of `fields`.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
fields = only_fields
|
||||
if fields and fields != ALL_FIELDS and not isinstance(fields, (list, tuple)):
|
||||
raise TypeError(
|
||||
'The `fields` option must be a list or tuple or "__all__". '
|
||||
"Got %s." % type(fields).__name__
|
||||
)
|
||||
|
||||
# Alias exclude_fields -> exclude
|
||||
if exclude_fields and exclude:
|
||||
raise Exception("Can't set both exclude_fields and exclude")
|
||||
if exclude_fields:
|
||||
warnings.warn(
|
||||
"Defining `exclude_fields` is deprecated in favour of `exclude`.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
exclude = exclude_fields
|
||||
if exclude and not isinstance(exclude, (list, tuple)):
|
||||
raise TypeError(
|
||||
"The `exclude` option must be a list or tuple. Got %s."
|
||||
% type(exclude).__name__
|
||||
)
|
||||
|
||||
if fields is None and exclude is None:
|
||||
warnings.warn(
|
||||
"Creating a DjangoObjectType without either the `fields` "
|
||||
"or the `exclude` option is deprecated. Add an explicit `fields "
|
||||
"= '__all__'` option on DjangoObjectType {class_name} to use all "
|
||||
"fields".format(class_name=cls.__name__,),
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
django_fields = yank_fields_from_attrs(
|
||||
construct_fields(model, registry, fields, exclude, convert_choices_to_enum),
|
||||
_as=graphene.Field,
|
||||
)
|
||||
|
||||
if use_connection is None and interfaces:
|
||||
use_connection = any(
|
||||
(issubclass(interface, Node) for interface in interfaces)
|
||||
)
|
||||
|
||||
if use_connection and not connection:
|
||||
# We create the connection automatically
|
||||
if not connection_class:
|
||||
connection_class = Connection
|
||||
|
||||
connection = connection_class.create_type(
|
||||
"{}Connection".format(options.get("name") or cls.__name__), node=cls
|
||||
)
|
||||
|
||||
if connection is not None:
|
||||
assert issubclass(connection, Connection), (
|
||||
"The connection must be a Connection. Received {}"
|
||||
).format(connection.__name__)
|
||||
|
||||
if not _meta:
|
||||
_meta = DjangoObjectTypeOptions(cls)
|
||||
|
||||
_meta.model = model
|
||||
_meta.registry = registry
|
||||
_meta.filter_fields = filter_fields
|
||||
_meta.filterset_class = filterset_class
|
||||
_meta.fields = django_fields
|
||||
_meta.connection = connection
|
||||
|
||||
super(DjangoObjectType, cls).__init_subclass_with_meta__(
|
||||
_meta=_meta, interfaces=interfaces, **options
|
||||
)
|
||||
|
||||
# Validate fields
|
||||
validate_fields(cls, model, _meta.fields, fields, exclude)
|
||||
|
||||
if not skip_registry:
|
||||
registry.register(cls)
|
||||
|
||||
def resolve_id(self, info):
|
||||
return self.pk
|
||||
|
||||
@classmethod
|
||||
def is_type_of(cls, root, info):
|
||||
if isinstance(root, cls):
|
||||
return True
|
||||
if not is_valid_django_model(root.__class__):
|
||||
raise Exception(('Received incompatible instance "{}".').format(root))
|
||||
|
||||
if cls._meta.model._meta.proxy:
|
||||
model = root._meta.model
|
||||
else:
|
||||
model = root._meta.model._meta.concrete_model
|
||||
|
||||
return model == cls._meta.model
|
||||
|
||||
@classmethod
|
||||
def get_queryset(cls, queryset, info):
|
||||
return queryset
|
||||
|
||||
@classmethod
|
||||
def get_node(cls, info, id):
|
||||
queryset = cls.get_queryset(cls._meta.model.objects, info)
|
||||
try:
|
||||
return queryset.get(pk=id)
|
||||
except cls._meta.model.DoesNotExist:
|
||||
return None
|
||||
|
||||
|
||||
class ErrorType(ObjectType):
|
||||
field = graphene.String(required=True)
|
||||
messages = graphene.List(graphene.NonNull(graphene.String), required=True)
|
||||
|
||||
@classmethod
|
||||
def from_errors(cls, errors):
|
||||
data = camelize(errors) if graphene_settings.CAMELCASE_ERRORS else errors
|
||||
return [cls(field=key, messages=value) for key, value in data.items()]
|
|
@ -1,19 +0,0 @@
|
|||
from .testing import GraphQLTestCase
|
||||
from .utils import (
|
||||
DJANGO_FILTER_INSTALLED,
|
||||
camelize,
|
||||
get_model_fields,
|
||||
get_reverse_fields,
|
||||
is_valid_django_model,
|
||||
maybe_queryset,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DJANGO_FILTER_INSTALLED",
|
||||
"get_reverse_fields",
|
||||
"maybe_queryset",
|
||||
"get_model_fields",
|
||||
"camelize",
|
||||
"is_valid_django_model",
|
||||
"GraphQLTestCase",
|
||||
]
|
|
@ -1,6 +0,0 @@
|
|||
import re
|
||||
from text_unidecode import unidecode
|
||||
|
||||
|
||||
def to_const(string):
|
||||
return re.sub(r"[\W|^]+", "_", unidecode(string)).upper()
|
|
@ -1,153 +0,0 @@
|
|||
import json
|
||||
import warnings
|
||||
|
||||
from django.test import Client, TestCase, TransactionTestCase
|
||||
|
||||
DEFAULT_GRAPHQL_URL = "/graphql/"
|
||||
|
||||
|
||||
def graphql_query(
|
||||
query,
|
||||
operation_name=None,
|
||||
input_data=None,
|
||||
variables=None,
|
||||
headers=None,
|
||||
client=None,
|
||||
graphql_url=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
query (string) - GraphQL query to run
|
||||
operation_name (string) - If the query is a mutation or named query, you must
|
||||
supply the op_name. For annon queries ("{ ... }"),
|
||||
should be None (default).
|
||||
input_data (dict) - If provided, the $input variable in GraphQL will be set
|
||||
to this value. If both ``input_data`` and ``variables``,
|
||||
are provided, the ``input`` field in the ``variables``
|
||||
dict will be overwritten with this value.
|
||||
variables (dict) - If provided, the "variables" field in GraphQL will be
|
||||
set to this value.
|
||||
headers (dict) - If provided, the headers in POST request to GRAPHQL_URL
|
||||
will be set to this value. Keys should be prepended with
|
||||
"HTTP_" (e.g. to specify the "Authorization" HTTP header,
|
||||
use "HTTP_AUTHORIZATION" as the key).
|
||||
client (django.test.Client) - Test client. Defaults to django.test.Client.
|
||||
graphql_url (string) - URL to graphql endpoint. Defaults to "/graphql".
|
||||
|
||||
Returns:
|
||||
Response object from client
|
||||
"""
|
||||
if client is None:
|
||||
client = Client()
|
||||
if not graphql_url:
|
||||
graphql_url = DEFAULT_GRAPHQL_URL
|
||||
|
||||
body = {"query": query}
|
||||
if operation_name:
|
||||
body["operationName"] = operation_name
|
||||
if variables:
|
||||
body["variables"] = variables
|
||||
if input_data:
|
||||
if "variables" in body:
|
||||
body["variables"]["input"] = input_data
|
||||
else:
|
||||
body["variables"] = {"input": input_data}
|
||||
if headers:
|
||||
resp = client.post(
|
||||
graphql_url, json.dumps(body), content_type="application/json", **headers
|
||||
)
|
||||
else:
|
||||
resp = client.post(
|
||||
graphql_url, json.dumps(body), content_type="application/json"
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
class GraphQLTestMixin(object):
|
||||
"""
|
||||
Based on: https://www.sam.today/blog/testing-graphql-with-graphene-django/
|
||||
"""
|
||||
|
||||
# URL to graphql endpoint
|
||||
GRAPHQL_URL = DEFAULT_GRAPHQL_URL
|
||||
|
||||
def query(
|
||||
self, query, operation_name=None, input_data=None, variables=None, headers=None
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
query (string) - GraphQL query to run
|
||||
operation_name (string) - If the query is a mutation or named query, you must
|
||||
supply the op_name. For annon queries ("{ ... }"),
|
||||
should be None (default).
|
||||
input_data (dict) - If provided, the $input variable in GraphQL will be set
|
||||
to this value. If both ``input_data`` and ``variables``,
|
||||
are provided, the ``input`` field in the ``variables``
|
||||
dict will be overwritten with this value.
|
||||
variables (dict) - If provided, the "variables" field in GraphQL will be
|
||||
set to this value.
|
||||
headers (dict) - If provided, the headers in POST request to GRAPHQL_URL
|
||||
will be set to this value. Keys should be prepended with
|
||||
"HTTP_" (e.g. to specify the "Authorization" HTTP header,
|
||||
use "HTTP_AUTHORIZATION" as the key).
|
||||
|
||||
Returns:
|
||||
Response object from client
|
||||
"""
|
||||
return graphql_query(
|
||||
query,
|
||||
operation_name=operation_name,
|
||||
input_data=input_data,
|
||||
variables=variables,
|
||||
headers=headers,
|
||||
client=self.client,
|
||||
graphql_url=self.GRAPHQL_URL,
|
||||
)
|
||||
|
||||
@property
|
||||
def _client(self):
|
||||
pass
|
||||
|
||||
@_client.getter
|
||||
def _client(self):
|
||||
warnings.warn(
|
||||
"Using `_client` is deprecated in favour of `client`.",
|
||||
PendingDeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.client
|
||||
|
||||
@_client.setter
|
||||
def _client(self, client):
|
||||
warnings.warn(
|
||||
"Using `_client` is deprecated in favour of `client`.",
|
||||
PendingDeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self.client = client
|
||||
|
||||
def assertResponseNoErrors(self, resp, msg=None):
|
||||
"""
|
||||
Assert that the call went through correctly. 200 means the syntax is ok, if there are no `errors`,
|
||||
the call was fine.
|
||||
:resp HttpResponse: Response
|
||||
"""
|
||||
content = json.loads(resp.content)
|
||||
self.assertEqual(resp.status_code, 200, msg or content)
|
||||
self.assertNotIn("errors", list(content.keys()), msg or content)
|
||||
|
||||
def assertResponseHasErrors(self, resp, msg=None):
|
||||
"""
|
||||
Assert that the call was failing. Take care: Even with errors, GraphQL returns status 200!
|
||||
:resp HttpResponse: Response
|
||||
"""
|
||||
content = json.loads(resp.content)
|
||||
self.assertIn("errors", list(content.keys()), msg or content)
|
||||
|
||||
|
||||
class GraphQLTestCase(GraphQLTestMixin, TestCase):
|
||||
pass
|
||||
|
||||
|
||||
class GraphQLTransactionTestCase(GraphQLTestMixin, TransactionTestCase):
|
||||
pass
|
|
@ -1,9 +0,0 @@
|
|||
from ..str_converters import to_const
|
||||
|
||||
|
||||
def test_to_const():
|
||||
assert to_const('snakes $1. on a "#plane') == "SNAKES_1_ON_A_PLANE"
|
||||
|
||||
|
||||
def test_to_const_unicode():
|
||||
assert to_const(u"Skoða þetta unicode stöff") == "SKODA_THETTA_UNICODE_STOFF"
|
|
@ -1,45 +0,0 @@
|
|||
import pytest
|
||||
|
||||
from .. import GraphQLTestCase
|
||||
from ...tests.test_types import with_local_registry
|
||||
from django.test import Client
|
||||
|
||||
|
||||
@with_local_registry
|
||||
def test_graphql_test_case_deprecated_client_getter():
|
||||
"""
|
||||
`GraphQLTestCase._client`' getter should raise pending deprecation warning.
|
||||
"""
|
||||
|
||||
class TestClass(GraphQLTestCase):
|
||||
GRAPHQL_SCHEMA = True
|
||||
|
||||
def runTest(self):
|
||||
pass
|
||||
|
||||
tc = TestClass()
|
||||
tc._pre_setup()
|
||||
tc.setUpClass()
|
||||
|
||||
with pytest.warns(PendingDeprecationWarning):
|
||||
tc._client
|
||||
|
||||
|
||||
@with_local_registry
|
||||
def test_graphql_test_case_deprecated_client_setter():
|
||||
"""
|
||||
`GraphQLTestCase._client`' setter should raise pending deprecation warning.
|
||||
"""
|
||||
|
||||
class TestClass(GraphQLTestCase):
|
||||
GRAPHQL_SCHEMA = True
|
||||
|
||||
def runTest(self):
|
||||
pass
|
||||
|
||||
tc = TestClass()
|
||||
tc._pre_setup()
|
||||
tc.setUpClass()
|
||||
|
||||
with pytest.warns(PendingDeprecationWarning):
|
||||
tc._client = Client()
|
|
@ -1,107 +0,0 @@
|
|||
import inspect
|
||||
|
||||
from django.db import connection, models, transaction
|
||||
from django.db.models.manager import Manager
|
||||
from django.utils.encoding import force_str
|
||||
from django.utils.functional import Promise
|
||||
|
||||
from graphene.utils.str_converters import to_camel_case
|
||||
|
||||
try:
|
||||
import django_filters # noqa
|
||||
|
||||
DJANGO_FILTER_INSTALLED = True
|
||||
except ImportError:
|
||||
DJANGO_FILTER_INSTALLED = False
|
||||
|
||||
|
||||
def isiterable(value):
|
||||
try:
|
||||
iter(value)
|
||||
except TypeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _camelize_django_str(s):
|
||||
if isinstance(s, Promise):
|
||||
s = force_str(s)
|
||||
return to_camel_case(s) if isinstance(s, str) else s
|
||||
|
||||
|
||||
def camelize(data):
|
||||
if isinstance(data, dict):
|
||||
return {_camelize_django_str(k): camelize(v) for k, v in data.items()}
|
||||
if isiterable(data) and not isinstance(data, (str, Promise)):
|
||||
return [camelize(d) for d in data]
|
||||
return data
|
||||
|
||||
|
||||
def get_reverse_fields(model, local_field_names):
|
||||
for name, attr in model.__dict__.items():
|
||||
# Don't duplicate any local fields
|
||||
if name in local_field_names:
|
||||
continue
|
||||
|
||||
# "rel" for FK and M2M relations and "related" for O2O Relations
|
||||
related = getattr(attr, "rel", None) or getattr(attr, "related", None)
|
||||
if isinstance(related, models.ManyToOneRel):
|
||||
yield (name, related)
|
||||
elif isinstance(related, models.ManyToManyRel) and not related.symmetrical:
|
||||
yield (name, related)
|
||||
|
||||
|
||||
def maybe_queryset(value):
|
||||
if isinstance(value, Manager):
|
||||
value = value.get_queryset()
|
||||
return value
|
||||
|
||||
|
||||
def get_model_fields(model):
|
||||
local_fields = [
|
||||
(field.name, field)
|
||||
for field in sorted(
|
||||
list(model._meta.fields) + list(model._meta.local_many_to_many)
|
||||
)
|
||||
]
|
||||
|
||||
# Make sure we don't duplicate local fields with "reverse" version
|
||||
local_field_names = [field[0] for field in local_fields]
|
||||
reverse_fields = get_reverse_fields(model, local_field_names)
|
||||
|
||||
all_fields = local_fields + list(reverse_fields)
|
||||
|
||||
return all_fields
|
||||
|
||||
|
||||
def is_valid_django_model(model):
|
||||
return inspect.isclass(model) and issubclass(model, models.Model)
|
||||
|
||||
|
||||
def import_single_dispatch():
|
||||
try:
|
||||
from functools import singledispatch
|
||||
except ImportError:
|
||||
singledispatch = None
|
||||
|
||||
if not singledispatch:
|
||||
try:
|
||||
from singledispatch import singledispatch
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if not singledispatch:
|
||||
raise Exception(
|
||||
"It seems your python version does not include "
|
||||
"functools.singledispatch. Please install the 'singledispatch' "
|
||||
"package. More information here: "
|
||||
"https://pypi.python.org/pypi/singledispatch"
|
||||
)
|
||||
|
||||
return singledispatch
|
||||
|
||||
|
||||
def set_rollback():
|
||||
atomic_requests = connection.settings_dict.get("ATOMIC_REQUESTS", False)
|
||||
if atomic_requests and connection.in_atomic_block:
|
||||
transaction.set_rollback(True)
|
|
@ -1,398 +0,0 @@
|
|||
import inspect
|
||||
import json
|
||||
import re
|
||||
|
||||
from django.db import connection, transaction
|
||||
from django.http import HttpResponse, HttpResponseNotAllowed
|
||||
from django.http.response import HttpResponseBadRequest
|
||||
from django.shortcuts import render
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.views.decorators.csrf import ensure_csrf_cookie
|
||||
from django.views.generic import View
|
||||
from graphql import OperationType, get_operation_ast, parse, validate
|
||||
from graphql.error import GraphQLError
|
||||
from graphql.error import format_error as format_graphql_error
|
||||
from graphql.execution import ExecutionResult
|
||||
|
||||
from graphene import Schema
|
||||
from graphql.execution.middleware import MiddlewareManager
|
||||
|
||||
from graphene_django.constants import MUTATION_ERRORS_FLAG
|
||||
from graphene_django.utils.utils import set_rollback
|
||||
|
||||
from .settings import graphene_settings
|
||||
|
||||
|
||||
class HttpError(Exception):
|
||||
def __init__(self, response, message=None, *args, **kwargs):
|
||||
self.response = response
|
||||
self.message = message = message or response.content.decode()
|
||||
super(HttpError, self).__init__(message, *args, **kwargs)
|
||||
|
||||
|
||||
def get_accepted_content_types(request):
|
||||
def qualify(x):
|
||||
parts = x.split(";", 1)
|
||||
if len(parts) == 2:
|
||||
match = re.match(r"(^|;)q=(0(\.\d{,3})?|1(\.0{,3})?)(;|$)", parts[1])
|
||||
if match:
|
||||
return parts[0].strip(), float(match.group(2))
|
||||
return parts[0].strip(), 1
|
||||
|
||||
raw_content_types = request.META.get("HTTP_ACCEPT", "*/*").split(",")
|
||||
qualified_content_types = map(qualify, raw_content_types)
|
||||
return list(
|
||||
x[0] for x in sorted(qualified_content_types, key=lambda x: x[1], reverse=True)
|
||||
)
|
||||
|
||||
|
||||
def instantiate_middleware(middlewares):
|
||||
for middleware in middlewares:
|
||||
if inspect.isclass(middleware):
|
||||
yield middleware()
|
||||
continue
|
||||
yield middleware
|
||||
|
||||
|
||||
class GraphQLView(View):
|
||||
graphiql_template = "graphene/graphiql.html"
|
||||
|
||||
# Polyfill for window.fetch.
|
||||
whatwg_fetch_version = "3.6.2"
|
||||
whatwg_fetch_sri = "sha256-+pQdxwAcHJdQ3e/9S4RK6g8ZkwdMgFQuHvLuN5uyk5c="
|
||||
|
||||
# React and ReactDOM.
|
||||
react_version = "17.0.2"
|
||||
react_sri = "sha256-Ipu/TQ50iCCVZBUsZyNJfxrDk0E2yhaEIz0vqI+kFG8="
|
||||
react_dom_sri = "sha256-nbMykgB6tsOFJ7OdVmPpdqMFVk4ZsqWocT6issAPUF0="
|
||||
|
||||
# The GraphiQL React app.
|
||||
graphiql_version = "1.4.1" # "1.0.3"
|
||||
graphiql_sri = "sha256-JUMkXBQWZMfJ7fGEsTXalxVA10lzKOS9loXdLjwZKi4=" # "sha256-VR4buIDY9ZXSyCNFHFNik6uSe0MhigCzgN4u7moCOTk="
|
||||
graphiql_css_sri = "sha256-Md3vdR7PDzWyo/aGfsFVF4tvS5/eAUWuIsg9QHUusCY=" # "sha256-LwqxjyZgqXDYbpxQJ5zLQeNcf7WVNSJ+r8yp2rnWE/E="
|
||||
|
||||
# The websocket transport library for subscriptions.
|
||||
subscriptions_transport_ws_version = "0.9.18"
|
||||
subscriptions_transport_ws_sri = (
|
||||
"sha256-i0hAXd4PdJ/cHX3/8tIy/Q/qKiWr5WSTxMFuL9tACkw="
|
||||
)
|
||||
|
||||
schema = None
|
||||
graphiql = False
|
||||
middleware = None
|
||||
root_value = None
|
||||
pretty = False
|
||||
batch = False
|
||||
subscription_path = None
|
||||
execution_context_class = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
schema=None,
|
||||
middleware=None,
|
||||
root_value=None,
|
||||
graphiql=False,
|
||||
pretty=False,
|
||||
batch=False,
|
||||
subscription_path=None,
|
||||
execution_context_class=None,
|
||||
):
|
||||
if not schema:
|
||||
schema = graphene_settings.SCHEMA
|
||||
|
||||
if middleware is None:
|
||||
middleware = graphene_settings.MIDDLEWARE
|
||||
|
||||
self.schema = self.schema or schema
|
||||
if middleware is not None:
|
||||
if isinstance(middleware, MiddlewareManager):
|
||||
self.middleware = middleware
|
||||
else:
|
||||
self.middleware = list(instantiate_middleware(middleware))
|
||||
self.root_value = root_value
|
||||
self.pretty = self.pretty or pretty
|
||||
self.graphiql = self.graphiql or graphiql
|
||||
self.batch = self.batch or batch
|
||||
self.execution_context_class = execution_context_class
|
||||
if subscription_path is None:
|
||||
self.subscription_path = graphene_settings.SUBSCRIPTION_PATH
|
||||
|
||||
assert isinstance(
|
||||
self.schema, Schema
|
||||
), "A Schema is required to be provided to GraphQLView."
|
||||
assert not all((graphiql, batch)), "Use either graphiql or batch processing"
|
||||
|
||||
# noinspection PyUnusedLocal
|
||||
def get_root_value(self, request):
|
||||
return self.root_value
|
||||
|
||||
def get_middleware(self, request):
|
||||
return self.middleware
|
||||
|
||||
def get_context(self, request):
|
||||
return request
|
||||
|
||||
@method_decorator(ensure_csrf_cookie)
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
try:
|
||||
if request.method.lower() not in ("get", "post"):
|
||||
raise HttpError(
|
||||
HttpResponseNotAllowed(
|
||||
["GET", "POST"], "GraphQL only supports GET and POST requests."
|
||||
)
|
||||
)
|
||||
|
||||
data = self.parse_body(request)
|
||||
show_graphiql = self.graphiql and self.can_display_graphiql(request, data)
|
||||
|
||||
if show_graphiql:
|
||||
return self.render_graphiql(
|
||||
request,
|
||||
# Dependency parameters.
|
||||
whatwg_fetch_version=self.whatwg_fetch_version,
|
||||
whatwg_fetch_sri=self.whatwg_fetch_sri,
|
||||
react_version=self.react_version,
|
||||
react_sri=self.react_sri,
|
||||
react_dom_sri=self.react_dom_sri,
|
||||
graphiql_version=self.graphiql_version,
|
||||
graphiql_sri=self.graphiql_sri,
|
||||
graphiql_css_sri=self.graphiql_css_sri,
|
||||
subscriptions_transport_ws_version=self.subscriptions_transport_ws_version,
|
||||
subscriptions_transport_ws_sri=self.subscriptions_transport_ws_sri,
|
||||
# The SUBSCRIPTION_PATH setting.
|
||||
subscription_path=self.subscription_path,
|
||||
# GraphiQL headers tab,
|
||||
graphiql_header_editor_enabled=graphene_settings.GRAPHIQL_HEADER_EDITOR_ENABLED,
|
||||
)
|
||||
|
||||
if self.batch:
|
||||
responses = [self.get_response(request, entry) for entry in data]
|
||||
result = "[{}]".format(
|
||||
",".join([response[0] for response in responses])
|
||||
)
|
||||
status_code = (
|
||||
responses
|
||||
and max(responses, key=lambda response: response[1])[1]
|
||||
or 200
|
||||
)
|
||||
else:
|
||||
result, status_code = self.get_response(request, data, show_graphiql)
|
||||
|
||||
return HttpResponse(
|
||||
status=status_code, content=result, content_type="application/json"
|
||||
)
|
||||
|
||||
except HttpError as e:
|
||||
response = e.response
|
||||
response["Content-Type"] = "application/json"
|
||||
response.content = self.json_encode(
|
||||
request, {"errors": [self.format_error(e)]}
|
||||
)
|
||||
return response
|
||||
|
||||
def get_response(self, request, data, show_graphiql=False):
|
||||
query, variables, operation_name, id = self.get_graphql_params(request, data)
|
||||
|
||||
execution_result = self.execute_graphql_request(
|
||||
request, data, query, variables, operation_name, show_graphiql
|
||||
)
|
||||
|
||||
if getattr(request, MUTATION_ERRORS_FLAG, False) is True:
|
||||
set_rollback()
|
||||
|
||||
status_code = 200
|
||||
if execution_result:
|
||||
response = {}
|
||||
|
||||
if execution_result.errors:
|
||||
set_rollback()
|
||||
response["errors"] = [
|
||||
self.format_error(e) for e in execution_result.errors
|
||||
]
|
||||
|
||||
if execution_result.errors and any(
|
||||
not getattr(e, "path", None) for e in execution_result.errors
|
||||
):
|
||||
status_code = 400
|
||||
else:
|
||||
response["data"] = execution_result.data
|
||||
|
||||
if self.batch:
|
||||
response["id"] = id
|
||||
response["status"] = status_code
|
||||
|
||||
result = self.json_encode(request, response, pretty=show_graphiql)
|
||||
else:
|
||||
result = None
|
||||
|
||||
return result, status_code
|
||||
|
||||
def render_graphiql(self, request, **data):
|
||||
return render(request, self.graphiql_template, data)
|
||||
|
||||
def json_encode(self, request, d, pretty=False):
|
||||
if not (self.pretty or pretty) and not request.GET.get("pretty"):
|
||||
return json.dumps(d, separators=(",", ":"))
|
||||
|
||||
return json.dumps(d, sort_keys=True, indent=2, separators=(",", ": "))
|
||||
|
||||
def parse_body(self, request):
|
||||
content_type = self.get_content_type(request)
|
||||
|
||||
if content_type == "application/graphql":
|
||||
return {"query": request.body.decode()}
|
||||
|
||||
elif content_type == "application/json":
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
body = request.body.decode("utf-8")
|
||||
except Exception as e:
|
||||
raise HttpError(HttpResponseBadRequest(str(e)))
|
||||
|
||||
try:
|
||||
request_json = json.loads(body)
|
||||
if self.batch:
|
||||
assert isinstance(request_json, list), (
|
||||
"Batch requests should receive a list, but received {}."
|
||||
).format(repr(request_json))
|
||||
assert (
|
||||
len(request_json) > 0
|
||||
), "Received an empty list in the batch request."
|
||||
else:
|
||||
assert isinstance(
|
||||
request_json, dict
|
||||
), "The received data is not a valid JSON query."
|
||||
return request_json
|
||||
except AssertionError as e:
|
||||
raise HttpError(HttpResponseBadRequest(str(e)))
|
||||
except (TypeError, ValueError):
|
||||
raise HttpError(HttpResponseBadRequest("POST body sent invalid JSON."))
|
||||
|
||||
elif content_type in [
|
||||
"application/x-www-form-urlencoded",
|
||||
"multipart/form-data",
|
||||
]:
|
||||
return request.POST
|
||||
|
||||
return {}
|
||||
|
||||
def execute_graphql_request(
|
||||
self, request, data, query, variables, operation_name, show_graphiql=False
|
||||
):
|
||||
if not query:
|
||||
if show_graphiql:
|
||||
return None
|
||||
raise HttpError(HttpResponseBadRequest("Must provide query string."))
|
||||
|
||||
try:
|
||||
document = parse(query)
|
||||
except Exception as e:
|
||||
return ExecutionResult(errors=[e])
|
||||
|
||||
if request.method.lower() == "get":
|
||||
operation_ast = get_operation_ast(document, operation_name)
|
||||
if operation_ast and operation_ast.operation != OperationType.QUERY:
|
||||
if show_graphiql:
|
||||
return None
|
||||
|
||||
raise HttpError(
|
||||
HttpResponseNotAllowed(
|
||||
["POST"],
|
||||
"Can only perform a {} operation from a POST request.".format(
|
||||
operation_ast.operation.value
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
validation_errors = validate(self.schema.graphql_schema, document)
|
||||
if validation_errors:
|
||||
return ExecutionResult(data=None, errors=validation_errors)
|
||||
|
||||
try:
|
||||
extra_options = {}
|
||||
if self.execution_context_class:
|
||||
extra_options["execution_context_class"] = self.execution_context_class
|
||||
|
||||
options = {
|
||||
"source": query,
|
||||
"root_value": self.get_root_value(request),
|
||||
"variable_values": variables,
|
||||
"operation_name": operation_name,
|
||||
"context_value": self.get_context(request),
|
||||
"middleware": self.get_middleware(request),
|
||||
}
|
||||
options.update(extra_options)
|
||||
|
||||
operation_ast = get_operation_ast(document, operation_name)
|
||||
if (
|
||||
operation_ast
|
||||
and operation_ast.operation == OperationType.MUTATION
|
||||
and (
|
||||
graphene_settings.ATOMIC_MUTATIONS is True
|
||||
or connection.settings_dict.get("ATOMIC_MUTATIONS", False) is True
|
||||
)
|
||||
):
|
||||
with transaction.atomic():
|
||||
result = self.schema.execute(**options)
|
||||
if getattr(request, MUTATION_ERRORS_FLAG, False) is True:
|
||||
transaction.set_rollback(True)
|
||||
return result
|
||||
|
||||
return self.schema.execute(**options)
|
||||
except Exception as e:
|
||||
return ExecutionResult(errors=[e])
|
||||
|
||||
@classmethod
|
||||
def can_display_graphiql(cls, request, data):
|
||||
raw = "raw" in request.GET or "raw" in data
|
||||
return not raw and cls.request_wants_html(request)
|
||||
|
||||
@classmethod
|
||||
def request_wants_html(cls, request):
|
||||
accepted = get_accepted_content_types(request)
|
||||
accepted_length = len(accepted)
|
||||
# the list will be ordered in preferred first - so we have to make
|
||||
# sure the most preferred gets the highest number
|
||||
html_priority = (
|
||||
accepted_length - accepted.index("text/html")
|
||||
if "text/html" in accepted
|
||||
else 0
|
||||
)
|
||||
json_priority = (
|
||||
accepted_length - accepted.index("application/json")
|
||||
if "application/json" in accepted
|
||||
else 0
|
||||
)
|
||||
|
||||
return html_priority > json_priority
|
||||
|
||||
@staticmethod
|
||||
def get_graphql_params(request, data):
|
||||
query = request.GET.get("query") or data.get("query")
|
||||
variables = request.GET.get("variables") or data.get("variables")
|
||||
id = request.GET.get("id") or data.get("id")
|
||||
|
||||
if variables and isinstance(variables, str):
|
||||
try:
|
||||
variables = json.loads(variables)
|
||||
except Exception:
|
||||
raise HttpError(HttpResponseBadRequest("Variables are invalid JSON."))
|
||||
|
||||
operation_name = request.GET.get("operationName") or data.get("operationName")
|
||||
if operation_name == "null":
|
||||
operation_name = None
|
||||
|
||||
return query, variables, operation_name, id
|
||||
|
||||
@staticmethod
|
||||
def format_error(error):
|
||||
if isinstance(error, GraphQLError):
|
||||
return format_graphql_error(error)
|
||||
|
||||
return {"message": str(error)}
|
||||
|
||||
@staticmethod
|
||||
def get_content_type(request):
|
||||
meta = request.META
|
||||
content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", ""))
|
||||
return content_type.split(";", 1)[0].lower()
|
Loading…
Reference in New Issue
Block a user