Delete graphene_django directory

This commit is contained in:
removedporn 2021-08-03 18:20:18 +08:00 committed by GitHub
parent e7f7d8da07
commit 00e27d6a66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
92 changed files with 0 additions and 12458 deletions

View File

@ -1,11 +0,0 @@
from .fields import DjangoConnectionField, DjangoListField
from .types import DjangoObjectType
__version__ = "3.0.0b7"
__all__ = [
"__version__",
"DjangoObjectType",
"DjangoListField",
"DjangoConnectionField",
]

View File

@ -1,25 +0,0 @@
class MissingType(object):
def __init__(self, *args, **kwargs):
pass
try:
# Postgres fields are only available in Django with psycopg2 installed
# and we cannot have psycopg2 on PyPy
from django.contrib.postgres.fields import (
IntegerRangeField,
ArrayField,
HStoreField,
JSONField as PGJSONField,
RangeField,
)
except ImportError:
IntegerRangeField, ArrayField, HStoreField, PGJSONField, RangeField = (
MissingType,
) * 5
try:
# JSONField is only available from Django 3.1
from django.db.models import JSONField
except ImportError:
JSONField = MissingType

View File

@ -1,18 +0,0 @@
import pytest
from graphene_django.settings import graphene_settings as gsettings
from .registry import reset_global_registry
@pytest.fixture(autouse=True)
def reset_registry_fixture(db):
yield None
reset_global_registry()
@pytest.fixture()
def graphene_settings():
settings = dict(gsettings.__dict__)
yield gsettings
gsettings.__dict__ = settings

View File

@ -1 +0,0 @@
MUTATION_ERRORS_FLAG = "graphene_mutation_has_errors"

View File

@ -1,356 +0,0 @@
from collections import OrderedDict
from functools import singledispatch, wraps
from django.db import models
from django.utils.encoding import force_str
from django.utils.functional import Promise
from django.utils.module_loading import import_string
from graphene import (
ID,
UUID,
Boolean,
Date,
DateTime,
Dynamic,
Enum,
Field,
Float,
Int,
List,
NonNull,
String,
Time,
Decimal,
)
from graphene.types.json import JSONString
from graphene.utils.str_converters import to_camel_case
from graphql import GraphQLError, assert_valid_name
from graphql.pyutils import register_description
from .compat import ArrayField, HStoreField, JSONField, PGJSONField, RangeField
from .fields import DjangoListField, DjangoConnectionField
from .settings import graphene_settings
from .utils.str_converters import to_const
class BlankValueField(Field):
def wrap_resolve(self, parent_resolver):
resolver = self.resolver or parent_resolver
# create custom resolver
def blank_field_wrapper(func):
@wraps(func)
def wrapped_resolver(*args, **kwargs):
return_value = func(*args, **kwargs)
if return_value == "":
return None
return return_value
return wrapped_resolver
return blank_field_wrapper(resolver)
def convert_choice_name(name):
name = to_const(force_str(name))
try:
assert_valid_name(name)
except GraphQLError:
name = "A_%s" % name
return name
def get_choices(choices):
converted_names = []
if isinstance(choices, OrderedDict):
choices = choices.items()
for value, help_text in choices:
if isinstance(help_text, (tuple, list)):
for choice in get_choices(help_text):
yield choice
else:
name = convert_choice_name(value)
while name in converted_names:
name += "_" + str(len(converted_names))
converted_names.append(name)
description = str(
help_text
) # TODO: translatable description: https://github.com/graphql-python/graphql-core-next/issues/58
yield name, value, description
def convert_choices_to_named_enum_with_descriptions(name, choices):
choices = list(get_choices(choices))
named_choices = [(c[0], c[1]) for c in choices]
named_choices_descriptions = {c[0]: c[2] for c in choices}
class EnumWithDescriptionsType(object):
@property
def description(self):
return str(named_choices_descriptions[self.name])
return_type = Enum(name, list(named_choices), type=EnumWithDescriptionsType)
return return_type
def generate_enum_name(django_model_meta, field):
if graphene_settings.DJANGO_CHOICE_FIELD_ENUM_CUSTOM_NAME:
# Try and import custom function
custom_func = import_string(
graphene_settings.DJANGO_CHOICE_FIELD_ENUM_CUSTOM_NAME
)
name = custom_func(field)
elif graphene_settings.DJANGO_CHOICE_FIELD_ENUM_V2_NAMING is True:
name = to_camel_case("{}_{}".format(django_model_meta.object_name, field.name))
else:
name = "{app_label}{object_name}{field_name}Choices".format(
app_label=to_camel_case(django_model_meta.app_label.title()),
object_name=django_model_meta.object_name,
field_name=to_camel_case(field.name.title()),
)
return name
def convert_choice_field_to_enum(field, name=None):
if name is None:
name = generate_enum_name(field.model._meta, field)
choices = field.choices
return convert_choices_to_named_enum_with_descriptions(name, choices)
def convert_django_field_with_choices(
field, registry=None, convert_choices_to_enum=True
):
if registry is not None:
converted = registry.get_converted_field(field)
if converted:
return converted
choices = getattr(field, "choices", None)
if choices and convert_choices_to_enum:
EnumCls = convert_choice_field_to_enum(field)
required = not (field.blank or field.null)
converted = EnumCls(
description=get_django_field_description(field), required=required
).mount_as(BlankValueField)
else:
converted = convert_django_field(field, registry)
if registry is not None:
registry.register_converted_field(field, converted)
return converted
def get_django_field_description(field):
return str(field.help_text) if field.help_text else None
@singledispatch
def convert_django_field(field, registry=None):
raise Exception(
"Don't know how to convert the Django field %s (%s)" % (field, field.__class__)
)
@convert_django_field.register(models.CharField)
@convert_django_field.register(models.TextField)
@convert_django_field.register(models.EmailField)
@convert_django_field.register(models.SlugField)
@convert_django_field.register(models.URLField)
@convert_django_field.register(models.GenericIPAddressField)
@convert_django_field.register(models.FileField)
@convert_django_field.register(models.FilePathField)
def convert_field_to_string(field, registry=None):
return String(
description=get_django_field_description(field), required=not field.null
)
@convert_django_field.register(models.BigAutoField)
@convert_django_field.register(models.AutoField)
def convert_field_to_id(field, registry=None):
return ID(description=get_django_field_description(field), required=not field.null)
if hasattr(models, "SmallAutoField"):
@convert_django_field.register(models.SmallAutoField)
def convert_field_small_to_id(field, registry=None):
return convert_field_to_id(field, registry)
@convert_django_field.register(models.UUIDField)
def convert_field_to_uuid(field, registry=None):
return UUID(
description=get_django_field_description(field), required=not field.null
)
@convert_django_field.register(models.PositiveIntegerField)
@convert_django_field.register(models.PositiveSmallIntegerField)
@convert_django_field.register(models.SmallIntegerField)
@convert_django_field.register(models.BigIntegerField)
@convert_django_field.register(models.IntegerField)
def convert_field_to_int(field, registry=None):
return Int(description=get_django_field_description(field), required=not field.null)
@convert_django_field.register(models.NullBooleanField)
@convert_django_field.register(models.BooleanField)
def convert_field_to_boolean(field, registry=None):
return Boolean(
description=get_django_field_description(field), required=not field.null
)
@convert_django_field.register(models.DecimalField)
def convert_field_to_decimal(field, registry=None):
return Decimal(description=field.help_text, required=not field.null)
@convert_django_field.register(models.FloatField)
@convert_django_field.register(models.DurationField)
def convert_field_to_float(field, registry=None):
return Float(
description=get_django_field_description(field), required=not field.null
)
@convert_django_field.register(models.DateTimeField)
def convert_datetime_to_string(field, registry=None):
return DateTime(
description=get_django_field_description(field), required=not field.null
)
@convert_django_field.register(models.DateField)
def convert_date_to_string(field, registry=None):
return Date(
description=get_django_field_description(field), required=not field.null
)
@convert_django_field.register(models.TimeField)
def convert_time_to_string(field, registry=None):
return Time(
description=get_django_field_description(field), required=not field.null
)
@convert_django_field.register(models.OneToOneRel)
def convert_onetoone_field_to_djangomodel(field, registry=None):
model = field.related_model
def dynamic_type():
_type = registry.get_type_for_model(model)
if not _type:
return
return Field(_type, required=not field.null)
return Dynamic(dynamic_type)
@convert_django_field.register(models.ManyToManyField)
@convert_django_field.register(models.ManyToManyRel)
@convert_django_field.register(models.ManyToOneRel)
def convert_field_to_list_or_connection(field, registry=None):
model = field.related_model
def dynamic_type():
_type = registry.get_type_for_model(model)
if not _type:
return
if isinstance(field, models.ManyToManyField):
description = get_django_field_description(field)
else:
description = get_django_field_description(field.field)
# If there is a connection, we should transform the field
# into a DjangoConnectionField
if _type._meta.connection:
# Use a DjangoFilterConnectionField if there are
# defined filter_fields or a filterset_class in the
# DjangoObjectType Meta
if _type._meta.filter_fields or _type._meta.filterset_class:
from .filter.fields import DjangoFilterConnectionField
return DjangoFilterConnectionField(
_type, required=True, description=description
)
return DjangoConnectionField(_type, required=True, description=description)
return DjangoListField(
_type,
required=True, # A Set is always returned, never None.
description=description,
)
return Dynamic(dynamic_type)
@convert_django_field.register(models.OneToOneField)
@convert_django_field.register(models.ForeignKey)
def convert_field_to_djangomodel(field, registry=None):
model = field.related_model
def dynamic_type():
_type = registry.get_type_for_model(model)
if not _type:
return
return Field(
_type,
description=get_django_field_description(field),
required=not field.null,
)
return Dynamic(dynamic_type)
@convert_django_field.register(ArrayField)
def convert_postgres_array_to_list(field, registry=None):
inner_type = convert_django_field(field.base_field)
if not isinstance(inner_type, (List, NonNull)):
inner_type = (
NonNull(type(inner_type))
if inner_type.kwargs["required"]
else type(inner_type)
)
return List(
inner_type,
description=get_django_field_description(field),
required=not field.null,
)
@convert_django_field.register(HStoreField)
@convert_django_field.register(PGJSONField)
@convert_django_field.register(JSONField)
def convert_pg_and_json_field_to_string(field, registry=None):
return JSONString(
description=get_django_field_description(field), required=not field.null
)
@convert_django_field.register(RangeField)
def convert_postgres_range_to_string(field, registry=None):
inner_type = convert_django_field(field.base_field)
if not isinstance(inner_type, (List, NonNull)):
inner_type = (
NonNull(type(inner_type))
if inner_type.kwargs["required"]
else type(inner_type)
)
return List(
inner_type,
description=get_django_field_description(field),
required=not field.null,
)
# Register Django lazy()-wrapped values as GraphQL description/help_text.
# This is needed for using lazy translations, see https://github.com/graphql-python/graphql-core-next/issues/58.
register_description(Promise)

View File

@ -1,4 +0,0 @@
from .middleware import DjangoDebugMiddleware
from .types import DjangoDebug
__all__ = ["DjangoDebugMiddleware", "DjangoDebug"]

View File

@ -1,17 +0,0 @@
import traceback
from django.utils.encoding import force_str
from .types import DjangoDebugException
def wrap_exception(exception):
return DjangoDebugException(
message=force_str(exception),
exc_type=force_str(type(exception)),
stack="".join(
traceback.format_exception(
etype=type(exception), value=exception, tb=exception.__traceback__
)
),
)

View File

@ -1,10 +0,0 @@
from graphene import ObjectType, String
class DjangoDebugException(ObjectType):
class Meta:
description = "Represents a single exception raised."
exc_type = String(required=True, description="The class of the exception")
message = String(required=True, description="The message of the exception")
stack = String(required=True, description="The stack trace")

View File

@ -1,71 +0,0 @@
from django.db import connections
from promise import Promise
from .sql.tracking import unwrap_cursor, wrap_cursor
from .exception.formating import wrap_exception
from .types import DjangoDebug
class DjangoDebugContext(object):
def __init__(self):
self.debug_promise = None
self.promises = []
self.object = DjangoDebug(sql=[], exceptions=[])
self.enable_instrumentation()
def get_debug_promise(self):
if not self.debug_promise:
self.debug_promise = Promise.all(self.promises)
self.promises = []
return self.debug_promise.then(self.on_resolve_all_promises).get()
def on_resolve_error(self, value):
if hasattr(self, "object"):
self.object.exceptions.append(wrap_exception(value))
return Promise.reject(value)
def on_resolve_all_promises(self, values):
if self.promises:
self.debug_promise = None
return self.get_debug_promise()
self.disable_instrumentation()
return self.object
def add_promise(self, promise):
if self.debug_promise:
self.promises.append(promise)
def enable_instrumentation(self):
# This is thread-safe because database connections are thread-local.
for connection in connections.all():
wrap_cursor(connection, self)
def disable_instrumentation(self):
for connection in connections.all():
unwrap_cursor(connection)
class DjangoDebugMiddleware(object):
def resolve(self, next, root, info, **args):
context = info.context
django_debug = getattr(context, "django_debug", None)
if not django_debug:
if context is None:
raise Exception("DjangoDebug cannot be executed in None contexts")
try:
context.django_debug = DjangoDebugContext()
except Exception:
raise Exception(
"DjangoDebug need the context to be writable, context received: {}.".format(
context.__class__.__name__
)
)
if info.schema.get_type("DjangoDebug") == info.return_type:
return context.django_debug.get_debug_promise()
try:
promise = next(root, info, **args)
except Exception as e:
return context.django_debug.on_resolve_error(e)
context.django_debug.add_promise(promise)
return promise

View File

@ -1,169 +0,0 @@
# Code obtained from django-debug-toolbar sql panel tracking
from __future__ import absolute_import, unicode_literals
import json
from threading import local
from time import time
from django.utils.encoding import force_str
from .types import DjangoDebugSQL
class SQLQueryTriggered(Exception):
"""Thrown when template panel triggers a query"""
class ThreadLocalState(local):
def __init__(self):
self.enabled = True
@property
def Wrapper(self):
if self.enabled:
return NormalCursorWrapper
return ExceptionCursorWrapper
def recording(self, v):
self.enabled = v
state = ThreadLocalState()
recording = state.recording # export function
def wrap_cursor(connection, panel):
if not hasattr(connection, "_graphene_cursor"):
connection._graphene_cursor = connection.cursor
def cursor():
return state.Wrapper(connection._graphene_cursor(), connection, panel)
connection.cursor = cursor
return cursor
def unwrap_cursor(connection):
if hasattr(connection, "_graphene_cursor"):
previous_cursor = connection._graphene_cursor
connection.cursor = previous_cursor
del connection._graphene_cursor
class ExceptionCursorWrapper(object):
"""
Wraps a cursor and raises an exception on any operation.
Used in Templates panel.
"""
def __init__(self, cursor, db, logger):
pass
def __getattr__(self, attr):
raise SQLQueryTriggered()
class NormalCursorWrapper(object):
"""
Wraps a cursor and logs queries.
"""
def __init__(self, cursor, db, logger):
self.cursor = cursor
# Instance of a BaseDatabaseWrapper subclass
self.db = db
# logger must implement a ``record`` method
self.logger = logger
def _quote_expr(self, element):
if isinstance(element, str):
return "'%s'" % force_str(element).replace("'", "''")
else:
return repr(element)
def _quote_params(self, params):
if not params:
return params
if isinstance(params, dict):
return dict((key, self._quote_expr(value)) for key, value in params.items())
return list(map(self._quote_expr, params))
def _decode(self, param):
try:
return force_str(param, strings_only=True)
except UnicodeDecodeError:
return "(encoded string)"
def _record(self, method, sql, params):
start_time = time()
try:
return method(sql, params)
finally:
stop_time = time()
duration = stop_time - start_time
_params = ""
try:
_params = json.dumps(list(map(self._decode, params)))
except Exception:
pass # object not JSON serializable
alias = getattr(self.db, "alias", "default")
conn = self.db.connection
vendor = getattr(conn, "vendor", "unknown")
params = {
"vendor": vendor,
"alias": alias,
"sql": self.db.ops.last_executed_query(
self.cursor, sql, self._quote_params(params)
),
"duration": duration,
"raw_sql": sql,
"params": _params,
"start_time": start_time,
"stop_time": stop_time,
"is_slow": duration > 10,
"is_select": sql.lower().strip().startswith("select"),
}
if vendor == "postgresql":
# If an erroneous query was ran on the connection, it might
# be in a state where checking isolation_level raises an
# exception.
try:
iso_level = conn.isolation_level
except conn.InternalError:
iso_level = "unknown"
params.update(
{
"trans_id": self.logger.get_transaction_id(alias),
"trans_status": conn.get_transaction_status(),
"iso_level": iso_level,
"encoding": conn.encoding,
}
)
_sql = DjangoDebugSQL(**params)
# We keep `sql` to maintain backwards compatibility
self.logger.object.sql.append(_sql)
def callproc(self, procname, params=None):
return self._record(self.cursor.callproc, procname, params)
def execute(self, sql, params=None):
return self._record(self.cursor.execute, sql, params)
def executemany(self, sql, param_list):
return self._record(self.cursor.executemany, sql, param_list)
def __getattr__(self, attr):
return getattr(self.cursor, attr)
def __iter__(self):
return iter(self.cursor)
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.close()

View File

@ -1,41 +0,0 @@
from graphene import Boolean, Float, ObjectType, String
class DjangoDebugSQL(ObjectType):
class Meta:
description = "Represents a single database query made to a Django managed DB."
vendor = String(
required=True,
description=(
"The type of database being used (e.g. postrgesql, mysql, sqlite)."
),
)
alias = String(
required=True, description="The Django database alias (e.g. 'default')."
)
sql = String(description="The actual SQL sent to this database.")
duration = Float(
required=True, description="Duration of this database query in seconds."
)
raw_sql = String(
required=True, description="The raw SQL of this query, without params."
)
params = String(
required=True, description="JSON encoded database query parameters."
)
start_time = Float(required=True, description="Start time of this database query.")
stop_time = Float(required=True, description="Stop time of this database query.")
is_slow = Boolean(
required=True,
description="Whether this database query took more than 10 seconds.",
)
is_select = Boolean(
required=True, description="Whether this database query was a SELECT."
)
# Postgres
trans_id = String(description="Postgres transaction ID if available.")
trans_status = String(description="Postgres transaction status if available.")
iso_level = String(description="Postgres isolation level if available.")
encoding = String(description="Postgres connection encoding if available.")

View File

@ -1,313 +0,0 @@
import graphene
import pytest
from graphene.relay import Node
from graphene_django import DjangoConnectionField, DjangoObjectType
from ...tests.models import Reporter
from ..middleware import DjangoDebugMiddleware
from ..types import DjangoDebug
class context(object):
pass
def test_should_query_field():
r1 = Reporter(last_name="ABA")
r1.save()
r2 = Reporter(last_name="Griffin")
r2.save()
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType)
debug = graphene.Field(DjangoDebug, name="_debug")
def resolve_reporter(self, info, **args):
return Reporter.objects.first()
query = """
query ReporterQuery {
reporter {
lastName
}
_debug {
sql {
rawSql
}
}
}
"""
expected = {
"reporter": {"lastName": "ABA"},
"_debug": {"sql": [{"rawSql": str(Reporter.objects.order_by("pk")[:1].query)}]},
}
schema = graphene.Schema(query=Query)
result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
)
assert not result.errors
assert result.data == expected
@pytest.mark.parametrize("max_limit", [None, 100])
def test_should_query_nested_field(graphene_settings, max_limit):
graphene_settings.RELAY_CONNECTION_MAX_LIMIT = max_limit
r1 = Reporter(last_name="ABA")
r1.save()
r2 = Reporter(last_name="Griffin")
r2.save()
r2.pets.add(r1)
r1.pets.add(r2)
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType)
debug = graphene.Field(DjangoDebug, name="_debug")
def resolve_reporter(self, info, **args):
return Reporter.objects.first()
query = """
query ReporterQuery {
reporter {
lastName
pets { edges { node {
lastName
pets { edges { node { lastName } } }
} } }
}
_debug {
sql {
rawSql
}
}
}
"""
expected = {
"reporter": {
"lastName": "ABA",
"pets": {
"edges": [
{
"node": {
"lastName": "Griffin",
"pets": {"edges": [{"node": {"lastName": "ABA"}}]},
}
}
]
},
}
}
schema = graphene.Schema(query=Query)
result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
)
assert not result.errors
query = str(Reporter.objects.order_by("pk")[:1].query)
assert result.data["_debug"]["sql"][0]["rawSql"] == query
assert "COUNT" in result.data["_debug"]["sql"][1]["rawSql"]
assert "tests_reporter_pets" in result.data["_debug"]["sql"][2]["rawSql"]
assert "COUNT" in result.data["_debug"]["sql"][3]["rawSql"]
assert "tests_reporter_pets" in result.data["_debug"]["sql"][4]["rawSql"]
assert len(result.data["_debug"]["sql"]) == 5
assert result.data["reporter"] == expected["reporter"]
def test_should_query_list():
r1 = Reporter(last_name="ABA")
r1.save()
r2 = Reporter(last_name="Griffin")
r2.save()
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType):
all_reporters = graphene.List(ReporterType)
debug = graphene.Field(DjangoDebug, name="_debug")
def resolve_all_reporters(self, info, **args):
return Reporter.objects.all()
query = """
query ReporterQuery {
allReporters {
lastName
}
_debug {
sql {
rawSql
}
}
}
"""
expected = {
"allReporters": [{"lastName": "ABA"}, {"lastName": "Griffin"}],
"_debug": {"sql": [{"rawSql": str(Reporter.objects.all().query)}]},
}
schema = graphene.Schema(query=Query)
result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
)
assert not result.errors
assert result.data == expected
@pytest.mark.parametrize("max_limit", [None, 100])
def test_should_query_connection(graphene_settings, max_limit):
graphene_settings.RELAY_CONNECTION_MAX_LIMIT = max_limit
r1 = Reporter(last_name="ABA")
r1.save()
r2 = Reporter(last_name="Griffin")
r2.save()
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
debug = graphene.Field(DjangoDebug, name="_debug")
def resolve_all_reporters(self, info, **args):
return Reporter.objects.all()
query = """
query ReporterQuery {
allReporters(first:1) {
edges {
node {
lastName
}
}
}
_debug {
sql {
rawSql
}
}
}
"""
expected = {"allReporters": {"edges": [{"node": {"lastName": "ABA"}}]}}
schema = graphene.Schema(query=Query)
result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
)
assert not result.errors
assert result.data["allReporters"] == expected["allReporters"]
assert len(result.data["_debug"]["sql"]) == 2
assert "COUNT" in result.data["_debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query)
assert result.data["_debug"]["sql"][1]["rawSql"] == query
@pytest.mark.parametrize("max_limit", [None, 100])
def test_should_query_connectionfilter(graphene_settings, max_limit):
graphene_settings.RELAY_CONNECTION_MAX_LIMIT = max_limit
from ...filter import DjangoFilterConnectionField
r1 = Reporter(last_name="ABA")
r1.save()
r2 = Reporter(last_name="Griffin")
r2.save()
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterType, fields=["last_name"])
s = graphene.String(resolver=lambda *_: "S")
debug = graphene.Field(DjangoDebug, name="_debug")
def resolve_all_reporters(self, info, **args):
return Reporter.objects.all()
query = """
query ReporterQuery {
allReporters(first:1) {
edges {
node {
lastName
}
}
}
_debug {
sql {
rawSql
}
}
}
"""
expected = {"allReporters": {"edges": [{"node": {"lastName": "ABA"}}]}}
schema = graphene.Schema(query=Query)
result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
)
assert not result.errors
assert result.data["allReporters"] == expected["allReporters"]
assert len(result.data["_debug"]["sql"]) == 2
assert "COUNT" in result.data["_debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query)
assert result.data["_debug"]["sql"][1]["rawSql"] == query
def test_should_query_stack_trace():
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType)
debug = graphene.Field(DjangoDebug, name="_debug")
def resolve_reporter(self, info, **args):
raise Exception("caught stack trace")
query = """
query ReporterQuery {
reporter {
lastName
}
_debug {
exceptions {
message
stack
}
}
}
"""
schema = graphene.Schema(query=Query)
result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
)
assert result.errors
assert len(result.data["_debug"]["exceptions"])
debug_exception = result.data["_debug"]["exceptions"][0]
assert debug_exception["stack"].count("\n") > 1
assert "test_query.py" in debug_exception["stack"]
assert debug_exception["message"] == "caught stack trace"

View File

@ -1,14 +0,0 @@
from graphene import List, ObjectType
from .sql.types import DjangoDebugSQL
from .exception.types import DjangoDebugException
class DjangoDebug(ObjectType):
class Meta:
description = "Debugging information for the current query."
sql = List(DjangoDebugSQL, description="Executed SQL queries for this API query.")
exceptions = List(
DjangoDebugException, description="Raise exceptions for this API query."
)

View File

@ -1,249 +0,0 @@
from functools import partial
from django.db.models.query import QuerySet
from graphql_relay.connection.arrayconnection import (
connection_from_array_slice,
cursor_to_offset,
get_offset_with_default,
offset_to_cursor,
)
from promise import Promise
from graphene import Int, NonNull
from graphene.relay import ConnectionField
from graphene.relay.connection import connection_adapter, page_info_adapter
from graphene.types import Field, List
from .settings import graphene_settings
from .utils import maybe_queryset
class DjangoListField(Field):
def __init__(self, _type, *args, **kwargs):
from .types import DjangoObjectType
if isinstance(_type, NonNull):
_type = _type.of_type
# Django would never return a Set of None vvvvvvv
super(DjangoListField, self).__init__(List(NonNull(_type)), *args, **kwargs)
assert issubclass(
self._underlying_type, DjangoObjectType
), "DjangoListField only accepts DjangoObjectType types"
@property
def _underlying_type(self):
_type = self._type
while hasattr(_type, "of_type"):
_type = _type.of_type
return _type
@property
def model(self):
return self._underlying_type._meta.model
def get_manager(self):
return self.model._default_manager
@staticmethod
def list_resolver(
django_object_type, resolver, default_manager, root, info, **args
):
queryset = maybe_queryset(resolver(root, info, **args))
if queryset is None:
queryset = maybe_queryset(default_manager)
if isinstance(queryset, QuerySet):
# Pass queryset to the DjangoObjectType get_queryset method
queryset = maybe_queryset(django_object_type.get_queryset(queryset, info))
return queryset
def wrap_resolve(self, parent_resolver):
resolver = super(DjangoListField, self).wrap_resolve(parent_resolver)
_type = self.type
if isinstance(_type, NonNull):
_type = _type.of_type
django_object_type = _type.of_type.of_type
return partial(
self.list_resolver, django_object_type, resolver, self.get_manager(),
)
class DjangoConnectionField(ConnectionField):
def __init__(self, *args, **kwargs):
self.on = kwargs.pop("on", False)
self.max_limit = kwargs.pop(
"max_limit", graphene_settings.RELAY_CONNECTION_MAX_LIMIT
)
self.enforce_first_or_last = kwargs.pop(
"enforce_first_or_last",
graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST,
)
kwargs.setdefault("offset", Int())
super(DjangoConnectionField, self).__init__(*args, **kwargs)
@property
def type(self):
from .types import DjangoObjectType
_type = super(ConnectionField, self).type
non_null = False
if isinstance(_type, NonNull):
_type = _type.of_type
non_null = True
assert issubclass(
_type, DjangoObjectType
), "DjangoConnectionField only accepts DjangoObjectType types"
assert _type._meta.connection, "The type {} doesn't have a connection".format(
_type.__name__
)
connection_type = _type._meta.connection
if non_null:
return NonNull(connection_type)
return connection_type
@property
def connection_type(self):
type = self.type
if isinstance(type, NonNull):
return type.of_type
return type
@property
def node_type(self):
return self.connection_type._meta.node
@property
def model(self):
return self.node_type._meta.model
def get_manager(self):
if self.on:
return getattr(self.model, self.on)
else:
return self.model._default_manager
@classmethod
def resolve_queryset(cls, connection, queryset, info, args):
# queryset is the resolved iterable from ObjectType
return connection._meta.node.get_queryset(queryset, info)
@classmethod
def resolve_connection(cls, connection, args, iterable, max_limit=None):
# Remove the offset parameter and convert it to an after cursor.
offset = args.pop("offset", None)
after = args.get("after")
if offset:
if after:
offset += cursor_to_offset(after) + 1
# input offset starts at 1 while the graphene offset starts at 0
args["after"] = offset_to_cursor(offset - 1)
iterable = maybe_queryset(iterable)
if isinstance(iterable, QuerySet):
list_length = iterable.count()
else:
list_length = len(iterable)
list_slice_length = (
min(max_limit, list_length) if max_limit is not None else list_length
)
# If after is higher than list_length, connection_from_list_slice
# would try to do a negative slicing which makes django throw an
# AssertionError
after = min(get_offset_with_default(args.get("after"), -1) + 1, list_length)
if max_limit is not None and args.get("first", None) is None:
if args.get("last", None) is not None:
after = list_length - args["last"]
else:
args["first"] = max_limit
connection = connection_from_array_slice(
iterable[after:],
args,
slice_start=after,
array_length=list_length,
array_slice_length=list_slice_length,
connection_type=partial(connection_adapter, connection),
edge_type=connection.Edge,
page_info_type=page_info_adapter,
)
connection.iterable = iterable
connection.length = list_length
return connection
@classmethod
def connection_resolver(
cls,
resolver,
connection,
default_manager,
queryset_resolver,
max_limit,
enforce_first_or_last,
root,
info,
**args
):
first = args.get("first")
last = args.get("last")
offset = args.get("offset")
before = args.get("before")
if enforce_first_or_last:
assert first or last, (
"You must provide a `first` or `last` value to properly paginate the `{}` connection."
).format(info.field_name)
if max_limit:
if first:
assert first <= max_limit, (
"Requesting {} records on the `{}` connection exceeds the `first` limit of {} records."
).format(first, info.field_name, max_limit)
args["first"] = min(first, max_limit)
if last:
assert last <= max_limit, (
"Requesting {} records on the `{}` connection exceeds the `last` limit of {} records."
).format(last, info.field_name, max_limit)
args["last"] = min(last, max_limit)
if offset is not None:
assert before is None, (
"You can't provide a `before` value at the same time as an `offset` value to properly paginate the `{}` connection."
).format(info.field_name)
# eventually leads to DjangoObjectType's get_queryset (accepts queryset)
# or a resolve_foo (does not accept queryset)
iterable = resolver(root, info, **args)
if iterable is None:
iterable = default_manager
# thus the iterable gets refiltered by resolve_queryset
# but iterable might be promise
iterable = queryset_resolver(connection, iterable, info, args)
on_resolve = partial(
cls.resolve_connection, connection, args, max_limit=max_limit
)
if Promise.is_thenable(iterable):
return Promise.resolve(iterable).then(on_resolve)
return on_resolve(iterable)
def wrap_resolve(self, parent_resolver):
return partial(
self.connection_resolver,
parent_resolver,
self.connection_type,
self.get_manager(),
self.get_queryset_resolver(),
self.max_limit,
self.enforce_first_or_last,
)
def get_queryset_resolver(self):
return self.resolve_queryset

View File

@ -1,29 +0,0 @@
import warnings
from ..utils import DJANGO_FILTER_INSTALLED
if not DJANGO_FILTER_INSTALLED:
warnings.warn(
"Use of django filtering requires the django-filter package "
"be installed. You can do so using `pip install django-filter`",
ImportWarning,
)
else:
from .fields import DjangoFilterConnectionField
from .filters import (
ArrayFilter,
GlobalIDFilter,
GlobalIDMultipleChoiceFilter,
ListFilter,
RangeFilter,
TypedFilter,
)
__all__ = [
"DjangoFilterConnectionField",
"GlobalIDFilter",
"GlobalIDMultipleChoiceFilter",
"ArrayFilter",
"ListFilter",
"RangeFilter",
"TypedFilter",
]

View File

@ -1,109 +0,0 @@
from collections import OrderedDict
from functools import partial
from django.core.exceptions import ValidationError
from graphene.types.enum import EnumType
from graphene.types.argument import to_arguments
from graphene.utils.str_converters import to_snake_case
from ..fields import DjangoConnectionField
from .utils import get_filtering_args_from_filterset, get_filterset_class
def convert_enum(data):
"""
Check if the data is a enum option (or potentially nested list of enum option)
and convert it to its value.
This method is used to pre-process the data for the filters as they can take an
graphene.Enum as argument, but filters (from django_filters) expect a simple value.
"""
if isinstance(data, list):
return [convert_enum(item) for item in data]
if isinstance(type(data), EnumType):
return data.value
else:
return data
class DjangoFilterConnectionField(DjangoConnectionField):
def __init__(
self,
type,
fields=None,
order_by=None,
extra_filter_meta=None,
filterset_class=None,
*args,
**kwargs
):
self._fields = fields
self._provided_filterset_class = filterset_class
self._filterset_class = None
self._filtering_args = None
self._extra_filter_meta = extra_filter_meta
self._base_args = None
super(DjangoFilterConnectionField, self).__init__(type, *args, **kwargs)
@property
def args(self):
return to_arguments(self._base_args or OrderedDict(), self.filtering_args)
@args.setter
def args(self, args):
self._base_args = args
@property
def filterset_class(self):
if not self._filterset_class:
fields = self._fields or self.node_type._meta.filter_fields
meta = dict(model=self.model, fields=fields)
if self._extra_filter_meta:
meta.update(self._extra_filter_meta)
filterset_class = (
self._provided_filterset_class or self.node_type._meta.filterset_class
)
self._filterset_class = get_filterset_class(filterset_class, **meta)
return self._filterset_class
@property
def filtering_args(self):
if not self._filtering_args:
self._filtering_args = get_filtering_args_from_filterset(
self.filterset_class, self.node_type
)
return self._filtering_args
@classmethod
def resolve_queryset(
cls, connection, iterable, info, args, filtering_args, filterset_class
):
def filter_kwargs():
kwargs = {}
for k, v in args.items():
if k in filtering_args:
if k == "order_by" and v is not None:
v = to_snake_case(v)
kwargs[k] = convert_enum(v)
return kwargs
qs = super(DjangoFilterConnectionField, cls).resolve_queryset(
connection, iterable, info, args
)
filterset = filterset_class(
data=filter_kwargs(), queryset=qs, request=info.context
)
if filterset.is_valid():
return filterset.qs
raise ValidationError(filterset.form.errors.as_json())
def get_queryset_resolver(self):
return partial(
self.resolve_queryset,
filterset_class=self.filterset_class,
filtering_args=self.filtering_args,
)

View File

@ -1,25 +0,0 @@
import warnings
from ...utils import DJANGO_FILTER_INSTALLED
if not DJANGO_FILTER_INSTALLED:
warnings.warn(
"Use of django filtering requires the django-filter package "
"be installed. You can do so using `pip install django-filter`",
ImportWarning,
)
else:
from .array_filter import ArrayFilter
from .global_id_filter import GlobalIDFilter, GlobalIDMultipleChoiceFilter
from .list_filter import ListFilter
from .range_filter import RangeFilter
from .typed_filter import TypedFilter
__all__ = [
"DjangoFilterConnectionField",
"GlobalIDFilter",
"GlobalIDMultipleChoiceFilter",
"ArrayFilter",
"ListFilter",
"RangeFilter",
"TypedFilter",
]

View File

@ -1,27 +0,0 @@
from django_filters.constants import EMPTY_VALUES
from .typed_filter import TypedFilter
class ArrayFilter(TypedFilter):
"""
Filter made for PostgreSQL ArrayField.
"""
def filter(self, qs, value):
"""
Override the default filter class to check first whether the list is
empty or not.
This needs to be done as in this case we expect to get the filter applied with
an empty list since it's a valid value but django_filter consider an empty list
to be an empty input value (see `EMPTY_VALUES`) meaning that
the filter does not need to be applied (hence returning the original
queryset).
"""
if value in EMPTY_VALUES and value != []:
return qs
if self.distinct:
qs = qs.distinct()
lookup = "%s__%s" % (self.field_name, self.lookup_expr)
qs = self.get_method(qs)(**{lookup: value})
return qs

View File

@ -1,28 +0,0 @@
from django_filters import Filter, MultipleChoiceFilter
from graphql_relay.node.node import from_global_id
from ...forms import GlobalIDFormField, GlobalIDMultipleChoiceField
class GlobalIDFilter(Filter):
"""
Filter for Relay global ID.
"""
field_class = GlobalIDFormField
def filter(self, qs, value):
""" Convert the filter value to a primary key before filtering """
_id = None
if value is not None:
_, _id = from_global_id(value)
return super(GlobalIDFilter, self).filter(qs, _id)
class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
field_class = GlobalIDMultipleChoiceField
def filter(self, qs, value):
gids = [from_global_id(v)[1] for v in value]
return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids)

View File

@ -1,26 +0,0 @@
from .typed_filter import TypedFilter
class ListFilter(TypedFilter):
"""
Filter that takes a list of value as input.
It is for example used for `__in` filters.
"""
def filter(self, qs, value):
"""
Override the default filter class to check first whether the list is
empty or not.
This needs to be done as in this case we expect to get an empty output
(if not an exclude filter) but django_filter consider an empty list
to be an empty input value (see `EMPTY_VALUES`) meaning that
the filter does not need to be applied (hence returning the original
queryset).
"""
if value is not None and len(value) == 0:
if self.exclude:
return qs
else:
return qs.none()
else:
return super(ListFilter, self).filter(qs, value)

View File

@ -1,24 +0,0 @@
from django.core.exceptions import ValidationError
from django.forms import Field
from .typed_filter import TypedFilter
def validate_range(value):
"""
Validator for range filter input: the list of value must be of length 2.
Note that validators are only run if the value is not empty.
"""
if len(value) != 2:
raise ValidationError(
"Invalid range specified: it needs to contain 2 values.", code="invalid"
)
class RangeField(Field):
default_validators = [validate_range]
empty_values = [None]
class RangeFilter(TypedFilter):
field_class = RangeField

View File

@ -1,27 +0,0 @@
from django_filters import Filter
from graphene.types.utils import get_type
class TypedFilter(Filter):
"""
Filter class for which the input GraphQL type can explicitly be provided.
If it is not provided, when building the schema, it will try to guess
it from the field.
"""
def __init__(self, input_type=None, *args, **kwargs):
self._input_type = input_type
super(TypedFilter, self).__init__(*args, **kwargs)
@property
def input_type(self):
input_type = get_type(self._input_type)
if input_type is not None:
if not callable(getattr(input_type, "get_type", None)):
raise ValueError(
"Wrong `input_type` for {}: it only accepts graphene types, got {}".format(
self.__class__.__name__, input_type
)
)
return input_type

View File

@ -1,51 +0,0 @@
import itertools
from django.db import models
from django_filters.filterset import BaseFilterSet, FilterSet
from django_filters.filterset import FILTER_FOR_DBFIELD_DEFAULTS
from .filters import GlobalIDFilter, GlobalIDMultipleChoiceFilter
GRAPHENE_FILTER_SET_OVERRIDES = {
models.AutoField: {"filter_class": GlobalIDFilter},
models.OneToOneField: {"filter_class": GlobalIDFilter},
models.ForeignKey: {"filter_class": GlobalIDFilter},
models.ManyToManyField: {"filter_class": GlobalIDMultipleChoiceFilter},
models.ManyToOneRel: {"filter_class": GlobalIDMultipleChoiceFilter},
models.ManyToManyRel: {"filter_class": GlobalIDMultipleChoiceFilter},
}
class GrapheneFilterSetMixin(BaseFilterSet):
""" A django_filters.filterset.BaseFilterSet with default filter overrides
to handle global IDs """
FILTER_DEFAULTS = dict(
itertools.chain(
FILTER_FOR_DBFIELD_DEFAULTS.items(), GRAPHENE_FILTER_SET_OVERRIDES.items()
)
)
def setup_filterset(filterset_class):
""" Wrap a provided filterset in Graphene-specific functionality
"""
return type(
"Graphene{}".format(filterset_class.__name__),
(filterset_class, GrapheneFilterSetMixin),
{},
)
def custom_filterset_factory(model, filterset_base_class=FilterSet, **meta):
""" Create a filterset for the given model using the provided meta data
"""
meta.update({"model": model})
meta_class = type(str("Meta"), (object,), meta)
filterset = type(
str("%sFilterSet" % model._meta.object_name),
(filterset_base_class, GrapheneFilterSetMixin),
{"Meta": meta_class},
)
return filterset

View File

@ -1,151 +0,0 @@
from mock import MagicMock
import pytest
from django.db import models
from django.db.models.query import QuerySet
from django_filters import filters
from django_filters import FilterSet
import graphene
from graphene.relay import Node
from graphene_django import DjangoObjectType
from graphene_django.utils import DJANGO_FILTER_INSTALLED
from graphene_django.filter import ArrayFilter, ListFilter
from ...compat import ArrayField
pytestmark = []
if DJANGO_FILTER_INSTALLED:
from graphene_django.filter import DjangoFilterConnectionField
else:
pytestmark.append(
pytest.mark.skipif(
True, reason="django_filters not installed or not compatible"
)
)
STORE = {"events": []}
class Event(models.Model):
name = models.CharField(max_length=50)
tags = ArrayField(models.CharField(max_length=50))
tag_ids = ArrayField(models.IntegerField())
random_field = ArrayField(models.BooleanField())
@pytest.fixture
def EventFilterSet():
class EventFilterSet(FilterSet):
class Meta:
model = Event
fields = {
"name": ["exact", "contains"],
}
# Those are actually usable with our Query fixture bellow
tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains")
tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap")
tags = ArrayFilter(field_name="tags", lookup_expr="exact")
# Those are actually not usable and only to check type declarations
tags_ids__contains = ArrayFilter(field_name="tag_ids", lookup_expr="contains")
tags_ids__overlap = ArrayFilter(field_name="tag_ids", lookup_expr="overlap")
tags_ids = ArrayFilter(field_name="tag_ids", lookup_expr="exact")
random_field__contains = ArrayFilter(
field_name="random_field", lookup_expr="contains"
)
random_field__overlap = ArrayFilter(
field_name="random_field", lookup_expr="overlap"
)
random_field = ArrayFilter(field_name="random_field", lookup_expr="exact")
return EventFilterSet
@pytest.fixture
def EventType(EventFilterSet):
class EventType(DjangoObjectType):
class Meta:
model = Event
interfaces = (Node,)
fields = "__all__"
filterset_class = EventFilterSet
return EventType
@pytest.fixture
def Query(EventType):
"""
Note that we have to use a custom resolver to replicate the arrayfield filter behavior as
we are running unit tests in sqlite which does not have ArrayFields.
"""
class Query(graphene.ObjectType):
events = DjangoFilterConnectionField(EventType)
def resolve_events(self, info, **kwargs):
events = [
Event(name="Live Show", tags=["concert", "music", "rock"],),
Event(name="Musical", tags=["movie", "music"],),
Event(name="Ballet", tags=["concert", "dance"],),
Event(name="Speech", tags=[],),
]
STORE["events"] = events
m_queryset = MagicMock(spec=QuerySet)
m_queryset.model = Event
def filter_events(**kwargs):
if "tags__contains" in kwargs:
STORE["events"] = list(
filter(
lambda e: set(kwargs["tags__contains"]).issubset(
set(e.tags)
),
STORE["events"],
)
)
if "tags__overlap" in kwargs:
STORE["events"] = list(
filter(
lambda e: not set(kwargs["tags__overlap"]).isdisjoint(
set(e.tags)
),
STORE["events"],
)
)
if "tags__exact" in kwargs:
STORE["events"] = list(
filter(
lambda e: set(kwargs["tags__exact"]) == set(e.tags),
STORE["events"],
)
)
def mock_queryset_filter(*args, **kwargs):
filter_events(**kwargs)
return m_queryset
def mock_queryset_none(*args, **kwargs):
STORE["events"] = []
return m_queryset
def mock_queryset_count(*args, **kwargs):
return len(STORE["events"])
m_queryset.all.return_value = m_queryset
m_queryset.filter.side_effect = mock_queryset_filter
m_queryset.none.side_effect = mock_queryset_none
m_queryset.count.side_effect = mock_queryset_count
m_queryset.__getitem__.side_effect = lambda index: STORE[
"events"
].__getitem__(index)
return m_queryset
return Query

View File

@ -1,30 +0,0 @@
import django_filters
from django_filters import OrderingFilter
from graphene_django.tests.models import Article, Pet, Reporter
class ArticleFilter(django_filters.FilterSet):
class Meta:
model = Article
fields = {
"headline": ["exact", "icontains"],
"pub_date": ["gt", "lt", "exact"],
"reporter": ["exact", "in"],
}
order_by = OrderingFilter(fields=("pub_date",))
class ReporterFilter(django_filters.FilterSet):
class Meta:
model = Reporter
fields = ["first_name", "last_name", "email", "pets"]
order_by = OrderingFilter(fields=("first_name",))
class PetFilter(django_filters.FilterSet):
class Meta:
model = Pet
fields = ["name"]

View File

@ -1,87 +0,0 @@
import pytest
from graphene import Schema
from ...compat import ArrayField, MissingType
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_contains_multiple(Query):
"""
Test contains filter on a array field of string.
"""
schema = Schema(query=Query)
query = """
query {
events (tags_Contains: ["concert", "music"]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == [
{"node": {"name": "Live Show"}},
]
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_contains_one(Query):
"""
Test contains filter on a array field of string.
"""
schema = Schema(query=Query)
query = """
query {
events (tags_Contains: ["music"]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == [
{"node": {"name": "Live Show"}},
{"node": {"name": "Musical"}},
]
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_contains_empty_list(Query):
"""
Test contains filter on a array field of string.
"""
schema = Schema(query=Query)
query = """
query {
events (tags_Contains: []) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == [
{"node": {"name": "Live Show"}},
{"node": {"name": "Musical"}},
{"node": {"name": "Ballet"}},
{"node": {"name": "Speech"}},
]

View File

@ -1,130 +0,0 @@
import pytest
from graphene import Schema
from ...compat import ArrayField, MissingType
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_exact_no_match(Query):
"""
Test exact filter on a array field of string.
"""
schema = Schema(query=Query)
query = """
query {
events (tags: ["concert", "music"]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == []
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_exact_match(Query):
"""
Test exact filter on a array field of string.
"""
schema = Schema(query=Query)
query = """
query {
events (tags: ["movie", "music"]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == [
{"node": {"name": "Musical"}},
]
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_exact_empty_list(Query):
"""
Test exact filter on a array field of string.
"""
schema = Schema(query=Query)
query = """
query {
events (tags: []) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == [
{"node": {"name": "Speech"}},
]
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_filter_schema_type(Query):
"""
Check that the type in the filter is an array field like on the object type.
"""
schema = Schema(query=Query)
schema_str = str(schema)
assert (
'''type EventType implements Node {
"""The ID of the object"""
id: ID!
name: String!
tags: [String!]!
tagIds: [Int!]!
randomField: [Boolean!]!
}'''
in schema_str
)
filters = {
"offset": "Int",
"before": "String",
"after": "String",
"first": "Int",
"last": "Int",
"name": "String",
"name_Contains": "String",
"tags_Contains": "[String!]",
"tags_Overlap": "[String!]",
"tags": "[String!]",
"tagsIds_Contains": "[Int!]",
"tagsIds_Overlap": "[Int!]",
"tagsIds": "[Int!]",
"randomField_Contains": "[Boolean!]",
"randomField_Overlap": "[Boolean!]",
"randomField": "[Boolean!]",
}
filters_str = ", ".join(
[
f"{filter_field}: {gql_type} = null"
for filter_field, gql_type in filters.items()
]
)
assert (
f"type Query {{\n events({filters_str}): EventTypeConnection\n}}" in schema_str
)

View File

@ -1,84 +0,0 @@
import pytest
from graphene import Schema
from ...compat import ArrayField, MissingType
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_overlap_multiple(Query):
"""
Test overlap filter on a array field of string.
"""
schema = Schema(query=Query)
query = """
query {
events (tags_Overlap: ["concert", "music"]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == [
{"node": {"name": "Live Show"}},
{"node": {"name": "Musical"}},
{"node": {"name": "Ballet"}},
]
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_overlap_one(Query):
"""
Test overlap filter on a array field of string.
"""
schema = Schema(query=Query)
query = """
query {
events (tags_Overlap: ["music"]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == [
{"node": {"name": "Live Show"}},
{"node": {"name": "Musical"}},
]
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_overlap_empty_list(Query):
"""
Test overlap filter on a array field of string.
"""
schema = Schema(query=Query)
query = """
query {
events (tags_Overlap: []) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == []

View File

@ -1,160 +0,0 @@
import pytest
import graphene
from graphene.relay import Node
from graphene_django import DjangoObjectType, DjangoConnectionField
from graphene_django.tests.models import Article, Reporter
from graphene_django.utils import DJANGO_FILTER_INSTALLED
pytestmark = []
if DJANGO_FILTER_INSTALLED:
from graphene_django.filter import DjangoFilterConnectionField
else:
pytestmark.append(
pytest.mark.skipif(
True, reason="django_filters not installed or not compatible"
)
)
@pytest.fixture
def schema():
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class ArticleType(DjangoObjectType):
class Meta:
model = Article
interfaces = (Node,)
fields = "__all__"
filter_fields = {
"lang": ["exact", "in"],
"reporter__a_choice": ["exact", "in"],
}
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
all_articles = DjangoFilterConnectionField(ArticleType)
schema = graphene.Schema(query=Query)
return schema
@pytest.fixture
def reporter_article_data():
john = Reporter.objects.create(
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
)
jane = Reporter.objects.create(
first_name="Jane", last_name="Doe", email="janedoe@example.com", a_choice=2
)
Article.objects.create(
headline="Article Node 1", reporter=john, editor=john, lang="es",
)
Article.objects.create(
headline="Article Node 2", reporter=john, editor=john, lang="en",
)
Article.objects.create(
headline="Article Node 3", reporter=jane, editor=jane, lang="en",
)
def test_filter_enum_on_connection(schema, reporter_article_data):
"""
Check that we can filter with enums on a connection.
"""
query = """
query {
allArticles(lang: ES) {
edges {
node {
headline
}
}
}
}
"""
expected = {"allArticles": {"edges": [{"node": {"headline": "Article Node 1"}},]}}
result = schema.execute(query)
assert not result.errors
assert result.data == expected
def test_filter_on_foreign_key_enum_field(schema, reporter_article_data):
"""
Check that we can filter with enums on a field from a foreign key.
"""
query = """
query {
allArticles(reporter_AChoice: A_1) {
edges {
node {
headline
}
}
}
}
"""
expected = {
"allArticles": {
"edges": [
{"node": {"headline": "Article Node 1"}},
{"node": {"headline": "Article Node 2"}},
]
}
}
result = schema.execute(query)
assert not result.errors
assert result.data == expected
def test_filter_enum_field_schema_type(schema):
"""
Check that the type in the filter is an enum like on the object type.
"""
schema_str = str(schema)
assert (
'''type ArticleType implements Node {
"""The ID of the object"""
id: ID!
headline: String!
pubDate: Date!
pubDateTime: DateTime!
reporter: ReporterType!
editor: ReporterType!
"""Language"""
lang: TestsArticleLangChoices!
importance: TestsArticleImportanceChoices
}'''
in schema_str
)
filters = {
"offset": "Int",
"before": "String",
"after": "String",
"first": "Int",
"last": "Int",
"lang": "TestsArticleLangChoices",
"lang_In": "[TestsArticleLangChoices]",
"reporter_AChoice": "TestsReporterAChoiceChoices",
"reporter_AChoice_In": "[TestsReporterAChoiceChoices]",
}
filters_str = ", ".join(
[
f"{filter_field}: {gql_type} = null"
for filter_field, gql_type in filters.items()
]
)
assert f" allArticles({filters_str}): ArticleTypeConnection\n" in schema_str

File diff suppressed because it is too large Load Diff

View File

@ -1,448 +0,0 @@
from datetime import datetime
import pytest
from django_filters import FilterSet
from django_filters import rest_framework as filters
from graphene import ObjectType, Schema
from graphene.relay import Node
from graphene_django import DjangoObjectType
from graphene_django.tests.models import Pet, Person, Reporter, Article, Film
from graphene_django.filter.tests.filters import ArticleFilter
from graphene_django.utils import DJANGO_FILTER_INSTALLED
pytestmark = []
if DJANGO_FILTER_INSTALLED:
from graphene_django.filter import DjangoFilterConnectionField
else:
pytestmark.append(
pytest.mark.skipif(
True, reason="django_filters not installed or not compatible"
)
)
@pytest.fixture
def query():
class PetNode(DjangoObjectType):
class Meta:
model = Pet
interfaces = (Node,)
fields = "__all__"
filter_fields = {
"id": ["exact", "in"],
"name": ["exact", "in"],
"age": ["exact", "in", "range"],
}
class ReporterNode(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
# choice filter using enum
filter_fields = {"reporter_type": ["exact", "in"]}
class ArticleNode(DjangoObjectType):
class Meta:
model = Article
interfaces = (Node,)
fields = "__all__"
filterset_class = ArticleFilter
class FilmNode(DjangoObjectType):
class Meta:
model = Film
interfaces = (Node,)
fields = "__all__"
# choice filter not using enum
filter_fields = {
"genre": ["exact", "in"],
}
convert_choices_to_enum = False
class PersonFilterSet(FilterSet):
class Meta:
model = Person
fields = {"name": ["in"]}
names = filters.BaseInFilter(method="filter_names")
def filter_names(self, qs, name, value):
"""
This custom filter take a string as input with comma separated values.
Note that the value here is already a list as it has been transformed by the BaseInFilter class.
"""
return qs.filter(name__in=value)
class PersonNode(DjangoObjectType):
class Meta:
model = Person
interfaces = (Node,)
filterset_class = PersonFilterSet
fields = "__all__"
class Query(ObjectType):
pets = DjangoFilterConnectionField(PetNode)
people = DjangoFilterConnectionField(PersonNode)
articles = DjangoFilterConnectionField(ArticleNode)
films = DjangoFilterConnectionField(FilmNode)
reporters = DjangoFilterConnectionField(ReporterNode)
return Query
def test_string_in_filter(query):
"""
Test in filter on a string field.
"""
Pet.objects.create(name="Brutus", age=12)
Pet.objects.create(name="Mimi", age=3)
Pet.objects.create(name="Jojo, the rabbit", age=3)
schema = Schema(query=query)
query = """
query {
pets (name_In: ["Brutus", "Jojo, the rabbit"]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["pets"]["edges"] == [
{"node": {"name": "Brutus"}},
{"node": {"name": "Jojo, the rabbit"}},
]
def test_string_in_filter_with_otjer_filter(query):
"""
Test in filter on a string field which has also a custom filter doing a similar operation.
"""
Person.objects.create(name="John")
Person.objects.create(name="Michael")
Person.objects.create(name="Angela")
schema = Schema(query=query)
query = """
query {
people (name_In: ["John", "Michael"]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["people"]["edges"] == [
{"node": {"name": "John"}},
{"node": {"name": "Michael"}},
]
def test_string_in_filter_with_declared_filter(query):
"""
Test in filter on a string field with a custom filterset class.
"""
Person.objects.create(name="John")
Person.objects.create(name="Michael")
Person.objects.create(name="Angela")
schema = Schema(query=query)
query = """
query {
people (names: "John,Michael") {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["people"]["edges"] == [
{"node": {"name": "John"}},
{"node": {"name": "Michael"}},
]
def test_int_in_filter(query):
"""
Test in filter on an integer field.
"""
Pet.objects.create(name="Brutus", age=12)
Pet.objects.create(name="Mimi", age=3)
Pet.objects.create(name="Jojo, the rabbit", age=3)
schema = Schema(query=query)
query = """
query {
pets (age_In: [3]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["pets"]["edges"] == [
{"node": {"name": "Mimi"}},
{"node": {"name": "Jojo, the rabbit"}},
]
query = """
query {
pets (age_In: [3, 12]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["pets"]["edges"] == [
{"node": {"name": "Brutus"}},
{"node": {"name": "Mimi"}},
{"node": {"name": "Jojo, the rabbit"}},
]
def test_in_filter_with_empty_list(query):
"""
Check that using a in filter with an empty list provided as input returns no objects.
"""
Pet.objects.create(name="Brutus", age=12)
Pet.objects.create(name="Mimi", age=8)
Pet.objects.create(name="Picotin", age=5)
schema = Schema(query=query)
query = """
query {
pets (name_In: []) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert len(result.data["pets"]["edges"]) == 0
def test_choice_in_filter_without_enum(query):
"""
Test in filter o an choice field not using an enum (Film.genre).
"""
john_doe = Reporter.objects.create(
first_name="John", last_name="Doe", email="john@doe.com"
)
jean_bon = Reporter.objects.create(
first_name="Jean", last_name="Bon", email="jean@bon.com"
)
documentary_film = Film.objects.create(genre="do")
documentary_film.reporters.add(john_doe)
action_film = Film.objects.create(genre="ac")
action_film.reporters.add(john_doe)
other_film = Film.objects.create(genre="ot")
other_film.reporters.add(john_doe)
other_film.reporters.add(jean_bon)
schema = Schema(query=query)
query = """
query {
films (genre_In: ["do", "ac"]) {
edges {
node {
genre
reporters {
edges {
node {
lastName
}
}
}
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["films"]["edges"] == [
{
"node": {
"genre": "do",
"reporters": {"edges": [{"node": {"lastName": "Doe"}}]},
}
},
{
"node": {
"genre": "ac",
"reporters": {"edges": [{"node": {"lastName": "Doe"}}]},
}
},
]
def test_fk_id_in_filter(query):
"""
Test in filter on an foreign key relationship.
"""
john_doe = Reporter.objects.create(
first_name="John", last_name="Doe", email="john@doe.com"
)
jean_bon = Reporter.objects.create(
first_name="Jean", last_name="Bon", email="jean@bon.com"
)
sara_croche = Reporter.objects.create(
first_name="Sara", last_name="Croche", email="sara@croche.com"
)
Article.objects.create(
headline="A",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=john_doe,
editor=john_doe,
)
Article.objects.create(
headline="B",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=jean_bon,
editor=jean_bon,
)
Article.objects.create(
headline="C",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=sara_croche,
editor=sara_croche,
)
schema = Schema(query=query)
query = """
query {
articles (reporter_In: [%s, %s]) {
edges {
node {
headline
reporter {
lastName
}
}
}
}
}
""" % (
john_doe.id,
jean_bon.id,
)
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A", "reporter": {"lastName": "Doe"}}},
{"node": {"headline": "B", "reporter": {"lastName": "Bon"}}},
]
def test_enum_in_filter(query):
"""
Test in filter on a choice field using an enum (Reporter.reporter_type).
"""
Reporter.objects.create(
first_name="John", last_name="Doe", email="john@doe.com", reporter_type=1
)
Reporter.objects.create(
first_name="Jean", last_name="Bon", email="jean@bon.com", reporter_type=2
)
Reporter.objects.create(
first_name="Jane", last_name="Doe", email="jane@doe.com", reporter_type=2
)
Reporter.objects.create(
first_name="Jack", last_name="Black", email="jack@black.com", reporter_type=None
)
schema = Schema(query=query)
query = """
query {
reporters (reporterType_In: [A_1]) {
edges {
node {
email
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["reporters"]["edges"] == [
{"node": {"email": "john@doe.com"}},
]
query = """
query {
reporters (reporterType_In: [A_2]) {
edges {
node {
email
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["reporters"]["edges"] == [
{"node": {"email": "jean@bon.com"}},
{"node": {"email": "jane@doe.com"}},
]
query = """
query {
reporters (reporterType_In: [A_2, A_1]) {
edges {
node {
email
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["reporters"]["edges"] == [
{"node": {"email": "john@doe.com"}},
{"node": {"email": "jean@bon.com"}},
{"node": {"email": "jane@doe.com"}},
]

View File

@ -1,115 +0,0 @@
import json
import pytest
from django_filters import FilterSet
from django_filters import rest_framework as filters
from graphene import ObjectType, Schema
from graphene.relay import Node
from graphene_django import DjangoObjectType
from graphene_django.tests.models import Pet
from graphene_django.utils import DJANGO_FILTER_INSTALLED
pytestmark = []
if DJANGO_FILTER_INSTALLED:
from graphene_django.filter import DjangoFilterConnectionField
else:
pytestmark.append(
pytest.mark.skipif(
True, reason="django_filters not installed or not compatible"
)
)
class PetNode(DjangoObjectType):
class Meta:
model = Pet
interfaces = (Node,)
fields = "__all__"
filter_fields = {
"name": ["exact", "in"],
"age": ["exact", "in", "range"],
}
class Query(ObjectType):
pets = DjangoFilterConnectionField(PetNode)
def test_int_range_filter():
"""
Test range filter on an integer field.
"""
Pet.objects.create(name="Brutus", age=12)
Pet.objects.create(name="Mimi", age=8)
Pet.objects.create(name="Jojo, the rabbit", age=3)
Pet.objects.create(name="Picotin", age=5)
schema = Schema(query=Query)
query = """
query {
pets (age_Range: [4, 9]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["pets"]["edges"] == [
{"node": {"name": "Mimi"}},
{"node": {"name": "Picotin"}},
]
def test_range_filter_with_invalid_input():
"""
Test range filter used with invalid inputs raise an error.
"""
Pet.objects.create(name="Brutus", age=12)
Pet.objects.create(name="Mimi", age=8)
Pet.objects.create(name="Jojo, the rabbit", age=3)
Pet.objects.create(name="Picotin", age=5)
schema = Schema(query=Query)
query = """
query ($rangeValue: [Int]) {
pets (age_Range: $rangeValue) {
edges {
node {
name
}
}
}
}
"""
expected_error = json.dumps(
{
"age__range": [
{
"message": "Invalid range specified: it needs to contain 2 values.",
"code": "invalid",
}
]
}
)
# Empty list
result = schema.execute(query, variables={"rangeValue": []})
assert len(result.errors) == 1
assert result.errors[0].message == expected_error
# Only one item in the list
result = schema.execute(query, variables={"rangeValue": [1]})
assert len(result.errors) == 1
assert result.errors[0].message == expected_error
# More than 2 items in the list
result = schema.execute(query, variables={"rangeValue": [1, 2, 3]})
assert len(result.errors) == 1
assert result.errors[0].message == expected_error

View File

@ -1,157 +0,0 @@
import pytest
from django_filters import FilterSet
import graphene
from graphene.relay import Node
from graphene_django import DjangoObjectType
from graphene_django.tests.models import Article, Reporter
from graphene_django.utils import DJANGO_FILTER_INSTALLED
pytestmark = []
if DJANGO_FILTER_INSTALLED:
from graphene_django.filter import (
DjangoFilterConnectionField,
TypedFilter,
ListFilter,
)
else:
pytestmark.append(
pytest.mark.skipif(
True, reason="django_filters not installed or not compatible"
)
)
@pytest.fixture
def schema():
class ArticleFilterSet(FilterSet):
class Meta:
model = Article
fields = {
"lang": ["exact", "in"],
}
lang__contains = TypedFilter(
field_name="lang", lookup_expr="icontains", input_type=graphene.String
)
lang__in_str = ListFilter(
field_name="lang",
lookup_expr="in",
input_type=graphene.List(graphene.String),
)
first_n = TypedFilter(input_type=graphene.Int, method="first_n_filter")
only_first = TypedFilter(
input_type=graphene.Boolean, method="only_first_filter"
)
def first_n_filter(self, queryset, _name, value):
return queryset[:value]
def only_first_filter(self, queryset, _name, value):
if value:
return queryset[:1]
else:
return queryset
class ArticleType(DjangoObjectType):
class Meta:
model = Article
interfaces = (Node,)
fields = "__all__"
filterset_class = ArticleFilterSet
class Query(graphene.ObjectType):
articles = DjangoFilterConnectionField(ArticleType)
schema = graphene.Schema(query=Query)
return schema
def test_typed_filter_schema(schema):
"""
Check that the type provided in the filter is reflected in the schema.
"""
schema_str = str(schema)
filters = {
"offset": "Int",
"before": "String",
"after": "String",
"first": "Int",
"last": "Int",
"lang": "TestsArticleLangChoices",
"lang_In": "[TestsArticleLangChoices]",
"lang_Contains": "String",
"lang_InStr": "[String]",
"firstN": "Int",
"onlyFirst": "Boolean",
}
all_articles_filters = (
schema_str.split(" articles(")[1]
.split("): ArticleTypeConnection\n")[0]
.split(", ")
)
for filter_field, gql_type in filters.items():
assert "{}: {} = null".format(filter_field, gql_type) in all_articles_filters
def test_typed_filters_work(schema):
reporter = Reporter.objects.create(first_name="John", last_name="Doe", email="")
Article.objects.create(
headline="A", reporter=reporter, editor=reporter, lang="es",
)
Article.objects.create(
headline="B", reporter=reporter, editor=reporter, lang="es",
)
Article.objects.create(
headline="C", reporter=reporter, editor=reporter, lang="en",
)
query = "query { articles (lang_In: [ES]) { edges { node { headline } } } }"
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}},
{"node": {"headline": "B"}},
]
query = 'query { articles (lang_InStr: ["es"]) { edges { node { headline } } } }'
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}},
{"node": {"headline": "B"}},
]
query = 'query { articles (lang_Contains: "n") { edges { node { headline } } } }'
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "C"}},
]
query = "query { articles (firstN: 2) { edges { node { headline } } } }"
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}},
{"node": {"headline": "B"}},
]
query = "query { articles (onlyFirst: true) { edges { node { headline } } } }"
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}},
]

View File

@ -1,155 +0,0 @@
import graphene
from django import forms
from django_filters.utils import get_model_field, get_field_parts
from django_filters.filters import Filter, BaseCSVFilter
from .filters import ArrayFilter, ListFilter, RangeFilter, TypedFilter
from .filterset import custom_filterset_factory, setup_filterset
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
def get_field_type(registry, model, field_name):
"""
Try to get a model field corresponding Graphql type from the DjangoObjectType.
"""
object_type = registry.get_type_for_model(model)
if object_type:
object_type_field = object_type._meta.fields.get(field_name)
if object_type_field:
field_type = object_type_field.type
if isinstance(field_type, graphene.NonNull):
field_type = field_type.of_type
return field_type
return None
def get_filtering_args_from_filterset(filterset_class, type):
"""
Inspect a FilterSet and produce the arguments to pass to a Graphene Field.
These arguments will be available to filter against in the GraphQL API.
"""
from ..forms.converter import convert_form_field
args = {}
model = filterset_class._meta.model
registry = type._meta.registry
for name, filter_field in filterset_class.base_filters.items():
filter_type = filter_field.lookup_expr
required = filter_field.extra.get("required", False)
field_type = None
form_field = None
if (
isinstance(filter_field, TypedFilter)
and filter_field.input_type is not None
):
# First check if the filter input type has been explicitely given
field_type = filter_field.input_type
else:
if name not in filterset_class.declared_filters or isinstance(
filter_field, TypedFilter
):
# Get the filter field for filters that are no explicitly declared.
if filter_type == "isnull":
field = graphene.Boolean(required=required)
else:
model_field = get_model_field(model, filter_field.field_name)
# Get the form field either from:
# 1. the formfield corresponding to the model field
# 2. the field defined on filter
if hasattr(model_field, "formfield"):
form_field = model_field.formfield(required=required)
if not form_field:
form_field = filter_field.field
# First try to get the matching field type from the GraphQL DjangoObjectType
if model_field:
if (
isinstance(form_field, forms.ModelChoiceField)
or isinstance(form_field, forms.ModelMultipleChoiceField)
or isinstance(form_field, GlobalIDMultipleChoiceField)
or isinstance(form_field, GlobalIDFormField)
):
# Foreign key have dynamic types and filtering on a foreign key actually means filtering on its ID.
field_type = get_field_type(
registry, model_field.related_model, "id"
)
else:
field_type = get_field_type(
registry, model_field.model, model_field.name
)
if not field_type:
# Fallback on converting the form field either because:
# - it's an explicitly declared filters
# - we did not manage to get the type from the model type
form_field = form_field or filter_field.field
field_type = convert_form_field(form_field).get_type()
if isinstance(filter_field, ListFilter) or isinstance(
filter_field, RangeFilter
):
# Replace InFilter/RangeFilter filters (`in`, `range`) argument type to be a list of
# the same type as the field. See comments in `replace_csv_filters` method for more details.
field_type = graphene.List(field_type)
args[name] = graphene.Argument(
field_type, description=filter_field.label, required=required,
)
return args
def get_filterset_class(filterset_class, **meta):
"""
Get the class to be used as the FilterSet.
"""
if filterset_class:
# If were given a FilterSet class, then set it up.
graphene_filterset_class = setup_filterset(filterset_class)
else:
# Otherwise create one.
graphene_filterset_class = custom_filterset_factory(**meta)
replace_csv_filters(graphene_filterset_class)
return graphene_filterset_class
def replace_csv_filters(filterset_class):
"""
Replace the "in" and "range" filters (that are not explicitly declared)
to not be BaseCSVFilter (BaseInFilter, BaseRangeFilter) objects anymore
but our custom InFilter/RangeFilter filter class that use the input
value as filter argument on the queryset.
This is because those BaseCSVFilter are expecting a string as input with
comma separated values.
But with GraphQl we can actually have a list as input and have a proper
type verification of each value in the list.
See issue https://github.com/graphql-python/graphene-django/issues/1068.
"""
for name, filter_field in list(filterset_class.base_filters.items()):
# Do not touch any declared filters
if name in filterset_class.declared_filters:
continue
filter_type = filter_field.lookup_expr
if filter_type == "in":
filterset_class.base_filters[name] = ListFilter(
field_name=filter_field.field_name,
lookup_expr=filter_field.lookup_expr,
label=filter_field.label,
method=filter_field.method,
exclude=filter_field.exclude,
**filter_field.extra
)
elif filter_type == "range":
filterset_class.base_filters[name] = RangeFilter(
field_name=filter_field.field_name,
lookup_expr=filter_field.lookup_expr,
label=filter_field.label,
method=filter_field.method,
exclude=filter_field.exclude,
**filter_field.extra
)

View File

@ -1 +0,0 @@
from .forms import GlobalIDFormField, GlobalIDMultipleChoiceField # noqa

View File

@ -1,99 +0,0 @@
from functools import singledispatch
from django import forms
from django.core.exceptions import ImproperlyConfigured
from graphene import ID, Boolean, Float, Int, List, String, UUID, Date, DateTime, Time
from .forms import GlobalIDFormField, GlobalIDMultipleChoiceField
def get_form_field_description(field):
return str(field.help_text) if field.help_text else None
@singledispatch
def convert_form_field(field):
raise ImproperlyConfigured(
"Don't know how to convert the Django form field %s (%s) "
"to Graphene type" % (field, field.__class__)
)
@convert_form_field.register(forms.fields.BaseTemporalField)
@convert_form_field.register(forms.CharField)
@convert_form_field.register(forms.EmailField)
@convert_form_field.register(forms.SlugField)
@convert_form_field.register(forms.URLField)
@convert_form_field.register(forms.ChoiceField)
@convert_form_field.register(forms.RegexField)
@convert_form_field.register(forms.Field)
def convert_form_field_to_string(field):
return String(
description=get_form_field_description(field), required=field.required
)
@convert_form_field.register(forms.UUIDField)
def convert_form_field_to_uuid(field):
return UUID(description=get_form_field_description(field), required=field.required)
@convert_form_field.register(forms.IntegerField)
@convert_form_field.register(forms.NumberInput)
def convert_form_field_to_int(field):
return Int(description=get_form_field_description(field), required=field.required)
@convert_form_field.register(forms.BooleanField)
def convert_form_field_to_boolean(field):
return Boolean(
description=get_form_field_description(field), required=field.required
)
@convert_form_field.register(forms.NullBooleanField)
def convert_form_field_to_nullboolean(field):
return Boolean(description=get_form_field_description(field))
@convert_form_field.register(forms.DecimalField)
@convert_form_field.register(forms.FloatField)
def convert_form_field_to_float(field):
return Float(description=get_form_field_description(field), required=field.required)
@convert_form_field.register(forms.MultipleChoiceField)
def convert_form_field_to_string_list(field):
return List(
String, description=get_form_field_description(field), required=field.required
)
@convert_form_field.register(forms.ModelMultipleChoiceField)
@convert_form_field.register(GlobalIDMultipleChoiceField)
def convert_form_field_to_id_list(field):
return List(ID, required=field.required)
@convert_form_field.register(forms.DateField)
def convert_form_field_to_date(field):
return Date(description=get_form_field_description(field), required=field.required)
@convert_form_field.register(forms.DateTimeField)
def convert_form_field_to_datetime(field):
return DateTime(
description=get_form_field_description(field), required=field.required
)
@convert_form_field.register(forms.TimeField)
def convert_form_field_to_time(field):
return Time(description=get_form_field_description(field), required=field.required)
@convert_form_field.register(forms.ModelChoiceField)
@convert_form_field.register(GlobalIDFormField)
def convert_form_field_to_id(field):
return ID(required=field.required)

View File

@ -1,40 +0,0 @@
import binascii
from django.core.exceptions import ValidationError
from django.forms import CharField, Field, MultipleChoiceField
from django.utils.translation import gettext_lazy as _
from graphql_relay import from_global_id
class GlobalIDFormField(Field):
default_error_messages = {"invalid": _("Invalid ID specified.")}
def clean(self, value):
if not value and not self.required:
return None
try:
_type, _id = from_global_id(value)
except (TypeError, ValueError, UnicodeDecodeError, binascii.Error):
raise ValidationError(self.error_messages["invalid"])
try:
CharField().clean(_id)
CharField().clean(_type)
except ValidationError:
raise ValidationError(self.error_messages["invalid"])
return value
class GlobalIDMultipleChoiceField(MultipleChoiceField):
default_error_messages = {
"invalid_choice": _("One of the specified IDs was invalid (%(value)s)."),
"invalid_list": _("Enter a list of values."),
}
def valid_value(self, value):
# Clean will raise a validation error if there is a problem
GlobalIDFormField().clean(value)
return True

View File

@ -1,192 +0,0 @@
# from django import forms
from collections import OrderedDict
import graphene
from graphene import Field, InputField
from graphene.relay.mutation import ClientIDMutation
from graphene.types.mutation import MutationOptions
# from graphene.types.inputobjecttype import (
# InputObjectTypeOptions,
# InputObjectType,
# )
from graphene.types.utils import yank_fields_from_attrs
from graphene_django.constants import MUTATION_ERRORS_FLAG
from graphene_django.registry import get_global_registry
from ..types import ErrorType
from .converter import convert_form_field
def fields_for_form(form, only_fields, exclude_fields):
fields = OrderedDict()
for name, field in form.fields.items():
is_not_in_only = only_fields and name not in only_fields
is_excluded = (
name
in exclude_fields # or
# name in already_created_fields
)
if is_not_in_only or is_excluded:
continue
fields[name] = convert_form_field(field)
return fields
class BaseDjangoFormMutation(ClientIDMutation):
class Meta:
abstract = True
@classmethod
def mutate_and_get_payload(cls, root, info, **input):
form = cls.get_form(root, info, **input)
if form.is_valid():
return cls.perform_mutate(form, info)
else:
errors = ErrorType.from_errors(form.errors)
_set_errors_flag_to_context(info)
return cls(errors=errors, **form.data)
@classmethod
def get_form(cls, root, info, **input):
form_kwargs = cls.get_form_kwargs(root, info, **input)
return cls._meta.form_class(**form_kwargs)
@classmethod
def get_form_kwargs(cls, root, info, **input):
kwargs = {"data": input}
pk = input.pop("id", None)
if pk:
instance = cls._meta.model._default_manager.get(pk=pk)
kwargs["instance"] = instance
return kwargs
class DjangoFormMutationOptions(MutationOptions):
form_class = None
class DjangoFormMutation(BaseDjangoFormMutation):
class Meta:
abstract = True
errors = graphene.List(ErrorType)
@classmethod
def __init_subclass_with_meta__(
cls, form_class=None, only_fields=(), exclude_fields=(), **options
):
if not form_class:
raise Exception("form_class is required for DjangoFormMutation")
form = form_class()
input_fields = fields_for_form(form, only_fields, exclude_fields)
output_fields = fields_for_form(form, only_fields, exclude_fields)
_meta = DjangoFormMutationOptions(cls)
_meta.form_class = form_class
_meta.fields = yank_fields_from_attrs(output_fields, _as=Field)
input_fields = yank_fields_from_attrs(input_fields, _as=InputField)
super(DjangoFormMutation, cls).__init_subclass_with_meta__(
_meta=_meta, input_fields=input_fields, **options
)
@classmethod
def perform_mutate(cls, form, info):
if hasattr(form, "save"):
# `save` method won't exist on plain Django forms, but this mutation can
# in theory be used with `ModelForm`s as well and we do want to save them.
form.save()
return cls(errors=[], **form.cleaned_data)
class DjangoModelDjangoFormMutationOptions(DjangoFormMutationOptions):
model = None
return_field_name = None
class DjangoModelFormMutation(BaseDjangoFormMutation):
class Meta:
abstract = True
errors = graphene.List(ErrorType)
@classmethod
def __init_subclass_with_meta__(
cls,
form_class=None,
model=None,
return_field_name=None,
only_fields=(),
exclude_fields=(),
**options
):
if not form_class:
raise Exception("form_class is required for DjangoModelFormMutation")
if not model:
model = form_class._meta.model
if not model:
raise Exception("model is required for DjangoModelFormMutation")
form = form_class()
input_fields = fields_for_form(form, only_fields, exclude_fields)
if "id" not in exclude_fields:
input_fields["id"] = graphene.ID()
registry = get_global_registry()
model_type = registry.get_type_for_model(model)
if not model_type:
raise Exception("No type registered for model: {}".format(model.__name__))
if not return_field_name:
model_name = model.__name__
return_field_name = model_name[:1].lower() + model_name[1:]
output_fields = OrderedDict()
output_fields[return_field_name] = graphene.Field(model_type)
_meta = DjangoModelDjangoFormMutationOptions(cls)
_meta.form_class = form_class
_meta.model = model
_meta.return_field_name = return_field_name
_meta.fields = yank_fields_from_attrs(output_fields, _as=Field)
input_fields = yank_fields_from_attrs(input_fields, _as=InputField)
super(DjangoModelFormMutation, cls).__init_subclass_with_meta__(
_meta=_meta, input_fields=input_fields, **options
)
@classmethod
def mutate_and_get_payload(cls, root, info, **input):
form = cls.get_form(root, info, **input)
if form.is_valid():
return cls.perform_mutate(form, info)
else:
errors = ErrorType.from_errors(form.errors)
_set_errors_flag_to_context(info)
return cls(errors=errors)
@classmethod
def perform_mutate(cls, form, info):
obj = form.save()
kwargs = {cls._meta.return_field_name: obj}
return cls(errors=[], **kwargs)
def _set_errors_flag_to_context(info):
# This is not ideal but necessary to keep the response errors empty
if info and info.context:
setattr(info.context, MUTATION_ERRORS_FLAG, True)

View File

@ -1,121 +0,0 @@
from django import forms
from py.test import raises
import graphene
from graphene import (
String,
Int,
Boolean,
Float,
ID,
UUID,
List,
NonNull,
DateTime,
Date,
Time,
)
from ..converter import convert_form_field
def assert_conversion(django_field, graphene_field, *args):
field = django_field(*args, help_text="Custom Help Text")
graphene_type = convert_form_field(field)
assert isinstance(graphene_type, graphene_field)
field = graphene_type.Field()
assert field.description == "Custom Help Text"
return field
def test_should_unknown_django_field_raise_exception():
with raises(Exception) as excinfo:
convert_form_field(None)
assert "Don't know how to convert the Django form field" in str(excinfo.value)
def test_should_date_convert_date():
assert_conversion(forms.DateField, Date)
def test_should_time_convert_time():
assert_conversion(forms.TimeField, Time)
def test_should_date_time_convert_date_time():
assert_conversion(forms.DateTimeField, DateTime)
def test_should_char_convert_string():
assert_conversion(forms.CharField, String)
def test_should_email_convert_string():
assert_conversion(forms.EmailField, String)
def test_should_slug_convert_string():
assert_conversion(forms.SlugField, String)
def test_should_url_convert_string():
assert_conversion(forms.URLField, String)
def test_should_choice_convert_string():
assert_conversion(forms.ChoiceField, String)
def test_should_base_field_convert_string():
assert_conversion(forms.Field, String)
def test_should_regex_convert_string():
assert_conversion(forms.RegexField, String, "[0-9]+")
def test_should_uuid_convert_string():
if hasattr(forms, "UUIDField"):
assert_conversion(forms.UUIDField, UUID)
def test_should_integer_convert_int():
assert_conversion(forms.IntegerField, Int)
def test_should_boolean_convert_boolean():
field = assert_conversion(forms.BooleanField, Boolean)
assert isinstance(field.type, NonNull)
def test_should_nullboolean_convert_boolean():
field = assert_conversion(forms.NullBooleanField, Boolean)
assert not isinstance(field.type, NonNull)
def test_should_float_convert_float():
assert_conversion(forms.FloatField, Float)
def test_should_decimal_convert_float():
assert_conversion(forms.DecimalField, Float)
def test_should_multiple_choice_convert_list():
field = forms.MultipleChoiceField()
graphene_type = convert_form_field(field)
assert isinstance(graphene_type, List)
assert graphene_type.of_type == String
def test_should_model_multiple_choice_convert_connectionorlist():
field = forms.ModelMultipleChoiceField(queryset=None)
graphene_type = convert_form_field(field)
assert isinstance(graphene_type, List)
assert graphene_type.of_type == ID
def test_should_manytoone_convert_connectionorlist():
field = forms.ModelChoiceField(queryset=None)
graphene_type = convert_form_field(field)
assert isinstance(graphene_type, ID)

View File

@ -1,385 +0,0 @@
import pytest
from django import forms
from django.core.exceptions import ValidationError
from py.test import raises
from graphene import Field, ObjectType, Schema, String
from graphene_django import DjangoObjectType
from graphene_django.tests.forms import PetForm
from graphene_django.tests.models import Pet
from graphene_django.tests.mutations import PetMutation
from ..mutation import DjangoFormMutation, DjangoModelFormMutation
class MyForm(forms.Form):
text = forms.CharField()
def clean_text(self):
text = self.cleaned_data["text"]
if text == "INVALID_INPUT":
raise ValidationError("Invalid input")
return text
def save(self):
pass
def test_needs_form_class():
with raises(Exception) as exc:
class MyMutation(DjangoFormMutation):
pass
assert exc.value.args[0] == "form_class is required for DjangoFormMutation"
def test_has_output_fields():
class MyMutation(DjangoFormMutation):
class Meta:
form_class = MyForm
assert "errors" in MyMutation._meta.fields
def test_has_input_fields():
class MyMutation(DjangoFormMutation):
class Meta:
form_class = MyForm
assert "text" in MyMutation.Input._meta.fields
def test_mutation_error_camelcased(graphene_settings):
class ExtraPetForm(PetForm):
test_field = forms.CharField(required=True)
class PetType(DjangoObjectType):
class Meta:
model = Pet
fields = "__all__"
class PetMutation(DjangoModelFormMutation):
pet = Field(PetType)
class Meta:
form_class = ExtraPetForm
result = PetMutation.mutate_and_get_payload(None, None)
assert {f.field for f in result.errors} == {"name", "age", "testField"}
graphene_settings.CAMELCASE_ERRORS = False
result = PetMutation.mutate_and_get_payload(None, None)
assert {f.field for f in result.errors} == {"name", "age", "test_field"}
class MockQuery(ObjectType):
a = String()
def test_form_invalid_form():
class MyMutation(DjangoFormMutation):
class Meta:
form_class = MyForm
class Mutation(ObjectType):
my_mutation = MyMutation.Field()
schema = Schema(query=MockQuery, mutation=Mutation)
result = schema.execute(
""" mutation MyMutation {
myMutation(input: { text: "INVALID_INPUT" }) {
errors {
field
messages
}
text
}
}
"""
)
assert result.errors is None
assert result.data["myMutation"]["errors"] == [
{"field": "text", "messages": ["Invalid input"]}
]
def test_form_valid_input():
class MyMutation(DjangoFormMutation):
class Meta:
form_class = MyForm
class Mutation(ObjectType):
my_mutation = MyMutation.Field()
schema = Schema(query=MockQuery, mutation=Mutation)
result = schema.execute(
""" mutation MyMutation {
myMutation(input: { text: "VALID_INPUT" }) {
errors {
field
messages
}
text
}
}
"""
)
assert result.errors is None
assert result.data["myMutation"]["errors"] == []
assert result.data["myMutation"]["text"] == "VALID_INPUT"
def test_default_meta_fields():
assert PetMutation._meta.model is Pet
assert PetMutation._meta.return_field_name == "pet"
assert "pet" in PetMutation._meta.fields
def test_default_input_meta_fields():
assert PetMutation._meta.model is Pet
assert PetMutation._meta.return_field_name == "pet"
assert "name" in PetMutation.Input._meta.fields
assert "client_mutation_id" in PetMutation.Input._meta.fields
assert "id" in PetMutation.Input._meta.fields
def test_exclude_fields_input_meta_fields():
class PetType(DjangoObjectType):
class Meta:
model = Pet
fields = "__all__"
class PetMutation(DjangoModelFormMutation):
pet = Field(PetType)
class Meta:
form_class = PetForm
exclude_fields = ["id"]
assert PetMutation._meta.model is Pet
assert PetMutation._meta.return_field_name == "pet"
assert "name" in PetMutation.Input._meta.fields
assert "age" in PetMutation.Input._meta.fields
assert "client_mutation_id" in PetMutation.Input._meta.fields
assert "id" not in PetMutation.Input._meta.fields
def test_custom_return_field_name():
class PetType(DjangoObjectType):
class Meta:
model = Pet
fields = "__all__"
class PetMutation(DjangoModelFormMutation):
pet = Field(PetType)
class Meta:
form_class = PetForm
model = Pet
return_field_name = "animal"
assert PetMutation._meta.model is Pet
assert PetMutation._meta.return_field_name == "animal"
assert "animal" in PetMutation._meta.fields
def test_model_form_mutation_mutate_existing():
class Mutation(ObjectType):
pet_mutation = PetMutation.Field()
schema = Schema(query=MockQuery, mutation=Mutation)
pet = Pet.objects.create(name="Axel", age=10)
result = schema.execute(
""" mutation PetMutation($pk: ID!) {
petMutation(input: { id: $pk, name: "Mia", age: 10 }) {
pet {
name
age
}
}
}
""",
variable_values={"pk": pet.pk},
)
assert result.errors is None
assert result.data["petMutation"]["pet"] == {"name": "Mia", "age": 10}
assert Pet.objects.count() == 1
pet.refresh_from_db()
assert pet.name == "Mia"
def test_model_form_mutation_creates_new():
class Mutation(ObjectType):
pet_mutation = PetMutation.Field()
schema = Schema(query=MockQuery, mutation=Mutation)
result = schema.execute(
""" mutation PetMutation {
petMutation(input: { name: "Mia", age: 10 }) {
pet {
name
age
}
errors {
field
messages
}
}
}
"""
)
assert result.errors is None
assert result.data["petMutation"]["pet"] == {"name": "Mia", "age": 10}
assert Pet.objects.count() == 1
pet = Pet.objects.get()
assert pet.name == "Mia"
assert pet.age == 10
def test_model_form_mutation_invalid_input():
class Mutation(ObjectType):
pet_mutation = PetMutation.Field()
schema = Schema(query=MockQuery, mutation=Mutation)
result = schema.execute(
""" mutation PetMutation {
petMutation(input: { name: "Mia", age: 99 }) {
pet {
name
age
}
errors {
field
messages
}
}
}
"""
)
assert result.errors is None
assert result.data["petMutation"]["pet"] is None
assert result.data["petMutation"]["errors"] == [
{"field": "age", "messages": ["Too old"]}
]
assert Pet.objects.count() == 0
def test_model_form_mutation_mutate_invalid_form():
result = PetMutation.mutate_and_get_payload(None, None)
# A pet was not created
Pet.objects.count() == 0
fields_w_error = [e.field for e in result.errors]
assert len(result.errors) == 2
assert result.errors[0].messages == ["This field is required."]
assert result.errors[1].messages == ["This field is required."]
assert "age" in fields_w_error
assert "name" in fields_w_error
def test_model_form_mutation_multiple_creation_valid():
class Mutation(ObjectType):
pet_mutation = PetMutation.Field()
schema = Schema(query=MockQuery, mutation=Mutation)
result = schema.execute(
"""
mutation PetMutations {
petMutation1: petMutation(input: { name: "Mia", age: 10 }) {
pet {
name
age
}
errors {
field
messages
}
}
petMutation2: petMutation(input: { name: "Enzo", age: 0 }) {
pet {
name
age
}
errors {
field
messages
}
}
}
"""
)
assert result.errors is None
assert result.data["petMutation1"]["pet"] == {"name": "Mia", "age": 10}
assert result.data["petMutation2"]["pet"] == {"name": "Enzo", "age": 0}
assert Pet.objects.count() == 2
pet1 = Pet.objects.first()
assert pet1.name == "Mia"
assert pet1.age == 10
pet2 = Pet.objects.last()
assert pet2.name == "Enzo"
assert pet2.age == 0
def test_model_form_mutation_multiple_creation_invalid():
class Mutation(ObjectType):
pet_mutation = PetMutation.Field()
schema = Schema(query=MockQuery, mutation=Mutation)
result = schema.execute(
"""
mutation PetMutations {
petMutation1: petMutation(input: { name: "Mia", age: 99 }) {
pet {
name
age
}
errors {
field
messages
}
}
petMutation2: petMutation(input: { name: "Enzo", age: 0 }) {
pet {
name
age
}
errors {
field
messages
}
}
}
"""
)
assert result.errors is None
assert result.data["petMutation1"]["pet"] is None
assert result.data["petMutation1"]["errors"] == [
{"field": "age", "messages": ["Too old"]}
]
assert result.data["petMutation2"]["pet"] == {"name": "Enzo", "age": 0}
assert Pet.objects.count() == 1
pet = Pet.objects.get()
assert pet.name == "Enzo"
assert pet.age == 0

View File

@ -1 +0,0 @@
from ..types import ErrorType # noqa Import ErrorType for backwards compatability

View File

@ -1,115 +0,0 @@
import os
import importlib
import json
import functools
from django.core.management.base import BaseCommand, CommandError
from django.utils import autoreload
from graphql import print_schema
from graphene_django.settings import graphene_settings
class CommandArguments(BaseCommand):
def add_arguments(self, parser):
parser.add_argument(
"--schema",
type=str,
dest="schema",
default=graphene_settings.SCHEMA,
help="Django app containing schema to dump, e.g. myproject.core.schema.schema",
)
parser.add_argument(
"--out",
type=str,
dest="out",
default=graphene_settings.SCHEMA_OUTPUT,
help="Output file, --out=- prints to stdout (default: schema.json)",
)
parser.add_argument(
"--indent",
type=int,
dest="indent",
default=graphene_settings.SCHEMA_INDENT,
help="Output file indent (default: None)",
)
parser.add_argument(
"--watch",
dest="watch",
default=False,
action="store_true",
help="Updates the schema on file changes (default: False)",
)
class Command(CommandArguments):
help = "Dump Graphene schema as a JSON or GraphQL file"
can_import_settings = True
requires_system_checks = False
def save_json_file(self, out, schema_dict, indent):
with open(out, "w") as outfile:
json.dump(schema_dict, outfile, indent=indent, sort_keys=True)
def save_graphql_file(self, out, schema):
with open(out, "w", encoding="utf-8") as outfile:
outfile.write(print_schema(schema.graphql_schema))
def get_schema(self, schema, out, indent):
schema_dict = {"data": schema.introspect()}
if out == "-" or out == "-.json":
self.stdout.write(json.dumps(schema_dict, indent=indent, sort_keys=True))
elif out == "-.graphql":
self.stdout.write(print_schema(schema))
else:
# Determine format
_, file_extension = os.path.splitext(out)
if file_extension == ".graphql":
self.save_graphql_file(out, schema)
elif file_extension == ".json":
self.save_json_file(out, schema_dict, indent)
else:
raise CommandError(
'Unrecognised file format "{}"'.format(file_extension)
)
style = getattr(self, "style", None)
success = getattr(style, "SUCCESS", lambda x: x)
self.stdout.write(
success("Successfully dumped GraphQL schema to {}".format(out))
)
def handle(self, *args, **options):
options_schema = options.get("schema")
if options_schema and type(options_schema) is str:
module_str, schema_name = options_schema.rsplit(".", 1)
mod = importlib.import_module(module_str)
schema = getattr(mod, schema_name)
elif options_schema:
schema = options_schema
else:
schema = graphene_settings.SCHEMA
out = options.get("out") or graphene_settings.SCHEMA_OUTPUT
if not schema:
raise CommandError(
"Specify schema on GRAPHENE.SCHEMA setting or by using --schema"
)
indent = options.get("indent")
watch = options.get("watch")
if watch:
autoreload.run_with_reloader(
functools.partial(self.get_schema, schema, out, indent)
)
else:
self.get_schema(schema, out, indent)

View File

@ -1,43 +0,0 @@
class Registry(object):
def __init__(self):
self._registry = {}
self._field_registry = {}
def register(self, cls):
from .types import DjangoObjectType
assert issubclass(
cls, DjangoObjectType
), 'Only DjangoObjectTypes can be registered, received "{}"'.format(
cls.__name__
)
assert cls._meta.registry == self, "Registry for a Model have to match."
# assert self.get_type_for_model(cls._meta.model) == cls, (
# 'Multiple DjangoObjectTypes registered for "{}"'.format(cls._meta.model)
# )
if not getattr(cls._meta, "skip_registry", False):
self._registry[cls._meta.model] = cls
def get_type_for_model(self, model):
return self._registry.get(model)
def register_converted_field(self, field, converted):
self._field_registry[field] = converted
def get_converted_field(self, field):
return self._field_registry.get(field)
registry = None
def get_global_registry():
global registry
if not registry:
registry = Registry()
return registry
def reset_global_registry():
global registry
registry = None

View File

@ -1,16 +0,0 @@
from django.db import models
class MyFakeModel(models.Model):
cool_name = models.CharField(max_length=50)
created = models.DateTimeField(auto_now_add=True)
class MyFakeModelWithPassword(models.Model):
cool_name = models.CharField(max_length=50)
password = models.CharField(max_length=50)
class MyFakeModelWithDate(models.Model):
cool_name = models.CharField(max_length=50)
last_edited = models.DateField()

View File

@ -1,175 +0,0 @@
from collections import OrderedDict
from django.shortcuts import get_object_or_404
from rest_framework import serializers
import graphene
from graphene.relay.mutation import ClientIDMutation
from graphene.types import Field, InputField
from graphene.types.mutation import MutationOptions
from graphene.types.objecttype import yank_fields_from_attrs
from ..types import ErrorType
from .serializer_converter import convert_serializer_field
class SerializerMutationOptions(MutationOptions):
lookup_field = None
model_class = None
model_operations = ["create", "update"]
serializer_class = None
def fields_for_serializer(
serializer,
only_fields,
exclude_fields,
is_input=False,
convert_choices_to_enum=True,
lookup_field=None,
):
fields = OrderedDict()
for name, field in serializer.fields.items():
is_not_in_only = only_fields and name not in only_fields
is_excluded = any(
[
name in exclude_fields,
field.write_only
and not is_input, # don't show write_only fields in Query
field.read_only
and is_input
and lookup_field != name, # don't show read_only fields in Input
]
)
if is_not_in_only or is_excluded:
continue
fields[name] = convert_serializer_field(
field, is_input=is_input, convert_choices_to_enum=convert_choices_to_enum
)
return fields
class SerializerMutation(ClientIDMutation):
class Meta:
abstract = True
errors = graphene.List(
ErrorType, description="May contain more than one error for same field."
)
@classmethod
def __init_subclass_with_meta__(
cls,
lookup_field=None,
serializer_class=None,
model_class=None,
model_operations=("create", "update"),
only_fields=(),
exclude_fields=(),
convert_choices_to_enum=True,
_meta=None,
**options
):
if not serializer_class:
raise Exception("serializer_class is required for the SerializerMutation")
if "update" not in model_operations and "create" not in model_operations:
raise Exception('model_operations must contain "create" and/or "update"')
serializer = serializer_class()
if model_class is None:
serializer_meta = getattr(serializer_class, "Meta", None)
if serializer_meta:
model_class = getattr(serializer_meta, "model", None)
if lookup_field is None and model_class:
lookup_field = model_class._meta.pk.name
input_fields = fields_for_serializer(
serializer,
only_fields,
exclude_fields,
is_input=True,
convert_choices_to_enum=convert_choices_to_enum,
lookup_field=lookup_field,
)
output_fields = fields_for_serializer(
serializer,
only_fields,
exclude_fields,
is_input=False,
convert_choices_to_enum=convert_choices_to_enum,
lookup_field=lookup_field,
)
if not _meta:
_meta = SerializerMutationOptions(cls)
_meta.lookup_field = lookup_field
_meta.model_operations = model_operations
_meta.serializer_class = serializer_class
_meta.model_class = model_class
_meta.fields = yank_fields_from_attrs(output_fields, _as=Field)
input_fields = yank_fields_from_attrs(input_fields, _as=InputField)
super(SerializerMutation, cls).__init_subclass_with_meta__(
_meta=_meta, input_fields=input_fields, **options
)
@classmethod
def get_serializer_kwargs(cls, root, info, **input):
lookup_field = cls._meta.lookup_field
model_class = cls._meta.model_class
if model_class:
if "update" in cls._meta.model_operations and lookup_field in input:
instance = get_object_or_404(
model_class, **{lookup_field: input[lookup_field]}
)
partial = True
elif "create" in cls._meta.model_operations:
instance = None
partial = False
else:
raise Exception(
'Invalid update operation. Input parameter "{}" required.'.format(
lookup_field
)
)
return {
"instance": instance,
"data": input,
"context": {"request": info.context},
"partial": partial,
}
return {"data": input, "context": {"request": info.context}}
@classmethod
def mutate_and_get_payload(cls, root, info, **input):
kwargs = cls.get_serializer_kwargs(root, info, **input)
serializer = cls._meta.serializer_class(**kwargs)
if serializer.is_valid():
return cls.perform_mutate(serializer, info)
else:
errors = ErrorType.from_errors(serializer.errors)
return cls(errors=errors)
@classmethod
def perform_mutate(cls, serializer, info):
obj = serializer.save()
kwargs = {}
for f, field in serializer.fields.items():
if not field.write_only:
if isinstance(field, serializers.SerializerMethodField):
kwargs[f] = field.to_representation(obj)
else:
kwargs[f] = field.get_attribute(obj)
return cls(errors=None, **kwargs)

View File

@ -1,159 +0,0 @@
from functools import singledispatch
from django.core.exceptions import ImproperlyConfigured
from rest_framework import serializers
import graphene
from ..registry import get_global_registry
from ..converter import convert_choices_to_named_enum_with_descriptions
from .types import DictType
@singledispatch
def get_graphene_type_from_serializer_field(field):
raise ImproperlyConfigured(
"Don't know how to convert the serializer field %s (%s) "
"to Graphene type" % (field, field.__class__)
)
def convert_serializer_field(field, is_input=True, convert_choices_to_enum=True):
"""
Converts a django rest frameworks field to a graphql field
and marks the field as required if we are creating an input type
and the field itself is required
"""
if isinstance(field, serializers.ChoiceField) and not convert_choices_to_enum:
graphql_type = graphene.String
else:
graphql_type = get_graphene_type_from_serializer_field(field)
args = []
kwargs = {"description": field.help_text, "required": is_input and field.required}
# if it is a tuple or a list it means that we are returning
# the graphql type and the child type
if isinstance(graphql_type, (list, tuple)):
kwargs["of_type"] = graphql_type[1]
graphql_type = graphql_type[0]
if isinstance(field, serializers.ModelSerializer):
if is_input:
graphql_type = convert_serializer_to_input_type(field.__class__)
else:
global_registry = get_global_registry()
field_model = field.Meta.model
args = [global_registry.get_type_for_model(field_model)]
elif isinstance(field, serializers.ListSerializer):
field = field.child
if is_input:
kwargs["of_type"] = convert_serializer_to_input_type(field.__class__)
else:
del kwargs["of_type"]
global_registry = get_global_registry()
field_model = field.Meta.model
args = [global_registry.get_type_for_model(field_model)]
return graphql_type(*args, **kwargs)
def convert_serializer_to_input_type(serializer_class):
cached_type = convert_serializer_to_input_type.cache.get(
serializer_class.__name__, None
)
if cached_type:
return cached_type
serializer = serializer_class()
items = {
name: convert_serializer_field(field)
for name, field in serializer.fields.items()
}
ret_type = type(
"{}Input".format(serializer.__class__.__name__),
(graphene.InputObjectType,),
items,
)
convert_serializer_to_input_type.cache[serializer_class.__name__] = ret_type
return ret_type
convert_serializer_to_input_type.cache = {}
@get_graphene_type_from_serializer_field.register(serializers.Field)
def convert_serializer_field_to_string(field):
return graphene.String
@get_graphene_type_from_serializer_field.register(serializers.ModelSerializer)
def convert_serializer_to_field(field):
return graphene.Field
@get_graphene_type_from_serializer_field.register(serializers.ListSerializer)
def convert_list_serializer_to_field(field):
child_type = get_graphene_type_from_serializer_field(field.child)
return (graphene.List, child_type)
@get_graphene_type_from_serializer_field.register(serializers.IntegerField)
def convert_serializer_field_to_int(field):
return graphene.Int
@get_graphene_type_from_serializer_field.register(serializers.BooleanField)
def convert_serializer_field_to_bool(field):
return graphene.Boolean
@get_graphene_type_from_serializer_field.register(serializers.FloatField)
@get_graphene_type_from_serializer_field.register(serializers.DecimalField)
def convert_serializer_field_to_float(field):
return graphene.Float
@get_graphene_type_from_serializer_field.register(serializers.DateTimeField)
def convert_serializer_field_to_datetime_time(field):
return graphene.types.datetime.DateTime
@get_graphene_type_from_serializer_field.register(serializers.DateField)
def convert_serializer_field_to_date_time(field):
return graphene.types.datetime.Date
@get_graphene_type_from_serializer_field.register(serializers.TimeField)
def convert_serializer_field_to_time(field):
return graphene.types.datetime.Time
@get_graphene_type_from_serializer_field.register(serializers.ListField)
def convert_serializer_field_to_list(field, is_input=True):
child_type = get_graphene_type_from_serializer_field(field.child)
return (graphene.List, child_type)
@get_graphene_type_from_serializer_field.register(serializers.DictField)
def convert_serializer_field_to_dict(field):
return DictType
@get_graphene_type_from_serializer_field.register(serializers.JSONField)
def convert_serializer_field_to_jsonstring(field):
return graphene.types.json.JSONString
@get_graphene_type_from_serializer_field.register(serializers.MultipleChoiceField)
def convert_serializer_field_to_list_of_enum(field):
child_type = convert_serializer_field_to_enum(field)
return (graphene.List, child_type)
@get_graphene_type_from_serializer_field.register(serializers.ChoiceField)
def convert_serializer_field_to_enum(field):
# enums require a name
name = field.field_name or field.source or "Choices"
return convert_choices_to_named_enum_with_descriptions(name, field.choices)

View File

@ -1,220 +0,0 @@
import copy
import graphene
from django.db import models
from graphene import InputObjectType
from py.test import raises
from rest_framework import serializers
from ..serializer_converter import convert_serializer_field
from ..types import DictType
def _get_type(
rest_framework_field, is_input=True, convert_choices_to_enum=True, **kwargs
):
# prevents the following error:
# AssertionError: The `source` argument is not meaningful when applied to a `child=` field.
# Remove `source=` from the field declaration.
# since we are reusing the same child in when testing the required attribute
if "child" in kwargs:
kwargs["child"] = copy.deepcopy(kwargs["child"])
field = rest_framework_field(**kwargs)
return convert_serializer_field(
field, is_input=is_input, convert_choices_to_enum=convert_choices_to_enum
)
def assert_conversion(rest_framework_field, graphene_field, **kwargs):
graphene_type = _get_type(
rest_framework_field, help_text="Custom Help Text", **kwargs
)
assert isinstance(graphene_type, graphene_field)
graphene_type_required = _get_type(
rest_framework_field, help_text="Custom Help Text", required=True, **kwargs
)
assert isinstance(graphene_type_required, graphene_field)
return graphene_type
def test_should_unknown_rest_framework_field_raise_exception():
with raises(Exception) as excinfo:
convert_serializer_field(None)
assert "Don't know how to convert the serializer field" in str(excinfo.value)
def test_should_char_convert_string():
assert_conversion(serializers.CharField, graphene.String)
def test_should_email_convert_string():
assert_conversion(serializers.EmailField, graphene.String)
def test_should_slug_convert_string():
assert_conversion(serializers.SlugField, graphene.String)
def test_should_url_convert_string():
assert_conversion(serializers.URLField, graphene.String)
def test_should_choice_convert_enum():
field = assert_conversion(
serializers.ChoiceField,
graphene.Enum,
choices=[("h", "Hello"), ("w", "World")],
source="word",
)
assert field._meta.enum.__members__["H"].value == "h"
assert field._meta.enum.__members__["H"].description == "Hello"
assert field._meta.enum.__members__["W"].value == "w"
assert field._meta.enum.__members__["W"].description == "World"
def test_should_choice_convert_string_if_enum_disabled():
assert_conversion(
serializers.ChoiceField,
graphene.String,
choices=[("h", "Hello"), ("w", "World")],
source="word",
convert_choices_to_enum=False,
)
def test_should_base_field_convert_string():
assert_conversion(serializers.Field, graphene.String)
def test_should_regex_convert_string():
assert_conversion(serializers.RegexField, graphene.String, regex="[0-9]+")
def test_should_uuid_convert_string():
if hasattr(serializers, "UUIDField"):
assert_conversion(serializers.UUIDField, graphene.String)
def test_should_model_convert_field():
class MyModelSerializer(serializers.ModelSerializer):
class Meta:
model = None
fields = "__all__"
assert_conversion(MyModelSerializer, graphene.Field, is_input=False)
def test_should_date_time_convert_datetime():
assert_conversion(serializers.DateTimeField, graphene.types.datetime.DateTime)
def test_should_date_convert_date():
assert_conversion(serializers.DateField, graphene.types.datetime.Date)
def test_should_time_convert_time():
assert_conversion(serializers.TimeField, graphene.types.datetime.Time)
def test_should_integer_convert_int():
assert_conversion(serializers.IntegerField, graphene.Int)
def test_should_boolean_convert_boolean():
assert_conversion(serializers.BooleanField, graphene.Boolean)
def test_should_float_convert_float():
assert_conversion(serializers.FloatField, graphene.Float)
def test_should_decimal_convert_float():
assert_conversion(
serializers.DecimalField, graphene.Float, max_digits=4, decimal_places=2
)
def test_should_list_convert_to_list():
class StringListField(serializers.ListField):
child = serializers.CharField()
field_a = assert_conversion(
serializers.ListField,
graphene.List,
child=serializers.IntegerField(min_value=0, max_value=100),
)
assert field_a.of_type == graphene.Int
field_b = assert_conversion(StringListField, graphene.List)
assert field_b.of_type == graphene.String
def test_should_list_serializer_convert_to_list():
class FooModel(models.Model):
pass
class ChildSerializer(serializers.ModelSerializer):
class Meta:
model = FooModel
fields = "__all__"
class ParentSerializer(serializers.ModelSerializer):
child = ChildSerializer(many=True)
class Meta:
model = FooModel
fields = "__all__"
converted_type = convert_serializer_field(
ParentSerializer().get_fields()["child"], is_input=True
)
assert isinstance(converted_type, graphene.List)
converted_type = convert_serializer_field(
ParentSerializer().get_fields()["child"], is_input=False
)
assert isinstance(converted_type, graphene.List)
assert converted_type.of_type is None
def test_should_dict_convert_dict():
assert_conversion(serializers.DictField, DictType)
def test_should_duration_convert_string():
assert_conversion(serializers.DurationField, graphene.String)
def test_should_file_convert_string():
assert_conversion(serializers.FileField, graphene.String)
def test_should_filepath_convert_string():
assert_conversion(serializers.FilePathField, graphene.Enum, path="/")
def test_should_ip_convert_string():
assert_conversion(serializers.IPAddressField, graphene.String)
def test_should_image_convert_string():
assert_conversion(serializers.ImageField, graphene.String)
def test_should_json_convert_jsonstring():
assert_conversion(serializers.JSONField, graphene.types.json.JSONString)
def test_should_multiplechoicefield_convert_to_list_of_enum():
field = assert_conversion(
serializers.MultipleChoiceField, graphene.List, choices=[1, 2, 3]
)
assert issubclass(field.of_type, graphene.Enum)

View File

@ -1,66 +0,0 @@
from django.db import models
from rest_framework import serializers
import graphene
from graphene import Schema
from graphene_django import DjangoObjectType
from graphene_django.rest_framework.mutation import SerializerMutation
class MyFakeChildModel(models.Model):
name = models.CharField(max_length=50)
created = models.DateTimeField(auto_now_add=True)
class MyFakeParentModel(models.Model):
name = models.CharField(max_length=50)
created = models.DateTimeField(auto_now_add=True)
child1 = models.OneToOneField(
MyFakeChildModel, related_name="parent1", on_delete=models.CASCADE
)
child2 = models.OneToOneField(
MyFakeChildModel, related_name="parent2", on_delete=models.CASCADE
)
class ParentType(DjangoObjectType):
class Meta:
model = MyFakeParentModel
interfaces = (graphene.relay.Node,)
fields = "__all__"
class ChildType(DjangoObjectType):
class Meta:
model = MyFakeChildModel
interfaces = (graphene.relay.Node,)
fields = "__all__"
class MyModelChildSerializer(serializers.ModelSerializer):
class Meta:
model = MyFakeChildModel
fields = "__all__"
class MyModelParentSerializer(serializers.ModelSerializer):
child1 = MyModelChildSerializer()
child2 = MyModelChildSerializer()
class Meta:
model = MyFakeParentModel
fields = "__all__"
class MyParentModelMutation(SerializerMutation):
class Meta:
serializer_class = MyModelParentSerializer
class Mutation(graphene.ObjectType):
createParentWithChild = MyParentModelMutation.Field()
def test_create_schema():
schema = Schema(mutation=Mutation, types=[ParentType, ChildType])
assert schema

View File

@ -1,286 +0,0 @@
import datetime
from py.test import raises
from rest_framework import serializers
from graphene import Field, ResolveInfo
from graphene.types.inputobjecttype import InputObjectType
from ...types import DjangoObjectType
from ..models import MyFakeModel, MyFakeModelWithDate, MyFakeModelWithPassword
from ..mutation import SerializerMutation
def mock_info():
return ResolveInfo(
None,
None,
None,
None,
path=None,
schema=None,
fragments=None,
root_value=None,
operation=None,
variable_values=None,
context=None,
is_awaitable=None,
)
class MyModelSerializer(serializers.ModelSerializer):
class Meta:
model = MyFakeModel
fields = "__all__"
class MyModelSerializerWithMethod(serializers.ModelSerializer):
days_since_last_edit = serializers.SerializerMethodField()
class Meta:
model = MyFakeModelWithDate
fields = "__all__"
def get_days_since_last_edit(self, obj):
now = datetime.date(2020, 1, 8)
return (now - obj.last_edited).days
class MyModelMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
class MySerializer(serializers.Serializer):
text = serializers.CharField()
model = MyModelSerializer()
def create(self, validated_data):
return validated_data
def test_needs_serializer_class():
with raises(Exception) as exc:
class MyMutation(SerializerMutation):
pass
assert str(exc.value) == "serializer_class is required for the SerializerMutation"
def test_has_fields():
class MyMutation(SerializerMutation):
class Meta:
serializer_class = MySerializer
assert "text" in MyMutation._meta.fields
assert "model" in MyMutation._meta.fields
assert "errors" in MyMutation._meta.fields
def test_has_input_fields():
class MyMutation(SerializerMutation):
class Meta:
serializer_class = MySerializer
assert "text" in MyMutation.Input._meta.fields
assert "model" in MyMutation.Input._meta.fields
def test_exclude_fields():
class MyMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
exclude_fields = ["created"]
assert "cool_name" in MyMutation._meta.fields
assert "created" not in MyMutation._meta.fields
assert "errors" in MyMutation._meta.fields
assert "cool_name" in MyMutation.Input._meta.fields
assert "created" not in MyMutation.Input._meta.fields
def test_write_only_field():
class WriteOnlyFieldModelSerializer(serializers.ModelSerializer):
password = serializers.CharField(write_only=True)
class Meta:
model = MyFakeModelWithPassword
fields = ["cool_name", "password"]
class MyMutation(SerializerMutation):
class Meta:
serializer_class = WriteOnlyFieldModelSerializer
result = MyMutation.mutate_and_get_payload(
None, mock_info(), **{"cool_name": "New Narf", "password": "admin"}
)
assert hasattr(result, "cool_name")
assert not hasattr(
result, "password"
), "'password' is write_only field and shouldn't be visible"
def test_write_only_field_using_extra_kwargs():
class WriteOnlyFieldModelSerializer(serializers.ModelSerializer):
class Meta:
model = MyFakeModelWithPassword
fields = ["cool_name", "password"]
extra_kwargs = {"password": {"write_only": True}}
class MyMutation(SerializerMutation):
class Meta:
serializer_class = WriteOnlyFieldModelSerializer
result = MyMutation.mutate_and_get_payload(
None, mock_info(), **{"cool_name": "New Narf", "password": "admin"}
)
assert hasattr(result, "cool_name")
assert not hasattr(
result, "password"
), "'password' is write_only field and shouldn't be visible"
def test_read_only_fields():
class ReadOnlyFieldModelSerializer(serializers.ModelSerializer):
id = serializers.CharField(read_only=True)
cool_name = serializers.CharField(read_only=True)
class Meta:
model = MyFakeModelWithPassword
lookup_field = "id"
fields = ["id", "cool_name", "password"]
class MyMutation(SerializerMutation):
class Meta:
serializer_class = ReadOnlyFieldModelSerializer
assert "password" in MyMutation.Input._meta.fields
assert "id" in MyMutation.Input._meta.fields
assert (
"cool_name" not in MyMutation.Input._meta.fields
), "'cool_name' is read_only field and shouldn't be on arguments"
def test_nested_model():
class MyFakeModelGrapheneType(DjangoObjectType):
class Meta:
model = MyFakeModel
fields = "__all__"
class MyMutation(SerializerMutation):
class Meta:
serializer_class = MySerializer
model_field = MyMutation._meta.fields["model"]
assert isinstance(model_field, Field)
assert model_field.type == MyFakeModelGrapheneType
model_input = MyMutation.Input._meta.fields["model"]
model_input_type = model_input._type.of_type
assert issubclass(model_input_type, InputObjectType)
assert "cool_name" in model_input_type._meta.fields
assert "created" in model_input_type._meta.fields
def test_mutate_and_get_payload_success():
class MyMutation(SerializerMutation):
class Meta:
serializer_class = MySerializer
result = MyMutation.mutate_and_get_payload(
None, mock_info(), **{"text": "value", "model": {"cool_name": "other_value"}}
)
assert result.errors is None
def test_model_add_mutate_and_get_payload_success():
result = MyModelMutation.mutate_and_get_payload(
None, mock_info(), **{"cool_name": "Narf"}
)
assert result.errors is None
assert result.cool_name == "Narf"
assert isinstance(result.created, datetime.datetime)
def test_model_update_mutate_and_get_payload_success():
instance = MyFakeModel.objects.create(cool_name="Narf")
result = MyModelMutation.mutate_and_get_payload(
None, mock_info(), **{"id": instance.id, "cool_name": "New Narf"}
)
assert result.errors is None
assert result.cool_name == "New Narf"
def test_model_partial_update_mutate_and_get_payload_success():
instance = MyFakeModel.objects.create(cool_name="Narf")
result = MyModelMutation.mutate_and_get_payload(
None, mock_info(), **{"id": instance.id}
)
assert result.errors is None
assert result.cool_name == "Narf"
def test_model_invalid_update_mutate_and_get_payload_success():
class InvalidModelMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
model_operations = ["update"]
with raises(Exception) as exc:
result = InvalidModelMutation.mutate_and_get_payload(
None, mock_info(), **{"cool_name": "Narf"}
)
assert '"id" required' in str(exc.value)
def test_perform_mutate_success():
class MyMethodMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializerWithMethod
result = MyMethodMutation.mutate_and_get_payload(
None,
mock_info(),
**{"cool_name": "Narf", "last_edited": datetime.date(2020, 1, 4)}
)
assert result.errors is None
assert result.cool_name == "Narf"
assert result.days_since_last_edit == 4
def test_mutate_and_get_payload_error():
class MyMutation(SerializerMutation):
class Meta:
serializer_class = MySerializer
# missing required fields
result = MyMutation.mutate_and_get_payload(None, mock_info(), **{})
assert len(result.errors) > 0
def test_model_mutate_and_get_payload_error():
# missing required fields
result = MyModelMutation.mutate_and_get_payload(None, mock_info(), **{})
assert len(result.errors) > 0
def test_mutation_error_camelcased(graphene_settings):
graphene_settings.CAMELCASE_ERRORS = True
result = MyModelMutation.mutate_and_get_payload(None, mock_info(), **{})
assert result.errors[0].field == "coolName"
def test_invalid_serializer_operations():
with raises(Exception) as exc:
class MyModelMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
model_operations = ["Add"]
assert "model_operations" in str(exc.value)

View File

@ -1,7 +0,0 @@
import graphene
from graphene.types.unmountedtype import UnmountedType
class DictType(UnmountedType):
key = graphene.String()
value = graphene.String()

View File

@ -1,140 +0,0 @@
"""
Settings for Graphene are all namespaced in the GRAPHENE setting.
For example your project's `settings.py` file might look like this:
GRAPHENE = {
'SCHEMA': 'my_app.schema.schema'
'MIDDLEWARE': (
'graphene_django.debug.DjangoDebugMiddleware',
)
}
This module provides the `graphene_settings` object, that is used to access
Graphene settings, checking for user settings first, then falling
back to the defaults.
"""
from __future__ import unicode_literals
from django.conf import settings
from django.test.signals import setting_changed
import importlib # Available in Python 3.1+
# Copied shamelessly from Django REST Framework
DEFAULTS = {
"SCHEMA": None,
"SCHEMA_OUTPUT": "schema.json",
"SCHEMA_INDENT": 2,
"MIDDLEWARE": (),
# Set to True if the connection fields must have
# either the first or last argument
"RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST": False,
# Max items returned in ConnectionFields / FilterConnectionFields
"RELAY_CONNECTION_MAX_LIMIT": 100,
"CAMELCASE_ERRORS": True,
# Set to True to enable v2 naming convention for choice field Enum's
"DJANGO_CHOICE_FIELD_ENUM_V2_NAMING": False,
"DJANGO_CHOICE_FIELD_ENUM_CUSTOM_NAME": None,
# Use a separate path for handling subscriptions.
"SUBSCRIPTION_PATH": None,
# By default GraphiQL headers editor tab is enabled, set to False to hide it
# This sets headerEditorEnabled GraphiQL option, for details go to
# https://github.com/graphql/graphiql/tree/main/packages/graphiql#options
"GRAPHIQL_HEADER_EDITOR_ENABLED": True,
"ATOMIC_MUTATIONS": False,
}
if settings.DEBUG:
DEFAULTS["MIDDLEWARE"] += ("graphene_django.debug.DjangoDebugMiddleware",)
# List of settings that may be in string import notation.
IMPORT_STRINGS = ("MIDDLEWARE", "SCHEMA")
def perform_import(val, setting_name):
"""
If the given setting is a string import notation,
then perform the necessary import or imports.
"""
if val is None:
return None
elif isinstance(val, str):
return import_from_string(val, setting_name)
elif isinstance(val, (list, tuple)):
return [import_from_string(item, setting_name) for item in val]
return val
def import_from_string(val, setting_name):
"""
Attempt to import a class from a string representation.
"""
try:
# Nod to tastypie's use of importlib.
parts = val.split(".")
module_path, class_name = ".".join(parts[:-1]), parts[-1]
module = importlib.import_module(module_path)
return getattr(module, class_name)
except (ImportError, AttributeError) as e:
msg = "Could not import '%s' for Graphene setting '%s'. %s: %s." % (
val,
setting_name,
e.__class__.__name__,
e,
)
raise ImportError(msg)
class GrapheneSettings(object):
"""
A settings object, that allows API settings to be accessed as properties.
For example:
from graphene_django.settings import settings
print(settings.SCHEMA)
Any setting with string import paths will be automatically resolved
and return the class, rather than the string literal.
"""
def __init__(self, user_settings=None, defaults=None, import_strings=None):
if user_settings:
self._user_settings = user_settings
self.defaults = defaults or DEFAULTS
self.import_strings = import_strings or IMPORT_STRINGS
@property
def user_settings(self):
if not hasattr(self, "_user_settings"):
self._user_settings = getattr(settings, "GRAPHENE", {})
return self._user_settings
def __getattr__(self, attr):
if attr not in self.defaults:
raise AttributeError("Invalid Graphene setting: '%s'" % attr)
try:
# Check if present in user settings
val = self.user_settings[attr]
except KeyError:
# Fall back to defaults
val = self.defaults[attr]
# Coerce import strings into classes
if attr in self.import_strings:
val = perform_import(val, attr)
# Cache the result
setattr(self, attr, val)
return val
graphene_settings = GrapheneSettings(None, DEFAULTS, IMPORT_STRINGS)
def reload_graphene_settings(*args, **kwargs):
global graphene_settings
setting, value = kwargs["setting"], kwargs["value"]
if setting == "GRAPHENE":
graphene_settings = GrapheneSettings(value, DEFAULTS, IMPORT_STRINGS)
setting_changed.connect(reload_graphene_settings)

View File

@ -1,203 +0,0 @@
(function (
document,
GRAPHENE_SETTINGS,
GraphiQL,
React,
ReactDOM,
SubscriptionsTransportWs,
fetch,
history,
location,
) {
// Parse the cookie value for a CSRF token
var csrftoken;
var cookies = ("; " + document.cookie).split("; csrftoken=");
if (cookies.length == 2) {
csrftoken = cookies.pop().split(";").shift();
} else {
csrftoken = document.querySelector("[name=csrfmiddlewaretoken]").value;
}
// Collect the URL parameters
var parameters = {};
location.hash
.substr(1)
.split("&")
.forEach(function (entry) {
var eq = entry.indexOf("=");
if (eq >= 0) {
parameters[decodeURIComponent(entry.slice(0, eq))] = decodeURIComponent(
entry.slice(eq + 1),
);
}
});
// Produce a Location fragment string from a parameter object.
function locationQuery(params) {
return (
"#" +
Object.keys(params)
.map(function (key) {
return (
encodeURIComponent(key) + "=" + encodeURIComponent(params[key])
);
})
.join("&")
);
}
// Derive a fetch URL from the current URL, sans the GraphQL parameters.
var graphqlParamNames = {
query: true,
variables: true,
operationName: true,
};
var otherParams = {};
for (var k in parameters) {
if (parameters.hasOwnProperty(k) && graphqlParamNames[k] !== true) {
otherParams[k] = parameters[k];
}
}
var fetchURL = locationQuery(otherParams);
// Defines a GraphQL fetcher using the fetch API.
function httpClient(graphQLParams, opts) {
if (typeof opts === 'undefined') {
opts = {};
}
var headers = opts.headers || {};
headers['Accept'] = headers['Accept'] || 'application/json';
headers['Content-Type'] = headers['Content-Type'] || 'application/json';
if (csrftoken) {
headers['X-CSRFToken'] = csrftoken
}
return fetch(fetchURL, {
method: "post",
headers: headers,
body: JSON.stringify(graphQLParams),
credentials: "include",
})
.then(function (response) {
return response.text();
})
.then(function (responseBody) {
try {
return JSON.parse(responseBody);
} catch (error) {
return responseBody;
}
});
}
// Derive the subscription URL. If the SUBSCRIPTION_URL setting is specified, uses that value. Otherwise
// assumes the current window location with an appropriate websocket protocol.
var subscribeURL =
location.origin.replace(/^http/, "ws") +
(GRAPHENE_SETTINGS.subscriptionPath || location.pathname);
// Create a subscription client.
var subscriptionClient = new SubscriptionsTransportWs.SubscriptionClient(
subscribeURL,
{
// Reconnect after any interruptions.
reconnect: true,
// Delay socket initialization until the first subscription is started.
lazy: true,
},
);
// Keep a reference to the currently-active subscription, if available.
var activeSubscription = null;
// Define a GraphQL fetcher that can intelligently route queries based on the operation type.
function graphQLFetcher(graphQLParams, opts) {
var operationType = getOperationType(graphQLParams);
// If we're about to execute a new operation, and we have an active subscription,
// unsubscribe before continuing.
if (activeSubscription) {
activeSubscription.unsubscribe();
activeSubscription = null;
}
if (operationType === "subscription") {
return {
subscribe: function (observer) {
activeSubscription = subscriptionClient;
return subscriptionClient.request(graphQLParams, opts).subscribe(observer);
},
};
} else {
return httpClient(graphQLParams, opts);
}
}
// Determine the type of operation being executed for a given set of GraphQL parameters.
function getOperationType(graphQLParams) {
// Run a regex against the query to determine the operation type (query, mutation, subscription).
var operationRegex = new RegExp(
// Look for lines that start with an operation keyword, ignoring whitespace.
"^\\s*(query|mutation|subscription)\\s*" +
// The operation keyword should be followed by whitespace and the operationName in the GraphQL parameters (if available).
(graphQLParams.operationName ? ("\\s+" + graphQLParams.operationName) : "") +
// The line should eventually encounter an opening curly brace.
"[^\\{]*\\{",
// Enable multiline matching.
"m",
);
var match = operationRegex.exec(graphQLParams.query);
if (!match) {
return "query";
}
return match[1];
}
// When the query and variables string is edited, update the URL bar so
// that it can be easily shared.
function onEditQuery(newQuery) {
parameters.query = newQuery;
updateURL();
}
function onEditVariables(newVariables) {
parameters.variables = newVariables;
updateURL();
}
function onEditOperationName(newOperationName) {
parameters.operationName = newOperationName;
updateURL();
}
function updateURL() {
history.replaceState(null, null, locationQuery(parameters));
}
var options = {
fetcher: graphQLFetcher,
onEditQuery: onEditQuery,
onEditVariables: onEditVariables,
onEditOperationName: onEditOperationName,
headerEditorEnabled: GRAPHENE_SETTINGS.graphiqlHeaderEditorEnabled,
query: parameters.query,
};
if (parameters.variables) {
options.variables = parameters.variables;
}
if (parameters.operation_name) {
options.operationName = parameters.operation_name;
}
// Render <GraphiQL /> into the body.
ReactDOM.render(
React.createElement(GraphiQL, options),
document.getElementById("editor"),
);
})(
document,
window.GRAPHENE_SETTINGS,
window.GraphiQL,
window.React,
window.ReactDOM,
window.SubscriptionsTransportWs,
window.fetch,
window.history,
window.location,
);

View File

@ -1,53 +0,0 @@
<!--
The request to this GraphQL server provided the header "Accept: text/html"
and as a result has been presented GraphiQL - an in-browser IDE for
exploring GraphQL.
If you wish to receive JSON, provide the header "Accept: application/json" or
add "&raw" to the end of the URL within a browser.
-->
{% load static %}
<!DOCTYPE html>
<html>
<head>
<style>
html, body, #editor {
height: 100%;
margin: 0;
overflow: hidden;
width: 100%;
}
</style>
<link href="https://cdn.jsdelivr.net/npm/graphiql@{{graphiql_version}}/graphiql.min.css"
integrity="{{graphiql_css_sri}}"
rel="stylesheet"
crossorigin="anonymous" />
<script src="https://cdn.jsdelivr.net/npm/whatwg-fetch@{{whatwg_fetch_version}}/dist/fetch.umd.js"
integrity="{{whatwg_fetch_sri}}"
crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/react@{{react_version}}/umd/react.production.min.js"
integrity="{{react_sri}}"
crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/react-dom@{{react_version}}/umd/react-dom.production.min.js"
integrity="{{react_dom_sri}}"
crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/graphiql@{{graphiql_version}}/graphiql.min.js"
integrity="{{graphiql_sri}}"
crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/subscriptions-transport-ws@{{subscriptions_transport_ws_version}}/browser/client.js"
integrity="{{subscriptions_transport_ws_sri}}"
crossorigin="anonymous"></script>
</head>
<body>
<div id="editor"></div>
{% csrf_token %}
<script type="application/javascript">
window.GRAPHENE_SETTINGS = {
{% if subscription_path %}
subscriptionPath: "{{subscription_path}}",
{% endif %}
graphiqlHeaderEditorEnabled: {{ graphiql_header_editor_enabled|yesno:"true,false" }},
};
</script>
<script src="{% static 'graphene_django/graphiql.js' %}"></script>
</body>
</html>

View File

@ -1,16 +0,0 @@
from django import forms
from django.core.exceptions import ValidationError
from .models import Pet
class PetForm(forms.ModelForm):
class Meta:
model = Pet
fields = "__all__"
def clean_age(self):
age = self.cleaned_data["age"]
if age >= 99:
raise ValidationError("Too old")
return age

View File

@ -1,44 +0,0 @@
# https://github.com/graphql-python/graphene-django/issues/520
import datetime
from django import forms
import graphene
from graphene import Field, ResolveInfo
from graphene.types.inputobjecttype import InputObjectType
from py.test import raises
from py.test import mark
from rest_framework import serializers
from ...types import DjangoObjectType
from ...rest_framework.models import MyFakeModel
from ...rest_framework.mutation import SerializerMutation
from ...forms.mutation import DjangoFormMutation
class MyModelSerializer(serializers.ModelSerializer):
class Meta:
model = MyFakeModel
fields = "__all__"
class MyForm(forms.Form):
text = forms.CharField()
def test_can_use_form_and_serializer_mutations():
class MyMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
class MyFormMutation(DjangoFormMutation):
class Meta:
form_class = MyForm
class Mutation(graphene.ObjectType):
my_mutation = MyMutation.Field()
my_form_mutation = MyFormMutation.Field()
graphene.Schema(mutation=Mutation)

View File

@ -1,119 +0,0 @@
from __future__ import absolute_import
from django.db import models
from django.utils.translation import gettext_lazy as _
CHOICES = ((1, "this"), (2, _("that")))
class Person(models.Model):
name = models.CharField(max_length=30)
class Pet(models.Model):
name = models.CharField(max_length=30)
age = models.PositiveIntegerField()
class FilmDetails(models.Model):
location = models.CharField(max_length=30)
film = models.OneToOneField(
"Film", on_delete=models.CASCADE, related_name="details"
)
class Film(models.Model):
genre = models.CharField(
max_length=2,
help_text="Genre",
choices=[("do", "Documentary"), ("ac", "Action"), ("ot", "Other")],
default="ot",
)
reporters = models.ManyToManyField("Reporter", related_name="films")
class DoeReporterManager(models.Manager):
def get_queryset(self):
return super(DoeReporterManager, self).get_queryset().filter(last_name="Doe")
class Reporter(models.Model):
first_name = models.CharField(max_length=30)
last_name = models.CharField(max_length=30)
email = models.EmailField()
pets = models.ManyToManyField("self")
a_choice = models.CharField(max_length=30, choices=CHOICES, blank=True)
objects = models.Manager()
doe_objects = DoeReporterManager()
reporter_type = models.IntegerField(
"Reporter Type",
null=True,
blank=True,
choices=[(1, "Regular"), (2, "CNN Reporter")],
)
def __str__(self): # __unicode__ on Python 2
return "%s %s" % (self.first_name, self.last_name)
def __init__(self, *args, **kwargs):
"""
Override the init method so that during runtime, Django
can know that this object can be a CNNReporter by casting
it to the proxy model. Otherwise, as far as Django knows,
when a CNNReporter is pulled from the database, it is still
of type Reporter. This was added to test proxy model support.
"""
super(Reporter, self).__init__(*args, **kwargs)
if self.reporter_type == 2: # quick and dirty way without enums
self.__class__ = CNNReporter
def some_method(self):
return 123
class CNNReporterManager(models.Manager):
def get_queryset(self):
return super(CNNReporterManager, self).get_queryset().filter(reporter_type=2)
class CNNReporter(Reporter):
"""
This class is a proxy model for Reporter, used for testing
proxy model support
"""
class Meta:
proxy = True
objects = CNNReporterManager()
class Article(models.Model):
headline = models.CharField(max_length=100)
pub_date = models.DateField(auto_now_add=True)
pub_date_time = models.DateTimeField(auto_now_add=True)
reporter = models.ForeignKey(
Reporter, on_delete=models.CASCADE, related_name="articles"
)
editor = models.ForeignKey(
Reporter, on_delete=models.CASCADE, related_name="edited_articles_+"
)
lang = models.CharField(
max_length=2,
help_text="Language",
choices=[("es", "Spanish"), ("en", "English")],
default="es",
)
importance = models.IntegerField(
"Importance",
null=True,
blank=True,
choices=[(1, "Very important"), (2, "Not as important")],
)
def __str__(self): # __unicode__ on Python 2
return self.headline
class Meta:
ordering = ("headline",)

View File

@ -1,18 +0,0 @@
from graphene import Field
from graphene_django.forms.mutation import DjangoFormMutation, DjangoModelFormMutation
from .forms import PetForm
from .types import PetType
class PetFormMutation(DjangoFormMutation):
class Meta:
form_class = PetForm
class PetMutation(DjangoModelFormMutation):
pet = Field(PetType)
class Meta:
form_class = PetForm

View File

@ -1,40 +0,0 @@
import graphene
from graphene import Schema, relay
from ..types import DjangoObjectType
from .models import Article, Reporter
class Character(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (relay.Node,)
fields = "__all__"
def get_node(self, info, id):
pass
class Human(DjangoObjectType):
raises = graphene.String()
class Meta:
model = Article
interfaces = (relay.Node,)
fields = "__all__"
def resolve_raises(self, info):
raise Exception("This field should raise exception")
def get_node(self, info, id):
pass
class Query(graphene.ObjectType):
human = graphene.Field(Human)
def resolve_human(self, info):
return Human()
schema = Schema(query=Query)

View File

@ -1,32 +0,0 @@
import graphene
from graphene import ObjectType, Schema
from .mutations import PetFormMutation, PetMutation
class QueryRoot(ObjectType):
thrower = graphene.String(required=True)
request = graphene.String(required=True)
test = graphene.String(who=graphene.String())
def resolve_thrower(self, info):
raise Exception("Throws!")
def resolve_request(self, info):
return info.context.GET.get("q")
def resolve_test(self, info, who=None):
return "Hello %s" % (who or "World")
class MutationRoot(ObjectType):
pet_form_mutation = PetFormMutation.Field()
pet_mutation = PetMutation.Field()
write_test = graphene.Field(QueryRoot)
def resolve_write_test(self, info):
return QueryRoot()
schema = Schema(query=QueryRoot, mutation=MutationRoot)

View File

@ -1,58 +0,0 @@
from textwrap import dedent
from django.core import management
from io import StringIO
from mock import mock_open, patch
from graphene import ObjectType, Schema, String
@patch("graphene_django.management.commands.graphql_schema.Command.save_json_file")
def test_generate_json_file_on_call_graphql_schema(savefile_mock):
out = StringIO()
management.call_command("graphql_schema", schema="", stdout=out)
assert "Successfully dumped GraphQL schema to schema.json" in out.getvalue()
@patch("json.dump")
def test_json_files_are_canonical(dump_mock):
open_mock = mock_open()
with patch("graphene_django.management.commands.graphql_schema.open", open_mock):
management.call_command("graphql_schema", schema="")
open_mock.assert_called_once()
dump_mock.assert_called_once()
assert dump_mock.call_args[1][
"sort_keys"
], "json.mock() should be used to sort the output"
assert (
dump_mock.call_args[1]["indent"] > 0
), "output should be pretty-printed by default"
def test_generate_graphql_file_on_call_graphql_schema():
class Query(ObjectType):
hi = String()
mock_schema = Schema(query=Query)
open_mock = mock_open()
with patch("graphene_django.management.commands.graphql_schema.open", open_mock):
management.call_command(
"graphql_schema", schema=mock_schema, out="schema.graphql"
)
open_mock.assert_called_once()
handle = open_mock()
assert handle.write.called_once()
schema_output = handle.write.call_args[0][0]
assert schema_output == dedent(
"""\
type Query {
hi: String
}
"""
)

View File

@ -1,468 +0,0 @@
from collections import namedtuple
import pytest
from django.db import models
from django.utils.translation import gettext_lazy as _
from py.test import raises
import graphene
from graphene import NonNull
from graphene.relay import ConnectionField, Node
from graphene.types.datetime import Date, DateTime, Time
from graphene.types.json import JSONString
from ..compat import (
ArrayField,
HStoreField,
JSONField,
PGJSONField,
MissingType,
RangeField,
)
from ..converter import (
convert_django_field,
convert_django_field_with_choices,
generate_enum_name,
)
from ..registry import Registry
from ..types import DjangoObjectType
from .models import Article, Film, FilmDetails, Reporter
# from graphene.core.types.custom_scalars import DateTime, Time, JSONString
def assert_conversion(django_field, graphene_field, *args, **kwargs):
_kwargs = kwargs.copy()
if "null" not in kwargs:
_kwargs["null"] = True
field = django_field(help_text="Custom Help Text", *args, **_kwargs)
graphene_type = convert_django_field(field)
assert isinstance(graphene_type, graphene_field)
field = graphene_type.Field()
assert field.description == "Custom Help Text"
_kwargs = kwargs.copy()
if "null" not in kwargs:
_kwargs["null"] = False
nonnull_field = django_field(*args, **_kwargs)
if not nonnull_field.null:
nonnull_graphene_type = convert_django_field(nonnull_field)
nonnull_field = nonnull_graphene_type.Field()
assert isinstance(nonnull_field.type, graphene.NonNull)
return nonnull_field
return field
def test_should_unknown_django_field_raise_exception():
with raises(Exception) as excinfo:
convert_django_field(None)
assert "Don't know how to convert the Django field" in str(excinfo.value)
def test_should_date_time_convert_string():
assert_conversion(models.DateTimeField, DateTime)
def test_should_date_convert_string():
assert_conversion(models.DateField, Date)
def test_should_time_convert_string():
assert_conversion(models.TimeField, Time)
def test_should_char_convert_string():
assert_conversion(models.CharField, graphene.String)
def test_should_text_convert_string():
assert_conversion(models.TextField, graphene.String)
def test_should_email_convert_string():
assert_conversion(models.EmailField, graphene.String)
def test_should_slug_convert_string():
assert_conversion(models.SlugField, graphene.String)
def test_should_url_convert_string():
assert_conversion(models.URLField, graphene.String)
def test_should_ipaddress_convert_string():
assert_conversion(models.GenericIPAddressField, graphene.String)
def test_should_file_convert_string():
assert_conversion(models.FileField, graphene.String)
def test_should_image_convert_string():
assert_conversion(models.ImageField, graphene.String)
def test_should_file_path_field_convert_string():
assert_conversion(models.FilePathField, graphene.String)
def test_should_auto_convert_id():
assert_conversion(models.AutoField, graphene.ID, primary_key=True)
def test_should_big_auto_convert_id():
assert_conversion(models.BigAutoField, graphene.ID, primary_key=True)
def test_should_small_auto_convert_id():
if hasattr(models, "SmallAutoField"):
assert_conversion(models.SmallAutoField, graphene.ID, primary_key=True)
def test_should_uuid_convert_id():
assert_conversion(models.UUIDField, graphene.UUID)
def test_should_auto_convert_duration():
assert_conversion(models.DurationField, graphene.Float)
def test_should_positive_integer_convert_int():
assert_conversion(models.PositiveIntegerField, graphene.Int)
def test_should_positive_small_convert_int():
assert_conversion(models.PositiveSmallIntegerField, graphene.Int)
def test_should_small_integer_convert_int():
assert_conversion(models.SmallIntegerField, graphene.Int)
def test_should_big_integer_convert_int():
assert_conversion(models.BigIntegerField, graphene.Int)
def test_should_integer_convert_int():
assert_conversion(models.IntegerField, graphene.Int)
def test_should_boolean_convert_boolean():
assert_conversion(models.BooleanField, graphene.Boolean, null=True)
def test_should_boolean_convert_non_null_boolean():
field = assert_conversion(models.BooleanField, graphene.Boolean, null=False)
assert isinstance(field.type, graphene.NonNull)
assert field.type.of_type == graphene.Boolean
def test_should_nullboolean_convert_boolean():
assert_conversion(models.NullBooleanField, graphene.Boolean)
def test_field_with_choices_convert_enum():
field = models.CharField(
help_text="Language", choices=(("es", "Spanish"), ("en", "English"))
)
class TranslatedModel(models.Model):
language = field
class Meta:
app_label = "test"
graphene_type = convert_django_field_with_choices(field).type.of_type
assert graphene_type._meta.name == "TestTranslatedModelLanguageChoices"
assert graphene_type._meta.enum.__members__["ES"].value == "es"
assert graphene_type._meta.enum.__members__["ES"].description == "Spanish"
assert graphene_type._meta.enum.__members__["EN"].value == "en"
assert graphene_type._meta.enum.__members__["EN"].description == "English"
def test_field_with_grouped_choices():
field = models.CharField(
help_text="Language",
choices=(("Europe", (("es", "Spanish"), ("en", "English"))),),
)
class GroupedChoicesModel(models.Model):
language = field
class Meta:
app_label = "test"
convert_django_field_with_choices(field)
def test_field_with_choices_gettext():
field = models.CharField(
help_text="Language", choices=(("es", _("Spanish")), ("en", _("English")))
)
class TranslatedChoicesModel(models.Model):
language = field
class Meta:
app_label = "test"
convert_django_field_with_choices(field)
def test_field_with_choices_collision():
field = models.CharField(
help_text="Timezone",
choices=(
("Etc/GMT+1+2", "Fake choice to produce double collision"),
("Etc/GMT+1", "Greenwich Mean Time +1"),
("Etc/GMT-1", "Greenwich Mean Time -1"),
),
)
class CollisionChoicesModel(models.Model):
timezone = field
class Meta:
app_label = "test"
convert_django_field_with_choices(field)
def test_field_with_choices_convert_enum_false():
field = models.CharField(
help_text="Language", choices=(("es", "Spanish"), ("en", "English"))
)
class TranslatedModel(models.Model):
language = field
class Meta:
app_label = "test"
graphene_type = convert_django_field_with_choices(
field, convert_choices_to_enum=False
)
assert isinstance(graphene_type, graphene.String)
def test_should_float_convert_float():
assert_conversion(models.FloatField, graphene.Float)
def test_should_float_convert_decimal():
assert_conversion(models.DecimalField, graphene.Decimal)
def test_should_manytomany_convert_connectionorlist():
registry = Registry()
dynamic_field = convert_django_field(Reporter._meta.local_many_to_many[0], registry)
assert not dynamic_field.get_type()
def test_should_manytomany_convert_connectionorlist_list():
class A(DjangoObjectType):
class Meta:
model = Reporter
fields = "__all__"
graphene_field = convert_django_field(
Reporter._meta.local_many_to_many[0], A._meta.registry
)
assert isinstance(graphene_field, graphene.Dynamic)
dynamic_field = graphene_field.get_type()
assert isinstance(dynamic_field, graphene.Field)
# A NonNull List of NonNull A ([A!]!)
# https://github.com/graphql-python/graphene-django/issues/448
assert isinstance(dynamic_field.type, NonNull)
assert isinstance(dynamic_field.type.of_type, graphene.List)
assert isinstance(dynamic_field.type.of_type.of_type, NonNull)
assert dynamic_field.type.of_type.of_type.of_type == A
def test_should_manytomany_convert_connectionorlist_connection():
class A(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
graphene_field = convert_django_field(
Reporter._meta.local_many_to_many[0], A._meta.registry
)
assert isinstance(graphene_field, graphene.Dynamic)
dynamic_field = graphene_field.get_type()
assert isinstance(dynamic_field, ConnectionField)
assert dynamic_field.type.of_type == A._meta.connection
def test_should_manytoone_convert_connectionorlist():
class A(DjangoObjectType):
class Meta:
model = Article
fields = "__all__"
graphene_field = convert_django_field(Reporter.articles.rel, A._meta.registry)
assert isinstance(graphene_field, graphene.Dynamic)
dynamic_field = graphene_field.get_type()
assert isinstance(dynamic_field, graphene.Field)
# a NonNull List of NonNull A ([A!]!)
assert isinstance(dynamic_field.type, NonNull)
assert isinstance(dynamic_field.type.of_type, graphene.List)
assert isinstance(dynamic_field.type.of_type.of_type, NonNull)
assert dynamic_field.type.of_type.of_type.of_type == A
def test_should_onetoone_reverse_convert_model():
class A(DjangoObjectType):
class Meta:
model = FilmDetails
fields = "__all__"
graphene_field = convert_django_field(Film.details.related, A._meta.registry)
assert isinstance(graphene_field, graphene.Dynamic)
dynamic_field = graphene_field.get_type()
assert isinstance(dynamic_field, graphene.Field)
assert dynamic_field.type == A
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_should_postgres_array_convert_list():
field = assert_conversion(
ArrayField, graphene.List, models.CharField(max_length=100)
)
assert isinstance(field.type, graphene.NonNull)
assert isinstance(field.type.of_type, graphene.List)
assert isinstance(field.type.of_type.of_type, graphene.NonNull)
assert field.type.of_type.of_type.of_type == graphene.String
field = assert_conversion(
ArrayField, graphene.List, models.CharField(max_length=100, null=True)
)
assert isinstance(field.type, graphene.NonNull)
assert isinstance(field.type.of_type, graphene.List)
assert field.type.of_type.of_type == graphene.String
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_should_postgres_array_multiple_convert_list():
field = assert_conversion(
ArrayField, graphene.List, ArrayField(models.CharField(max_length=100))
)
assert isinstance(field.type, graphene.NonNull)
assert isinstance(field.type.of_type, graphene.List)
assert isinstance(field.type.of_type.of_type, graphene.List)
assert isinstance(field.type.of_type.of_type.of_type, graphene.NonNull)
assert field.type.of_type.of_type.of_type.of_type == graphene.String
field = assert_conversion(
ArrayField,
graphene.List,
ArrayField(models.CharField(max_length=100, null=True)),
)
assert isinstance(field.type, graphene.NonNull)
assert isinstance(field.type.of_type, graphene.List)
assert isinstance(field.type.of_type.of_type, graphene.List)
assert field.type.of_type.of_type.of_type == graphene.String
@pytest.mark.skipif(HStoreField is MissingType, reason="HStoreField should exist")
def test_should_postgres_hstore_convert_string():
assert_conversion(HStoreField, JSONString)
@pytest.mark.skipif(PGJSONField is MissingType, reason="PGJSONField should exist")
def test_should_postgres_json_convert_string():
assert_conversion(PGJSONField, JSONString)
@pytest.mark.skipif(JSONField is MissingType, reason="JSONField should exist")
def test_should_json_convert_string():
assert_conversion(JSONField, JSONString)
@pytest.mark.skipif(RangeField is MissingType, reason="RangeField should exist")
def test_should_postgres_range_convert_list():
from django.contrib.postgres.fields import IntegerRangeField
field = assert_conversion(IntegerRangeField, graphene.List)
assert isinstance(field.type, graphene.NonNull)
assert isinstance(field.type.of_type, graphene.List)
assert isinstance(field.type.of_type.of_type, graphene.NonNull)
assert field.type.of_type.of_type.of_type == graphene.Int
def test_generate_enum_name():
MockDjangoModelMeta = namedtuple("DjangoMeta", ["app_label", "object_name"])
# Simple case
field = graphene.Field(graphene.String, name="type")
model_meta = MockDjangoModelMeta(app_label="users", object_name="User")
assert generate_enum_name(model_meta, field) == "UsersUserTypeChoices"
# More complicated multiple work case
field = graphene.Field(graphene.String, name="fizz_buzz")
model_meta = MockDjangoModelMeta(
app_label="some_long_app_name", object_name="SomeObject"
)
assert (
generate_enum_name(model_meta, field)
== "SomeLongAppNameSomeObjectFizzBuzzChoices"
)
def test_generate_v2_enum_name(graphene_settings):
MockDjangoModelMeta = namedtuple("DjangoMeta", ["app_label", "object_name"])
graphene_settings.DJANGO_CHOICE_FIELD_ENUM_V2_NAMING = True
# Simple case
field = graphene.Field(graphene.String, name="type")
model_meta = MockDjangoModelMeta(app_label="users", object_name="User")
assert generate_enum_name(model_meta, field) == "UserType"
# More complicated multiple work case
field = graphene.Field(graphene.String, name="fizz_buzz")
model_meta = MockDjangoModelMeta(
app_label="some_long_app_name", object_name="SomeObject"
)
assert generate_enum_name(model_meta, field) == "SomeObjectFizzBuzz"
def test_choice_enum_blank_value():
"""Test that choice fields with blank values work"""
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
fields = (
"first_name",
"a_choice",
)
class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType)
def resolve_reporter(root, info):
return Reporter.objects.first()
schema = graphene.Schema(query=Query)
# Create model with empty choice option
Reporter.objects.create(
first_name="Bridget", last_name="Jones", email="bridget@example.com"
)
result = schema.execute(
"""
query {
reporter {
firstName
aChoice
}
}
"""
)
assert not result.errors
assert result.data == {
"reporter": {"firstName": "Bridget", "aChoice": None},
}

View File

@ -1,502 +0,0 @@
import datetime
from django.db.models import Count
import pytest
from graphene import List, NonNull, ObjectType, Schema, String
from ..fields import DjangoListField
from ..types import DjangoObjectType
from .models import Article as ArticleModel
from .models import Reporter as ReporterModel
class TestDjangoListField:
def test_only_django_object_types(self):
class TestType(ObjectType):
foo = String()
with pytest.raises(AssertionError):
list_field = DjangoListField(TestType)
def test_only_import_paths(self):
list_field = DjangoListField("graphene_django.tests.schema.Human")
from .schema import Human
assert list_field._type.of_type.of_type is Human
def test_non_null_type(self):
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name",)
list_field = DjangoListField(NonNull(Reporter))
assert isinstance(list_field.type, List)
assert isinstance(list_field.type.of_type, NonNull)
assert list_field.type.of_type.of_type is Reporter
def test_get_django_model(self):
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name",)
list_field = DjangoListField(Reporter)
assert list_field.model is ReporterModel
def test_list_field_default_queryset(self):
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name",)
class Query(ObjectType):
reporters = DjangoListField(Reporter)
schema = Schema(query=Query)
query = """
query {
reporters {
firstName
}
}
"""
ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
result = schema.execute(query)
assert not result.errors
assert result.data == {
"reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}]
}
def test_list_field_queryset_is_not_cached(self):
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name",)
class Query(ObjectType):
reporters = DjangoListField(Reporter)
schema = Schema(query=Query)
query = """
query {
reporters {
firstName
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data == {"reporters": []}
ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
result = schema.execute(query)
assert not result.errors
assert result.data == {
"reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}]
}
def test_override_resolver(self):
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name",)
class Query(ObjectType):
reporters = DjangoListField(Reporter)
def resolve_reporters(_, info):
return ReporterModel.objects.filter(first_name="Tara")
schema = Schema(query=Query)
query = """
query {
reporters {
firstName
}
}
"""
ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
result = schema.execute(query)
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Tara"}]}
def test_nested_list_field(self):
class Article(DjangoObjectType):
class Meta:
model = ArticleModel
fields = ("headline",)
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name", "articles")
class Query(ObjectType):
reporters = DjangoListField(Reporter)
schema = Schema(query=Query)
query = """
query {
reporters {
firstName
articles {
headline
}
}
}
"""
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
ArticleModel.objects.create(
headline="Amazing news",
reporter=r1,
pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
editor=r1,
)
ArticleModel.objects.create(
headline="Not so good news",
reporter=r1,
pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
editor=r1,
)
result = schema.execute(query)
assert not result.errors
assert result.data == {
"reporters": [
{
"firstName": "Tara",
"articles": [
{"headline": "Amazing news"},
{"headline": "Not so good news"},
],
},
{"firstName": "Debra", "articles": []},
]
}
def test_override_resolver_nested_list_field(self):
class Article(DjangoObjectType):
class Meta:
model = ArticleModel
fields = ("headline",)
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name", "articles")
def resolve_articles(reporter, info):
return reporter.articles.filter(headline__contains="Amazing")
class Query(ObjectType):
reporters = DjangoListField(Reporter)
schema = Schema(query=Query)
query = """
query {
reporters {
firstName
articles {
headline
}
}
}
"""
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
ArticleModel.objects.create(
headline="Amazing news",
reporter=r1,
pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
editor=r1,
)
ArticleModel.objects.create(
headline="Not so good news",
reporter=r1,
pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
editor=r1,
)
result = schema.execute(query)
assert not result.errors
assert result.data == {
"reporters": [
{"firstName": "Tara", "articles": [{"headline": "Amazing news"}]},
{"firstName": "Debra", "articles": []},
]
}
def test_get_queryset_filter(self):
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name", "articles")
@classmethod
def get_queryset(cls, queryset, info):
# Only get reporters with at least 1 article
return queryset.annotate(article_count=Count("articles")).filter(
article_count__gt=0
)
class Query(ObjectType):
reporters = DjangoListField(Reporter)
def resolve_reporters(_, info):
return ReporterModel.objects.all()
schema = Schema(query=Query)
query = """
query {
reporters {
firstName
}
}
"""
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
ArticleModel.objects.create(
headline="Amazing news",
reporter=r1,
pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
editor=r1,
)
result = schema.execute(query)
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Tara"}]}
def test_resolve_list(self):
"""Resolving a plain list should work (and not call get_queryset)"""
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name", "articles")
@classmethod
def get_queryset(cls, queryset, info):
# Only get reporters with at least 1 article
return queryset.annotate(article_count=Count("articles")).filter(
article_count__gt=0
)
class Query(ObjectType):
reporters = DjangoListField(Reporter)
def resolve_reporters(_, info):
return [ReporterModel.objects.get(first_name="Debra")]
schema = Schema(query=Query)
query = """
query {
reporters {
firstName
}
}
"""
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
ArticleModel.objects.create(
headline="Amazing news",
reporter=r1,
pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
editor=r1,
)
result = schema.execute(query)
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Debra"}]}
def test_get_queryset_foreign_key(self):
class Article(DjangoObjectType):
class Meta:
model = ArticleModel
fields = ("headline",)
@classmethod
def get_queryset(cls, queryset, info):
# Rose tinted glasses
return queryset.exclude(headline__contains="Not so good")
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name", "articles")
class Query(ObjectType):
reporters = DjangoListField(Reporter)
schema = Schema(query=Query)
query = """
query {
reporters {
firstName
articles {
headline
}
}
}
"""
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
ArticleModel.objects.create(
headline="Amazing news",
reporter=r1,
pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
editor=r1,
)
ArticleModel.objects.create(
headline="Not so good news",
reporter=r1,
pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
editor=r1,
)
result = schema.execute(query)
assert not result.errors
assert result.data == {
"reporters": [
{"firstName": "Tara", "articles": [{"headline": "Amazing news"}]},
{"firstName": "Debra", "articles": []},
]
}
def test_resolve_list_external_resolver(self):
"""Resolving a plain list from external resolver should work (and not call get_queryset)"""
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name", "articles")
@classmethod
def get_queryset(cls, queryset, info):
# Only get reporters with at least 1 article
return queryset.annotate(article_count=Count("articles")).filter(
article_count__gt=0
)
def resolve_reporters(_, info):
return [ReporterModel.objects.get(first_name="Debra")]
class Query(ObjectType):
reporters = DjangoListField(Reporter, resolver=resolve_reporters)
schema = Schema(query=Query)
query = """
query {
reporters {
firstName
}
}
"""
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
ArticleModel.objects.create(
headline="Amazing news",
reporter=r1,
pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
editor=r1,
)
result = schema.execute(query)
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Debra"}]}
def test_get_queryset_filter_external_resolver(self):
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name", "articles")
@classmethod
def get_queryset(cls, queryset, info):
# Only get reporters with at least 1 article
return queryset.annotate(article_count=Count("articles")).filter(
article_count__gt=0
)
def resolve_reporters(_, info):
return ReporterModel.objects.all()
class Query(ObjectType):
reporters = DjangoListField(Reporter, resolver=resolve_reporters)
schema = Schema(query=Query)
query = """
query {
reporters {
firstName
}
}
"""
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
ArticleModel.objects.create(
headline="Amazing news",
reporter=r1,
pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
editor=r1,
)
result = schema.execute(query)
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Tara"}]}

View File

@ -1,40 +0,0 @@
from django.core.exceptions import ValidationError
from py.test import raises
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
# 'TXlUeXBlOmFiYw==' -> 'MyType', 'abc'
def test_global_id_valid():
field = GlobalIDFormField()
field.clean("TXlUeXBlOmFiYw==")
def test_global_id_invalid():
field = GlobalIDFormField()
with raises(ValidationError):
field.clean("badvalue")
def test_global_id_multiple_valid():
field = GlobalIDMultipleChoiceField()
field.clean(["TXlUeXBlOmFiYw==", "TXlUeXBlOmFiYw=="])
def test_global_id_multiple_invalid():
field = GlobalIDMultipleChoiceField()
with raises(ValidationError):
field.clean(["badvalue", "another bad avue"])
def test_global_id_none():
field = GlobalIDFormField()
with raises(ValidationError):
field.clean(None)
def test_global_id_none_optional():
field = GlobalIDFormField(required=False)
field.clean(None)

File diff suppressed because it is too large Load Diff

View File

@ -1,55 +0,0 @@
from py.test import raises
from ..registry import Registry
from ..types import DjangoObjectType
from .models import Reporter
def test_should_raise_if_no_model():
with raises(Exception) as excinfo:
class Character1(DjangoObjectType):
fields = "__all__"
assert "valid Django Model" in str(excinfo.value)
def test_should_raise_if_model_is_invalid():
with raises(Exception) as excinfo:
class Character2(DjangoObjectType):
class Meta:
model = 1
fields = "__all__"
assert "valid Django Model" in str(excinfo.value)
def test_should_map_fields_correctly():
class ReporterType2(DjangoObjectType):
class Meta:
model = Reporter
registry = Registry()
fields = "__all__"
fields = list(ReporterType2._meta.fields.keys())
assert fields[:-2] == [
"id",
"first_name",
"last_name",
"email",
"pets",
"a_choice",
"reporter_type",
]
assert sorted(fields[-2:]) == ["articles", "films"]
def test_should_map_only_few_fields():
class Reporter2(DjangoObjectType):
class Meta:
model = Reporter
fields = ("id", "email")
assert list(Reporter2._meta.fields.keys()) == ["id", "email"]

View File

@ -1,691 +0,0 @@
from collections import OrderedDict, defaultdict
from textwrap import dedent
import pytest
from django.db import models
from mock import patch
from graphene import Connection, Field, Interface, ObjectType, Schema, String
from graphene.relay import Node
from .. import registry
from ..filter import DjangoFilterConnectionField
from ..types import DjangoObjectType, DjangoObjectTypeOptions
from .models import Article as ArticleModel
from .models import Reporter as ReporterModel
class Reporter(DjangoObjectType):
"""Reporter description"""
class Meta:
model = ReporterModel
fields = "__all__"
class ArticleConnection(Connection):
"""Article Connection"""
test = String()
def resolve_test():
return "test"
class Meta:
abstract = True
class Article(DjangoObjectType):
"""Article description"""
class Meta:
model = ArticleModel
interfaces = (Node,)
connection_class = ArticleConnection
fields = "__all__"
class RootQuery(ObjectType):
node = Node.Field()
schema = Schema(query=RootQuery, types=[Article, Reporter])
def test_django_interface():
assert issubclass(Node, Interface)
assert issubclass(Node, Node)
@patch("graphene_django.tests.models.Article.objects.get", return_value=Article(id=1))
def test_django_get_node(get):
article = Article.get_node(None, 1)
get.assert_called_with(pk=1)
assert article.id == 1
def test_django_objecttype_map_correct_fields():
fields = Reporter._meta.fields
fields = list(fields.keys())
assert fields[:-2] == [
"id",
"first_name",
"last_name",
"email",
"pets",
"a_choice",
"reporter_type",
]
assert sorted(fields[-2:]) == ["articles", "films"]
def test_django_objecttype_with_node_have_correct_fields():
fields = Article._meta.fields
assert list(fields.keys()) == [
"id",
"headline",
"pub_date",
"pub_date_time",
"reporter",
"editor",
"lang",
"importance",
]
def test_django_objecttype_with_custom_meta():
class ArticleTypeOptions(DjangoObjectTypeOptions):
"""Article Type Options"""
class ArticleType(DjangoObjectType):
class Meta:
abstract = True
@classmethod
def __init_subclass_with_meta__(cls, **options):
options.setdefault("_meta", ArticleTypeOptions(cls))
super(ArticleType, cls).__init_subclass_with_meta__(**options)
class Article(ArticleType):
class Meta:
model = ArticleModel
fields = "__all__"
assert isinstance(Article._meta, ArticleTypeOptions)
def test_schema_representation():
expected = dedent(
"""\
schema {
query: RootQuery
}
\"""Article description\"""
type Article implements Node {
\"""The ID of the object\"""
id: ID!
headline: String!
pubDate: Date!
pubDateTime: DateTime!
reporter: Reporter!
editor: Reporter!
\"""Language\"""
lang: TestsArticleLangChoices!
importance: TestsArticleImportanceChoices
}
\"""An object with an ID\"""
interface Node {
\"""The ID of the object\"""
id: ID!
}
\"""
The `Date` scalar type represents a Date
value as specified by
[iso8601](https://en.wikipedia.org/wiki/ISO_8601).
\"""
scalar Date
\"""
The `DateTime` scalar type represents a DateTime
value as specified by
[iso8601](https://en.wikipedia.org/wiki/ISO_8601).
\"""
scalar DateTime
\"""An enumeration.\"""
enum TestsArticleLangChoices {
\"""Spanish\"""
ES
\"""English\"""
EN
}
\"""An enumeration.\"""
enum TestsArticleImportanceChoices {
\"""Very important\"""
A_1
\"""Not as important\"""
A_2
}
\"""Reporter description\"""
type Reporter {
id: ID!
firstName: String!
lastName: String!
email: String!
pets: [Reporter!]!
aChoice: TestsReporterAChoiceChoices
reporterType: TestsReporterReporterTypeChoices
articles(offset: Int = null, before: String = null, after: String = null, first: Int = null, last: Int = null): ArticleConnection!
}
\"""An enumeration.\"""
enum TestsReporterAChoiceChoices {
\"""this\"""
A_1
\"""that\"""
A_2
}
\"""An enumeration.\"""
enum TestsReporterReporterTypeChoices {
\"""Regular\"""
A_1
\"""CNN Reporter\"""
A_2
}
type ArticleConnection {
\"""Pagination data for this connection.\"""
pageInfo: PageInfo!
\"""Contains the nodes in this connection.\"""
edges: [ArticleEdge]!
test: String
}
\"""
The Relay compliant `PageInfo` type, containing data necessary to paginate this connection.
\"""
type PageInfo {
\"""When paginating forwards, are there more items?\"""
hasNextPage: Boolean!
\"""When paginating backwards, are there more items?\"""
hasPreviousPage: Boolean!
\"""When paginating backwards, the cursor to continue.\"""
startCursor: String
\"""When paginating forwards, the cursor to continue.\"""
endCursor: String
}
\"""A Relay edge containing a `Article` and its cursor.\"""
type ArticleEdge {
\"""The item at the end of the edge\"""
node: Article
\"""A cursor for use in pagination\"""
cursor: String!
}
type RootQuery {
node(
\"""The ID of the object\"""
id: ID!
): Node
}
"""
)
assert str(schema) == expected
def with_local_registry(func):
def inner(*args, **kwargs):
old = registry.get_global_registry()
try:
retval = func(*args, **kwargs)
except Exception as e:
registry.registry = old
raise e
else:
registry.registry = old
return retval
return inner
@with_local_registry
def test_django_objecttype_only_fields():
with pytest.warns(DeprecationWarning):
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
only_fields = ("id", "email", "films")
fields = list(Reporter._meta.fields.keys())
assert fields == ["id", "email", "films"]
@with_local_registry
def test_django_objecttype_fields():
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("id", "email", "films")
fields = list(Reporter._meta.fields.keys())
assert fields == ["id", "email", "films"]
@with_local_registry
def test_django_objecttype_fields_empty():
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ()
fields = list(Reporter._meta.fields.keys())
assert fields == []
@with_local_registry
def test_django_objecttype_only_fields_and_fields():
with pytest.raises(Exception):
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
only_fields = ("id", "email", "films")
fields = ("id", "email", "films")
@with_local_registry
def test_django_objecttype_all_fields():
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = "__all__"
fields = list(Reporter._meta.fields.keys())
assert len(fields) == len(ReporterModel._meta.get_fields())
@with_local_registry
def test_django_objecttype_exclude_fields():
with pytest.warns(DeprecationWarning):
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
exclude_fields = ["email"]
fields = list(Reporter._meta.fields.keys())
assert "email" not in fields
@with_local_registry
def test_django_objecttype_exclude():
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
exclude = ["email"]
fields = list(Reporter._meta.fields.keys())
assert "email" not in fields
@with_local_registry
def test_django_objecttype_exclude_fields_and_exclude():
with pytest.raises(Exception):
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
exclude = ["email"]
exclude_fields = ["email"]
@with_local_registry
def test_django_objecttype_exclude_and_only():
with pytest.raises(AssertionError):
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
exclude = ["email"]
fields = ["id"]
@with_local_registry
def test_django_objecttype_fields_exclude_type_checking():
with pytest.raises(TypeError):
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = "foo"
with pytest.raises(TypeError):
class Reporter2(DjangoObjectType):
class Meta:
model = ReporterModel
exclude = "foo"
@with_local_registry
def test_django_objecttype_fields_exist_on_model():
with pytest.warns(UserWarning, match=r"Field name .* doesn't exist"):
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ["first_name", "foo", "email"]
with pytest.warns(
UserWarning,
match=r"Field name .* matches an attribute on Django model .* but it's not a model field",
) as record:
class Reporter2(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ["first_name", "some_method", "email"]
# Don't warn if selecting a custom field
with pytest.warns(None) as record:
class Reporter3(DjangoObjectType):
custom_field = String()
class Meta:
model = ReporterModel
fields = ["first_name", "custom_field", "email"]
assert len(record) == 0
@with_local_registry
def test_django_objecttype_exclude_fields_exist_on_model():
with pytest.warns(
UserWarning,
match=r"Django model .* does not have a field or attribute named .*",
):
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
exclude = ["foo"]
# Don't warn if selecting a custom field
with pytest.warns(
UserWarning,
match=r"Excluding the custom field .* on DjangoObjectType .* has no effect.",
):
class Reporter3(DjangoObjectType):
custom_field = String()
class Meta:
model = ReporterModel
exclude = ["custom_field"]
# Don't warn on exclude fields
with pytest.warns(None) as record:
class Reporter4(DjangoObjectType):
class Meta:
model = ReporterModel
exclude = ["email", "first_name"]
assert len(record) == 0
@with_local_registry
def test_django_objecttype_neither_fields_nor_exclude():
with pytest.warns(
DeprecationWarning,
match=r"Creating a DjangoObjectType without either the `fields` "
"or the `exclude` option is deprecated.",
):
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
with pytest.warns(None) as record:
class Reporter2(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ["email"]
assert len(record) == 0
with pytest.warns(None) as record:
class Reporter3(DjangoObjectType):
class Meta:
model = ReporterModel
exclude = ["email"]
assert len(record) == 0
def custom_enum_name(field):
return "CustomEnum{}".format(field.name.title())
class TestDjangoObjectType:
@pytest.fixture
def PetModel(self):
class PetModel(models.Model):
kind = models.CharField(choices=(("cat", "Cat"), ("dog", "Dog")))
cuteness = models.IntegerField(
choices=((1, "Kind of cute"), (2, "Pretty cute"), (3, "OMG SO CUTE!!!"))
)
yield PetModel
# Clear Django model cache so we don't get warnings when creating the
# model multiple times
PetModel._meta.apps.all_models = defaultdict(OrderedDict)
def test_django_objecttype_convert_choices_enum_false(self, PetModel):
class Pet(DjangoObjectType):
class Meta:
model = PetModel
convert_choices_to_enum = False
fields = "__all__"
class Query(ObjectType):
pet = Field(Pet)
schema = Schema(query=Query)
assert str(schema) == dedent(
"""\
type Query {
pet: Pet
}
type Pet {
id: ID!
kind: String!
cuteness: Int!
}
"""
)
def test_django_objecttype_convert_choices_enum_list(self, PetModel):
class Pet(DjangoObjectType):
class Meta:
model = PetModel
convert_choices_to_enum = ["kind"]
fields = "__all__"
class Query(ObjectType):
pet = Field(Pet)
schema = Schema(query=Query)
assert str(schema) == dedent(
"""\
type Query {
pet: Pet
}
type Pet {
id: ID!
kind: TestsPetModelKindChoices!
cuteness: Int!
}
\"""An enumeration.\"""
enum TestsPetModelKindChoices {
\"""Cat\"""
CAT
\"""Dog\"""
DOG
}
"""
)
def test_django_objecttype_convert_choices_enum_empty_list(self, PetModel):
class Pet(DjangoObjectType):
class Meta:
model = PetModel
convert_choices_to_enum = []
fields = "__all__"
class Query(ObjectType):
pet = Field(Pet)
schema = Schema(query=Query)
assert str(schema) == dedent(
"""\
type Query {
pet: Pet
}
type Pet {
id: ID!
kind: String!
cuteness: Int!
}
"""
)
def test_django_objecttype_convert_choices_enum_naming_collisions(
self, PetModel, graphene_settings
):
class PetModelKind(DjangoObjectType):
class Meta:
model = PetModel
fields = ["id", "kind"]
class Query(ObjectType):
pet = Field(PetModelKind)
schema = Schema(query=Query)
assert str(schema) == dedent(
"""\
type Query {
pet: PetModelKind
}
type PetModelKind {
id: ID!
kind: TestsPetModelKindChoices!
}
\"""An enumeration.\"""
enum TestsPetModelKindChoices {
\"""Cat\"""
CAT
\"""Dog\"""
DOG
}
"""
)
def test_django_objecttype_choices_custom_enum_name(
self, PetModel, graphene_settings
):
graphene_settings.DJANGO_CHOICE_FIELD_ENUM_CUSTOM_NAME = (
"graphene_django.tests.test_types.custom_enum_name"
)
class PetModelKind(DjangoObjectType):
class Meta:
model = PetModel
fields = ["id", "kind"]
class Query(ObjectType):
pet = Field(PetModelKind)
schema = Schema(query=Query)
assert str(schema) == dedent(
"""\
type Query {
pet: PetModelKind
}
type PetModelKind {
id: ID!
kind: CustomEnumKind!
}
\"""An enumeration.\"""
enum CustomEnumKind {
\"""Cat\"""
CAT
\"""Dog\"""
DOG
}
"""
)
@with_local_registry
def test_django_objecttype_name_connection_propagation():
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
name = "CustomReporterName"
fields = "__all__"
filter_fields = ["email"]
interfaces = (Node,)
class Query(ObjectType):
reporter = Node.Field(Reporter)
reporters = DjangoFilterConnectionField(Reporter)
assert Reporter._meta.name == "CustomReporterName"
schema = str(Schema(query=Query))
assert "type CustomReporterName implements Node {" in schema
assert "type CustomReporterNameConnection {" in schema
assert "type CustomReporterNameEdge {" in schema
assert "type Reporter implements Node {" not in schema
assert "type ReporterConnection {" not in schema
assert "type ReporterEdge {" not in schema

View File

@ -1,88 +0,0 @@
import json
import pytest
from django.utils.translation import gettext_lazy
from mock import patch
from ..utils import camelize, get_model_fields, GraphQLTestCase
from .models import Film, Reporter
from ..utils.testing import graphql_query
def test_get_model_fields_no_duplication():
reporter_fields = get_model_fields(Reporter)
reporter_name_set = set([field[0] for field in reporter_fields])
assert len(reporter_fields) == len(reporter_name_set)
film_fields = get_model_fields(Film)
film_name_set = set([field[0] for field in film_fields])
assert len(film_fields) == len(film_name_set)
def test_camelize():
assert camelize({}) == {}
assert camelize("value_a") == "value_a"
assert camelize({"value_a": "value_b"}) == {"valueA": "value_b"}
assert camelize({"value_a": ["value_b"]}) == {"valueA": ["value_b"]}
assert camelize({"value_a": ["value_b"]}) == {"valueA": ["value_b"]}
assert camelize({"nested_field": {"value_a": ["error"], "value_b": ["error"]}}) == {
"nestedField": {"valueA": ["error"], "valueB": ["error"]}
}
assert camelize({"value_a": gettext_lazy("value_b")}) == {"valueA": "value_b"}
assert camelize({"value_a": [gettext_lazy("value_b")]}) == {"valueA": ["value_b"]}
assert camelize(gettext_lazy("value_a")) == "value_a"
assert camelize({gettext_lazy("value_a"): gettext_lazy("value_b")}) == {
"valueA": "value_b"
}
assert camelize({0: {"field_a": ["errors"]}}) == {0: {"fieldA": ["errors"]}}
@pytest.mark.django_db
@patch("graphene_django.utils.testing.Client.post")
def test_graphql_test_case_operation_name(post_mock):
"""
Test that `GraphQLTestCase.query()`'s `operation_name` argument produces an `operationName` field.
"""
class TestClass(GraphQLTestCase):
GRAPHQL_SCHEMA = True
def runTest(self):
pass
tc = TestClass()
tc._pre_setup()
tc.setUpClass()
tc.query("query { }", operation_name="QueryName")
body = json.loads(post_mock.call_args.args[1])
# `operationName` field from https://graphql.org/learn/serving-over-http/#post-request
assert (
"operationName",
"QueryName",
) in body.items(), "Field 'operationName' is not present in the final request."
@pytest.mark.django_db
@patch("graphene_django.utils.testing.Client.post")
def test_graphql_query_case_operation_name(post_mock):
graphql_query("query { }", operation_name="QueryName")
body = json.loads(post_mock.call_args.args[1])
# `operationName` field from https://graphql.org/learn/serving-over-http/#post-request
assert (
"operationName",
"QueryName",
) in body.items(), "Field 'operationName' is not present in the final request."
@pytest.fixture
def client_query(client):
def func(*args, **kwargs):
return graphql_query(*args, client=client, **kwargs)
return func
def test_pytest_fixture_usage(client_query):
response = client_query("query { test }")
content = json.loads(response.content)
assert content == {"data": {"test": "Hello World"}}

View File

@ -1,834 +0,0 @@
import json
import pytest
from mock import patch
from django.db import connection
from graphene_django.settings import graphene_settings
from .models import Pet
try:
from urllib import urlencode
except ImportError:
from urllib.parse import urlencode
def url_string(string="/graphql", **url_params):
if url_params:
string += "?" + urlencode(url_params)
return string
def batch_url_string(**url_params):
return url_string("/graphql/batch", **url_params)
def response_json(response):
return json.loads(response.content.decode())
j = lambda **kwargs: json.dumps(kwargs)
jl = lambda **kwargs: json.dumps([kwargs])
def test_graphiql_is_enabled(client):
response = client.get(url_string(), HTTP_ACCEPT="text/html")
assert response.status_code == 200
assert response["Content-Type"].split(";")[0] == "text/html"
def test_qfactor_graphiql(client):
response = client.get(
url_string(query="{test}"),
HTTP_ACCEPT="application/json;q=0.8, text/html;q=0.9",
)
assert response.status_code == 200
assert response["Content-Type"].split(";")[0] == "text/html"
def test_qfactor_json(client):
response = client.get(
url_string(query="{test}"),
HTTP_ACCEPT="text/html;q=0.8, application/json;q=0.9",
)
assert response.status_code == 200
assert response["Content-Type"].split(";")[0] == "application/json"
assert response_json(response) == {"data": {"test": "Hello World"}}
def test_allows_get_with_query_param(client):
response = client.get(url_string(query="{test}"))
assert response.status_code == 200
assert response_json(response) == {"data": {"test": "Hello World"}}
def test_allows_get_with_variable_values(client):
response = client.get(
url_string(
query="query helloWho($who: String){ test(who: $who) }",
variables=json.dumps({"who": "Dolly"}),
)
)
assert response.status_code == 200
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
def test_allows_get_with_operation_name(client):
response = client.get(
url_string(
query="""
query helloYou { test(who: "You"), ...shared }
query helloWorld { test(who: "World"), ...shared }
query helloDolly { test(who: "Dolly"), ...shared }
fragment shared on QueryRoot {
shared: test(who: "Everyone")
}
""",
operationName="helloWorld",
)
)
assert response.status_code == 200
assert response_json(response) == {
"data": {"test": "Hello World", "shared": "Hello Everyone"}
}
def test_reports_validation_errors(client):
response = client.get(url_string(query="{ test, unknownOne, unknownTwo }"))
assert response.status_code == 400
assert response_json(response) == {
"errors": [
{
"message": "Cannot query field 'unknownOne' on type 'QueryRoot'.",
"locations": [{"line": 1, "column": 9}],
"path": None,
},
{
"message": "Cannot query field 'unknownTwo' on type 'QueryRoot'.",
"locations": [{"line": 1, "column": 21}],
"path": None,
},
]
}
def test_errors_when_missing_operation_name(client):
response = client.get(
url_string(
query="""
query TestQuery { test }
mutation TestMutation { writeTest { test } }
"""
)
)
assert response.status_code == 400
assert response_json(response) == {
"errors": [
{
"message": "Must provide operation name if query contains multiple operations.",
"locations": None,
"path": None,
}
]
}
def test_errors_when_sending_a_mutation_via_get(client):
response = client.get(
url_string(
query="""
mutation TestMutation { writeTest { test } }
"""
)
)
assert response.status_code == 405
assert response_json(response) == {
"errors": [
{"message": "Can only perform a mutation operation from a POST request."}
]
}
def test_errors_when_selecting_a_mutation_within_a_get(client):
response = client.get(
url_string(
query="""
query TestQuery { test }
mutation TestMutation { writeTest { test } }
""",
operationName="TestMutation",
)
)
assert response.status_code == 405
assert response_json(response) == {
"errors": [
{"message": "Can only perform a mutation operation from a POST request."}
]
}
def test_allows_mutation_to_exist_within_a_get(client):
response = client.get(
url_string(
query="""
query TestQuery { test }
mutation TestMutation { writeTest { test } }
""",
operationName="TestQuery",
)
)
assert response.status_code == 200
assert response_json(response) == {"data": {"test": "Hello World"}}
def test_allows_post_with_json_encoding(client):
response = client.post(url_string(), j(query="{test}"), "application/json")
assert response.status_code == 200
assert response_json(response) == {"data": {"test": "Hello World"}}
def test_batch_allows_post_with_json_encoding(client):
response = client.post(
batch_url_string(), jl(id=1, query="{test}"), "application/json"
)
assert response.status_code == 200
assert response_json(response) == [
{"id": 1, "data": {"test": "Hello World"}, "status": 200}
]
def test_batch_fails_if_is_empty(client):
response = client.post(batch_url_string(), "[]", "application/json")
assert response.status_code == 400
assert response_json(response) == {
"errors": [{"message": "Received an empty list in the batch request."}]
}
def test_allows_sending_a_mutation_via_post(client):
response = client.post(
url_string(),
j(query="mutation TestMutation { writeTest { test } }"),
"application/json",
)
assert response.status_code == 200
assert response_json(response) == {"data": {"writeTest": {"test": "Hello World"}}}
def test_allows_post_with_url_encoding(client):
response = client.post(
url_string(),
urlencode(dict(query="{test}")),
"application/x-www-form-urlencoded",
)
assert response.status_code == 200
assert response_json(response) == {"data": {"test": "Hello World"}}
def test_supports_post_json_query_with_string_variables(client):
response = client.post(
url_string(),
j(
query="query helloWho($who: String){ test(who: $who) }",
variables=json.dumps({"who": "Dolly"}),
),
"application/json",
)
assert response.status_code == 200
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
def test_batch_supports_post_json_query_with_string_variables(client):
response = client.post(
batch_url_string(),
jl(
id=1,
query="query helloWho($who: String){ test(who: $who) }",
variables=json.dumps({"who": "Dolly"}),
),
"application/json",
)
assert response.status_code == 200
assert response_json(response) == [
{"id": 1, "data": {"test": "Hello Dolly"}, "status": 200}
]
def test_supports_post_json_query_with_json_variables(client):
response = client.post(
url_string(),
j(
query="query helloWho($who: String){ test(who: $who) }",
variables={"who": "Dolly"},
),
"application/json",
)
assert response.status_code == 200
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
def test_batch_supports_post_json_query_with_json_variables(client):
response = client.post(
batch_url_string(),
jl(
id=1,
query="query helloWho($who: String){ test(who: $who) }",
variables={"who": "Dolly"},
),
"application/json",
)
assert response.status_code == 200
assert response_json(response) == [
{"id": 1, "data": {"test": "Hello Dolly"}, "status": 200}
]
def test_supports_post_url_encoded_query_with_string_variables(client):
response = client.post(
url_string(),
urlencode(
dict(
query="query helloWho($who: String){ test(who: $who) }",
variables=json.dumps({"who": "Dolly"}),
)
),
"application/x-www-form-urlencoded",
)
assert response.status_code == 200
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
def test_supports_post_json_quey_with_get_variable_values(client):
response = client.post(
url_string(variables=json.dumps({"who": "Dolly"})),
j(query="query helloWho($who: String){ test(who: $who) }"),
"application/json",
)
assert response.status_code == 200
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
def test_post_url_encoded_query_with_get_variable_values(client):
response = client.post(
url_string(variables=json.dumps({"who": "Dolly"})),
urlencode(dict(query="query helloWho($who: String){ test(who: $who) }")),
"application/x-www-form-urlencoded",
)
assert response.status_code == 200
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
def test_supports_post_raw_text_query_with_get_variable_values(client):
response = client.post(
url_string(variables=json.dumps({"who": "Dolly"})),
"query helloWho($who: String){ test(who: $who) }",
"application/graphql",
)
assert response.status_code == 200
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
def test_allows_post_with_operation_name(client):
response = client.post(
url_string(),
j(
query="""
query helloYou { test(who: "You"), ...shared }
query helloWorld { test(who: "World"), ...shared }
query helloDolly { test(who: "Dolly"), ...shared }
fragment shared on QueryRoot {
shared: test(who: "Everyone")
}
""",
operationName="helloWorld",
),
"application/json",
)
assert response.status_code == 200
assert response_json(response) == {
"data": {"test": "Hello World", "shared": "Hello Everyone"}
}
def test_batch_allows_post_with_operation_name(client):
response = client.post(
batch_url_string(),
jl(
id=1,
query="""
query helloYou { test(who: "You"), ...shared }
query helloWorld { test(who: "World"), ...shared }
query helloDolly { test(who: "Dolly"), ...shared }
fragment shared on QueryRoot {
shared: test(who: "Everyone")
}
""",
operationName="helloWorld",
),
"application/json",
)
assert response.status_code == 200
assert response_json(response) == [
{
"id": 1,
"data": {"test": "Hello World", "shared": "Hello Everyone"},
"status": 200,
}
]
def test_allows_post_with_get_operation_name(client):
response = client.post(
url_string(operationName="helloWorld"),
"""
query helloYou { test(who: "You"), ...shared }
query helloWorld { test(who: "World"), ...shared }
query helloDolly { test(who: "Dolly"), ...shared }
fragment shared on QueryRoot {
shared: test(who: "Everyone")
}
""",
"application/graphql",
)
assert response.status_code == 200
assert response_json(response) == {
"data": {"test": "Hello World", "shared": "Hello Everyone"}
}
@pytest.mark.urls("graphene_django.tests.urls_inherited")
def test_inherited_class_with_attributes_works(client):
inherited_url = "/graphql/inherited/"
# Check schema and pretty attributes work
response = client.post(url_string(inherited_url, query="{test}"))
assert response.content.decode() == (
"{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}"
)
# Check graphiql works
response = client.get(url_string(inherited_url), HTTP_ACCEPT="text/html")
assert response.status_code == 200
@pytest.mark.urls("graphene_django.tests.urls_pretty")
def test_supports_pretty_printing(client):
response = client.get(url_string(query="{test}"))
assert response.content.decode() == (
"{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}"
)
def test_supports_pretty_printing_by_request(client):
response = client.get(url_string(query="{test}", pretty="1"))
assert response.content.decode() == (
"{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}"
)
def test_handles_field_errors_caught_by_graphql(client):
response = client.get(url_string(query="{thrower}"))
assert response.status_code == 200
assert response_json(response) == {
"data": None,
"errors": [
{
"locations": [{"column": 2, "line": 1}],
"path": ["thrower"],
"message": "Throws!",
}
],
}
def test_handles_syntax_errors_caught_by_graphql(client):
response = client.get(url_string(query="syntaxerror"))
assert response.status_code == 400
assert response_json(response) == {
"errors": [
{
"locations": [{"column": 1, "line": 1}],
"message": "Syntax Error: Unexpected Name 'syntaxerror'.",
"path": None,
}
]
}
def test_handles_errors_caused_by_a_lack_of_query(client):
response = client.get(url_string())
assert response.status_code == 400
assert response_json(response) == {
"errors": [{"message": "Must provide query string."}]
}
def test_handles_not_expected_json_bodies(client):
response = client.post(url_string(), "[]", "application/json")
assert response.status_code == 400
assert response_json(response) == {
"errors": [{"message": "The received data is not a valid JSON query."}]
}
def test_handles_invalid_json_bodies(client):
response = client.post(url_string(), "[oh}", "application/json")
assert response.status_code == 400
assert response_json(response) == {
"errors": [{"message": "POST body sent invalid JSON."}]
}
def test_handles_django_request_error(client, monkeypatch):
def mocked_read(*args):
raise IOError("foo-bar")
monkeypatch.setattr("django.http.request.HttpRequest.read", mocked_read)
valid_json = json.dumps(dict(foo="bar"))
response = client.post(url_string(), valid_json, "application/json")
assert response.status_code == 400
assert response_json(response) == {"errors": [{"message": "foo-bar"}]}
def test_handles_incomplete_json_bodies(client):
response = client.post(url_string(), '{"query":', "application/json")
assert response.status_code == 400
assert response_json(response) == {
"errors": [{"message": "POST body sent invalid JSON."}]
}
def test_handles_plain_post_text(client):
response = client.post(
url_string(variables=json.dumps({"who": "Dolly"})),
"query helloWho($who: String){ test(who: $who) }",
"text/plain",
)
assert response.status_code == 400
assert response_json(response) == {
"errors": [{"message": "Must provide query string."}]
}
def test_handles_poorly_formed_variables(client):
response = client.get(
url_string(
query="query helloWho($who: String){ test(who: $who) }", variables="who:You"
)
)
assert response.status_code == 400
assert response_json(response) == {
"errors": [{"message": "Variables are invalid JSON."}]
}
def test_handles_unsupported_http_methods(client):
response = client.put(url_string(query="{test}"))
assert response.status_code == 405
assert response["Allow"] == "GET, POST"
assert response_json(response) == {
"errors": [{"message": "GraphQL only supports GET and POST requests."}]
}
def test_passes_request_into_context_request(client):
response = client.get(url_string(query="{request}", q="testing"))
assert response.status_code == 200
assert response_json(response) == {"data": {"request": "testing"}}
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", False)
@patch.dict(
connection.settings_dict, {"ATOMIC_MUTATIONS": False, "ATOMIC_REQUESTS": True}
)
def test_form_mutation_multiple_creation_invalid_atomic_request(client):
query = """
mutation PetMutations {
petFormMutation1: petFormMutation(input: { name: "Mia", age: 99 }) {
errors {
field
messages
}
}
petFormMutation2: petFormMutation(input: { name: "Enzo", age: 0 }) {
errors {
field
messages
}
}
}
"""
response = client.post(url_string(query=query))
content = response_json(response)
assert "errors" not in content
assert content["data"]["petFormMutation1"]["errors"] == [
{"field": "age", "messages": ["Too old"]}
]
assert content["data"]["petFormMutation2"]["errors"] == []
assert Pet.objects.count() == 0
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", False)
@patch.dict(
connection.settings_dict, {"ATOMIC_MUTATIONS": True, "ATOMIC_REQUESTS": False}
)
def test_form_mutation_multiple_creation_invalid_atomic_mutation_1(client):
query = """
mutation PetMutations {
petFormMutation1: petFormMutation(input: { name: "Mia", age: 99 }) {
errors {
field
messages
}
}
petFormMutation2: petFormMutation(input: { name: "Enzo", age: 0 }) {
errors {
field
messages
}
}
}
"""
response = client.post(url_string(query=query))
content = response_json(response)
assert "errors" not in content
assert content["data"]["petFormMutation1"]["errors"] == [
{"field": "age", "messages": ["Too old"]}
]
assert content["data"]["petFormMutation2"]["errors"] == []
assert Pet.objects.count() == 0
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", True)
@patch.dict(
connection.settings_dict, {"ATOMIC_MUTATIONS": False, "ATOMIC_REQUESTS": False}
)
def test_form_mutation_multiple_creation_invalid_atomic_mutation_2(client):
query = """
mutation PetMutations {
petFormMutation1: petFormMutation(input: { name: "Mia", age: 99 }) {
errors {
field
messages
}
}
petFormMutation2: petFormMutation(input: { name: "Enzo", age: 0 }) {
errors {
field
messages
}
}
}
"""
response = client.post(url_string(query=query))
content = response_json(response)
assert "errors" not in content
assert content["data"]["petFormMutation1"]["errors"] == [
{"field": "age", "messages": ["Too old"]}
]
assert content["data"]["petFormMutation2"]["errors"] == []
assert Pet.objects.count() == 0
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", False)
@patch.dict(
connection.settings_dict, {"ATOMIC_MUTATIONS": False, "ATOMIC_REQUESTS": False}
)
def test_form_mutation_multiple_creation_invalid_non_atomic(client):
query = """
mutation PetMutations {
petFormMutation1: petFormMutation(input: { name: "Mia", age: 99 }) {
errors {
field
messages
}
}
petFormMutation2: petFormMutation(input: { name: "Enzo", age: 0 }) {
errors {
field
messages
}
}
}
"""
response = client.post(url_string(query=query))
content = response_json(response)
assert "errors" not in content
assert content["data"]["petFormMutation1"]["errors"] == [
{"field": "age", "messages": ["Too old"]}
]
assert content["data"]["petFormMutation2"]["errors"] == []
assert Pet.objects.count() == 1
pet = Pet.objects.get()
assert pet.name == "Enzo"
assert pet.age == 0
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", False)
@patch.dict(
connection.settings_dict, {"ATOMIC_MUTATIONS": False, "ATOMIC_REQUESTS": True}
)
def test_model_form_mutation_multiple_creation_invalid_atomic_request(client):
query = """
mutation PetMutations {
petMutation1: petMutation(input: { name: "Mia", age: 99 }) {
pet {
name
age
}
errors {
field
messages
}
}
petMutation2: petMutation(input: { name: "Enzo", age: 0 }) {
pet {
name
age
}
errors {
field
messages
}
}
}
"""
response = client.post(url_string(query=query))
content = response_json(response)
assert "errors" not in content
assert content["data"]["petMutation1"]["pet"] is None
assert content["data"]["petMutation1"]["errors"] == [
{"field": "age", "messages": ["Too old"]}
]
assert content["data"]["petMutation2"]["pet"] == {"name": "Enzo", "age": 0}
assert Pet.objects.count() == 0
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", False)
@patch.dict(
connection.settings_dict, {"ATOMIC_MUTATIONS": False, "ATOMIC_REQUESTS": False}
)
def test_model_form_mutation_multiple_creation_invalid_non_atomic(client):
query = """
mutation PetMutations {
petMutation1: petMutation(input: { name: "Mia", age: 99 }) {
pet {
name
age
}
errors {
field
messages
}
}
petMutation2: petMutation(input: { name: "Enzo", age: 0 }) {
pet {
name
age
}
errors {
field
messages
}
}
}
"""
response = client.post(url_string(query=query))
content = response_json(response)
assert "errors" not in content
assert content["data"]["petMutation1"]["pet"] is None
assert content["data"]["petMutation1"]["errors"] == [
{"field": "age", "messages": ["Too old"]}
]
assert content["data"]["petMutation2"]["pet"] == {"name": "Enzo", "age": 0}
assert Pet.objects.count() == 1
pet = Pet.objects.get()
assert pet.name == "Enzo"
assert pet.age == 0
@patch("graphene_django.utils.utils.transaction.set_rollback")
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", False)
@patch.dict(
connection.settings_dict, {"ATOMIC_MUTATIONS": False, "ATOMIC_REQUESTS": True}
)
def test_query_errors_atomic_request(set_rollback_mock, client):
client.get(url_string(query="force error"))
set_rollback_mock.assert_called_once_with(True)
@patch("graphene_django.utils.utils.transaction.set_rollback")
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", False)
@patch.dict(
connection.settings_dict, {"ATOMIC_MUTATIONS": False, "ATOMIC_REQUESTS": False}
)
def test_query_errors_non_atomic(set_rollback_mock, client):
client.get(url_string(query="force error"))
set_rollback_mock.assert_not_called()

View File

@ -1,9 +0,0 @@
from graphene_django.types import DjangoObjectType
from .models import Pet
class PetType(DjangoObjectType):
class Meta:
model = Pet
fields = "__all__"

View File

@ -1,8 +0,0 @@
from django.conf.urls import url
from ..views import GraphQLView
urlpatterns = [
url(r"^graphql/batch", GraphQLView.as_view(batch=True)),
url(r"^graphql", GraphQLView.as_view(graphiql=True)),
]

View File

@ -1,13 +0,0 @@
from django.conf.urls import url
from ..views import GraphQLView
from .schema_view import schema
class CustomGraphQLView(GraphQLView):
schema = schema
graphiql = True
pretty = True
urlpatterns = [url(r"^graphql/inherited/$", CustomGraphQLView.as_view())]

View File

@ -1,6 +0,0 @@
from django.conf.urls import url
from ..views import GraphQLView
from .schema_view import schema
urlpatterns = [url(r"^graphql", GraphQLView.as_view(schema=schema, pretty=True))]

View File

@ -1,305 +0,0 @@
import warnings
from collections import OrderedDict
from typing import Type
import graphene
from django.db.models import Model
from graphene.relay import Connection, Node
from graphene.types.objecttype import ObjectType, ObjectTypeOptions
from graphene.types.utils import yank_fields_from_attrs
from .converter import convert_django_field_with_choices
from .registry import Registry, get_global_registry
from .settings import graphene_settings
from .utils import (
DJANGO_FILTER_INSTALLED,
camelize,
get_model_fields,
is_valid_django_model,
)
ALL_FIELDS = "__all__"
def construct_fields(
model, registry, only_fields, exclude_fields, convert_choices_to_enum
):
_model_fields = get_model_fields(model)
fields = OrderedDict()
for name, field in _model_fields:
is_not_in_only = (
only_fields is not None
and only_fields != ALL_FIELDS
and name not in only_fields
)
# is_already_created = name in options.fields
is_excluded = (
exclude_fields is not None and name in exclude_fields
) # or is_already_created
# https://docs.djangoproject.com/en/1.10/ref/models/fields/#django.db.models.ForeignKey.related_query_name
is_no_backref = str(name).endswith("+")
if is_not_in_only or is_excluded or is_no_backref:
# We skip this field if we specify only_fields and is not
# in there. Or when we exclude this field in exclude_fields.
# Or when there is no back reference.
continue
_convert_choices_to_enum = convert_choices_to_enum
if not isinstance(_convert_choices_to_enum, bool):
# then `convert_choices_to_enum` is a list of field names to convert
if name in _convert_choices_to_enum:
_convert_choices_to_enum = True
else:
_convert_choices_to_enum = False
converted = convert_django_field_with_choices(
field, registry, convert_choices_to_enum=_convert_choices_to_enum
)
fields[name] = converted
return fields
def validate_fields(type_, model, fields, only_fields, exclude_fields):
# Validate the given fields against the model's fields and custom fields
all_field_names = set(fields.keys())
only_fields = only_fields if only_fields is not ALL_FIELDS else ()
for name in only_fields or ():
if name in all_field_names:
continue
if hasattr(model, name):
warnings.warn(
(
'Field name "{field_name}" matches an attribute on Django model "{app_label}.{object_name}" '
"but it's not a model field so Graphene cannot determine what type it should be. "
'Either define the type of the field on DjangoObjectType "{type_}" or remove it from the "fields" list.'
).format(
field_name=name,
app_label=model._meta.app_label,
object_name=model._meta.object_name,
type_=type_,
)
)
else:
warnings.warn(
(
'Field name "{field_name}" doesn\'t exist on Django model "{app_label}.{object_name}". '
'Consider removing the field from the "fields" list of DjangoObjectType "{type_}" because it has no effect.'
).format(
field_name=name,
app_label=model._meta.app_label,
object_name=model._meta.object_name,
type_=type_,
)
)
# Validate exclude fields
for name in exclude_fields or ():
if name in all_field_names:
# Field is a custom field
warnings.warn(
(
'Excluding the custom field "{field_name}" on DjangoObjectType "{type_}" has no effect. '
'Either remove the custom field or remove the field from the "exclude" list.'
).format(field_name=name, type_=type_)
)
else:
if not hasattr(model, name):
warnings.warn(
(
'Django model "{app_label}.{object_name}" does not have a field or attribute named "{field_name}". '
'Consider removing the field from the "exclude" list of DjangoObjectType "{type_}" because it has no effect'
).format(
field_name=name,
app_label=model._meta.app_label,
object_name=model._meta.object_name,
type_=type_,
)
)
class DjangoObjectTypeOptions(ObjectTypeOptions):
model = None # type: Model
registry = None # type: Registry
connection = None # type: Type[Connection]
filter_fields = ()
filterset_class = None
class DjangoObjectType(ObjectType):
@classmethod
def __init_subclass_with_meta__(
cls,
model=None,
registry=None,
skip_registry=False,
only_fields=None, # deprecated in favour of `fields`
fields=None,
exclude_fields=None, # deprecated in favour of `exclude`
exclude=None,
filter_fields=None,
filterset_class=None,
connection=None,
connection_class=None,
use_connection=None,
interfaces=(),
convert_choices_to_enum=True,
_meta=None,
**options
):
assert is_valid_django_model(model), (
'You need to pass a valid Django Model in {}.Meta, received "{}".'
).format(cls.__name__, model)
if not registry:
registry = get_global_registry()
assert isinstance(registry, Registry), (
"The attribute registry in {} needs to be an instance of "
'Registry, received "{}".'
).format(cls.__name__, registry)
if filter_fields and filterset_class:
raise Exception("Can't set both filter_fields and filterset_class")
if not DJANGO_FILTER_INSTALLED and (filter_fields or filterset_class):
raise Exception(
(
"Can only set filter_fields or filterset_class if "
"Django-Filter is installed"
)
)
assert not (fields and exclude), (
"Cannot set both 'fields' and 'exclude' options on "
"DjangoObjectType {class_name}.".format(class_name=cls.__name__)
)
# Alias only_fields -> fields
if only_fields and fields:
raise Exception("Can't set both only_fields and fields")
if only_fields:
warnings.warn(
"Defining `only_fields` is deprecated in favour of `fields`.",
DeprecationWarning,
stacklevel=2,
)
fields = only_fields
if fields and fields != ALL_FIELDS and not isinstance(fields, (list, tuple)):
raise TypeError(
'The `fields` option must be a list or tuple or "__all__". '
"Got %s." % type(fields).__name__
)
# Alias exclude_fields -> exclude
if exclude_fields and exclude:
raise Exception("Can't set both exclude_fields and exclude")
if exclude_fields:
warnings.warn(
"Defining `exclude_fields` is deprecated in favour of `exclude`.",
DeprecationWarning,
stacklevel=2,
)
exclude = exclude_fields
if exclude and not isinstance(exclude, (list, tuple)):
raise TypeError(
"The `exclude` option must be a list or tuple. Got %s."
% type(exclude).__name__
)
if fields is None and exclude is None:
warnings.warn(
"Creating a DjangoObjectType without either the `fields` "
"or the `exclude` option is deprecated. Add an explicit `fields "
"= '__all__'` option on DjangoObjectType {class_name} to use all "
"fields".format(class_name=cls.__name__,),
DeprecationWarning,
stacklevel=2,
)
django_fields = yank_fields_from_attrs(
construct_fields(model, registry, fields, exclude, convert_choices_to_enum),
_as=graphene.Field,
)
if use_connection is None and interfaces:
use_connection = any(
(issubclass(interface, Node) for interface in interfaces)
)
if use_connection and not connection:
# We create the connection automatically
if not connection_class:
connection_class = Connection
connection = connection_class.create_type(
"{}Connection".format(options.get("name") or cls.__name__), node=cls
)
if connection is not None:
assert issubclass(connection, Connection), (
"The connection must be a Connection. Received {}"
).format(connection.__name__)
if not _meta:
_meta = DjangoObjectTypeOptions(cls)
_meta.model = model
_meta.registry = registry
_meta.filter_fields = filter_fields
_meta.filterset_class = filterset_class
_meta.fields = django_fields
_meta.connection = connection
super(DjangoObjectType, cls).__init_subclass_with_meta__(
_meta=_meta, interfaces=interfaces, **options
)
# Validate fields
validate_fields(cls, model, _meta.fields, fields, exclude)
if not skip_registry:
registry.register(cls)
def resolve_id(self, info):
return self.pk
@classmethod
def is_type_of(cls, root, info):
if isinstance(root, cls):
return True
if not is_valid_django_model(root.__class__):
raise Exception(('Received incompatible instance "{}".').format(root))
if cls._meta.model._meta.proxy:
model = root._meta.model
else:
model = root._meta.model._meta.concrete_model
return model == cls._meta.model
@classmethod
def get_queryset(cls, queryset, info):
return queryset
@classmethod
def get_node(cls, info, id):
queryset = cls.get_queryset(cls._meta.model.objects, info)
try:
return queryset.get(pk=id)
except cls._meta.model.DoesNotExist:
return None
class ErrorType(ObjectType):
field = graphene.String(required=True)
messages = graphene.List(graphene.NonNull(graphene.String), required=True)
@classmethod
def from_errors(cls, errors):
data = camelize(errors) if graphene_settings.CAMELCASE_ERRORS else errors
return [cls(field=key, messages=value) for key, value in data.items()]

View File

@ -1,19 +0,0 @@
from .testing import GraphQLTestCase
from .utils import (
DJANGO_FILTER_INSTALLED,
camelize,
get_model_fields,
get_reverse_fields,
is_valid_django_model,
maybe_queryset,
)
__all__ = [
"DJANGO_FILTER_INSTALLED",
"get_reverse_fields",
"maybe_queryset",
"get_model_fields",
"camelize",
"is_valid_django_model",
"GraphQLTestCase",
]

View File

@ -1,6 +0,0 @@
import re
from text_unidecode import unidecode
def to_const(string):
return re.sub(r"[\W|^]+", "_", unidecode(string)).upper()

View File

@ -1,153 +0,0 @@
import json
import warnings
from django.test import Client, TestCase, TransactionTestCase
DEFAULT_GRAPHQL_URL = "/graphql/"
def graphql_query(
query,
operation_name=None,
input_data=None,
variables=None,
headers=None,
client=None,
graphql_url=None,
):
"""
Args:
query (string) - GraphQL query to run
operation_name (string) - If the query is a mutation or named query, you must
supply the op_name. For annon queries ("{ ... }"),
should be None (default).
input_data (dict) - If provided, the $input variable in GraphQL will be set
to this value. If both ``input_data`` and ``variables``,
are provided, the ``input`` field in the ``variables``
dict will be overwritten with this value.
variables (dict) - If provided, the "variables" field in GraphQL will be
set to this value.
headers (dict) - If provided, the headers in POST request to GRAPHQL_URL
will be set to this value. Keys should be prepended with
"HTTP_" (e.g. to specify the "Authorization" HTTP header,
use "HTTP_AUTHORIZATION" as the key).
client (django.test.Client) - Test client. Defaults to django.test.Client.
graphql_url (string) - URL to graphql endpoint. Defaults to "/graphql".
Returns:
Response object from client
"""
if client is None:
client = Client()
if not graphql_url:
graphql_url = DEFAULT_GRAPHQL_URL
body = {"query": query}
if operation_name:
body["operationName"] = operation_name
if variables:
body["variables"] = variables
if input_data:
if "variables" in body:
body["variables"]["input"] = input_data
else:
body["variables"] = {"input": input_data}
if headers:
resp = client.post(
graphql_url, json.dumps(body), content_type="application/json", **headers
)
else:
resp = client.post(
graphql_url, json.dumps(body), content_type="application/json"
)
return resp
class GraphQLTestMixin(object):
"""
Based on: https://www.sam.today/blog/testing-graphql-with-graphene-django/
"""
# URL to graphql endpoint
GRAPHQL_URL = DEFAULT_GRAPHQL_URL
def query(
self, query, operation_name=None, input_data=None, variables=None, headers=None
):
"""
Args:
query (string) - GraphQL query to run
operation_name (string) - If the query is a mutation or named query, you must
supply the op_name. For annon queries ("{ ... }"),
should be None (default).
input_data (dict) - If provided, the $input variable in GraphQL will be set
to this value. If both ``input_data`` and ``variables``,
are provided, the ``input`` field in the ``variables``
dict will be overwritten with this value.
variables (dict) - If provided, the "variables" field in GraphQL will be
set to this value.
headers (dict) - If provided, the headers in POST request to GRAPHQL_URL
will be set to this value. Keys should be prepended with
"HTTP_" (e.g. to specify the "Authorization" HTTP header,
use "HTTP_AUTHORIZATION" as the key).
Returns:
Response object from client
"""
return graphql_query(
query,
operation_name=operation_name,
input_data=input_data,
variables=variables,
headers=headers,
client=self.client,
graphql_url=self.GRAPHQL_URL,
)
@property
def _client(self):
pass
@_client.getter
def _client(self):
warnings.warn(
"Using `_client` is deprecated in favour of `client`.",
PendingDeprecationWarning,
stacklevel=2,
)
return self.client
@_client.setter
def _client(self, client):
warnings.warn(
"Using `_client` is deprecated in favour of `client`.",
PendingDeprecationWarning,
stacklevel=2,
)
self.client = client
def assertResponseNoErrors(self, resp, msg=None):
"""
Assert that the call went through correctly. 200 means the syntax is ok, if there are no `errors`,
the call was fine.
:resp HttpResponse: Response
"""
content = json.loads(resp.content)
self.assertEqual(resp.status_code, 200, msg or content)
self.assertNotIn("errors", list(content.keys()), msg or content)
def assertResponseHasErrors(self, resp, msg=None):
"""
Assert that the call was failing. Take care: Even with errors, GraphQL returns status 200!
:resp HttpResponse: Response
"""
content = json.loads(resp.content)
self.assertIn("errors", list(content.keys()), msg or content)
class GraphQLTestCase(GraphQLTestMixin, TestCase):
pass
class GraphQLTransactionTestCase(GraphQLTestMixin, TransactionTestCase):
pass

View File

@ -1,9 +0,0 @@
from ..str_converters import to_const
def test_to_const():
assert to_const('snakes $1. on a "#plane') == "SNAKES_1_ON_A_PLANE"
def test_to_const_unicode():
assert to_const(u"Skoða þetta unicode stöff") == "SKODA_THETTA_UNICODE_STOFF"

View File

@ -1,45 +0,0 @@
import pytest
from .. import GraphQLTestCase
from ...tests.test_types import with_local_registry
from django.test import Client
@with_local_registry
def test_graphql_test_case_deprecated_client_getter():
"""
`GraphQLTestCase._client`' getter should raise pending deprecation warning.
"""
class TestClass(GraphQLTestCase):
GRAPHQL_SCHEMA = True
def runTest(self):
pass
tc = TestClass()
tc._pre_setup()
tc.setUpClass()
with pytest.warns(PendingDeprecationWarning):
tc._client
@with_local_registry
def test_graphql_test_case_deprecated_client_setter():
"""
`GraphQLTestCase._client`' setter should raise pending deprecation warning.
"""
class TestClass(GraphQLTestCase):
GRAPHQL_SCHEMA = True
def runTest(self):
pass
tc = TestClass()
tc._pre_setup()
tc.setUpClass()
with pytest.warns(PendingDeprecationWarning):
tc._client = Client()

View File

@ -1,107 +0,0 @@
import inspect
from django.db import connection, models, transaction
from django.db.models.manager import Manager
from django.utils.encoding import force_str
from django.utils.functional import Promise
from graphene.utils.str_converters import to_camel_case
try:
import django_filters # noqa
DJANGO_FILTER_INSTALLED = True
except ImportError:
DJANGO_FILTER_INSTALLED = False
def isiterable(value):
try:
iter(value)
except TypeError:
return False
return True
def _camelize_django_str(s):
if isinstance(s, Promise):
s = force_str(s)
return to_camel_case(s) if isinstance(s, str) else s
def camelize(data):
if isinstance(data, dict):
return {_camelize_django_str(k): camelize(v) for k, v in data.items()}
if isiterable(data) and not isinstance(data, (str, Promise)):
return [camelize(d) for d in data]
return data
def get_reverse_fields(model, local_field_names):
for name, attr in model.__dict__.items():
# Don't duplicate any local fields
if name in local_field_names:
continue
# "rel" for FK and M2M relations and "related" for O2O Relations
related = getattr(attr, "rel", None) or getattr(attr, "related", None)
if isinstance(related, models.ManyToOneRel):
yield (name, related)
elif isinstance(related, models.ManyToManyRel) and not related.symmetrical:
yield (name, related)
def maybe_queryset(value):
if isinstance(value, Manager):
value = value.get_queryset()
return value
def get_model_fields(model):
local_fields = [
(field.name, field)
for field in sorted(
list(model._meta.fields) + list(model._meta.local_many_to_many)
)
]
# Make sure we don't duplicate local fields with "reverse" version
local_field_names = [field[0] for field in local_fields]
reverse_fields = get_reverse_fields(model, local_field_names)
all_fields = local_fields + list(reverse_fields)
return all_fields
def is_valid_django_model(model):
return inspect.isclass(model) and issubclass(model, models.Model)
def import_single_dispatch():
try:
from functools import singledispatch
except ImportError:
singledispatch = None
if not singledispatch:
try:
from singledispatch import singledispatch
except ImportError:
pass
if not singledispatch:
raise Exception(
"It seems your python version does not include "
"functools.singledispatch. Please install the 'singledispatch' "
"package. More information here: "
"https://pypi.python.org/pypi/singledispatch"
)
return singledispatch
def set_rollback():
atomic_requests = connection.settings_dict.get("ATOMIC_REQUESTS", False)
if atomic_requests and connection.in_atomic_block:
transaction.set_rollback(True)

View File

@ -1,398 +0,0 @@
import inspect
import json
import re
from django.db import connection, transaction
from django.http import HttpResponse, HttpResponseNotAllowed
from django.http.response import HttpResponseBadRequest
from django.shortcuts import render
from django.utils.decorators import method_decorator
from django.views.decorators.csrf import ensure_csrf_cookie
from django.views.generic import View
from graphql import OperationType, get_operation_ast, parse, validate
from graphql.error import GraphQLError
from graphql.error import format_error as format_graphql_error
from graphql.execution import ExecutionResult
from graphene import Schema
from graphql.execution.middleware import MiddlewareManager
from graphene_django.constants import MUTATION_ERRORS_FLAG
from graphene_django.utils.utils import set_rollback
from .settings import graphene_settings
class HttpError(Exception):
def __init__(self, response, message=None, *args, **kwargs):
self.response = response
self.message = message = message or response.content.decode()
super(HttpError, self).__init__(message, *args, **kwargs)
def get_accepted_content_types(request):
def qualify(x):
parts = x.split(";", 1)
if len(parts) == 2:
match = re.match(r"(^|;)q=(0(\.\d{,3})?|1(\.0{,3})?)(;|$)", parts[1])
if match:
return parts[0].strip(), float(match.group(2))
return parts[0].strip(), 1
raw_content_types = request.META.get("HTTP_ACCEPT", "*/*").split(",")
qualified_content_types = map(qualify, raw_content_types)
return list(
x[0] for x in sorted(qualified_content_types, key=lambda x: x[1], reverse=True)
)
def instantiate_middleware(middlewares):
for middleware in middlewares:
if inspect.isclass(middleware):
yield middleware()
continue
yield middleware
class GraphQLView(View):
graphiql_template = "graphene/graphiql.html"
# Polyfill for window.fetch.
whatwg_fetch_version = "3.6.2"
whatwg_fetch_sri = "sha256-+pQdxwAcHJdQ3e/9S4RK6g8ZkwdMgFQuHvLuN5uyk5c="
# React and ReactDOM.
react_version = "17.0.2"
react_sri = "sha256-Ipu/TQ50iCCVZBUsZyNJfxrDk0E2yhaEIz0vqI+kFG8="
react_dom_sri = "sha256-nbMykgB6tsOFJ7OdVmPpdqMFVk4ZsqWocT6issAPUF0="
# The GraphiQL React app.
graphiql_version = "1.4.1" # "1.0.3"
graphiql_sri = "sha256-JUMkXBQWZMfJ7fGEsTXalxVA10lzKOS9loXdLjwZKi4=" # "sha256-VR4buIDY9ZXSyCNFHFNik6uSe0MhigCzgN4u7moCOTk="
graphiql_css_sri = "sha256-Md3vdR7PDzWyo/aGfsFVF4tvS5/eAUWuIsg9QHUusCY=" # "sha256-LwqxjyZgqXDYbpxQJ5zLQeNcf7WVNSJ+r8yp2rnWE/E="
# The websocket transport library for subscriptions.
subscriptions_transport_ws_version = "0.9.18"
subscriptions_transport_ws_sri = (
"sha256-i0hAXd4PdJ/cHX3/8tIy/Q/qKiWr5WSTxMFuL9tACkw="
)
schema = None
graphiql = False
middleware = None
root_value = None
pretty = False
batch = False
subscription_path = None
execution_context_class = None
def __init__(
self,
schema=None,
middleware=None,
root_value=None,
graphiql=False,
pretty=False,
batch=False,
subscription_path=None,
execution_context_class=None,
):
if not schema:
schema = graphene_settings.SCHEMA
if middleware is None:
middleware = graphene_settings.MIDDLEWARE
self.schema = self.schema or schema
if middleware is not None:
if isinstance(middleware, MiddlewareManager):
self.middleware = middleware
else:
self.middleware = list(instantiate_middleware(middleware))
self.root_value = root_value
self.pretty = self.pretty or pretty
self.graphiql = self.graphiql or graphiql
self.batch = self.batch or batch
self.execution_context_class = execution_context_class
if subscription_path is None:
self.subscription_path = graphene_settings.SUBSCRIPTION_PATH
assert isinstance(
self.schema, Schema
), "A Schema is required to be provided to GraphQLView."
assert not all((graphiql, batch)), "Use either graphiql or batch processing"
# noinspection PyUnusedLocal
def get_root_value(self, request):
return self.root_value
def get_middleware(self, request):
return self.middleware
def get_context(self, request):
return request
@method_decorator(ensure_csrf_cookie)
def dispatch(self, request, *args, **kwargs):
try:
if request.method.lower() not in ("get", "post"):
raise HttpError(
HttpResponseNotAllowed(
["GET", "POST"], "GraphQL only supports GET and POST requests."
)
)
data = self.parse_body(request)
show_graphiql = self.graphiql and self.can_display_graphiql(request, data)
if show_graphiql:
return self.render_graphiql(
request,
# Dependency parameters.
whatwg_fetch_version=self.whatwg_fetch_version,
whatwg_fetch_sri=self.whatwg_fetch_sri,
react_version=self.react_version,
react_sri=self.react_sri,
react_dom_sri=self.react_dom_sri,
graphiql_version=self.graphiql_version,
graphiql_sri=self.graphiql_sri,
graphiql_css_sri=self.graphiql_css_sri,
subscriptions_transport_ws_version=self.subscriptions_transport_ws_version,
subscriptions_transport_ws_sri=self.subscriptions_transport_ws_sri,
# The SUBSCRIPTION_PATH setting.
subscription_path=self.subscription_path,
# GraphiQL headers tab,
graphiql_header_editor_enabled=graphene_settings.GRAPHIQL_HEADER_EDITOR_ENABLED,
)
if self.batch:
responses = [self.get_response(request, entry) for entry in data]
result = "[{}]".format(
",".join([response[0] for response in responses])
)
status_code = (
responses
and max(responses, key=lambda response: response[1])[1]
or 200
)
else:
result, status_code = self.get_response(request, data, show_graphiql)
return HttpResponse(
status=status_code, content=result, content_type="application/json"
)
except HttpError as e:
response = e.response
response["Content-Type"] = "application/json"
response.content = self.json_encode(
request, {"errors": [self.format_error(e)]}
)
return response
def get_response(self, request, data, show_graphiql=False):
query, variables, operation_name, id = self.get_graphql_params(request, data)
execution_result = self.execute_graphql_request(
request, data, query, variables, operation_name, show_graphiql
)
if getattr(request, MUTATION_ERRORS_FLAG, False) is True:
set_rollback()
status_code = 200
if execution_result:
response = {}
if execution_result.errors:
set_rollback()
response["errors"] = [
self.format_error(e) for e in execution_result.errors
]
if execution_result.errors and any(
not getattr(e, "path", None) for e in execution_result.errors
):
status_code = 400
else:
response["data"] = execution_result.data
if self.batch:
response["id"] = id
response["status"] = status_code
result = self.json_encode(request, response, pretty=show_graphiql)
else:
result = None
return result, status_code
def render_graphiql(self, request, **data):
return render(request, self.graphiql_template, data)
def json_encode(self, request, d, pretty=False):
if not (self.pretty or pretty) and not request.GET.get("pretty"):
return json.dumps(d, separators=(",", ":"))
return json.dumps(d, sort_keys=True, indent=2, separators=(",", ": "))
def parse_body(self, request):
content_type = self.get_content_type(request)
if content_type == "application/graphql":
return {"query": request.body.decode()}
elif content_type == "application/json":
# noinspection PyBroadException
try:
body = request.body.decode("utf-8")
except Exception as e:
raise HttpError(HttpResponseBadRequest(str(e)))
try:
request_json = json.loads(body)
if self.batch:
assert isinstance(request_json, list), (
"Batch requests should receive a list, but received {}."
).format(repr(request_json))
assert (
len(request_json) > 0
), "Received an empty list in the batch request."
else:
assert isinstance(
request_json, dict
), "The received data is not a valid JSON query."
return request_json
except AssertionError as e:
raise HttpError(HttpResponseBadRequest(str(e)))
except (TypeError, ValueError):
raise HttpError(HttpResponseBadRequest("POST body sent invalid JSON."))
elif content_type in [
"application/x-www-form-urlencoded",
"multipart/form-data",
]:
return request.POST
return {}
def execute_graphql_request(
self, request, data, query, variables, operation_name, show_graphiql=False
):
if not query:
if show_graphiql:
return None
raise HttpError(HttpResponseBadRequest("Must provide query string."))
try:
document = parse(query)
except Exception as e:
return ExecutionResult(errors=[e])
if request.method.lower() == "get":
operation_ast = get_operation_ast(document, operation_name)
if operation_ast and operation_ast.operation != OperationType.QUERY:
if show_graphiql:
return None
raise HttpError(
HttpResponseNotAllowed(
["POST"],
"Can only perform a {} operation from a POST request.".format(
operation_ast.operation.value
),
)
)
validation_errors = validate(self.schema.graphql_schema, document)
if validation_errors:
return ExecutionResult(data=None, errors=validation_errors)
try:
extra_options = {}
if self.execution_context_class:
extra_options["execution_context_class"] = self.execution_context_class
options = {
"source": query,
"root_value": self.get_root_value(request),
"variable_values": variables,
"operation_name": operation_name,
"context_value": self.get_context(request),
"middleware": self.get_middleware(request),
}
options.update(extra_options)
operation_ast = get_operation_ast(document, operation_name)
if (
operation_ast
and operation_ast.operation == OperationType.MUTATION
and (
graphene_settings.ATOMIC_MUTATIONS is True
or connection.settings_dict.get("ATOMIC_MUTATIONS", False) is True
)
):
with transaction.atomic():
result = self.schema.execute(**options)
if getattr(request, MUTATION_ERRORS_FLAG, False) is True:
transaction.set_rollback(True)
return result
return self.schema.execute(**options)
except Exception as e:
return ExecutionResult(errors=[e])
@classmethod
def can_display_graphiql(cls, request, data):
raw = "raw" in request.GET or "raw" in data
return not raw and cls.request_wants_html(request)
@classmethod
def request_wants_html(cls, request):
accepted = get_accepted_content_types(request)
accepted_length = len(accepted)
# the list will be ordered in preferred first - so we have to make
# sure the most preferred gets the highest number
html_priority = (
accepted_length - accepted.index("text/html")
if "text/html" in accepted
else 0
)
json_priority = (
accepted_length - accepted.index("application/json")
if "application/json" in accepted
else 0
)
return html_priority > json_priority
@staticmethod
def get_graphql_params(request, data):
query = request.GET.get("query") or data.get("query")
variables = request.GET.get("variables") or data.get("variables")
id = request.GET.get("id") or data.get("id")
if variables and isinstance(variables, str):
try:
variables = json.loads(variables)
except Exception:
raise HttpError(HttpResponseBadRequest("Variables are invalid JSON."))
operation_name = request.GET.get("operationName") or data.get("operationName")
if operation_name == "null":
operation_name = None
return query, variables, operation_name, id
@staticmethod
def format_error(error):
if isinstance(error, GraphQLError):
return format_graphql_error(error)
return {"message": str(error)}
@staticmethod
def get_content_type(request):
meta = request.META
content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", ""))
return content_type.split(";", 1)[0].lower()