mirror of
				https://github.com/graphql-python/graphene-django.git
				synced 2025-11-04 09:57:53 +03:00 
			
		
		
		
	Reformatted files using black
This commit is contained in:
		
							parent
							
								
									96789b291f
								
							
						
					
					
						commit
						54ef52e1c6
					
				| 
						 | 
				
			
			@ -1,14 +1,6 @@
 | 
			
		|||
from .types import (
 | 
			
		||||
    DjangoObjectType,
 | 
			
		||||
)
 | 
			
		||||
from .fields import (
 | 
			
		||||
    DjangoConnectionField,
 | 
			
		||||
)
 | 
			
		||||
from .types import DjangoObjectType
 | 
			
		||||
from .fields import DjangoConnectionField
 | 
			
		||||
 | 
			
		||||
__version__ = '2.1rc1'
 | 
			
		||||
__version__ = "2.1rc1"
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    '__version__',
 | 
			
		||||
    'DjangoObjectType',
 | 
			
		||||
    'DjangoConnectionField'
 | 
			
		||||
]
 | 
			
		||||
__all__ = ["__version__", "DjangoObjectType", "DjangoConnectionField"]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -7,7 +7,7 @@ try:
 | 
			
		|||
    # and we cannot have psycopg2 on PyPy
 | 
			
		||||
    from django.contrib.postgres.fields import ArrayField, HStoreField, RangeField
 | 
			
		||||
except ImportError:
 | 
			
		||||
    ArrayField, HStoreField, JSONField, RangeField = (MissingType, ) * 4
 | 
			
		||||
    ArrayField, HStoreField, JSONField, RangeField = (MissingType,) * 4
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,8 +1,22 @@
 | 
			
		|||
from django.db import models
 | 
			
		||||
from django.utils.encoding import force_text
 | 
			
		||||
 | 
			
		||||
from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List,
 | 
			
		||||
                      NonNull, String, UUID, DateTime, Date, Time)
 | 
			
		||||
from graphene import (
 | 
			
		||||
    ID,
 | 
			
		||||
    Boolean,
 | 
			
		||||
    Dynamic,
 | 
			
		||||
    Enum,
 | 
			
		||||
    Field,
 | 
			
		||||
    Float,
 | 
			
		||||
    Int,
 | 
			
		||||
    List,
 | 
			
		||||
    NonNull,
 | 
			
		||||
    String,
 | 
			
		||||
    UUID,
 | 
			
		||||
    DateTime,
 | 
			
		||||
    Date,
 | 
			
		||||
    Time,
 | 
			
		||||
)
 | 
			
		||||
from graphene.types.json import JSONString
 | 
			
		||||
from graphene.utils.str_converters import to_camel_case, to_const
 | 
			
		||||
from graphql import assert_valid_name
 | 
			
		||||
| 
						 | 
				
			
			@ -32,7 +46,7 @@ def get_choices(choices):
 | 
			
		|||
        else:
 | 
			
		||||
            name = convert_choice_name(value)
 | 
			
		||||
            while name in converted_names:
 | 
			
		||||
                name += '_' + str(len(converted_names))
 | 
			
		||||
                name += "_" + str(len(converted_names))
 | 
			
		||||
            converted_names.append(name)
 | 
			
		||||
            description = help_text
 | 
			
		||||
            yield name, value, description
 | 
			
		||||
| 
						 | 
				
			
			@ -43,16 +57,15 @@ def convert_django_field_with_choices(field, registry=None):
 | 
			
		|||
        converted = registry.get_converted_field(field)
 | 
			
		||||
        if converted:
 | 
			
		||||
            return converted
 | 
			
		||||
    choices = getattr(field, 'choices', None)
 | 
			
		||||
    choices = getattr(field, "choices", None)
 | 
			
		||||
    if choices:
 | 
			
		||||
        meta = field.model._meta
 | 
			
		||||
        name = to_camel_case('{}_{}'.format(meta.object_name, field.name))
 | 
			
		||||
        name = to_camel_case("{}_{}".format(meta.object_name, field.name))
 | 
			
		||||
        choices = list(get_choices(choices))
 | 
			
		||||
        named_choices = [(c[0], c[1]) for c in choices]
 | 
			
		||||
        named_choices_descriptions = {c[0]: c[2] for c in choices}
 | 
			
		||||
 | 
			
		||||
        class EnumWithDescriptionsType(object):
 | 
			
		||||
 | 
			
		||||
            @property
 | 
			
		||||
            def description(self):
 | 
			
		||||
                return named_choices_descriptions[self.name]
 | 
			
		||||
| 
						 | 
				
			
			@ -69,8 +82,8 @@ def convert_django_field_with_choices(field, registry=None):
 | 
			
		|||
@singledispatch
 | 
			
		||||
def convert_django_field(field, registry=None):
 | 
			
		||||
    raise Exception(
 | 
			
		||||
        "Don't know how to convert the Django field %s (%s)" %
 | 
			
		||||
        (field, field.__class__))
 | 
			
		||||
        "Don't know how to convert the Django field %s (%s)" % (field, field.__class__)
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@convert_django_field.register(models.CharField)
 | 
			
		||||
| 
						 | 
				
			
			@ -147,7 +160,7 @@ def convert_onetoone_field_to_djangomodel(field, registry=None):
 | 
			
		|||
 | 
			
		||||
        # We do this for a bug in Django 1.8, where null attr
 | 
			
		||||
        # is not available in the OneToOneRel instance
 | 
			
		||||
        null = getattr(field, 'null', True)
 | 
			
		||||
        null = getattr(field, "null", True)
 | 
			
		||||
        return Field(_type, required=not null)
 | 
			
		||||
 | 
			
		||||
    return Dynamic(dynamic_type)
 | 
			
		||||
| 
						 | 
				
			
			@ -171,6 +184,7 @@ def convert_field_to_list_or_connection(field, registry=None):
 | 
			
		|||
            # defined filter_fields in the DjangoObjectType Meta
 | 
			
		||||
            if _type._meta.filter_fields:
 | 
			
		||||
                from .filter.fields import DjangoFilterConnectionField
 | 
			
		||||
 | 
			
		||||
                return DjangoFilterConnectionField(_type)
 | 
			
		||||
 | 
			
		||||
            return DjangoConnectionField(_type)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,4 +1,4 @@
 | 
			
		|||
from .middleware import DjangoDebugMiddleware
 | 
			
		||||
from .types import DjangoDebug
 | 
			
		||||
 | 
			
		||||
__all__ = ['DjangoDebugMiddleware', 'DjangoDebug']
 | 
			
		||||
__all__ = ["DjangoDebugMiddleware", "DjangoDebug"]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -7,7 +7,6 @@ from .types import DjangoDebug
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class DjangoDebugContext(object):
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.debug_promise = None
 | 
			
		||||
        self.promises = []
 | 
			
		||||
| 
						 | 
				
			
			@ -38,20 +37,21 @@ class DjangoDebugContext(object):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class DjangoDebugMiddleware(object):
 | 
			
		||||
 | 
			
		||||
    def resolve(self, next, root, info, **args):
 | 
			
		||||
        context = info.context
 | 
			
		||||
        django_debug = getattr(context, 'django_debug', None)
 | 
			
		||||
        django_debug = getattr(context, "django_debug", None)
 | 
			
		||||
        if not django_debug:
 | 
			
		||||
            if context is None:
 | 
			
		||||
                raise Exception('DjangoDebug cannot be executed in None contexts')
 | 
			
		||||
                raise Exception("DjangoDebug cannot be executed in None contexts")
 | 
			
		||||
            try:
 | 
			
		||||
                context.django_debug = DjangoDebugContext()
 | 
			
		||||
            except Exception:
 | 
			
		||||
                raise Exception('DjangoDebug need the context to be writable, context received: {}.'.format(
 | 
			
		||||
                    context.__class__.__name__
 | 
			
		||||
                ))
 | 
			
		||||
        if info.schema.get_type('DjangoDebug') == info.return_type:
 | 
			
		||||
                raise Exception(
 | 
			
		||||
                    "DjangoDebug need the context to be writable, context received: {}.".format(
 | 
			
		||||
                        context.__class__.__name__
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
        if info.schema.get_type("DjangoDebug") == info.return_type:
 | 
			
		||||
            return context.django_debug.get_debug_promise()
 | 
			
		||||
        promise = next(root, info, **args)
 | 
			
		||||
        context.django_debug.add_promise(promise)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,7 +16,6 @@ class SQLQueryTriggered(Exception):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class ThreadLocalState(local):
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.enabled = True
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -35,7 +34,7 @@ recording = state.recording  # export function
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def wrap_cursor(connection, panel):
 | 
			
		||||
    if not hasattr(connection, '_graphene_cursor'):
 | 
			
		||||
    if not hasattr(connection, "_graphene_cursor"):
 | 
			
		||||
        connection._graphene_cursor = connection.cursor
 | 
			
		||||
 | 
			
		||||
        def cursor():
 | 
			
		||||
| 
						 | 
				
			
			@ -46,7 +45,7 @@ def wrap_cursor(connection, panel):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def unwrap_cursor(connection):
 | 
			
		||||
    if hasattr(connection, '_graphene_cursor'):
 | 
			
		||||
    if hasattr(connection, "_graphene_cursor"):
 | 
			
		||||
        previous_cursor = connection._graphene_cursor
 | 
			
		||||
        connection.cursor = previous_cursor
 | 
			
		||||
        del connection._graphene_cursor
 | 
			
		||||
| 
						 | 
				
			
			@ -87,15 +86,14 @@ class NormalCursorWrapper(object):
 | 
			
		|||
        if not params:
 | 
			
		||||
            return params
 | 
			
		||||
        if isinstance(params, dict):
 | 
			
		||||
            return dict((key, self._quote_expr(value))
 | 
			
		||||
                        for key, value in params.items())
 | 
			
		||||
            return dict((key, self._quote_expr(value)) for key, value in params.items())
 | 
			
		||||
        return list(map(self._quote_expr, params))
 | 
			
		||||
 | 
			
		||||
    def _decode(self, param):
 | 
			
		||||
        try:
 | 
			
		||||
            return force_text(param, strings_only=True)
 | 
			
		||||
        except UnicodeDecodeError:
 | 
			
		||||
            return '(encoded string)'
 | 
			
		||||
            return "(encoded string)"
 | 
			
		||||
 | 
			
		||||
    def _record(self, method, sql, params):
 | 
			
		||||
        start_time = time()
 | 
			
		||||
| 
						 | 
				
			
			@ -103,45 +101,48 @@ class NormalCursorWrapper(object):
 | 
			
		|||
            return method(sql, params)
 | 
			
		||||
        finally:
 | 
			
		||||
            stop_time = time()
 | 
			
		||||
            duration = (stop_time - start_time)
 | 
			
		||||
            _params = ''
 | 
			
		||||
            duration = stop_time - start_time
 | 
			
		||||
            _params = ""
 | 
			
		||||
            try:
 | 
			
		||||
                _params = json.dumps(list(map(self._decode, params)))
 | 
			
		||||
            except Exception:
 | 
			
		||||
                pass  # object not JSON serializable
 | 
			
		||||
 | 
			
		||||
            alias = getattr(self.db, 'alias', 'default')
 | 
			
		||||
            alias = getattr(self.db, "alias", "default")
 | 
			
		||||
            conn = self.db.connection
 | 
			
		||||
            vendor = getattr(conn, 'vendor', 'unknown')
 | 
			
		||||
            vendor = getattr(conn, "vendor", "unknown")
 | 
			
		||||
 | 
			
		||||
            params = {
 | 
			
		||||
                'vendor': vendor,
 | 
			
		||||
                'alias': alias,
 | 
			
		||||
                'sql': self.db.ops.last_executed_query(
 | 
			
		||||
                    self.cursor, sql, self._quote_params(params)),
 | 
			
		||||
                'duration': duration,
 | 
			
		||||
                'raw_sql': sql,
 | 
			
		||||
                'params': _params,
 | 
			
		||||
                'start_time': start_time,
 | 
			
		||||
                'stop_time': stop_time,
 | 
			
		||||
                'is_slow': duration > 10,
 | 
			
		||||
                'is_select': sql.lower().strip().startswith('select'),
 | 
			
		||||
                "vendor": vendor,
 | 
			
		||||
                "alias": alias,
 | 
			
		||||
                "sql": self.db.ops.last_executed_query(
 | 
			
		||||
                    self.cursor, sql, self._quote_params(params)
 | 
			
		||||
                ),
 | 
			
		||||
                "duration": duration,
 | 
			
		||||
                "raw_sql": sql,
 | 
			
		||||
                "params": _params,
 | 
			
		||||
                "start_time": start_time,
 | 
			
		||||
                "stop_time": stop_time,
 | 
			
		||||
                "is_slow": duration > 10,
 | 
			
		||||
                "is_select": sql.lower().strip().startswith("select"),
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            if vendor == 'postgresql':
 | 
			
		||||
            if vendor == "postgresql":
 | 
			
		||||
                # If an erroneous query was ran on the connection, it might
 | 
			
		||||
                # be in a state where checking isolation_level raises an
 | 
			
		||||
                # exception.
 | 
			
		||||
                try:
 | 
			
		||||
                    iso_level = conn.isolation_level
 | 
			
		||||
                except conn.InternalError:
 | 
			
		||||
                    iso_level = 'unknown'
 | 
			
		||||
                params.update({
 | 
			
		||||
                    'trans_id': self.logger.get_transaction_id(alias),
 | 
			
		||||
                    'trans_status': conn.get_transaction_status(),
 | 
			
		||||
                    'iso_level': iso_level,
 | 
			
		||||
                    'encoding': conn.encoding,
 | 
			
		||||
                })
 | 
			
		||||
                    iso_level = "unknown"
 | 
			
		||||
                params.update(
 | 
			
		||||
                    {
 | 
			
		||||
                        "trans_id": self.logger.get_transaction_id(alias),
 | 
			
		||||
                        "trans_status": conn.get_transaction_status(),
 | 
			
		||||
                        "iso_level": iso_level,
 | 
			
		||||
                        "encoding": conn.encoding,
 | 
			
		||||
                    }
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            _sql = DjangoDebugSQL(**params)
 | 
			
		||||
            # We keep `sql` to maintain backwards compatibility
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,31 +12,31 @@ from ..types import DjangoDebug
 | 
			
		|||
class context(object):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# from examples.starwars_django.models import Character
 | 
			
		||||
 | 
			
		||||
pytestmark = pytest.mark.django_db
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_should_query_field():
 | 
			
		||||
    r1 = Reporter(last_name='ABA')
 | 
			
		||||
    r1 = Reporter(last_name="ABA")
 | 
			
		||||
    r1.save()
 | 
			
		||||
    r2 = Reporter(last_name='Griffin')
 | 
			
		||||
    r2 = Reporter(last_name="Griffin")
 | 
			
		||||
    r2.save()
 | 
			
		||||
 | 
			
		||||
    class ReporterType(DjangoObjectType):
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Reporter
 | 
			
		||||
            interfaces = (Node, )
 | 
			
		||||
            interfaces = (Node,)
 | 
			
		||||
 | 
			
		||||
    class Query(graphene.ObjectType):
 | 
			
		||||
        reporter = graphene.Field(ReporterType)
 | 
			
		||||
        debug = graphene.Field(DjangoDebug, name='__debug')
 | 
			
		||||
        debug = graphene.Field(DjangoDebug, name="__debug")
 | 
			
		||||
 | 
			
		||||
        def resolve_reporter(self, info, **args):
 | 
			
		||||
            return Reporter.objects.first()
 | 
			
		||||
 | 
			
		||||
    query = '''
 | 
			
		||||
    query = """
 | 
			
		||||
        query ReporterQuery {
 | 
			
		||||
          reporter {
 | 
			
		||||
            lastName
 | 
			
		||||
| 
						 | 
				
			
			@ -47,43 +47,40 @@ def test_should_query_field():
 | 
			
		|||
            }
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
    '''
 | 
			
		||||
    """
 | 
			
		||||
    expected = {
 | 
			
		||||
        'reporter': {
 | 
			
		||||
            'lastName': 'ABA',
 | 
			
		||||
        "reporter": {"lastName": "ABA"},
 | 
			
		||||
        "__debug": {
 | 
			
		||||
            "sql": [{"rawSql": str(Reporter.objects.order_by("pk")[:1].query)}]
 | 
			
		||||
        },
 | 
			
		||||
        '__debug': {
 | 
			
		||||
            'sql': [{
 | 
			
		||||
                'rawSql': str(Reporter.objects.order_by('pk')[:1].query)
 | 
			
		||||
            }]
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    schema = graphene.Schema(query=Query)
 | 
			
		||||
    result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()])
 | 
			
		||||
    result = schema.execute(
 | 
			
		||||
        query, context_value=context(), middleware=[DjangoDebugMiddleware()]
 | 
			
		||||
    )
 | 
			
		||||
    assert not result.errors
 | 
			
		||||
    assert result.data == expected
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_should_query_list():
 | 
			
		||||
    r1 = Reporter(last_name='ABA')
 | 
			
		||||
    r1 = Reporter(last_name="ABA")
 | 
			
		||||
    r1.save()
 | 
			
		||||
    r2 = Reporter(last_name='Griffin')
 | 
			
		||||
    r2 = Reporter(last_name="Griffin")
 | 
			
		||||
    r2.save()
 | 
			
		||||
 | 
			
		||||
    class ReporterType(DjangoObjectType):
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Reporter
 | 
			
		||||
            interfaces = (Node, )
 | 
			
		||||
            interfaces = (Node,)
 | 
			
		||||
 | 
			
		||||
    class Query(graphene.ObjectType):
 | 
			
		||||
        all_reporters = graphene.List(ReporterType)
 | 
			
		||||
        debug = graphene.Field(DjangoDebug, name='__debug')
 | 
			
		||||
        debug = graphene.Field(DjangoDebug, name="__debug")
 | 
			
		||||
 | 
			
		||||
        def resolve_all_reporters(self, info, **args):
 | 
			
		||||
            return Reporter.objects.all()
 | 
			
		||||
 | 
			
		||||
    query = '''
 | 
			
		||||
    query = """
 | 
			
		||||
        query ReporterQuery {
 | 
			
		||||
          allReporters {
 | 
			
		||||
            lastName
 | 
			
		||||
| 
						 | 
				
			
			@ -94,45 +91,38 @@ def test_should_query_list():
 | 
			
		|||
            }
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
    '''
 | 
			
		||||
    """
 | 
			
		||||
    expected = {
 | 
			
		||||
        'allReporters': [{
 | 
			
		||||
            'lastName': 'ABA',
 | 
			
		||||
        }, {
 | 
			
		||||
            'lastName': 'Griffin',
 | 
			
		||||
        }],
 | 
			
		||||
        '__debug': {
 | 
			
		||||
            'sql': [{
 | 
			
		||||
                'rawSql': str(Reporter.objects.all().query)
 | 
			
		||||
            }]
 | 
			
		||||
        }
 | 
			
		||||
        "allReporters": [{"lastName": "ABA"}, {"lastName": "Griffin"}],
 | 
			
		||||
        "__debug": {"sql": [{"rawSql": str(Reporter.objects.all().query)}]},
 | 
			
		||||
    }
 | 
			
		||||
    schema = graphene.Schema(query=Query)
 | 
			
		||||
    result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()])
 | 
			
		||||
    result = schema.execute(
 | 
			
		||||
        query, context_value=context(), middleware=[DjangoDebugMiddleware()]
 | 
			
		||||
    )
 | 
			
		||||
    assert not result.errors
 | 
			
		||||
    assert result.data == expected
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_should_query_connection():
 | 
			
		||||
    r1 = Reporter(last_name='ABA')
 | 
			
		||||
    r1 = Reporter(last_name="ABA")
 | 
			
		||||
    r1.save()
 | 
			
		||||
    r2 = Reporter(last_name='Griffin')
 | 
			
		||||
    r2 = Reporter(last_name="Griffin")
 | 
			
		||||
    r2.save()
 | 
			
		||||
 | 
			
		||||
    class ReporterType(DjangoObjectType):
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Reporter
 | 
			
		||||
            interfaces = (Node, )
 | 
			
		||||
            interfaces = (Node,)
 | 
			
		||||
 | 
			
		||||
    class Query(graphene.ObjectType):
 | 
			
		||||
        all_reporters = DjangoConnectionField(ReporterType)
 | 
			
		||||
        debug = graphene.Field(DjangoDebug, name='__debug')
 | 
			
		||||
        debug = graphene.Field(DjangoDebug, name="__debug")
 | 
			
		||||
 | 
			
		||||
        def resolve_all_reporters(self, info, **args):
 | 
			
		||||
            return Reporter.objects.all()
 | 
			
		||||
 | 
			
		||||
    query = '''
 | 
			
		||||
    query = """
 | 
			
		||||
        query ReporterQuery {
 | 
			
		||||
          allReporters(first:1) {
 | 
			
		||||
            edges {
 | 
			
		||||
| 
						 | 
				
			
			@ -147,48 +137,41 @@ def test_should_query_connection():
 | 
			
		|||
            }
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
    '''
 | 
			
		||||
    expected = {
 | 
			
		||||
        'allReporters': {
 | 
			
		||||
            'edges': [{
 | 
			
		||||
                'node': {
 | 
			
		||||
                    'lastName': 'ABA',
 | 
			
		||||
                }
 | 
			
		||||
            }]
 | 
			
		||||
        },
 | 
			
		||||
    }
 | 
			
		||||
    """
 | 
			
		||||
    expected = {"allReporters": {"edges": [{"node": {"lastName": "ABA"}}]}}
 | 
			
		||||
    schema = graphene.Schema(query=Query)
 | 
			
		||||
    result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()])
 | 
			
		||||
    result = schema.execute(
 | 
			
		||||
        query, context_value=context(), middleware=[DjangoDebugMiddleware()]
 | 
			
		||||
    )
 | 
			
		||||
    assert not result.errors
 | 
			
		||||
    assert result.data['allReporters'] == expected['allReporters']
 | 
			
		||||
    assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql']
 | 
			
		||||
    assert result.data["allReporters"] == expected["allReporters"]
 | 
			
		||||
    assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"]
 | 
			
		||||
    query = str(Reporter.objects.all()[:1].query)
 | 
			
		||||
    assert result.data['__debug']['sql'][1]['rawSql'] == query
 | 
			
		||||
    assert result.data["__debug"]["sql"][1]["rawSql"] == query
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_should_query_connectionfilter():
 | 
			
		||||
    from ...filter import DjangoFilterConnectionField
 | 
			
		||||
 | 
			
		||||
    r1 = Reporter(last_name='ABA')
 | 
			
		||||
    r1 = Reporter(last_name="ABA")
 | 
			
		||||
    r1.save()
 | 
			
		||||
    r2 = Reporter(last_name='Griffin')
 | 
			
		||||
    r2 = Reporter(last_name="Griffin")
 | 
			
		||||
    r2.save()
 | 
			
		||||
 | 
			
		||||
    class ReporterType(DjangoObjectType):
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Reporter
 | 
			
		||||
            interfaces = (Node, )
 | 
			
		||||
            interfaces = (Node,)
 | 
			
		||||
 | 
			
		||||
    class Query(graphene.ObjectType):
 | 
			
		||||
        all_reporters = DjangoFilterConnectionField(ReporterType, fields=['last_name'])
 | 
			
		||||
        all_reporters = DjangoFilterConnectionField(ReporterType, fields=["last_name"])
 | 
			
		||||
        s = graphene.String(resolver=lambda *_: "S")
 | 
			
		||||
        debug = graphene.Field(DjangoDebug, name='__debug')
 | 
			
		||||
        debug = graphene.Field(DjangoDebug, name="__debug")
 | 
			
		||||
 | 
			
		||||
        def resolve_all_reporters(self, info, **args):
 | 
			
		||||
            return Reporter.objects.all()
 | 
			
		||||
 | 
			
		||||
    query = '''
 | 
			
		||||
    query = """
 | 
			
		||||
        query ReporterQuery {
 | 
			
		||||
          allReporters(first:1) {
 | 
			
		||||
            edges {
 | 
			
		||||
| 
						 | 
				
			
			@ -203,20 +186,14 @@ def test_should_query_connectionfilter():
 | 
			
		|||
            }
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
    '''
 | 
			
		||||
    expected = {
 | 
			
		||||
        'allReporters': {
 | 
			
		||||
            'edges': [{
 | 
			
		||||
                'node': {
 | 
			
		||||
                    'lastName': 'ABA',
 | 
			
		||||
                }
 | 
			
		||||
            }]
 | 
			
		||||
        },
 | 
			
		||||
    }
 | 
			
		||||
    """
 | 
			
		||||
    expected = {"allReporters": {"edges": [{"node": {"lastName": "ABA"}}]}}
 | 
			
		||||
    schema = graphene.Schema(query=Query)
 | 
			
		||||
    result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()])
 | 
			
		||||
    result = schema.execute(
 | 
			
		||||
        query, context_value=context(), middleware=[DjangoDebugMiddleware()]
 | 
			
		||||
    )
 | 
			
		||||
    assert not result.errors
 | 
			
		||||
    assert result.data['allReporters'] == expected['allReporters']
 | 
			
		||||
    assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql']
 | 
			
		||||
    assert result.data["allReporters"] == expected["allReporters"]
 | 
			
		||||
    assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"]
 | 
			
		||||
    query = str(Reporter.objects.all()[:1].query)
 | 
			
		||||
    assert result.data['__debug']['sql'][1]['rawSql'] == query
 | 
			
		||||
    assert result.data["__debug"]["sql"][1]["rawSql"] == query
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,7 +13,6 @@ from .utils import maybe_queryset
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class DjangoListField(Field):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, _type, *args, **kwargs):
 | 
			
		||||
        super(DjangoListField, self).__init__(List(_type), *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -30,25 +29,28 @@ class DjangoListField(Field):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class DjangoConnectionField(ConnectionField):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        self.on = kwargs.pop('on', False)
 | 
			
		||||
        self.on = kwargs.pop("on", False)
 | 
			
		||||
        self.max_limit = kwargs.pop(
 | 
			
		||||
            'max_limit',
 | 
			
		||||
            graphene_settings.RELAY_CONNECTION_MAX_LIMIT
 | 
			
		||||
            "max_limit", graphene_settings.RELAY_CONNECTION_MAX_LIMIT
 | 
			
		||||
        )
 | 
			
		||||
        self.enforce_first_or_last = kwargs.pop(
 | 
			
		||||
            'enforce_first_or_last',
 | 
			
		||||
            graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST
 | 
			
		||||
            "enforce_first_or_last",
 | 
			
		||||
            graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST,
 | 
			
		||||
        )
 | 
			
		||||
        super(DjangoConnectionField, self).__init__(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def type(self):
 | 
			
		||||
        from .types import DjangoObjectType
 | 
			
		||||
 | 
			
		||||
        _type = super(ConnectionField, self).type
 | 
			
		||||
        assert issubclass(_type, DjangoObjectType), "DjangoConnectionField only accepts DjangoObjectType types"
 | 
			
		||||
        assert _type._meta.connection, "The type {} doesn't have a connection".format(_type.__name__)
 | 
			
		||||
        assert issubclass(
 | 
			
		||||
            _type, DjangoObjectType
 | 
			
		||||
        ), "DjangoConnectionField only accepts DjangoObjectType types"
 | 
			
		||||
        assert _type._meta.connection, "The type {} doesn't have a connection".format(
 | 
			
		||||
            _type.__name__
 | 
			
		||||
        )
 | 
			
		||||
        return _type._meta.connection
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
| 
						 | 
				
			
			@ -100,28 +102,37 @@ class DjangoConnectionField(ConnectionField):
 | 
			
		|||
        return connection
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def connection_resolver(cls, resolver, connection, default_manager, max_limit,
 | 
			
		||||
                            enforce_first_or_last, root, info, **args):
 | 
			
		||||
        first = args.get('first')
 | 
			
		||||
        last = args.get('last')
 | 
			
		||||
    def connection_resolver(
 | 
			
		||||
        cls,
 | 
			
		||||
        resolver,
 | 
			
		||||
        connection,
 | 
			
		||||
        default_manager,
 | 
			
		||||
        max_limit,
 | 
			
		||||
        enforce_first_or_last,
 | 
			
		||||
        root,
 | 
			
		||||
        info,
 | 
			
		||||
        **args
 | 
			
		||||
    ):
 | 
			
		||||
        first = args.get("first")
 | 
			
		||||
        last = args.get("last")
 | 
			
		||||
 | 
			
		||||
        if enforce_first_or_last:
 | 
			
		||||
            assert first or last, (
 | 
			
		||||
                'You must provide a `first` or `last` value to properly paginate the `{}` connection.'
 | 
			
		||||
                "You must provide a `first` or `last` value to properly paginate the `{}` connection."
 | 
			
		||||
            ).format(info.field_name)
 | 
			
		||||
 | 
			
		||||
        if max_limit:
 | 
			
		||||
            if first:
 | 
			
		||||
                assert first <= max_limit, (
 | 
			
		||||
                    'Requesting {} records on the `{}` connection exceeds the `first` limit of {} records.'
 | 
			
		||||
                    "Requesting {} records on the `{}` connection exceeds the `first` limit of {} records."
 | 
			
		||||
                ).format(first, info.field_name, max_limit)
 | 
			
		||||
                args['first'] = min(first, max_limit)
 | 
			
		||||
                args["first"] = min(first, max_limit)
 | 
			
		||||
 | 
			
		||||
            if last:
 | 
			
		||||
                assert last <= max_limit, (
 | 
			
		||||
                    'Requesting {} records on the `{}` connection exceeds the `last` limit of {} records.'
 | 
			
		||||
                    "Requesting {} records on the `{}` connection exceeds the `last` limit of {} records."
 | 
			
		||||
                ).format(last, info.field_name, max_limit)
 | 
			
		||||
                args['last'] = min(last, max_limit)
 | 
			
		||||
                args["last"] = min(last, max_limit)
 | 
			
		||||
 | 
			
		||||
        iterable = resolver(root, info, **args)
 | 
			
		||||
        on_resolve = partial(cls.resolve_connection, connection, default_manager, args)
 | 
			
		||||
| 
						 | 
				
			
			@ -138,5 +149,5 @@ class DjangoConnectionField(ConnectionField):
 | 
			
		|||
            self.type,
 | 
			
		||||
            self.get_manager(),
 | 
			
		||||
            self.max_limit,
 | 
			
		||||
            self.enforce_first_or_last
 | 
			
		||||
            self.enforce_first_or_last,
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -4,11 +4,15 @@ from ..utils import DJANGO_FILTER_INSTALLED
 | 
			
		|||
if not DJANGO_FILTER_INSTALLED:
 | 
			
		||||
    warnings.warn(
 | 
			
		||||
        "Use of django filtering requires the django-filter package "
 | 
			
		||||
        "be installed. You can do so using `pip install django-filter`", ImportWarning
 | 
			
		||||
        "be installed. You can do so using `pip install django-filter`",
 | 
			
		||||
        ImportWarning,
 | 
			
		||||
    )
 | 
			
		||||
else:
 | 
			
		||||
    from .fields import DjangoFilterConnectionField
 | 
			
		||||
    from .filterset import GlobalIDFilter, GlobalIDMultipleChoiceFilter
 | 
			
		||||
 | 
			
		||||
    __all__ = ['DjangoFilterConnectionField',
 | 
			
		||||
               'GlobalIDFilter', 'GlobalIDMultipleChoiceFilter']
 | 
			
		||||
    __all__ = [
 | 
			
		||||
        "DjangoFilterConnectionField",
 | 
			
		||||
        "GlobalIDFilter",
 | 
			
		||||
        "GlobalIDMultipleChoiceFilter",
 | 
			
		||||
    ]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -7,10 +7,16 @@ from .utils import get_filtering_args_from_filterset, get_filterset_class
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class DjangoFilterConnectionField(DjangoConnectionField):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, type, fields=None, order_by=None,
 | 
			
		||||
                 extra_filter_meta=None, filterset_class=None,
 | 
			
		||||
                 *args, **kwargs):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        type,
 | 
			
		||||
        fields=None,
 | 
			
		||||
        order_by=None,
 | 
			
		||||
        extra_filter_meta=None,
 | 
			
		||||
        filterset_class=None,
 | 
			
		||||
        *args,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ):
 | 
			
		||||
        self._fields = fields
 | 
			
		||||
        self._provided_filterset_class = filterset_class
 | 
			
		||||
        self._filterset_class = None
 | 
			
		||||
| 
						 | 
				
			
			@ -30,12 +36,13 @@ class DjangoFilterConnectionField(DjangoConnectionField):
 | 
			
		|||
    def filterset_class(self):
 | 
			
		||||
        if not self._filterset_class:
 | 
			
		||||
            fields = self._fields or self.node_type._meta.filter_fields
 | 
			
		||||
            meta = dict(model=self.model,
 | 
			
		||||
                        fields=fields)
 | 
			
		||||
            meta = dict(model=self.model, fields=fields)
 | 
			
		||||
            if self._extra_filter_meta:
 | 
			
		||||
                meta.update(self._extra_filter_meta)
 | 
			
		||||
 | 
			
		||||
            self._filterset_class = get_filterset_class(self._provided_filterset_class, **meta)
 | 
			
		||||
            self._filterset_class = get_filterset_class(
 | 
			
		||||
                self._provided_filterset_class, **meta
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        return self._filterset_class
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -52,28 +59,40 @@ class DjangoFilterConnectionField(DjangoConnectionField):
 | 
			
		|||
 | 
			
		||||
        # See related PR: https://github.com/graphql-python/graphene-django/pull/126
 | 
			
		||||
 | 
			
		||||
        assert not (default_queryset.query.low_mark and queryset.query.low_mark), (
 | 
			
		||||
            'Received two sliced querysets (low mark) in the connection, please slice only in one.'
 | 
			
		||||
        )
 | 
			
		||||
        assert not (default_queryset.query.high_mark and queryset.query.high_mark), (
 | 
			
		||||
            'Received two sliced querysets (high mark) in the connection, please slice only in one.'
 | 
			
		||||
        )
 | 
			
		||||
        assert not (
 | 
			
		||||
            default_queryset.query.low_mark and queryset.query.low_mark
 | 
			
		||||
        ), "Received two sliced querysets (low mark) in the connection, please slice only in one."
 | 
			
		||||
        assert not (
 | 
			
		||||
            default_queryset.query.high_mark and queryset.query.high_mark
 | 
			
		||||
        ), "Received two sliced querysets (high mark) in the connection, please slice only in one."
 | 
			
		||||
        low = default_queryset.query.low_mark or queryset.query.low_mark
 | 
			
		||||
        high = default_queryset.query.high_mark or queryset.query.high_mark
 | 
			
		||||
        default_queryset.query.clear_limits()
 | 
			
		||||
        queryset = super(DjangoFilterConnectionField, cls).merge_querysets(default_queryset, queryset)
 | 
			
		||||
        queryset = super(DjangoFilterConnectionField, cls).merge_querysets(
 | 
			
		||||
            default_queryset, queryset
 | 
			
		||||
        )
 | 
			
		||||
        queryset.query.set_limits(low, high)
 | 
			
		||||
        return queryset
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def connection_resolver(cls, resolver, connection, default_manager, max_limit,
 | 
			
		||||
                            enforce_first_or_last, filterset_class, filtering_args,
 | 
			
		||||
                            root, info, **args):
 | 
			
		||||
    def connection_resolver(
 | 
			
		||||
        cls,
 | 
			
		||||
        resolver,
 | 
			
		||||
        connection,
 | 
			
		||||
        default_manager,
 | 
			
		||||
        max_limit,
 | 
			
		||||
        enforce_first_or_last,
 | 
			
		||||
        filterset_class,
 | 
			
		||||
        filtering_args,
 | 
			
		||||
        root,
 | 
			
		||||
        info,
 | 
			
		||||
        **args
 | 
			
		||||
    ):
 | 
			
		||||
        filter_kwargs = {k: v for k, v in args.items() if k in filtering_args}
 | 
			
		||||
        qs = filterset_class(
 | 
			
		||||
            data=filter_kwargs,
 | 
			
		||||
            queryset=default_manager.get_queryset(),
 | 
			
		||||
            request=info.context
 | 
			
		||||
            request=info.context,
 | 
			
		||||
        ).qs
 | 
			
		||||
 | 
			
		||||
        return super(DjangoFilterConnectionField, cls).connection_resolver(
 | 
			
		||||
| 
						 | 
				
			
			@ -96,5 +115,5 @@ class DjangoFilterConnectionField(DjangoConnectionField):
 | 
			
		|||
            self.max_limit,
 | 
			
		||||
            self.enforce_first_or_last,
 | 
			
		||||
            self.filterset_class,
 | 
			
		||||
            self.filtering_args
 | 
			
		||||
            self.filtering_args,
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -28,26 +28,19 @@ class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
GRAPHENE_FILTER_SET_OVERRIDES = {
 | 
			
		||||
    models.AutoField: {
 | 
			
		||||
        'filter_class': GlobalIDFilter,
 | 
			
		||||
    },
 | 
			
		||||
    models.OneToOneField: {
 | 
			
		||||
        'filter_class': GlobalIDFilter,
 | 
			
		||||
    },
 | 
			
		||||
    models.ForeignKey: {
 | 
			
		||||
        'filter_class': GlobalIDFilter,
 | 
			
		||||
    },
 | 
			
		||||
    models.ManyToManyField: {
 | 
			
		||||
        'filter_class': GlobalIDMultipleChoiceFilter,
 | 
			
		||||
    }
 | 
			
		||||
    models.AutoField: {"filter_class": GlobalIDFilter},
 | 
			
		||||
    models.OneToOneField: {"filter_class": GlobalIDFilter},
 | 
			
		||||
    models.ForeignKey: {"filter_class": GlobalIDFilter},
 | 
			
		||||
    models.ManyToManyField: {"filter_class": GlobalIDMultipleChoiceFilter},
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GrapheneFilterSetMixin(BaseFilterSet):
 | 
			
		||||
    FILTER_DEFAULTS = dict(itertools.chain(
 | 
			
		||||
        FILTER_FOR_DBFIELD_DEFAULTS.items(),
 | 
			
		||||
        GRAPHENE_FILTER_SET_OVERRIDES.items()
 | 
			
		||||
    ))
 | 
			
		||||
    FILTER_DEFAULTS = dict(
 | 
			
		||||
        itertools.chain(
 | 
			
		||||
            FILTER_FOR_DBFIELD_DEFAULTS.items(), GRAPHENE_FILTER_SET_OVERRIDES.items()
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def filter_for_reverse_field(cls, f, name):
 | 
			
		||||
| 
						 | 
				
			
			@ -62,10 +55,7 @@ class GrapheneFilterSetMixin(BaseFilterSet):
 | 
			
		|||
        except AttributeError:
 | 
			
		||||
            rel = f.field.rel
 | 
			
		||||
 | 
			
		||||
        default = {
 | 
			
		||||
            'name': name,
 | 
			
		||||
            'label': capfirst(rel.related_name)
 | 
			
		||||
        }
 | 
			
		||||
        default = {"name": name, "label": capfirst(rel.related_name)}
 | 
			
		||||
        if rel.multiple:
 | 
			
		||||
            # For to-many relationships
 | 
			
		||||
            return GlobalIDMultipleChoiceFilter(**default)
 | 
			
		||||
| 
						 | 
				
			
			@ -78,25 +68,20 @@ def setup_filterset(filterset_class):
 | 
			
		|||
    """ Wrap a provided filterset in Graphene-specific functionality
 | 
			
		||||
    """
 | 
			
		||||
    return type(
 | 
			
		||||
        'Graphene{}'.format(filterset_class.__name__),
 | 
			
		||||
        "Graphene{}".format(filterset_class.__name__),
 | 
			
		||||
        (filterset_class, GrapheneFilterSetMixin),
 | 
			
		||||
        {},
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def custom_filterset_factory(model, filterset_base_class=FilterSet,
 | 
			
		||||
                             **meta):
 | 
			
		||||
def custom_filterset_factory(model, filterset_base_class=FilterSet, **meta):
 | 
			
		||||
    """ Create a filterset for the given model using the provided meta data
 | 
			
		||||
    """
 | 
			
		||||
    meta.update({
 | 
			
		||||
        'model': model,
 | 
			
		||||
    })
 | 
			
		||||
    meta_class = type(str('Meta'), (object,), meta)
 | 
			
		||||
    meta.update({"model": model})
 | 
			
		||||
    meta_class = type(str("Meta"), (object,), meta)
 | 
			
		||||
    filterset = type(
 | 
			
		||||
        str('%sFilterSet' % model._meta.object_name),
 | 
			
		||||
        str("%sFilterSet" % model._meta.object_name),
 | 
			
		||||
        (filterset_base_class, GrapheneFilterSetMixin),
 | 
			
		||||
        {
 | 
			
		||||
            'Meta': meta_class
 | 
			
		||||
        }
 | 
			
		||||
        {"Meta": meta_class},
 | 
			
		||||
    )
 | 
			
		||||
    return filterset
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -5,29 +5,26 @@ from graphene_django.tests.models import Article, Pet, Reporter
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class ArticleFilter(django_filters.FilterSet):
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        model = Article
 | 
			
		||||
        fields = {
 | 
			
		||||
            'headline': ['exact', 'icontains'],
 | 
			
		||||
            'pub_date': ['gt', 'lt', 'exact'],
 | 
			
		||||
            'reporter': ['exact'],
 | 
			
		||||
            "headline": ["exact", "icontains"],
 | 
			
		||||
            "pub_date": ["gt", "lt", "exact"],
 | 
			
		||||
            "reporter": ["exact"],
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    order_by = OrderingFilter(fields=('pub_date',))
 | 
			
		||||
    order_by = OrderingFilter(fields=("pub_date",))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReporterFilter(django_filters.FilterSet):
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        model = Reporter
 | 
			
		||||
        fields = ['first_name', 'last_name', 'email', 'pets']
 | 
			
		||||
        fields = ["first_name", "last_name", "email", "pets"]
 | 
			
		||||
 | 
			
		||||
    order_by = OrderingFilter(fields=('pub_date',))
 | 
			
		||||
    order_by = OrderingFilter(fields=("pub_date",))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PetFilter(django_filters.FilterSet):
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        model = Pet
 | 
			
		||||
        fields = ['name']
 | 
			
		||||
        fields = ["name"]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -5,8 +5,7 @@ import pytest
 | 
			
		|||
from graphene import Field, ObjectType, Schema, Argument, Float, Boolean, String
 | 
			
		||||
from graphene.relay import Node
 | 
			
		||||
from graphene_django import DjangoObjectType
 | 
			
		||||
from graphene_django.forms import (GlobalIDFormField,
 | 
			
		||||
                                   GlobalIDMultipleChoiceField)
 | 
			
		||||
from graphene_django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField
 | 
			
		||||
from graphene_django.tests.models import Article, Pet, Reporter
 | 
			
		||||
from graphene_django.utils import DJANGO_FILTER_INSTALLED
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -20,36 +19,43 @@ if DJANGO_FILTER_INSTALLED:
 | 
			
		|||
    import django_filters
 | 
			
		||||
    from django_filters import FilterSet, NumberFilter
 | 
			
		||||
 | 
			
		||||
    from graphene_django.filter import (GlobalIDFilter, DjangoFilterConnectionField,
 | 
			
		||||
                                        GlobalIDMultipleChoiceFilter)
 | 
			
		||||
    from graphene_django.filter.tests.filters import ArticleFilter, PetFilter, ReporterFilter
 | 
			
		||||
    from graphene_django.filter import (
 | 
			
		||||
        GlobalIDFilter,
 | 
			
		||||
        DjangoFilterConnectionField,
 | 
			
		||||
        GlobalIDMultipleChoiceFilter,
 | 
			
		||||
    )
 | 
			
		||||
    from graphene_django.filter.tests.filters import (
 | 
			
		||||
        ArticleFilter,
 | 
			
		||||
        PetFilter,
 | 
			
		||||
        ReporterFilter,
 | 
			
		||||
    )
 | 
			
		||||
else:
 | 
			
		||||
    pytestmark.append(pytest.mark.skipif(True, reason='django_filters not installed or not compatible'))
 | 
			
		||||
    pytestmark.append(
 | 
			
		||||
        pytest.mark.skipif(
 | 
			
		||||
            True, reason="django_filters not installed or not compatible"
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
pytestmark.append(pytest.mark.django_db)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if DJANGO_FILTER_INSTALLED:
 | 
			
		||||
    class ArticleNode(DjangoObjectType):
 | 
			
		||||
 | 
			
		||||
    class ArticleNode(DjangoObjectType):
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Article
 | 
			
		||||
            interfaces = (Node, )
 | 
			
		||||
            filter_fields = ('headline', )
 | 
			
		||||
 | 
			
		||||
            interfaces = (Node,)
 | 
			
		||||
            filter_fields = ("headline",)
 | 
			
		||||
 | 
			
		||||
    class ReporterNode(DjangoObjectType):
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Reporter
 | 
			
		||||
            interfaces = (Node, )
 | 
			
		||||
 | 
			
		||||
            interfaces = (Node,)
 | 
			
		||||
 | 
			
		||||
    class PetNode(DjangoObjectType):
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Pet
 | 
			
		||||
            interfaces = (Node, )
 | 
			
		||||
            interfaces = (Node,)
 | 
			
		||||
 | 
			
		||||
    # schema = Schema()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -59,58 +65,47 @@ def get_args(field):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def assert_arguments(field, *arguments):
 | 
			
		||||
    ignore = ('after', 'before', 'first', 'last', 'order_by')
 | 
			
		||||
    ignore = ("after", "before", "first", "last", "order_by")
 | 
			
		||||
    args = get_args(field)
 | 
			
		||||
    actual = [
 | 
			
		||||
        name
 | 
			
		||||
        for name in args
 | 
			
		||||
        if name not in ignore and not name.startswith('_')
 | 
			
		||||
    ]
 | 
			
		||||
    assert set(arguments) == set(actual), \
 | 
			
		||||
        'Expected arguments ({}) did not match actual ({})'.format(
 | 
			
		||||
            arguments,
 | 
			
		||||
            actual
 | 
			
		||||
    )
 | 
			
		||||
    actual = [name for name in args if name not in ignore and not name.startswith("_")]
 | 
			
		||||
    assert set(arguments) == set(
 | 
			
		||||
        actual
 | 
			
		||||
    ), "Expected arguments ({}) did not match actual ({})".format(arguments, actual)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def assert_orderable(field):
 | 
			
		||||
    args = get_args(field)
 | 
			
		||||
    assert 'order_by' in args, \
 | 
			
		||||
        'Field cannot be ordered'
 | 
			
		||||
    assert "order_by" in args, "Field cannot be ordered"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def assert_not_orderable(field):
 | 
			
		||||
    args = get_args(field)
 | 
			
		||||
    assert 'order_by' not in args, \
 | 
			
		||||
        'Field can be ordered'
 | 
			
		||||
    assert "order_by" not in args, "Field can be ordered"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_filter_explicit_filterset_arguments():
 | 
			
		||||
    field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleFilter)
 | 
			
		||||
    assert_arguments(field,
 | 
			
		||||
                     'headline', 'headline__icontains',
 | 
			
		||||
                     'pub_date', 'pub_date__gt', 'pub_date__lt',
 | 
			
		||||
                     'reporter',
 | 
			
		||||
                     )
 | 
			
		||||
    assert_arguments(
 | 
			
		||||
        field,
 | 
			
		||||
        "headline",
 | 
			
		||||
        "headline__icontains",
 | 
			
		||||
        "pub_date",
 | 
			
		||||
        "pub_date__gt",
 | 
			
		||||
        "pub_date__lt",
 | 
			
		||||
        "reporter",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_filter_shortcut_filterset_arguments_list():
 | 
			
		||||
    field = DjangoFilterConnectionField(ArticleNode, fields=['pub_date', 'reporter'])
 | 
			
		||||
    assert_arguments(field,
 | 
			
		||||
                     'pub_date',
 | 
			
		||||
                     'reporter',
 | 
			
		||||
                     )
 | 
			
		||||
    field = DjangoFilterConnectionField(ArticleNode, fields=["pub_date", "reporter"])
 | 
			
		||||
    assert_arguments(field, "pub_date", "reporter")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_filter_shortcut_filterset_arguments_dict():
 | 
			
		||||
    field = DjangoFilterConnectionField(ArticleNode, fields={
 | 
			
		||||
        'headline': ['exact', 'icontains'],
 | 
			
		||||
        'reporter': ['exact'],
 | 
			
		||||
    })
 | 
			
		||||
    assert_arguments(field,
 | 
			
		||||
                     'headline', 'headline__icontains',
 | 
			
		||||
                     'reporter',
 | 
			
		||||
                     )
 | 
			
		||||
    field = DjangoFilterConnectionField(
 | 
			
		||||
        ArticleNode, fields={"headline": ["exact", "icontains"], "reporter": ["exact"]}
 | 
			
		||||
    )
 | 
			
		||||
    assert_arguments(field, "headline", "headline__icontains", "reporter")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_filter_explicit_filterset_orderable():
 | 
			
		||||
| 
						 | 
				
			
			@ -134,15 +129,14 @@ def test_filter_explicit_filterset_not_orderable():
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def test_filter_shortcut_filterset_extra_meta():
 | 
			
		||||
    field = DjangoFilterConnectionField(ArticleNode, extra_filter_meta={
 | 
			
		||||
        'exclude': ('headline', )
 | 
			
		||||
    })
 | 
			
		||||
    assert 'headline' not in field.filterset_class.get_fields()
 | 
			
		||||
    field = DjangoFilterConnectionField(
 | 
			
		||||
        ArticleNode, extra_filter_meta={"exclude": ("headline",)}
 | 
			
		||||
    )
 | 
			
		||||
    assert "headline" not in field.filterset_class.get_fields()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_filter_shortcut_filterset_context():
 | 
			
		||||
    class ArticleContextFilter(django_filters.FilterSet):
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Article
 | 
			
		||||
            exclude = set()
 | 
			
		||||
| 
						 | 
				
			
			@ -153,17 +147,31 @@ def test_filter_shortcut_filterset_context():
 | 
			
		|||
            return qs.filter(reporter=self.request.reporter)
 | 
			
		||||
 | 
			
		||||
    class Query(ObjectType):
 | 
			
		||||
        context_articles = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleContextFilter)
 | 
			
		||||
        context_articles = DjangoFilterConnectionField(
 | 
			
		||||
            ArticleNode, filterset_class=ArticleContextFilter
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    r1 = Reporter.objects.create(first_name='r1', last_name='r1', email='r1@test.com')
 | 
			
		||||
    r2 = Reporter.objects.create(first_name='r2', last_name='r2', email='r2@test.com')
 | 
			
		||||
    Article.objects.create(headline='a1', pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r1, editor=r1)
 | 
			
		||||
    Article.objects.create(headline='a2', pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r2, editor=r2)
 | 
			
		||||
    r1 = Reporter.objects.create(first_name="r1", last_name="r1", email="r1@test.com")
 | 
			
		||||
    r2 = Reporter.objects.create(first_name="r2", last_name="r2", email="r2@test.com")
 | 
			
		||||
    Article.objects.create(
 | 
			
		||||
        headline="a1",
 | 
			
		||||
        pub_date=datetime.now(),
 | 
			
		||||
        pub_date_time=datetime.now(),
 | 
			
		||||
        reporter=r1,
 | 
			
		||||
        editor=r1,
 | 
			
		||||
    )
 | 
			
		||||
    Article.objects.create(
 | 
			
		||||
        headline="a2",
 | 
			
		||||
        pub_date=datetime.now(),
 | 
			
		||||
        pub_date_time=datetime.now(),
 | 
			
		||||
        reporter=r2,
 | 
			
		||||
        editor=r2,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    class context(object):
 | 
			
		||||
        reporter = r2
 | 
			
		||||
 | 
			
		||||
    query = '''
 | 
			
		||||
    query = """
 | 
			
		||||
    query {
 | 
			
		||||
        contextArticles {
 | 
			
		||||
            edges {
 | 
			
		||||
| 
						 | 
				
			
			@ -173,42 +181,39 @@ def test_filter_shortcut_filterset_context():
 | 
			
		|||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    '''
 | 
			
		||||
    """
 | 
			
		||||
    schema = Schema(query=Query)
 | 
			
		||||
    result = schema.execute(query, context_value=context())
 | 
			
		||||
    assert not result.errors
 | 
			
		||||
 | 
			
		||||
    assert len(result.data['contextArticles']['edges']) == 1
 | 
			
		||||
    assert result.data['contextArticles']['edges'][0]['node']['headline'] == 'a2'
 | 
			
		||||
    assert len(result.data["contextArticles"]["edges"]) == 1
 | 
			
		||||
    assert result.data["contextArticles"]["edges"][0]["node"]["headline"] == "a2"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_filter_filterset_information_on_meta():
 | 
			
		||||
    class ReporterFilterNode(DjangoObjectType):
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Reporter
 | 
			
		||||
            interfaces = (Node, )
 | 
			
		||||
            filter_fields = ['first_name', 'articles']
 | 
			
		||||
            interfaces = (Node,)
 | 
			
		||||
            filter_fields = ["first_name", "articles"]
 | 
			
		||||
 | 
			
		||||
    field = DjangoFilterConnectionField(ReporterFilterNode)
 | 
			
		||||
    assert_arguments(field, 'first_name', 'articles')
 | 
			
		||||
    assert_arguments(field, "first_name", "articles")
 | 
			
		||||
    assert_not_orderable(field)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_filter_filterset_information_on_meta_related():
 | 
			
		||||
    class ReporterFilterNode(DjangoObjectType):
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Reporter
 | 
			
		||||
            interfaces = (Node, )
 | 
			
		||||
            filter_fields = ['first_name', 'articles']
 | 
			
		||||
            interfaces = (Node,)
 | 
			
		||||
            filter_fields = ["first_name", "articles"]
 | 
			
		||||
 | 
			
		||||
    class ArticleFilterNode(DjangoObjectType):
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Article
 | 
			
		||||
            interfaces = (Node, )
 | 
			
		||||
            filter_fields = ['headline', 'reporter']
 | 
			
		||||
            interfaces = (Node,)
 | 
			
		||||
            filter_fields = ["headline", "reporter"]
 | 
			
		||||
 | 
			
		||||
    class Query(ObjectType):
 | 
			
		||||
        all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
 | 
			
		||||
| 
						 | 
				
			
			@ -217,25 +222,23 @@ def test_filter_filterset_information_on_meta_related():
 | 
			
		|||
        article = Field(ArticleFilterNode)
 | 
			
		||||
 | 
			
		||||
    schema = Schema(query=Query)
 | 
			
		||||
    articles_field = ReporterFilterNode._meta.fields['articles'].get_type()
 | 
			
		||||
    assert_arguments(articles_field, 'headline', 'reporter')
 | 
			
		||||
    articles_field = ReporterFilterNode._meta.fields["articles"].get_type()
 | 
			
		||||
    assert_arguments(articles_field, "headline", "reporter")
 | 
			
		||||
    assert_not_orderable(articles_field)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_filter_filterset_related_results():
 | 
			
		||||
    class ReporterFilterNode(DjangoObjectType):
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Reporter
 | 
			
		||||
            interfaces = (Node, )
 | 
			
		||||
            filter_fields = ['first_name', 'articles']
 | 
			
		||||
            interfaces = (Node,)
 | 
			
		||||
            filter_fields = ["first_name", "articles"]
 | 
			
		||||
 | 
			
		||||
    class ArticleFilterNode(DjangoObjectType):
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            interfaces = (Node, )
 | 
			
		||||
            interfaces = (Node,)
 | 
			
		||||
            model = Article
 | 
			
		||||
            filter_fields = ['headline', 'reporter']
 | 
			
		||||
            filter_fields = ["headline", "reporter"]
 | 
			
		||||
 | 
			
		||||
    class Query(ObjectType):
 | 
			
		||||
        all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
 | 
			
		||||
| 
						 | 
				
			
			@ -243,12 +246,22 @@ def test_filter_filterset_related_results():
 | 
			
		|||
        reporter = Field(ReporterFilterNode)
 | 
			
		||||
        article = Field(ArticleFilterNode)
 | 
			
		||||
 | 
			
		||||
    r1 = Reporter.objects.create(first_name='r1', last_name='r1', email='r1@test.com')
 | 
			
		||||
    r2 = Reporter.objects.create(first_name='r2', last_name='r2', email='r2@test.com')
 | 
			
		||||
    Article.objects.create(headline='a1', pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r1)
 | 
			
		||||
    Article.objects.create(headline='a2', pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r2)
 | 
			
		||||
    r1 = Reporter.objects.create(first_name="r1", last_name="r1", email="r1@test.com")
 | 
			
		||||
    r2 = Reporter.objects.create(first_name="r2", last_name="r2", email="r2@test.com")
 | 
			
		||||
    Article.objects.create(
 | 
			
		||||
        headline="a1",
 | 
			
		||||
        pub_date=datetime.now(),
 | 
			
		||||
        pub_date_time=datetime.now(),
 | 
			
		||||
        reporter=r1,
 | 
			
		||||
    )
 | 
			
		||||
    Article.objects.create(
 | 
			
		||||
        headline="a2",
 | 
			
		||||
        pub_date=datetime.now(),
 | 
			
		||||
        pub_date_time=datetime.now(),
 | 
			
		||||
        reporter=r2,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    query = '''
 | 
			
		||||
    query = """
 | 
			
		||||
    query {
 | 
			
		||||
        allReporters {
 | 
			
		||||
            edges {
 | 
			
		||||
| 
						 | 
				
			
			@ -264,123 +277,134 @@ def test_filter_filterset_related_results():
 | 
			
		|||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    '''
 | 
			
		||||
    """
 | 
			
		||||
    schema = Schema(query=Query)
 | 
			
		||||
    result = schema.execute(query)
 | 
			
		||||
    assert not result.errors
 | 
			
		||||
    # We should only get back a single article for each reporter
 | 
			
		||||
    assert len(result.data['allReporters']['edges'][0]['node']['articles']['edges']) == 1
 | 
			
		||||
    assert len(result.data['allReporters']['edges'][1]['node']['articles']['edges']) == 1
 | 
			
		||||
    assert (
 | 
			
		||||
        len(result.data["allReporters"]["edges"][0]["node"]["articles"]["edges"]) == 1
 | 
			
		||||
    )
 | 
			
		||||
    assert (
 | 
			
		||||
        len(result.data["allReporters"]["edges"][1]["node"]["articles"]["edges"]) == 1
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_global_id_field_implicit():
 | 
			
		||||
    field = DjangoFilterConnectionField(ArticleNode, fields=['id'])
 | 
			
		||||
    field = DjangoFilterConnectionField(ArticleNode, fields=["id"])
 | 
			
		||||
    filterset_class = field.filterset_class
 | 
			
		||||
    id_filter = filterset_class.base_filters['id']
 | 
			
		||||
    id_filter = filterset_class.base_filters["id"]
 | 
			
		||||
    assert isinstance(id_filter, GlobalIDFilter)
 | 
			
		||||
    assert id_filter.field_class == GlobalIDFormField
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_global_id_field_explicit():
 | 
			
		||||
    class ArticleIdFilter(django_filters.FilterSet):
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Article
 | 
			
		||||
            fields = ['id']
 | 
			
		||||
            fields = ["id"]
 | 
			
		||||
 | 
			
		||||
    field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter)
 | 
			
		||||
    filterset_class = field.filterset_class
 | 
			
		||||
    id_filter = filterset_class.base_filters['id']
 | 
			
		||||
    id_filter = filterset_class.base_filters["id"]
 | 
			
		||||
    assert isinstance(id_filter, GlobalIDFilter)
 | 
			
		||||
    assert id_filter.field_class == GlobalIDFormField
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_filterset_descriptions():
 | 
			
		||||
    class ArticleIdFilter(django_filters.FilterSet):
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Article
 | 
			
		||||
            fields = ['id']
 | 
			
		||||
            fields = ["id"]
 | 
			
		||||
 | 
			
		||||
        max_time = django_filters.NumberFilter(method='filter_max_time', label="The maximum time")
 | 
			
		||||
        max_time = django_filters.NumberFilter(
 | 
			
		||||
            method="filter_max_time", label="The maximum time"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter)
 | 
			
		||||
    max_time = field.args['max_time']
 | 
			
		||||
    max_time = field.args["max_time"]
 | 
			
		||||
    assert isinstance(max_time, Argument)
 | 
			
		||||
    assert max_time.type == Float
 | 
			
		||||
    assert max_time.description == 'The maximum time'
 | 
			
		||||
    assert max_time.description == "The maximum time"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_global_id_field_relation():
 | 
			
		||||
    field = DjangoFilterConnectionField(ArticleNode, fields=['reporter'])
 | 
			
		||||
    field = DjangoFilterConnectionField(ArticleNode, fields=["reporter"])
 | 
			
		||||
    filterset_class = field.filterset_class
 | 
			
		||||
    id_filter = filterset_class.base_filters['reporter']
 | 
			
		||||
    id_filter = filterset_class.base_filters["reporter"]
 | 
			
		||||
    assert isinstance(id_filter, GlobalIDFilter)
 | 
			
		||||
    assert id_filter.field_class == GlobalIDFormField
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_global_id_multiple_field_implicit():
 | 
			
		||||
    field = DjangoFilterConnectionField(ReporterNode, fields=['pets'])
 | 
			
		||||
    field = DjangoFilterConnectionField(ReporterNode, fields=["pets"])
 | 
			
		||||
    filterset_class = field.filterset_class
 | 
			
		||||
    multiple_filter = filterset_class.base_filters['pets']
 | 
			
		||||
    multiple_filter = filterset_class.base_filters["pets"]
 | 
			
		||||
    assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter)
 | 
			
		||||
    assert multiple_filter.field_class == GlobalIDMultipleChoiceField
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_global_id_multiple_field_explicit():
 | 
			
		||||
    class ReporterPetsFilter(django_filters.FilterSet):
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Reporter
 | 
			
		||||
            fields = ['pets']
 | 
			
		||||
            fields = ["pets"]
 | 
			
		||||
 | 
			
		||||
    field = DjangoFilterConnectionField(ReporterNode, filterset_class=ReporterPetsFilter)
 | 
			
		||||
    field = DjangoFilterConnectionField(
 | 
			
		||||
        ReporterNode, filterset_class=ReporterPetsFilter
 | 
			
		||||
    )
 | 
			
		||||
    filterset_class = field.filterset_class
 | 
			
		||||
    multiple_filter = filterset_class.base_filters['pets']
 | 
			
		||||
    multiple_filter = filterset_class.base_filters["pets"]
 | 
			
		||||
    assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter)
 | 
			
		||||
    assert multiple_filter.field_class == GlobalIDMultipleChoiceField
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_global_id_multiple_field_implicit_reverse():
 | 
			
		||||
    field = DjangoFilterConnectionField(ReporterNode, fields=['articles'])
 | 
			
		||||
    field = DjangoFilterConnectionField(ReporterNode, fields=["articles"])
 | 
			
		||||
    filterset_class = field.filterset_class
 | 
			
		||||
    multiple_filter = filterset_class.base_filters['articles']
 | 
			
		||||
    multiple_filter = filterset_class.base_filters["articles"]
 | 
			
		||||
    assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter)
 | 
			
		||||
    assert multiple_filter.field_class == GlobalIDMultipleChoiceField
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_global_id_multiple_field_explicit_reverse():
 | 
			
		||||
    class ReporterPetsFilter(django_filters.FilterSet):
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Reporter
 | 
			
		||||
            fields = ['articles']
 | 
			
		||||
            fields = ["articles"]
 | 
			
		||||
 | 
			
		||||
    field = DjangoFilterConnectionField(ReporterNode, filterset_class=ReporterPetsFilter)
 | 
			
		||||
    field = DjangoFilterConnectionField(
 | 
			
		||||
        ReporterNode, filterset_class=ReporterPetsFilter
 | 
			
		||||
    )
 | 
			
		||||
    filterset_class = field.filterset_class
 | 
			
		||||
    multiple_filter = filterset_class.base_filters['articles']
 | 
			
		||||
    multiple_filter = filterset_class.base_filters["articles"]
 | 
			
		||||
    assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter)
 | 
			
		||||
    assert multiple_filter.field_class == GlobalIDMultipleChoiceField
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_filter_filterset_related_results():
 | 
			
		||||
    class ReporterFilterNode(DjangoObjectType):
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Reporter
 | 
			
		||||
            interfaces = (Node, )
 | 
			
		||||
            filter_fields = {
 | 
			
		||||
                'first_name': ['icontains']
 | 
			
		||||
            }
 | 
			
		||||
            interfaces = (Node,)
 | 
			
		||||
            filter_fields = {"first_name": ["icontains"]}
 | 
			
		||||
 | 
			
		||||
    class Query(ObjectType):
 | 
			
		||||
        all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
 | 
			
		||||
 | 
			
		||||
    r1 = Reporter.objects.create(first_name='A test user', last_name='Last Name', email='test1@test.com')
 | 
			
		||||
    r2 = Reporter.objects.create(first_name='Other test user', last_name='Other Last Name', email='test2@test.com')
 | 
			
		||||
    r3 = Reporter.objects.create(first_name='Random', last_name='RandomLast', email='random@test.com')
 | 
			
		||||
    r1 = Reporter.objects.create(
 | 
			
		||||
        first_name="A test user", last_name="Last Name", email="test1@test.com"
 | 
			
		||||
    )
 | 
			
		||||
    r2 = Reporter.objects.create(
 | 
			
		||||
        first_name="Other test user",
 | 
			
		||||
        last_name="Other Last Name",
 | 
			
		||||
        email="test2@test.com",
 | 
			
		||||
    )
 | 
			
		||||
    r3 = Reporter.objects.create(
 | 
			
		||||
        first_name="Random", last_name="RandomLast", email="random@test.com"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    query = '''
 | 
			
		||||
    query = """
 | 
			
		||||
    query {
 | 
			
		||||
        allReporters(firstName_Icontains: "test") {
 | 
			
		||||
            edges {
 | 
			
		||||
| 
						 | 
				
			
			@ -390,12 +414,12 @@ def test_filter_filterset_related_results():
 | 
			
		|||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    '''
 | 
			
		||||
    """
 | 
			
		||||
    schema = Schema(query=Query)
 | 
			
		||||
    result = schema.execute(query)
 | 
			
		||||
    assert not result.errors
 | 
			
		||||
    # We should only get two reporters
 | 
			
		||||
    assert len(result.data['allReporters']['edges']) == 2
 | 
			
		||||
    assert len(result.data["allReporters"]["edges"]) == 2
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_recursive_filter_connection():
 | 
			
		||||
| 
						 | 
				
			
			@ -407,79 +431,73 @@ def test_recursive_filter_connection():
 | 
			
		|||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Reporter
 | 
			
		||||
            interfaces = (Node, )
 | 
			
		||||
            interfaces = (Node,)
 | 
			
		||||
 | 
			
		||||
    class Query(ObjectType):
 | 
			
		||||
        all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
 | 
			
		||||
 | 
			
		||||
    assert ReporterFilterNode._meta.fields['child_reporters'].node_type == ReporterFilterNode
 | 
			
		||||
    assert (
 | 
			
		||||
        ReporterFilterNode._meta.fields["child_reporters"].node_type
 | 
			
		||||
        == ReporterFilterNode
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_should_query_filter_node_limit():
 | 
			
		||||
    class ReporterFilter(FilterSet):
 | 
			
		||||
        limit = NumberFilter(method='filter_limit')
 | 
			
		||||
        limit = NumberFilter(method="filter_limit")
 | 
			
		||||
 | 
			
		||||
        def filter_limit(self, queryset, name, value):
 | 
			
		||||
            return queryset[:value]
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Reporter
 | 
			
		||||
            fields = ['first_name', ]
 | 
			
		||||
            fields = ["first_name"]
 | 
			
		||||
 | 
			
		||||
    class ReporterType(DjangoObjectType):
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Reporter
 | 
			
		||||
            interfaces = (Node, )
 | 
			
		||||
            interfaces = (Node,)
 | 
			
		||||
 | 
			
		||||
    class ArticleType(DjangoObjectType):
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Article
 | 
			
		||||
            interfaces = (Node, )
 | 
			
		||||
            filter_fields = ('lang', )
 | 
			
		||||
            interfaces = (Node,)
 | 
			
		||||
            filter_fields = ("lang",)
 | 
			
		||||
 | 
			
		||||
    class Query(ObjectType):
 | 
			
		||||
        all_reporters = DjangoFilterConnectionField(
 | 
			
		||||
            ReporterType,
 | 
			
		||||
            filterset_class=ReporterFilter
 | 
			
		||||
            ReporterType, filterset_class=ReporterFilter
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        def resolve_all_reporters(self, info, **args):
 | 
			
		||||
            return Reporter.objects.order_by('a_choice')
 | 
			
		||||
            return Reporter.objects.order_by("a_choice")
 | 
			
		||||
 | 
			
		||||
    Reporter.objects.create(
 | 
			
		||||
        first_name='Bob',
 | 
			
		||||
        last_name='Doe',
 | 
			
		||||
        email='bobdoe@example.com',
 | 
			
		||||
        a_choice=2
 | 
			
		||||
        first_name="Bob", last_name="Doe", email="bobdoe@example.com", a_choice=2
 | 
			
		||||
    )
 | 
			
		||||
    r = Reporter.objects.create(
 | 
			
		||||
        first_name='John',
 | 
			
		||||
        last_name='Doe',
 | 
			
		||||
        email='johndoe@example.com',
 | 
			
		||||
        a_choice=1
 | 
			
		||||
        first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    Article.objects.create(
 | 
			
		||||
        headline='Article Node 1',
 | 
			
		||||
        headline="Article Node 1",
 | 
			
		||||
        pub_date=datetime.now(),
 | 
			
		||||
        pub_date_time=datetime.now(),
 | 
			
		||||
        reporter=r,
 | 
			
		||||
        editor=r,
 | 
			
		||||
        lang='es'
 | 
			
		||||
        lang="es",
 | 
			
		||||
    )
 | 
			
		||||
    Article.objects.create(
 | 
			
		||||
        headline='Article Node 2',
 | 
			
		||||
        headline="Article Node 2",
 | 
			
		||||
        pub_date=datetime.now(),
 | 
			
		||||
        pub_date_time=datetime.now(),
 | 
			
		||||
        reporter=r,
 | 
			
		||||
        editor=r,
 | 
			
		||||
        lang='en'
 | 
			
		||||
        lang="en",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    schema = Schema(query=Query)
 | 
			
		||||
    query = '''
 | 
			
		||||
    query = """
 | 
			
		||||
        query NodeFilteringQuery {
 | 
			
		||||
            allReporters(limit: 1) {
 | 
			
		||||
                edges {
 | 
			
		||||
| 
						 | 
				
			
			@ -498,24 +516,23 @@ def test_should_query_filter_node_limit():
 | 
			
		|||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    '''
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    expected = {
 | 
			
		||||
        'allReporters': {
 | 
			
		||||
            'edges': [{
 | 
			
		||||
                'node': {
 | 
			
		||||
                    'id': 'UmVwb3J0ZXJUeXBlOjI=',
 | 
			
		||||
                    'firstName': 'John',
 | 
			
		||||
                    'articles': {
 | 
			
		||||
                        'edges': [{
 | 
			
		||||
                            'node': {
 | 
			
		||||
                                'id': 'QXJ0aWNsZVR5cGU6MQ==',
 | 
			
		||||
                                'lang': 'ES'
 | 
			
		||||
                            }
 | 
			
		||||
                        }]
 | 
			
		||||
        "allReporters": {
 | 
			
		||||
            "edges": [
 | 
			
		||||
                {
 | 
			
		||||
                    "node": {
 | 
			
		||||
                        "id": "UmVwb3J0ZXJUeXBlOjI=",
 | 
			
		||||
                        "firstName": "John",
 | 
			
		||||
                        "articles": {
 | 
			
		||||
                            "edges": [
 | 
			
		||||
                                {"node": {"id": "QXJ0aWNsZVR5cGU6MQ==", "lang": "ES"}}
 | 
			
		||||
                            ]
 | 
			
		||||
                        },
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }]
 | 
			
		||||
            ]
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -526,45 +543,37 @@ def test_should_query_filter_node_limit():
 | 
			
		|||
 | 
			
		||||
def test_should_query_filter_node_double_limit_raises():
 | 
			
		||||
    class ReporterFilter(FilterSet):
 | 
			
		||||
        limit = NumberFilter(method='filter_limit')
 | 
			
		||||
        limit = NumberFilter(method="filter_limit")
 | 
			
		||||
 | 
			
		||||
        def filter_limit(self, queryset, name, value):
 | 
			
		||||
            return queryset[:value]
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Reporter
 | 
			
		||||
            fields = ['first_name', ]
 | 
			
		||||
            fields = ["first_name"]
 | 
			
		||||
 | 
			
		||||
    class ReporterType(DjangoObjectType):
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Reporter
 | 
			
		||||
            interfaces = (Node, )
 | 
			
		||||
            interfaces = (Node,)
 | 
			
		||||
 | 
			
		||||
    class Query(ObjectType):
 | 
			
		||||
        all_reporters = DjangoFilterConnectionField(
 | 
			
		||||
            ReporterType,
 | 
			
		||||
            filterset_class=ReporterFilter
 | 
			
		||||
            ReporterType, filterset_class=ReporterFilter
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        def resolve_all_reporters(self, info, **args):
 | 
			
		||||
            return Reporter.objects.order_by('a_choice')[:2]
 | 
			
		||||
            return Reporter.objects.order_by("a_choice")[:2]
 | 
			
		||||
 | 
			
		||||
    Reporter.objects.create(
 | 
			
		||||
        first_name='Bob',
 | 
			
		||||
        last_name='Doe',
 | 
			
		||||
        email='bobdoe@example.com',
 | 
			
		||||
        a_choice=2
 | 
			
		||||
        first_name="Bob", last_name="Doe", email="bobdoe@example.com", a_choice=2
 | 
			
		||||
    )
 | 
			
		||||
    r = Reporter.objects.create(
 | 
			
		||||
        first_name='John',
 | 
			
		||||
        last_name='Doe',
 | 
			
		||||
        email='johndoe@example.com',
 | 
			
		||||
        a_choice=1
 | 
			
		||||
        first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    schema = Schema(query=Query)
 | 
			
		||||
    query = '''
 | 
			
		||||
    query = """
 | 
			
		||||
        query NodeFilteringQuery {
 | 
			
		||||
            allReporters(limit: 1) {
 | 
			
		||||
                edges {
 | 
			
		||||
| 
						 | 
				
			
			@ -575,41 +584,40 @@ def test_should_query_filter_node_double_limit_raises():
 | 
			
		|||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    '''
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    result = schema.execute(query)
 | 
			
		||||
    assert len(result.errors) == 1
 | 
			
		||||
    assert str(result.errors[0]) == (
 | 
			
		||||
        'Received two sliced querysets (high mark) in the connection, please slice only in one.'
 | 
			
		||||
        "Received two sliced querysets (high mark) in the connection, please slice only in one."
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_order_by_is_perserved():
 | 
			
		||||
    class ReporterType(DjangoObjectType):
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Reporter
 | 
			
		||||
            interfaces = (Node, )
 | 
			
		||||
            interfaces = (Node,)
 | 
			
		||||
            filter_fields = ()
 | 
			
		||||
 | 
			
		||||
    class Query(ObjectType):
 | 
			
		||||
        all_reporters = DjangoFilterConnectionField(ReporterType, reverse_order=Boolean())
 | 
			
		||||
        all_reporters = DjangoFilterConnectionField(
 | 
			
		||||
            ReporterType, reverse_order=Boolean()
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        def resolve_all_reporters(self, info, reverse_order=False, **args):
 | 
			
		||||
            reporters = Reporter.objects.order_by('first_name')
 | 
			
		||||
            reporters = Reporter.objects.order_by("first_name")
 | 
			
		||||
 | 
			
		||||
            if reverse_order:
 | 
			
		||||
                return reporters.reverse()
 | 
			
		||||
 | 
			
		||||
            return reporters
 | 
			
		||||
 | 
			
		||||
    Reporter.objects.create(
 | 
			
		||||
        first_name='b',
 | 
			
		||||
    )
 | 
			
		||||
    r = Reporter.objects.create(
 | 
			
		||||
        first_name='a',
 | 
			
		||||
    )
 | 
			
		||||
    Reporter.objects.create(first_name="b")
 | 
			
		||||
    r = Reporter.objects.create(first_name="a")
 | 
			
		||||
 | 
			
		||||
    schema = Schema(query=Query)
 | 
			
		||||
    query = '''
 | 
			
		||||
    query = """
 | 
			
		||||
        query NodeFilteringQuery {
 | 
			
		||||
            allReporters(first: 1) {
 | 
			
		||||
                edges {
 | 
			
		||||
| 
						 | 
				
			
			@ -619,23 +627,14 @@ def test_order_by_is_perserved():
 | 
			
		|||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    '''
 | 
			
		||||
    expected = {
 | 
			
		||||
        'allReporters': {
 | 
			
		||||
            'edges': [{
 | 
			
		||||
                'node': {
 | 
			
		||||
                    'firstName': 'a',
 | 
			
		||||
                }
 | 
			
		||||
            }]
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    """
 | 
			
		||||
    expected = {"allReporters": {"edges": [{"node": {"firstName": "a"}}]}}
 | 
			
		||||
 | 
			
		||||
    result = schema.execute(query)
 | 
			
		||||
    assert not result.errors
 | 
			
		||||
    assert result.data == expected
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    reverse_query = '''
 | 
			
		||||
    reverse_query = """
 | 
			
		||||
        query NodeFilteringQuery {
 | 
			
		||||
            allReporters(first: 1, reverseOrder: true) {
 | 
			
		||||
                edges {
 | 
			
		||||
| 
						 | 
				
			
			@ -645,23 +644,16 @@ def test_order_by_is_perserved():
 | 
			
		|||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    '''
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    reverse_expected = {
 | 
			
		||||
        'allReporters': {
 | 
			
		||||
            'edges': [{
 | 
			
		||||
                'node': {
 | 
			
		||||
                    'firstName': 'b',
 | 
			
		||||
                }
 | 
			
		||||
            }]
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    reverse_expected = {"allReporters": {"edges": [{"node": {"firstName": "b"}}]}}
 | 
			
		||||
 | 
			
		||||
    reverse_result = schema.execute(reverse_query)
 | 
			
		||||
 | 
			
		||||
    assert not reverse_result.errors
 | 
			
		||||
    assert reverse_result.data == reverse_expected
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_annotation_is_perserved():
 | 
			
		||||
    class ReporterType(DjangoObjectType):
 | 
			
		||||
        full_name = String()
 | 
			
		||||
| 
						 | 
				
			
			@ -671,7 +663,7 @@ def test_annotation_is_perserved():
 | 
			
		|||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Reporter
 | 
			
		||||
            interfaces = (Node, )
 | 
			
		||||
            interfaces = (Node,)
 | 
			
		||||
            filter_fields = ()
 | 
			
		||||
 | 
			
		||||
    class Query(ObjectType):
 | 
			
		||||
| 
						 | 
				
			
			@ -679,17 +671,16 @@ def test_annotation_is_perserved():
 | 
			
		|||
 | 
			
		||||
        def resolve_all_reporters(self, info, **args):
 | 
			
		||||
            return Reporter.objects.annotate(
 | 
			
		||||
                full_name=Concat('first_name', Value(' '), 'last_name', output_field=TextField())
 | 
			
		||||
                full_name=Concat(
 | 
			
		||||
                    "first_name", Value(" "), "last_name", output_field=TextField()
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    Reporter.objects.create(
 | 
			
		||||
        first_name='John',
 | 
			
		||||
        last_name='Doe',
 | 
			
		||||
    )
 | 
			
		||||
    Reporter.objects.create(first_name="John", last_name="Doe")
 | 
			
		||||
 | 
			
		||||
    schema = Schema(query=Query)
 | 
			
		||||
 | 
			
		||||
    query = '''
 | 
			
		||||
    query = """
 | 
			
		||||
        query NodeFilteringQuery {
 | 
			
		||||
            allReporters(first: 1) {
 | 
			
		||||
                edges {
 | 
			
		||||
| 
						 | 
				
			
			@ -699,16 +690,8 @@ def test_annotation_is_perserved():
 | 
			
		|||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    '''
 | 
			
		||||
    expected = {
 | 
			
		||||
        'allReporters': {
 | 
			
		||||
            'edges': [{
 | 
			
		||||
                'node': {
 | 
			
		||||
                    'fullName': 'John Doe',
 | 
			
		||||
                }
 | 
			
		||||
            }]
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    """
 | 
			
		||||
    expected = {"allReporters": {"edges": [{"node": {"fullName": "John Doe"}}]}}
 | 
			
		||||
 | 
			
		||||
    result = schema.execute(query)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -14,8 +14,7 @@ singledispatch = import_single_dispatch()
 | 
			
		|||
def convert_form_field(field):
 | 
			
		||||
    raise ImproperlyConfigured(
 | 
			
		||||
        "Don't know how to convert the Django form field %s (%s) "
 | 
			
		||||
        "to Graphene type" %
 | 
			
		||||
        (field, field.__class__)
 | 
			
		||||
        "to Graphene type" % (field, field.__class__)
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -8,9 +8,7 @@ from graphql_relay import from_global_id
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class GlobalIDFormField(Field):
 | 
			
		||||
    default_error_messages = {
 | 
			
		||||
        'invalid': _('Invalid ID specified.'),
 | 
			
		||||
    }
 | 
			
		||||
    default_error_messages = {"invalid": _("Invalid ID specified.")}
 | 
			
		||||
 | 
			
		||||
    def clean(self, value):
 | 
			
		||||
        if not value and not self.required:
 | 
			
		||||
| 
						 | 
				
			
			@ -19,21 +17,21 @@ class GlobalIDFormField(Field):
 | 
			
		|||
        try:
 | 
			
		||||
            _type, _id = from_global_id(value)
 | 
			
		||||
        except (TypeError, ValueError, UnicodeDecodeError, binascii.Error):
 | 
			
		||||
            raise ValidationError(self.error_messages['invalid'])
 | 
			
		||||
            raise ValidationError(self.error_messages["invalid"])
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            CharField().clean(_id)
 | 
			
		||||
            CharField().clean(_type)
 | 
			
		||||
        except ValidationError:
 | 
			
		||||
            raise ValidationError(self.error_messages['invalid'])
 | 
			
		||||
            raise ValidationError(self.error_messages["invalid"])
 | 
			
		||||
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GlobalIDMultipleChoiceField(MultipleChoiceField):
 | 
			
		||||
    default_error_messages = {
 | 
			
		||||
        'invalid_choice': _('One of the specified IDs was invalid (%(value)s).'),
 | 
			
		||||
        'invalid_list': _('Enter a list of values.'),
 | 
			
		||||
        "invalid_choice": _("One of the specified IDs was invalid (%(value)s)."),
 | 
			
		||||
        "invalid_list": _("Enter a list of values."),
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    def valid_value(self, value):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -5,6 +5,7 @@ import graphene
 | 
			
		|||
from graphene import Field, InputField
 | 
			
		||||
from graphene.relay.mutation import ClientIDMutation
 | 
			
		||||
from graphene.types.mutation import MutationOptions
 | 
			
		||||
 | 
			
		||||
# from graphene.types.inputobjecttype import (
 | 
			
		||||
#     InputObjectTypeOptions,
 | 
			
		||||
#     InputObjectType,
 | 
			
		||||
| 
						 | 
				
			
			@ -21,7 +22,8 @@ def fields_for_form(form, only_fields, exclude_fields):
 | 
			
		|||
    for name, field in form.fields.items():
 | 
			
		||||
        is_not_in_only = only_fields and name not in only_fields
 | 
			
		||||
        is_excluded = (
 | 
			
		||||
            name in exclude_fields  # or
 | 
			
		||||
            name
 | 
			
		||||
            in exclude_fields  # or
 | 
			
		||||
            # name in already_created_fields
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -57,12 +59,12 @@ class BaseDjangoFormMutation(ClientIDMutation):
 | 
			
		|||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def get_form_kwargs(cls, root, info, **input):
 | 
			
		||||
        kwargs = {'data': input}
 | 
			
		||||
        kwargs = {"data": input}
 | 
			
		||||
 | 
			
		||||
        pk = input.pop('id', None)
 | 
			
		||||
        pk = input.pop("id", None)
 | 
			
		||||
        if pk:
 | 
			
		||||
            instance = cls._meta.model._default_manager.get(pk=pk)
 | 
			
		||||
            kwargs['instance'] = instance
 | 
			
		||||
            kwargs["instance"] = instance
 | 
			
		||||
 | 
			
		||||
        return kwargs
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -100,11 +102,12 @@ class DjangoFormMutation(BaseDjangoFormMutation):
 | 
			
		|||
    errors = graphene.List(ErrorType)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def __init_subclass_with_meta__(cls, form_class=None,
 | 
			
		||||
                                    only_fields=(), exclude_fields=(), **options):
 | 
			
		||||
    def __init_subclass_with_meta__(
 | 
			
		||||
        cls, form_class=None, only_fields=(), exclude_fields=(), **options
 | 
			
		||||
    ):
 | 
			
		||||
 | 
			
		||||
        if not form_class:
 | 
			
		||||
            raise Exception('form_class is required for DjangoFormMutation')
 | 
			
		||||
            raise Exception("form_class is required for DjangoFormMutation")
 | 
			
		||||
 | 
			
		||||
        form = form_class()
 | 
			
		||||
        input_fields = fields_for_form(form, only_fields, exclude_fields)
 | 
			
		||||
| 
						 | 
				
			
			@ -112,16 +115,12 @@ class DjangoFormMutation(BaseDjangoFormMutation):
 | 
			
		|||
 | 
			
		||||
        _meta = DjangoFormMutationOptions(cls)
 | 
			
		||||
        _meta.form_class = form_class
 | 
			
		||||
        _meta.fields = yank_fields_from_attrs(
 | 
			
		||||
            output_fields,
 | 
			
		||||
            _as=Field,
 | 
			
		||||
        )
 | 
			
		||||
        _meta.fields = yank_fields_from_attrs(output_fields, _as=Field)
 | 
			
		||||
 | 
			
		||||
        input_fields = yank_fields_from_attrs(
 | 
			
		||||
            input_fields,
 | 
			
		||||
            _as=InputField,
 | 
			
		||||
        input_fields = yank_fields_from_attrs(input_fields, _as=InputField)
 | 
			
		||||
        super(DjangoFormMutation, cls).__init_subclass_with_meta__(
 | 
			
		||||
            _meta=_meta, input_fields=input_fields, **options
 | 
			
		||||
        )
 | 
			
		||||
        super(DjangoFormMutation, cls).__init_subclass_with_meta__(_meta=_meta, input_fields=input_fields, **options)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def perform_mutate(cls, form, info):
 | 
			
		||||
| 
						 | 
				
			
			@ -141,21 +140,28 @@ class DjangoModelFormMutation(BaseDjangoFormMutation):
 | 
			
		|||
    errors = graphene.List(ErrorType)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def __init_subclass_with_meta__(cls, form_class=None, model=None, return_field_name=None,
 | 
			
		||||
                                    only_fields=(), exclude_fields=(), **options):
 | 
			
		||||
    def __init_subclass_with_meta__(
 | 
			
		||||
        cls,
 | 
			
		||||
        form_class=None,
 | 
			
		||||
        model=None,
 | 
			
		||||
        return_field_name=None,
 | 
			
		||||
        only_fields=(),
 | 
			
		||||
        exclude_fields=(),
 | 
			
		||||
        **options
 | 
			
		||||
    ):
 | 
			
		||||
 | 
			
		||||
        if not form_class:
 | 
			
		||||
            raise Exception('form_class is required for DjangoModelFormMutation')
 | 
			
		||||
            raise Exception("form_class is required for DjangoModelFormMutation")
 | 
			
		||||
 | 
			
		||||
        if not model:
 | 
			
		||||
            model = form_class._meta.model
 | 
			
		||||
 | 
			
		||||
        if not model:
 | 
			
		||||
            raise Exception('model is required for DjangoModelFormMutation')
 | 
			
		||||
            raise Exception("model is required for DjangoModelFormMutation")
 | 
			
		||||
 | 
			
		||||
        form = form_class()
 | 
			
		||||
        input_fields = fields_for_form(form, only_fields, exclude_fields)
 | 
			
		||||
        input_fields['id'] = graphene.ID()
 | 
			
		||||
        input_fields["id"] = graphene.ID()
 | 
			
		||||
 | 
			
		||||
        registry = get_global_registry()
 | 
			
		||||
        model_type = registry.get_type_for_model(model)
 | 
			
		||||
| 
						 | 
				
			
			@ -171,19 +177,11 @@ class DjangoModelFormMutation(BaseDjangoFormMutation):
 | 
			
		|||
        _meta.form_class = form_class
 | 
			
		||||
        _meta.model = model
 | 
			
		||||
        _meta.return_field_name = return_field_name
 | 
			
		||||
        _meta.fields = yank_fields_from_attrs(
 | 
			
		||||
            output_fields,
 | 
			
		||||
            _as=Field,
 | 
			
		||||
        )
 | 
			
		||||
        _meta.fields = yank_fields_from_attrs(output_fields, _as=Field)
 | 
			
		||||
 | 
			
		||||
        input_fields = yank_fields_from_attrs(
 | 
			
		||||
            input_fields,
 | 
			
		||||
            _as=InputField,
 | 
			
		||||
        )
 | 
			
		||||
        input_fields = yank_fields_from_attrs(input_fields, _as=InputField)
 | 
			
		||||
        super(DjangoModelFormMutation, cls).__init_subclass_with_meta__(
 | 
			
		||||
            _meta=_meta,
 | 
			
		||||
            input_fields=input_fields,
 | 
			
		||||
            **options
 | 
			
		||||
            _meta=_meta, input_fields=input_fields, **options
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,24 +2,36 @@ from django import forms
 | 
			
		|||
from py.test import raises
 | 
			
		||||
 | 
			
		||||
import graphene
 | 
			
		||||
from graphene import String, Int, Boolean, Float, ID, UUID, List, NonNull, DateTime, Date, Time
 | 
			
		||||
from graphene import (
 | 
			
		||||
    String,
 | 
			
		||||
    Int,
 | 
			
		||||
    Boolean,
 | 
			
		||||
    Float,
 | 
			
		||||
    ID,
 | 
			
		||||
    UUID,
 | 
			
		||||
    List,
 | 
			
		||||
    NonNull,
 | 
			
		||||
    DateTime,
 | 
			
		||||
    Date,
 | 
			
		||||
    Time,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from ..converter import convert_form_field
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def assert_conversion(django_field, graphene_field, *args):
 | 
			
		||||
    field = django_field(*args, help_text='Custom Help Text')
 | 
			
		||||
    field = django_field(*args, help_text="Custom Help Text")
 | 
			
		||||
    graphene_type = convert_form_field(field)
 | 
			
		||||
    assert isinstance(graphene_type, graphene_field)
 | 
			
		||||
    field = graphene_type.Field()
 | 
			
		||||
    assert field.description == 'Custom Help Text'
 | 
			
		||||
    assert field.description == "Custom Help Text"
 | 
			
		||||
    return field
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_should_unknown_django_field_raise_exception():
 | 
			
		||||
    with raises(Exception) as excinfo:
 | 
			
		||||
        convert_form_field(None)
 | 
			
		||||
    assert 'Don\'t know how to convert the Django form field' in str(excinfo.value)
 | 
			
		||||
    assert "Don't know how to convert the Django form field" in str(excinfo.value)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_should_date_convert_date():
 | 
			
		||||
| 
						 | 
				
			
			@ -59,11 +71,11 @@ def test_should_base_field_convert_string():
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def test_should_regex_convert_string():
 | 
			
		||||
    assert_conversion(forms.RegexField, String, '[0-9]+')
 | 
			
		||||
    assert_conversion(forms.RegexField, String, "[0-9]+")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_should_uuid_convert_string():
 | 
			
		||||
    if hasattr(forms, 'UUIDField'):
 | 
			
		||||
    if hasattr(forms, "UUIDField"):
 | 
			
		||||
        assert_conversion(forms.UUIDField, UUID)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -11,18 +11,18 @@ class MyForm(forms.Form):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class PetForm(forms.ModelForm):
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        model = Pet
 | 
			
		||||
        fields = ('name',)
 | 
			
		||||
        fields = ("name",)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_needs_form_class():
 | 
			
		||||
    with raises(Exception) as exc:
 | 
			
		||||
 | 
			
		||||
        class MyMutation(DjangoFormMutation):
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
    assert exc.value.args[0] == 'form_class is required for DjangoFormMutation'
 | 
			
		||||
    assert exc.value.args[0] == "form_class is required for DjangoFormMutation"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_has_output_fields():
 | 
			
		||||
| 
						 | 
				
			
			@ -30,7 +30,7 @@ def test_has_output_fields():
 | 
			
		|||
        class Meta:
 | 
			
		||||
            form_class = MyForm
 | 
			
		||||
 | 
			
		||||
    assert 'errors' in MyMutation._meta.fields
 | 
			
		||||
    assert "errors" in MyMutation._meta.fields
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_has_input_fields():
 | 
			
		||||
| 
						 | 
				
			
			@ -38,19 +38,18 @@ def test_has_input_fields():
 | 
			
		|||
        class Meta:
 | 
			
		||||
            form_class = MyForm
 | 
			
		||||
 | 
			
		||||
    assert 'text' in MyMutation.Input._meta.fields
 | 
			
		||||
    assert "text" in MyMutation.Input._meta.fields
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ModelFormMutationTests(TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_default_meta_fields(self):
 | 
			
		||||
        class PetMutation(DjangoModelFormMutation):
 | 
			
		||||
            class Meta:
 | 
			
		||||
                form_class = PetForm
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(PetMutation._meta.model, Pet)
 | 
			
		||||
        self.assertEqual(PetMutation._meta.return_field_name, 'pet')
 | 
			
		||||
        self.assertIn('pet', PetMutation._meta.fields)
 | 
			
		||||
        self.assertEqual(PetMutation._meta.return_field_name, "pet")
 | 
			
		||||
        self.assertIn("pet", PetMutation._meta.fields)
 | 
			
		||||
 | 
			
		||||
    def test_return_field_name_is_camelcased(self):
 | 
			
		||||
        class PetMutation(DjangoModelFormMutation):
 | 
			
		||||
| 
						 | 
				
			
			@ -59,31 +58,31 @@ class ModelFormMutationTests(TestCase):
 | 
			
		|||
                model = FilmDetails
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(PetMutation._meta.model, FilmDetails)
 | 
			
		||||
        self.assertEqual(PetMutation._meta.return_field_name, 'filmDetails')
 | 
			
		||||
        self.assertEqual(PetMutation._meta.return_field_name, "filmDetails")
 | 
			
		||||
 | 
			
		||||
    def test_custom_return_field_name(self):
 | 
			
		||||
        class PetMutation(DjangoModelFormMutation):
 | 
			
		||||
            class Meta:
 | 
			
		||||
                form_class = PetForm
 | 
			
		||||
                model = Film
 | 
			
		||||
                return_field_name = 'animal'
 | 
			
		||||
                return_field_name = "animal"
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(PetMutation._meta.model, Film)
 | 
			
		||||
        self.assertEqual(PetMutation._meta.return_field_name, 'animal')
 | 
			
		||||
        self.assertIn('animal', PetMutation._meta.fields)
 | 
			
		||||
        self.assertEqual(PetMutation._meta.return_field_name, "animal")
 | 
			
		||||
        self.assertIn("animal", PetMutation._meta.fields)
 | 
			
		||||
 | 
			
		||||
    def test_model_form_mutation_mutate(self):
 | 
			
		||||
        class PetMutation(DjangoModelFormMutation):
 | 
			
		||||
            class Meta:
 | 
			
		||||
                form_class = PetForm
 | 
			
		||||
 | 
			
		||||
        pet = Pet.objects.create(name='Axel')
 | 
			
		||||
        pet = Pet.objects.create(name="Axel")
 | 
			
		||||
 | 
			
		||||
        result = PetMutation.mutate_and_get_payload(None, None, id=pet.pk, name='Mia')
 | 
			
		||||
        result = PetMutation.mutate_and_get_payload(None, None, id=pet.pk, name="Mia")
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(Pet.objects.count(), 1)
 | 
			
		||||
        pet.refresh_from_db()
 | 
			
		||||
        self.assertEqual(pet.name, 'Mia')
 | 
			
		||||
        self.assertEqual(pet.name, "Mia")
 | 
			
		||||
        self.assertEqual(result.errors, [])
 | 
			
		||||
 | 
			
		||||
    def test_model_form_mutation_updates_existing_(self):
 | 
			
		||||
| 
						 | 
				
			
			@ -91,11 +90,11 @@ class ModelFormMutationTests(TestCase):
 | 
			
		|||
            class Meta:
 | 
			
		||||
                form_class = PetForm
 | 
			
		||||
 | 
			
		||||
        result = PetMutation.mutate_and_get_payload(None, None, name='Mia')
 | 
			
		||||
        result = PetMutation.mutate_and_get_payload(None, None, name="Mia")
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(Pet.objects.count(), 1)
 | 
			
		||||
        pet = Pet.objects.get()
 | 
			
		||||
        self.assertEqual(pet.name, 'Mia')
 | 
			
		||||
        self.assertEqual(pet.name, "Mia")
 | 
			
		||||
        self.assertEqual(result.errors, [])
 | 
			
		||||
 | 
			
		||||
    def test_model_form_mutation_mutate_invalid_form(self):
 | 
			
		||||
| 
						 | 
				
			
			@ -109,5 +108,5 @@ class ModelFormMutationTests(TestCase):
 | 
			
		|||
        self.assertEqual(Pet.objects.count(), 0)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(result.errors), 1)
 | 
			
		||||
        self.assertEqual(result.errors[0].field, 'name')
 | 
			
		||||
        self.assertEqual(result.errors[0].messages, ['This field is required.'])
 | 
			
		||||
        self.assertEqual(result.errors[0].field, "name")
 | 
			
		||||
        self.assertEqual(result.errors[0].messages, ["This field is required."])
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -7,43 +7,45 @@ from graphene_django.settings import graphene_settings
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class CommandArguments(BaseCommand):
 | 
			
		||||
 | 
			
		||||
    def add_arguments(self, parser):
 | 
			
		||||
        parser.add_argument(
 | 
			
		||||
            '--schema',
 | 
			
		||||
            "--schema",
 | 
			
		||||
            type=str,
 | 
			
		||||
            dest='schema',
 | 
			
		||||
            dest="schema",
 | 
			
		||||
            default=graphene_settings.SCHEMA,
 | 
			
		||||
            help='Django app containing schema to dump, e.g. myproject.core.schema.schema')
 | 
			
		||||
            help="Django app containing schema to dump, e.g. myproject.core.schema.schema",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        parser.add_argument(
 | 
			
		||||
            '--out',
 | 
			
		||||
            "--out",
 | 
			
		||||
            type=str,
 | 
			
		||||
            dest='out',
 | 
			
		||||
            dest="out",
 | 
			
		||||
            default=graphene_settings.SCHEMA_OUTPUT,
 | 
			
		||||
            help='Output file (default: schema.json)')
 | 
			
		||||
            help="Output file (default: schema.json)",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        parser.add_argument(
 | 
			
		||||
            '--indent',
 | 
			
		||||
            "--indent",
 | 
			
		||||
            type=int,
 | 
			
		||||
            dest='indent',
 | 
			
		||||
            dest="indent",
 | 
			
		||||
            default=graphene_settings.SCHEMA_INDENT,
 | 
			
		||||
            help='Output file indent (default: None)')
 | 
			
		||||
            help="Output file indent (default: None)",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Command(CommandArguments):
 | 
			
		||||
    help = 'Dump Graphene schema JSON to file'
 | 
			
		||||
    help = "Dump Graphene schema JSON to file"
 | 
			
		||||
    can_import_settings = True
 | 
			
		||||
 | 
			
		||||
    def save_file(self, out, schema_dict, indent):
 | 
			
		||||
        with open(out, 'w') as outfile:
 | 
			
		||||
        with open(out, "w") as outfile:
 | 
			
		||||
            json.dump(schema_dict, outfile, indent=indent)
 | 
			
		||||
 | 
			
		||||
    def handle(self, *args, **options):
 | 
			
		||||
        options_schema = options.get('schema')
 | 
			
		||||
        options_schema = options.get("schema")
 | 
			
		||||
 | 
			
		||||
        if options_schema and type(options_schema) is str:
 | 
			
		||||
            module_str, schema_name = options_schema.rsplit('.', 1)
 | 
			
		||||
            module_str, schema_name = options_schema.rsplit(".", 1)
 | 
			
		||||
            mod = importlib.import_module(module_str)
 | 
			
		||||
            schema = getattr(mod, schema_name)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -53,16 +55,18 @@ class Command(CommandArguments):
 | 
			
		|||
        else:
 | 
			
		||||
            schema = graphene_settings.SCHEMA
 | 
			
		||||
 | 
			
		||||
        out = options.get('out') or graphene_settings.SCHEMA_OUTPUT
 | 
			
		||||
        out = options.get("out") or graphene_settings.SCHEMA_OUTPUT
 | 
			
		||||
 | 
			
		||||
        if not schema:
 | 
			
		||||
            raise CommandError('Specify schema on GRAPHENE.SCHEMA setting or by using --schema')
 | 
			
		||||
            raise CommandError(
 | 
			
		||||
                "Specify schema on GRAPHENE.SCHEMA setting or by using --schema"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        indent = options.get('indent')
 | 
			
		||||
        schema_dict = {'data': schema.introspect()}
 | 
			
		||||
        indent = options.get("indent")
 | 
			
		||||
        schema_dict = {"data": schema.introspect()}
 | 
			
		||||
        self.save_file(out, schema_dict, indent)
 | 
			
		||||
 | 
			
		||||
        style = getattr(self, 'style', None)
 | 
			
		||||
        success = getattr(style, 'SUCCESS', lambda x: x)
 | 
			
		||||
        style = getattr(self, "style", None)
 | 
			
		||||
        success = getattr(style, "SUCCESS", lambda x: x)
 | 
			
		||||
 | 
			
		||||
        self.stdout.write(success('Successfully dumped GraphQL schema to %s' % out))
 | 
			
		||||
        self.stdout.write(success("Successfully dumped GraphQL schema to %s" % out))
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,20 +1,21 @@
 | 
			
		|||
 | 
			
		||||
class Registry(object):
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self._registry = {}
 | 
			
		||||
        self._field_registry = {}
 | 
			
		||||
 | 
			
		||||
    def register(self, cls):
 | 
			
		||||
        from .types import DjangoObjectType
 | 
			
		||||
 | 
			
		||||
        assert issubclass(
 | 
			
		||||
            cls, DjangoObjectType), 'Only DjangoObjectTypes can be registered, received "{}"'.format(
 | 
			
		||||
            cls.__name__)
 | 
			
		||||
        assert cls._meta.registry == self, 'Registry for a Model have to match.'
 | 
			
		||||
            cls, DjangoObjectType
 | 
			
		||||
        ), 'Only DjangoObjectTypes can be registered, received "{}"'.format(
 | 
			
		||||
            cls.__name__
 | 
			
		||||
        )
 | 
			
		||||
        assert cls._meta.registry == self, "Registry for a Model have to match."
 | 
			
		||||
        # assert self.get_type_for_model(cls._meta.model) == cls, (
 | 
			
		||||
        #     'Multiple DjangoObjectTypes registered for "{}"'.format(cls._meta.model)
 | 
			
		||||
        # )
 | 
			
		||||
        if not getattr(cls._meta, 'skip_registry', False):
 | 
			
		||||
        if not getattr(cls._meta, "skip_registry", False):
 | 
			
		||||
            self._registry[cls._meta.model] = cls
 | 
			
		||||
 | 
			
		||||
    def get_type_for_model(self, model):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -6,20 +6,16 @@ import graphene
 | 
			
		|||
from graphene.types import Field, InputField
 | 
			
		||||
from graphene.types.mutation import MutationOptions
 | 
			
		||||
from graphene.relay.mutation import ClientIDMutation
 | 
			
		||||
from graphene.types.objecttype import (
 | 
			
		||||
    yank_fields_from_attrs
 | 
			
		||||
)
 | 
			
		||||
from graphene.types.objecttype import yank_fields_from_attrs
 | 
			
		||||
 | 
			
		||||
from .serializer_converter import (
 | 
			
		||||
    convert_serializer_field
 | 
			
		||||
)
 | 
			
		||||
from .serializer_converter import convert_serializer_field
 | 
			
		||||
from .types import ErrorType
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SerializerMutationOptions(MutationOptions):
 | 
			
		||||
    lookup_field = None
 | 
			
		||||
    model_class = None
 | 
			
		||||
    model_operations = ['create', 'update']
 | 
			
		||||
    model_operations = ["create", "update"]
 | 
			
		||||
    serializer_class = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -28,7 +24,8 @@ def fields_for_serializer(serializer, only_fields, exclude_fields, is_input=Fals
 | 
			
		|||
    for name, field in serializer.fields.items():
 | 
			
		||||
        is_not_in_only = only_fields and name not in only_fields
 | 
			
		||||
        is_excluded = (
 | 
			
		||||
            name in exclude_fields  # or
 | 
			
		||||
            name
 | 
			
		||||
            in exclude_fields  # or
 | 
			
		||||
            # name in already_created_fields
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -44,49 +41,54 @@ class SerializerMutation(ClientIDMutation):
 | 
			
		|||
        abstract = True
 | 
			
		||||
 | 
			
		||||
    errors = graphene.List(
 | 
			
		||||
        ErrorType,
 | 
			
		||||
        description='May contain more than one error for same field.'
 | 
			
		||||
        ErrorType, description="May contain more than one error for same field."
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def __init_subclass_with_meta__(cls, lookup_field=None,
 | 
			
		||||
                                    serializer_class=None, model_class=None,
 | 
			
		||||
                                    model_operations=['create', 'update'],
 | 
			
		||||
                                    only_fields=(), exclude_fields=(), **options):
 | 
			
		||||
    def __init_subclass_with_meta__(
 | 
			
		||||
        cls,
 | 
			
		||||
        lookup_field=None,
 | 
			
		||||
        serializer_class=None,
 | 
			
		||||
        model_class=None,
 | 
			
		||||
        model_operations=["create", "update"],
 | 
			
		||||
        only_fields=(),
 | 
			
		||||
        exclude_fields=(),
 | 
			
		||||
        **options
 | 
			
		||||
    ):
 | 
			
		||||
 | 
			
		||||
        if not serializer_class:
 | 
			
		||||
            raise Exception('serializer_class is required for the SerializerMutation')
 | 
			
		||||
            raise Exception("serializer_class is required for the SerializerMutation")
 | 
			
		||||
 | 
			
		||||
        if 'update' not in model_operations and 'create' not in model_operations:
 | 
			
		||||
        if "update" not in model_operations and "create" not in model_operations:
 | 
			
		||||
            raise Exception('model_operations must contain "create" and/or "update"')
 | 
			
		||||
 | 
			
		||||
        serializer = serializer_class()
 | 
			
		||||
        if model_class is None:
 | 
			
		||||
            serializer_meta = getattr(serializer_class, 'Meta', None)
 | 
			
		||||
            serializer_meta = getattr(serializer_class, "Meta", None)
 | 
			
		||||
            if serializer_meta:
 | 
			
		||||
                model_class = getattr(serializer_meta, 'model', None)
 | 
			
		||||
                model_class = getattr(serializer_meta, "model", None)
 | 
			
		||||
 | 
			
		||||
        if lookup_field is None and model_class:
 | 
			
		||||
            lookup_field = model_class._meta.pk.name
 | 
			
		||||
 | 
			
		||||
        input_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=True)
 | 
			
		||||
        output_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=False)
 | 
			
		||||
        input_fields = fields_for_serializer(
 | 
			
		||||
            serializer, only_fields, exclude_fields, is_input=True
 | 
			
		||||
        )
 | 
			
		||||
        output_fields = fields_for_serializer(
 | 
			
		||||
            serializer, only_fields, exclude_fields, is_input=False
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        _meta = SerializerMutationOptions(cls)
 | 
			
		||||
        _meta.lookup_field = lookup_field
 | 
			
		||||
        _meta.model_operations = model_operations
 | 
			
		||||
        _meta.serializer_class = serializer_class
 | 
			
		||||
        _meta.model_class = model_class
 | 
			
		||||
        _meta.fields = yank_fields_from_attrs(
 | 
			
		||||
            output_fields,
 | 
			
		||||
            _as=Field,
 | 
			
		||||
        )
 | 
			
		||||
        _meta.fields = yank_fields_from_attrs(output_fields, _as=Field)
 | 
			
		||||
 | 
			
		||||
        input_fields = yank_fields_from_attrs(
 | 
			
		||||
            input_fields,
 | 
			
		||||
            _as=InputField,
 | 
			
		||||
        input_fields = yank_fields_from_attrs(input_fields, _as=InputField)
 | 
			
		||||
        super(SerializerMutation, cls).__init_subclass_with_meta__(
 | 
			
		||||
            _meta=_meta, input_fields=input_fields, **options
 | 
			
		||||
        )
 | 
			
		||||
        super(SerializerMutation, cls).__init_subclass_with_meta__(_meta=_meta, input_fields=input_fields, **options)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def get_serializer_kwargs(cls, root, info, **input):
 | 
			
		||||
| 
						 | 
				
			
			@ -94,24 +96,26 @@ class SerializerMutation(ClientIDMutation):
 | 
			
		|||
        model_class = cls._meta.model_class
 | 
			
		||||
 | 
			
		||||
        if model_class:
 | 
			
		||||
            if 'update' in cls._meta.model_operations and lookup_field in input:
 | 
			
		||||
                instance = get_object_or_404(model_class, **{
 | 
			
		||||
                    lookup_field: input[lookup_field]})
 | 
			
		||||
            elif 'create' in cls._meta.model_operations:
 | 
			
		||||
            if "update" in cls._meta.model_operations and lookup_field in input:
 | 
			
		||||
                instance = get_object_or_404(
 | 
			
		||||
                    model_class, **{lookup_field: input[lookup_field]}
 | 
			
		||||
                )
 | 
			
		||||
            elif "create" in cls._meta.model_operations:
 | 
			
		||||
                instance = None
 | 
			
		||||
            else:
 | 
			
		||||
                raise Exception(
 | 
			
		||||
                    'Invalid update operation. Input parameter "{}" required.'.format(
 | 
			
		||||
                        lookup_field
 | 
			
		||||
                    ))
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            return {
 | 
			
		||||
                'instance': instance,
 | 
			
		||||
                'data': input,
 | 
			
		||||
                'context': {'request': info.context}
 | 
			
		||||
                "instance": instance,
 | 
			
		||||
                "data": input,
 | 
			
		||||
                "context": {"request": info.context},
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        return {'data': input, 'context': {'request': info.context}}
 | 
			
		||||
        return {"data": input, "context": {"request": info.context}}
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def mutate_and_get_payload(cls, root, info, **input):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -28,15 +28,12 @@ def convert_serializer_field(field, is_input=True):
 | 
			
		|||
    graphql_type = get_graphene_type_from_serializer_field(field)
 | 
			
		||||
 | 
			
		||||
    args = []
 | 
			
		||||
    kwargs = {
 | 
			
		||||
        'description': field.help_text,
 | 
			
		||||
        'required': is_input and field.required,
 | 
			
		||||
    }
 | 
			
		||||
    kwargs = {"description": field.help_text, "required": is_input and field.required}
 | 
			
		||||
 | 
			
		||||
    # if it is a tuple or a list it means that we are returning
 | 
			
		||||
    # the graphql type and the child type
 | 
			
		||||
    if isinstance(graphql_type, (list, tuple)):
 | 
			
		||||
        kwargs['of_type'] = graphql_type[1]
 | 
			
		||||
        kwargs["of_type"] = graphql_type[1]
 | 
			
		||||
        graphql_type = graphql_type[0]
 | 
			
		||||
 | 
			
		||||
    if isinstance(field, serializers.ModelSerializer):
 | 
			
		||||
| 
						 | 
				
			
			@ -49,9 +46,9 @@ def convert_serializer_field(field, is_input=True):
 | 
			
		|||
    elif isinstance(field, serializers.ListSerializer):
 | 
			
		||||
        field = field.child
 | 
			
		||||
        if is_input:
 | 
			
		||||
            kwargs['of_type'] = convert_serializer_to_input_type(field.__class__)
 | 
			
		||||
            kwargs["of_type"] = convert_serializer_to_input_type(field.__class__)
 | 
			
		||||
        else:
 | 
			
		||||
            del kwargs['of_type']
 | 
			
		||||
            del kwargs["of_type"]
 | 
			
		||||
            global_registry = get_global_registry()
 | 
			
		||||
            field_model = field.Meta.model
 | 
			
		||||
            args = [global_registry.get_type_for_model(field_model)]
 | 
			
		||||
| 
						 | 
				
			
			@ -68,9 +65,9 @@ def convert_serializer_to_input_type(serializer_class):
 | 
			
		|||
    }
 | 
			
		||||
 | 
			
		||||
    return type(
 | 
			
		||||
        '{}Input'.format(serializer.__class__.__name__),
 | 
			
		||||
        "{}Input".format(serializer.__class__.__name__),
 | 
			
		||||
        (graphene.InputObjectType,),
 | 
			
		||||
        items
 | 
			
		||||
        items,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,8 +16,8 @@ def _get_type(rest_framework_field, is_input=True, **kwargs):
 | 
			
		|||
    # Remove `source=` from the field declaration.
 | 
			
		||||
    # since we are reusing the same child in when testing the required attribute
 | 
			
		||||
 | 
			
		||||
    if 'child' in kwargs:
 | 
			
		||||
        kwargs['child'] = copy.deepcopy(kwargs['child'])
 | 
			
		||||
    if "child" in kwargs:
 | 
			
		||||
        kwargs["child"] = copy.deepcopy(kwargs["child"])
 | 
			
		||||
 | 
			
		||||
    field = rest_framework_field(**kwargs)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -25,11 +25,13 @@ def _get_type(rest_framework_field, is_input=True, **kwargs):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def assert_conversion(rest_framework_field, graphene_field, **kwargs):
 | 
			
		||||
    graphene_type = _get_type(rest_framework_field, help_text='Custom Help Text', **kwargs)
 | 
			
		||||
    graphene_type = _get_type(
 | 
			
		||||
        rest_framework_field, help_text="Custom Help Text", **kwargs
 | 
			
		||||
    )
 | 
			
		||||
    assert isinstance(graphene_type, graphene_field)
 | 
			
		||||
 | 
			
		||||
    graphene_type_required = _get_type(
 | 
			
		||||
        rest_framework_field, help_text='Custom Help Text', required=True, **kwargs
 | 
			
		||||
        rest_framework_field, help_text="Custom Help Text", required=True, **kwargs
 | 
			
		||||
    )
 | 
			
		||||
    assert isinstance(graphene_type_required, graphene_field)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -39,7 +41,7 @@ def assert_conversion(rest_framework_field, graphene_field, **kwargs):
 | 
			
		|||
def test_should_unknown_rest_framework_field_raise_exception():
 | 
			
		||||
    with raises(Exception) as excinfo:
 | 
			
		||||
        convert_serializer_field(None)
 | 
			
		||||
    assert 'Don\'t know how to convert the serializer field' in str(excinfo.value)
 | 
			
		||||
    assert "Don't know how to convert the serializer field" in str(excinfo.value)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_should_char_convert_string():
 | 
			
		||||
| 
						 | 
				
			
			@ -67,11 +69,11 @@ def test_should_base_field_convert_string():
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def test_should_regex_convert_string():
 | 
			
		||||
    assert_conversion(serializers.RegexField, graphene.String, regex='[0-9]+')
 | 
			
		||||
    assert_conversion(serializers.RegexField, graphene.String, regex="[0-9]+")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_should_uuid_convert_string():
 | 
			
		||||
    if hasattr(serializers, 'UUIDField'):
 | 
			
		||||
    if hasattr(serializers, "UUIDField"):
 | 
			
		||||
        assert_conversion(serializers.UUIDField, graphene.String)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -79,7 +81,7 @@ def test_should_model_convert_field():
 | 
			
		|||
    class MyModelSerializer(serializers.ModelSerializer):
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = None
 | 
			
		||||
            fields = '__all__'
 | 
			
		||||
            fields = "__all__"
 | 
			
		||||
 | 
			
		||||
    assert_conversion(MyModelSerializer, graphene.Field, is_input=False)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -109,7 +111,9 @@ def test_should_float_convert_float():
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def test_should_decimal_convert_float():
 | 
			
		||||
    assert_conversion(serializers.DecimalField, graphene.Float, max_digits=4, decimal_places=2)
 | 
			
		||||
    assert_conversion(
 | 
			
		||||
        serializers.DecimalField, graphene.Float, max_digits=4, decimal_places=2
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_should_list_convert_to_list():
 | 
			
		||||
| 
						 | 
				
			
			@ -119,7 +123,7 @@ def test_should_list_convert_to_list():
 | 
			
		|||
    field_a = assert_conversion(
 | 
			
		||||
        serializers.ListField,
 | 
			
		||||
        graphene.List,
 | 
			
		||||
        child=serializers.IntegerField(min_value=0, max_value=100)
 | 
			
		||||
        child=serializers.IntegerField(min_value=0, max_value=100),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert field_a.of_type == graphene.Int
 | 
			
		||||
| 
						 | 
				
			
			@ -136,19 +140,23 @@ def test_should_list_serializer_convert_to_list():
 | 
			
		|||
    class ChildSerializer(serializers.ModelSerializer):
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = FooModel
 | 
			
		||||
            fields = '__all__'
 | 
			
		||||
            fields = "__all__"
 | 
			
		||||
 | 
			
		||||
    class ParentSerializer(serializers.ModelSerializer):
 | 
			
		||||
        child = ChildSerializer(many=True)
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = FooModel
 | 
			
		||||
            fields = '__all__'
 | 
			
		||||
            fields = "__all__"
 | 
			
		||||
 | 
			
		||||
    converted_type = convert_serializer_field(ParentSerializer().get_fields()['child'], is_input=True)
 | 
			
		||||
    converted_type = convert_serializer_field(
 | 
			
		||||
        ParentSerializer().get_fields()["child"], is_input=True
 | 
			
		||||
    )
 | 
			
		||||
    assert isinstance(converted_type, graphene.List)
 | 
			
		||||
 | 
			
		||||
    converted_type = convert_serializer_field(ParentSerializer().get_fields()['child'], is_input=False)
 | 
			
		||||
    converted_type = convert_serializer_field(
 | 
			
		||||
        ParentSerializer().get_fields()["child"], is_input=False
 | 
			
		||||
    )
 | 
			
		||||
    assert isinstance(converted_type, graphene.List)
 | 
			
		||||
    assert converted_type.of_type is None
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -166,7 +174,7 @@ def test_should_file_convert_string():
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def test_should_filepath_convert_string():
 | 
			
		||||
    assert_conversion(serializers.FilePathField, graphene.String, path='/')
 | 
			
		||||
    assert_conversion(serializers.FilePathField, graphene.String, path="/")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_should_ip_convert_string():
 | 
			
		||||
| 
						 | 
				
			
			@ -182,6 +190,8 @@ def test_should_json_convert_jsonstring():
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def test_should_multiplechoicefield_convert_to_list_of_string():
 | 
			
		||||
    field = assert_conversion(serializers.MultipleChoiceField, graphene.List, choices=[1, 2, 3])
 | 
			
		||||
    field = assert_conversion(
 | 
			
		||||
        serializers.MultipleChoiceField, graphene.List, choices=[1, 2, 3]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert field.of_type == graphene.String
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -10,30 +10,33 @@ from ...types import DjangoObjectType
 | 
			
		|||
from ..models import MyFakeModel
 | 
			
		||||
from ..mutation import SerializerMutation
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def mock_info():
 | 
			
		||||
  return ResolveInfo(
 | 
			
		||||
    None,
 | 
			
		||||
    None,
 | 
			
		||||
    None,
 | 
			
		||||
    None,
 | 
			
		||||
    schema=None,
 | 
			
		||||
    fragments=None,
 | 
			
		||||
    root_value=None,
 | 
			
		||||
    operation=None,
 | 
			
		||||
    variable_values=None,
 | 
			
		||||
    context=None
 | 
			
		||||
  )
 | 
			
		||||
    return ResolveInfo(
 | 
			
		||||
        None,
 | 
			
		||||
        None,
 | 
			
		||||
        None,
 | 
			
		||||
        None,
 | 
			
		||||
        schema=None,
 | 
			
		||||
        fragments=None,
 | 
			
		||||
        root_value=None,
 | 
			
		||||
        operation=None,
 | 
			
		||||
        variable_values=None,
 | 
			
		||||
        context=None,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MyModelSerializer(serializers.ModelSerializer):
 | 
			
		||||
    class Meta:
 | 
			
		||||
        model = MyFakeModel
 | 
			
		||||
        fields = '__all__'
 | 
			
		||||
        fields = "__all__"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MyModelMutation(SerializerMutation):
 | 
			
		||||
    class Meta:
 | 
			
		||||
        serializer_class = MyModelSerializer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MySerializer(serializers.Serializer):
 | 
			
		||||
    text = serializers.CharField()
 | 
			
		||||
    model = MyModelSerializer()
 | 
			
		||||
| 
						 | 
				
			
			@ -44,10 +47,11 @@ class MySerializer(serializers.Serializer):
 | 
			
		|||
 | 
			
		||||
def test_needs_serializer_class():
 | 
			
		||||
    with raises(Exception) as exc:
 | 
			
		||||
 | 
			
		||||
        class MyMutation(SerializerMutation):
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
    assert str(exc.value) == 'serializer_class is required for the SerializerMutation'
 | 
			
		||||
    assert str(exc.value) == "serializer_class is required for the SerializerMutation"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_has_fields():
 | 
			
		||||
| 
						 | 
				
			
			@ -55,9 +59,9 @@ def test_has_fields():
 | 
			
		|||
        class Meta:
 | 
			
		||||
            serializer_class = MySerializer
 | 
			
		||||
 | 
			
		||||
    assert 'text' in MyMutation._meta.fields
 | 
			
		||||
    assert 'model' in MyMutation._meta.fields
 | 
			
		||||
    assert 'errors' in MyMutation._meta.fields
 | 
			
		||||
    assert "text" in MyMutation._meta.fields
 | 
			
		||||
    assert "model" in MyMutation._meta.fields
 | 
			
		||||
    assert "errors" in MyMutation._meta.fields
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_has_input_fields():
 | 
			
		||||
| 
						 | 
				
			
			@ -65,25 +69,24 @@ def test_has_input_fields():
 | 
			
		|||
        class Meta:
 | 
			
		||||
            serializer_class = MySerializer
 | 
			
		||||
 | 
			
		||||
    assert 'text' in MyMutation.Input._meta.fields
 | 
			
		||||
    assert 'model' in MyMutation.Input._meta.fields
 | 
			
		||||
    assert "text" in MyMutation.Input._meta.fields
 | 
			
		||||
    assert "model" in MyMutation.Input._meta.fields
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_exclude_fields():
 | 
			
		||||
    class MyMutation(SerializerMutation):
 | 
			
		||||
        class Meta:
 | 
			
		||||
            serializer_class = MyModelSerializer
 | 
			
		||||
            exclude_fields = ['created']
 | 
			
		||||
            exclude_fields = ["created"]
 | 
			
		||||
 | 
			
		||||
    assert 'cool_name' in MyMutation._meta.fields
 | 
			
		||||
    assert 'created' not in MyMutation._meta.fields
 | 
			
		||||
    assert 'errors' in MyMutation._meta.fields
 | 
			
		||||
    assert 'cool_name' in MyMutation.Input._meta.fields
 | 
			
		||||
    assert 'created' not in MyMutation.Input._meta.fields
 | 
			
		||||
    assert "cool_name" in MyMutation._meta.fields
 | 
			
		||||
    assert "created" not in MyMutation._meta.fields
 | 
			
		||||
    assert "errors" in MyMutation._meta.fields
 | 
			
		||||
    assert "cool_name" in MyMutation.Input._meta.fields
 | 
			
		||||
    assert "created" not in MyMutation.Input._meta.fields
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_nested_model():
 | 
			
		||||
 | 
			
		||||
    class MyFakeModelGrapheneType(DjangoObjectType):
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = MyFakeModel
 | 
			
		||||
| 
						 | 
				
			
			@ -92,67 +95,64 @@ def test_nested_model():
 | 
			
		|||
        class Meta:
 | 
			
		||||
            serializer_class = MySerializer
 | 
			
		||||
 | 
			
		||||
    model_field = MyMutation._meta.fields['model']
 | 
			
		||||
    model_field = MyMutation._meta.fields["model"]
 | 
			
		||||
    assert isinstance(model_field, Field)
 | 
			
		||||
    assert model_field.type == MyFakeModelGrapheneType
 | 
			
		||||
 | 
			
		||||
    model_input = MyMutation.Input._meta.fields['model']
 | 
			
		||||
    model_input = MyMutation.Input._meta.fields["model"]
 | 
			
		||||
    model_input_type = model_input._type.of_type
 | 
			
		||||
    assert issubclass(model_input_type, InputObjectType)
 | 
			
		||||
    assert 'cool_name' in model_input_type._meta.fields
 | 
			
		||||
    assert 'created' in model_input_type._meta.fields
 | 
			
		||||
    assert "cool_name" in model_input_type._meta.fields
 | 
			
		||||
    assert "created" in model_input_type._meta.fields
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_mutate_and_get_payload_success():
 | 
			
		||||
 | 
			
		||||
    class MyMutation(SerializerMutation):
 | 
			
		||||
        class Meta:
 | 
			
		||||
            serializer_class = MySerializer
 | 
			
		||||
 | 
			
		||||
    result = MyMutation.mutate_and_get_payload(None, mock_info(), **{
 | 
			
		||||
        'text': 'value',
 | 
			
		||||
        'model': {
 | 
			
		||||
            'cool_name': 'other_value'
 | 
			
		||||
        }
 | 
			
		||||
    })
 | 
			
		||||
    result = MyMutation.mutate_and_get_payload(
 | 
			
		||||
        None, mock_info(), **{"text": "value", "model": {"cool_name": "other_value"}}
 | 
			
		||||
    )
 | 
			
		||||
    assert result.errors is None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@mark.django_db
 | 
			
		||||
def test_model_add_mutate_and_get_payload_success():
 | 
			
		||||
    result = MyModelMutation.mutate_and_get_payload(None, mock_info(), **{
 | 
			
		||||
        'cool_name': 'Narf',
 | 
			
		||||
    })
 | 
			
		||||
    result = MyModelMutation.mutate_and_get_payload(
 | 
			
		||||
        None, mock_info(), **{"cool_name": "Narf"}
 | 
			
		||||
    )
 | 
			
		||||
    assert result.errors is None
 | 
			
		||||
    assert result.cool_name == 'Narf'
 | 
			
		||||
    assert result.cool_name == "Narf"
 | 
			
		||||
    assert isinstance(result.created, datetime.datetime)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@mark.django_db
 | 
			
		||||
def test_model_update_mutate_and_get_payload_success():
 | 
			
		||||
    instance = MyFakeModel.objects.create(cool_name="Narf")
 | 
			
		||||
    result = MyModelMutation.mutate_and_get_payload(None, mock_info(), **{
 | 
			
		||||
        'id': instance.id,
 | 
			
		||||
        'cool_name': 'New Narf',
 | 
			
		||||
    })
 | 
			
		||||
    result = MyModelMutation.mutate_and_get_payload(
 | 
			
		||||
        None, mock_info(), **{"id": instance.id, "cool_name": "New Narf"}
 | 
			
		||||
    )
 | 
			
		||||
    assert result.errors is None
 | 
			
		||||
    assert result.cool_name == 'New Narf'
 | 
			
		||||
    assert result.cool_name == "New Narf"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@mark.django_db
 | 
			
		||||
def test_model_invalid_update_mutate_and_get_payload_success():
 | 
			
		||||
    class InvalidModelMutation(SerializerMutation):
 | 
			
		||||
        class Meta:
 | 
			
		||||
            serializer_class = MyModelSerializer
 | 
			
		||||
            model_operations = ['update']
 | 
			
		||||
            model_operations = ["update"]
 | 
			
		||||
 | 
			
		||||
    with raises(Exception) as exc:
 | 
			
		||||
        result = InvalidModelMutation.mutate_and_get_payload(None, mock_info(), **{
 | 
			
		||||
            'cool_name': 'Narf',
 | 
			
		||||
        })
 | 
			
		||||
        result = InvalidModelMutation.mutate_and_get_payload(
 | 
			
		||||
            None, mock_info(), **{"cool_name": "Narf"}
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    assert '"id" required' in str(exc.value)
 | 
			
		||||
 | 
			
		||||
def test_mutate_and_get_payload_error():
 | 
			
		||||
 | 
			
		||||
def test_mutate_and_get_payload_error():
 | 
			
		||||
    class MyMutation(SerializerMutation):
 | 
			
		||||
        class Meta:
 | 
			
		||||
            serializer_class = MySerializer
 | 
			
		||||
| 
						 | 
				
			
			@ -161,16 +161,19 @@ def test_mutate_and_get_payload_error():
 | 
			
		|||
    result = MyMutation.mutate_and_get_payload(None, mock_info(), **{})
 | 
			
		||||
    assert len(result.errors) > 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_model_mutate_and_get_payload_error():
 | 
			
		||||
    # missing required fields
 | 
			
		||||
    result = MyModelMutation.mutate_and_get_payload(None, mock_info(), **{})
 | 
			
		||||
    assert len(result.errors) > 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_invalid_serializer_operations():
 | 
			
		||||
    with raises(Exception) as exc:
 | 
			
		||||
 | 
			
		||||
        class MyModelMutation(SerializerMutation):
 | 
			
		||||
            class Meta:
 | 
			
		||||
                serializer_class = MyModelSerializer
 | 
			
		||||
                model_operations = ['Add']
 | 
			
		||||
                model_operations = ["Add"]
 | 
			
		||||
 | 
			
		||||
    assert 'model_operations' in str(exc.value)
 | 
			
		||||
    assert "model_operations" in str(exc.value)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -26,27 +26,22 @@ except ImportError:
 | 
			
		|||
# Copied shamelessly from Django REST Framework
 | 
			
		||||
 | 
			
		||||
DEFAULTS = {
 | 
			
		||||
    'SCHEMA': None,
 | 
			
		||||
    'SCHEMA_OUTPUT': 'schema.json',
 | 
			
		||||
    'SCHEMA_INDENT': None,
 | 
			
		||||
    'MIDDLEWARE': (),
 | 
			
		||||
    "SCHEMA": None,
 | 
			
		||||
    "SCHEMA_OUTPUT": "schema.json",
 | 
			
		||||
    "SCHEMA_INDENT": None,
 | 
			
		||||
    "MIDDLEWARE": (),
 | 
			
		||||
    # Set to True if the connection fields must have
 | 
			
		||||
    # either the first or last argument
 | 
			
		||||
    'RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST': False,
 | 
			
		||||
    "RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST": False,
 | 
			
		||||
    # Max items returned in ConnectionFields / FilterConnectionFields
 | 
			
		||||
    'RELAY_CONNECTION_MAX_LIMIT': 100,
 | 
			
		||||
    "RELAY_CONNECTION_MAX_LIMIT": 100,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
if settings.DEBUG:
 | 
			
		||||
    DEFAULTS['MIDDLEWARE'] += (
 | 
			
		||||
        'graphene_django.debug.DjangoDebugMiddleware',
 | 
			
		||||
    )
 | 
			
		||||
    DEFAULTS["MIDDLEWARE"] += ("graphene_django.debug.DjangoDebugMiddleware",)
 | 
			
		||||
 | 
			
		||||
# List of settings that may be in string import notation.
 | 
			
		||||
IMPORT_STRINGS = (
 | 
			
		||||
    'MIDDLEWARE',
 | 
			
		||||
    'SCHEMA',
 | 
			
		||||
)
 | 
			
		||||
IMPORT_STRINGS = ("MIDDLEWARE", "SCHEMA")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def perform_import(val, setting_name):
 | 
			
		||||
| 
						 | 
				
			
			@ -69,12 +64,17 @@ def import_from_string(val, setting_name):
 | 
			
		|||
    """
 | 
			
		||||
    try:
 | 
			
		||||
        # Nod to tastypie's use of importlib.
 | 
			
		||||
        parts = val.split('.')
 | 
			
		||||
        module_path, class_name = '.'.join(parts[:-1]), parts[-1]
 | 
			
		||||
        parts = val.split(".")
 | 
			
		||||
        module_path, class_name = ".".join(parts[:-1]), parts[-1]
 | 
			
		||||
        module = importlib.import_module(module_path)
 | 
			
		||||
        return getattr(module, class_name)
 | 
			
		||||
    except (ImportError, AttributeError) as e:
 | 
			
		||||
        msg = "Could not import '%s' for Graphene setting '%s'. %s: %s." % (val, setting_name, e.__class__.__name__, e)
 | 
			
		||||
        msg = "Could not import '%s' for Graphene setting '%s'. %s: %s." % (
 | 
			
		||||
            val,
 | 
			
		||||
            setting_name,
 | 
			
		||||
            e.__class__.__name__,
 | 
			
		||||
            e,
 | 
			
		||||
        )
 | 
			
		||||
        raise ImportError(msg)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -96,8 +96,8 @@ class GrapheneSettings(object):
 | 
			
		|||
 | 
			
		||||
    @property
 | 
			
		||||
    def user_settings(self):
 | 
			
		||||
        if not hasattr(self, '_user_settings'):
 | 
			
		||||
            self._user_settings = getattr(settings, 'GRAPHENE', {})
 | 
			
		||||
        if not hasattr(self, "_user_settings"):
 | 
			
		||||
            self._user_settings = getattr(settings, "GRAPHENE", {})
 | 
			
		||||
        return self._user_settings
 | 
			
		||||
 | 
			
		||||
    def __getattr__(self, attr):
 | 
			
		||||
| 
						 | 
				
			
			@ -125,8 +125,8 @@ graphene_settings = GrapheneSettings(None, DEFAULTS, IMPORT_STRINGS)
 | 
			
		|||
 | 
			
		||||
def reload_graphene_settings(*args, **kwargs):
 | 
			
		||||
    global graphene_settings
 | 
			
		||||
    setting, value = kwargs['setting'], kwargs['value']
 | 
			
		||||
    if setting == 'GRAPHENE':
 | 
			
		||||
    setting, value = kwargs["setting"], kwargs["value"]
 | 
			
		||||
    if setting == "GRAPHENE":
 | 
			
		||||
        graphene_settings = GrapheneSettings(value, DEFAULTS, IMPORT_STRINGS)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,10 +3,7 @@ from __future__ import absolute_import
 | 
			
		|||
from django.db import models
 | 
			
		||||
from django.utils.translation import ugettext_lazy as _
 | 
			
		||||
 | 
			
		||||
CHOICES = (
 | 
			
		||||
    (1, 'this'),
 | 
			
		||||
    (2, _('that'))
 | 
			
		||||
)
 | 
			
		||||
CHOICES = ((1, "this"), (2, _("that")))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Pet(models.Model):
 | 
			
		||||
| 
						 | 
				
			
			@ -15,38 +12,43 @@ class Pet(models.Model):
 | 
			
		|||
 | 
			
		||||
class FilmDetails(models.Model):
 | 
			
		||||
    location = models.CharField(max_length=30)
 | 
			
		||||
    film = models.OneToOneField('Film', on_delete=models.CASCADE, related_name='details')
 | 
			
		||||
    film = models.OneToOneField(
 | 
			
		||||
        "Film", on_delete=models.CASCADE, related_name="details"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Film(models.Model):
 | 
			
		||||
    genre = models.CharField(max_length=2, help_text='Genre', choices=[
 | 
			
		||||
        ('do', 'Documentary'),
 | 
			
		||||
        ('ot', 'Other')
 | 
			
		||||
    ], default='ot')
 | 
			
		||||
    reporters = models.ManyToManyField('Reporter',
 | 
			
		||||
                                       related_name='films')
 | 
			
		||||
    genre = models.CharField(
 | 
			
		||||
        max_length=2,
 | 
			
		||||
        help_text="Genre",
 | 
			
		||||
        choices=[("do", "Documentary"), ("ot", "Other")],
 | 
			
		||||
        default="ot",
 | 
			
		||||
    )
 | 
			
		||||
    reporters = models.ManyToManyField("Reporter", related_name="films")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DoeReporterManager(models.Manager):
 | 
			
		||||
    def get_queryset(self):
 | 
			
		||||
        return super(DoeReporterManager, self).get_queryset().filter(last_name="Doe")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Reporter(models.Model):
 | 
			
		||||
    first_name = models.CharField(max_length=30)
 | 
			
		||||
    last_name = models.CharField(max_length=30)
 | 
			
		||||
    email = models.EmailField()
 | 
			
		||||
    pets = models.ManyToManyField('self')
 | 
			
		||||
    pets = models.ManyToManyField("self")
 | 
			
		||||
    a_choice = models.CharField(max_length=30, choices=CHOICES)
 | 
			
		||||
    objects = models.Manager()
 | 
			
		||||
    doe_objects = DoeReporterManager()
 | 
			
		||||
 | 
			
		||||
    reporter_type = models.IntegerField(
 | 
			
		||||
        'Reporter Type',
 | 
			
		||||
        "Reporter Type",
 | 
			
		||||
        null=True,
 | 
			
		||||
        blank=True,
 | 
			
		||||
        choices=[(1, u'Regular'), (2, u'CNN Reporter')]
 | 
			
		||||
        choices=[(1, u"Regular"), (2, u"CNN Reporter")],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def __str__(self):              # __unicode__ on Python 2
 | 
			
		||||
    def __str__(self):  # __unicode__ on Python 2
 | 
			
		||||
        return "%s %s" % (self.first_name, self.last_name)
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
| 
						 | 
				
			
			@ -61,11 +63,13 @@ class Reporter(models.Model):
 | 
			
		|||
        if self.reporter_type == 2:  # quick and dirty way without enums
 | 
			
		||||
            self.__class__ = CNNReporter
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CNNReporter(Reporter):
 | 
			
		||||
    """
 | 
			
		||||
    This class is a proxy model for Reporter, used for testing
 | 
			
		||||
    proxy model support
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        proxy = True
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -74,17 +78,27 @@ class Article(models.Model):
 | 
			
		|||
    headline = models.CharField(max_length=100)
 | 
			
		||||
    pub_date = models.DateField()
 | 
			
		||||
    pub_date_time = models.DateTimeField()
 | 
			
		||||
    reporter = models.ForeignKey(Reporter, on_delete=models.CASCADE, related_name='articles')
 | 
			
		||||
    editor = models.ForeignKey(Reporter, on_delete=models.CASCADE, related_name='edited_articles_+')
 | 
			
		||||
    lang = models.CharField(max_length=2, help_text='Language', choices=[
 | 
			
		||||
        ('es', 'Spanish'),
 | 
			
		||||
        ('en', 'English')
 | 
			
		||||
    ], default='es')
 | 
			
		||||
    importance = models.IntegerField('Importance', null=True, blank=True,
 | 
			
		||||
                                     choices=[(1, u'Very important'), (2, u'Not as important')])
 | 
			
		||||
    reporter = models.ForeignKey(
 | 
			
		||||
        Reporter, on_delete=models.CASCADE, related_name="articles"
 | 
			
		||||
    )
 | 
			
		||||
    editor = models.ForeignKey(
 | 
			
		||||
        Reporter, on_delete=models.CASCADE, related_name="edited_articles_+"
 | 
			
		||||
    )
 | 
			
		||||
    lang = models.CharField(
 | 
			
		||||
        max_length=2,
 | 
			
		||||
        help_text="Language",
 | 
			
		||||
        choices=[("es", "Spanish"), ("en", "English")],
 | 
			
		||||
        default="es",
 | 
			
		||||
    )
 | 
			
		||||
    importance = models.IntegerField(
 | 
			
		||||
        "Importance",
 | 
			
		||||
        null=True,
 | 
			
		||||
        blank=True,
 | 
			
		||||
        choices=[(1, u"Very important"), (2, u"Not as important")],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def __str__(self):              # __unicode__ on Python 2
 | 
			
		||||
    def __str__(self):  # __unicode__ on Python 2
 | 
			
		||||
        return self.headline
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        ordering = ('headline',)
 | 
			
		||||
        ordering = ("headline",)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -6,10 +6,9 @@ from .models import Article, Reporter
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class Character(DjangoObjectType):
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        model = Reporter
 | 
			
		||||
        interfaces = (relay.Node, )
 | 
			
		||||
        interfaces = (relay.Node,)
 | 
			
		||||
 | 
			
		||||
    def get_node(self, info, id):
 | 
			
		||||
        pass
 | 
			
		||||
| 
						 | 
				
			
			@ -20,7 +19,7 @@ class Human(DjangoObjectType):
 | 
			
		|||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        model = Article
 | 
			
		||||
        interfaces = (relay.Node, )
 | 
			
		||||
        interfaces = (relay.Node,)
 | 
			
		||||
 | 
			
		||||
    def resolve_raises(self, info):
 | 
			
		||||
        raise Exception("This field should raise exception")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,10 +12,10 @@ class QueryRoot(ObjectType):
 | 
			
		|||
        raise Exception("Throws!")
 | 
			
		||||
 | 
			
		||||
    def resolve_request(self, info):
 | 
			
		||||
        return info.context.GET.get('q')
 | 
			
		||||
        return info.context.GET.get("q")
 | 
			
		||||
 | 
			
		||||
    def resolve_test(self, info, who=None):
 | 
			
		||||
        return 'Hello %s' % (who or 'World')
 | 
			
		||||
        return "Hello %s" % (who or "World")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MutationRoot(ObjectType):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,8 +3,8 @@ from mock import patch
 | 
			
		|||
from six import StringIO
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@patch('graphene_django.management.commands.graphql_schema.Command.save_file')
 | 
			
		||||
@patch("graphene_django.management.commands.graphql_schema.Command.save_file")
 | 
			
		||||
def test_generate_file_on_call_graphql_schema(savefile_mock, settings):
 | 
			
		||||
    out = StringIO()
 | 
			
		||||
    management.call_command('graphql_schema', schema='', stdout=out)
 | 
			
		||||
    management.call_command("graphql_schema", schema="", stdout=out)
 | 
			
		||||
    assert "Successfully dumped GraphQL schema to schema.json" in out.getvalue()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,11 +19,11 @@ from .models import Article, Film, FilmDetails, Reporter
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def assert_conversion(django_field, graphene_field, *args, **kwargs):
 | 
			
		||||
    field = django_field(help_text='Custom Help Text', null=True, *args, **kwargs)
 | 
			
		||||
    field = django_field(help_text="Custom Help Text", null=True, *args, **kwargs)
 | 
			
		||||
    graphene_type = convert_django_field(field)
 | 
			
		||||
    assert isinstance(graphene_type, graphene_field)
 | 
			
		||||
    field = graphene_type.Field()
 | 
			
		||||
    assert field.description == 'Custom Help Text'
 | 
			
		||||
    assert field.description == "Custom Help Text"
 | 
			
		||||
    nonnull_field = django_field(null=False, *args, **kwargs)
 | 
			
		||||
    if not nonnull_field.null:
 | 
			
		||||
        nonnull_graphene_type = convert_django_field(nonnull_field)
 | 
			
		||||
| 
						 | 
				
			
			@ -36,7 +36,8 @@ def assert_conversion(django_field, graphene_field, *args, **kwargs):
 | 
			
		|||
def test_should_unknown_django_field_raise_exception():
 | 
			
		||||
    with raises(Exception) as excinfo:
 | 
			
		||||
        convert_django_field(None)
 | 
			
		||||
    assert 'Don\'t know how to convert the Django field' in str(excinfo.value)
 | 
			
		||||
    assert "Don't know how to convert the Django field" in str(excinfo.value)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_should_date_time_convert_string():
 | 
			
		||||
    assert_conversion(models.DateTimeField, DateTime)
 | 
			
		||||
| 
						 | 
				
			
			@ -128,70 +129,69 @@ def test_should_nullboolean_convert_boolean():
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def test_field_with_choices_convert_enum():
 | 
			
		||||
    field = models.CharField(help_text='Language', choices=(
 | 
			
		||||
        ('es', 'Spanish'),
 | 
			
		||||
        ('en', 'English')
 | 
			
		||||
    ))
 | 
			
		||||
    field = models.CharField(
 | 
			
		||||
        help_text="Language", choices=(("es", "Spanish"), ("en", "English"))
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    class TranslatedModel(models.Model):
 | 
			
		||||
        language = field
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            app_label = 'test'
 | 
			
		||||
            app_label = "test"
 | 
			
		||||
 | 
			
		||||
    graphene_type = convert_django_field_with_choices(field)
 | 
			
		||||
    assert isinstance(graphene_type, graphene.Enum)
 | 
			
		||||
    assert graphene_type._meta.name == 'TranslatedModelLanguage'
 | 
			
		||||
    assert graphene_type._meta.enum.__members__['ES'].value == 'es'
 | 
			
		||||
    assert graphene_type._meta.enum.__members__['ES'].description == 'Spanish'
 | 
			
		||||
    assert graphene_type._meta.enum.__members__['EN'].value == 'en'
 | 
			
		||||
    assert graphene_type._meta.enum.__members__['EN'].description == 'English'
 | 
			
		||||
    assert graphene_type._meta.name == "TranslatedModelLanguage"
 | 
			
		||||
    assert graphene_type._meta.enum.__members__["ES"].value == "es"
 | 
			
		||||
    assert graphene_type._meta.enum.__members__["ES"].description == "Spanish"
 | 
			
		||||
    assert graphene_type._meta.enum.__members__["EN"].value == "en"
 | 
			
		||||
    assert graphene_type._meta.enum.__members__["EN"].description == "English"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_field_with_grouped_choices():
 | 
			
		||||
    field = models.CharField(help_text='Language', choices=(
 | 
			
		||||
        ('Europe', (
 | 
			
		||||
            ('es', 'Spanish'),
 | 
			
		||||
            ('en', 'English'),
 | 
			
		||||
        )),
 | 
			
		||||
    ))
 | 
			
		||||
    field = models.CharField(
 | 
			
		||||
        help_text="Language",
 | 
			
		||||
        choices=(("Europe", (("es", "Spanish"), ("en", "English"))),),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    class GroupedChoicesModel(models.Model):
 | 
			
		||||
        language = field
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            app_label = 'test'
 | 
			
		||||
            app_label = "test"
 | 
			
		||||
 | 
			
		||||
    convert_django_field_with_choices(field)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_field_with_choices_gettext():
 | 
			
		||||
    field = models.CharField(help_text='Language', choices=(
 | 
			
		||||
        ('es', _('Spanish')),
 | 
			
		||||
        ('en', _('English'))
 | 
			
		||||
    ))
 | 
			
		||||
    field = models.CharField(
 | 
			
		||||
        help_text="Language", choices=(("es", _("Spanish")), ("en", _("English")))
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    class TranslatedChoicesModel(models.Model):
 | 
			
		||||
        language = field
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            app_label = 'test'
 | 
			
		||||
            app_label = "test"
 | 
			
		||||
 | 
			
		||||
    convert_django_field_with_choices(field)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_field_with_choices_collision():
 | 
			
		||||
    field = models.CharField(help_text='Timezone', choices=(
 | 
			
		||||
        ('Etc/GMT+1+2', 'Fake choice to produce double collision'),
 | 
			
		||||
        ('Etc/GMT+1', 'Greenwich Mean Time +1'),
 | 
			
		||||
        ('Etc/GMT-1', 'Greenwich Mean Time -1'),
 | 
			
		||||
    ))
 | 
			
		||||
    field = models.CharField(
 | 
			
		||||
        help_text="Timezone",
 | 
			
		||||
        choices=(
 | 
			
		||||
            ("Etc/GMT+1+2", "Fake choice to produce double collision"),
 | 
			
		||||
            ("Etc/GMT+1", "Greenwich Mean Time +1"),
 | 
			
		||||
            ("Etc/GMT-1", "Greenwich Mean Time -1"),
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    class CollisionChoicesModel(models.Model):
 | 
			
		||||
        timezone = field
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            app_label = 'test'
 | 
			
		||||
            app_label = "test"
 | 
			
		||||
 | 
			
		||||
    convert_django_field_with_choices(field)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -208,11 +208,12 @@ def test_should_manytomany_convert_connectionorlist():
 | 
			
		|||
 | 
			
		||||
def test_should_manytomany_convert_connectionorlist_list():
 | 
			
		||||
    class A(DjangoObjectType):
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Reporter
 | 
			
		||||
 | 
			
		||||
    graphene_field = convert_django_field(Reporter._meta.local_many_to_many[0], A._meta.registry)
 | 
			
		||||
    graphene_field = convert_django_field(
 | 
			
		||||
        Reporter._meta.local_many_to_many[0], A._meta.registry
 | 
			
		||||
    )
 | 
			
		||||
    assert isinstance(graphene_field, graphene.Dynamic)
 | 
			
		||||
    dynamic_field = graphene_field.get_type()
 | 
			
		||||
    assert isinstance(dynamic_field, graphene.Field)
 | 
			
		||||
| 
						 | 
				
			
			@ -222,12 +223,13 @@ def test_should_manytomany_convert_connectionorlist_list():
 | 
			
		|||
 | 
			
		||||
def test_should_manytomany_convert_connectionorlist_connection():
 | 
			
		||||
    class A(DjangoObjectType):
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Reporter
 | 
			
		||||
            interfaces = (Node, )
 | 
			
		||||
            interfaces = (Node,)
 | 
			
		||||
 | 
			
		||||
    graphene_field = convert_django_field(Reporter._meta.local_many_to_many[0], A._meta.registry)
 | 
			
		||||
    graphene_field = convert_django_field(
 | 
			
		||||
        Reporter._meta.local_many_to_many[0], A._meta.registry
 | 
			
		||||
    )
 | 
			
		||||
    assert isinstance(graphene_field, graphene.Dynamic)
 | 
			
		||||
    dynamic_field = graphene_field.get_type()
 | 
			
		||||
    assert isinstance(dynamic_field, ConnectionField)
 | 
			
		||||
| 
						 | 
				
			
			@ -236,11 +238,11 @@ def test_should_manytomany_convert_connectionorlist_connection():
 | 
			
		|||
 | 
			
		||||
def test_should_manytoone_convert_connectionorlist():
 | 
			
		||||
    # Django 1.9 uses 'rel', <1.9 uses 'related
 | 
			
		||||
    related = getattr(Reporter.articles, 'rel', None) or \
 | 
			
		||||
        getattr(Reporter.articles, 'related')
 | 
			
		||||
    related = getattr(Reporter.articles, "rel", None) or getattr(
 | 
			
		||||
        Reporter.articles, "related"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    class A(DjangoObjectType):
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Article
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -254,11 +256,9 @@ def test_should_manytoone_convert_connectionorlist():
 | 
			
		|||
 | 
			
		||||
def test_should_onetoone_reverse_convert_model():
 | 
			
		||||
    # Django 1.9 uses 'rel', <1.9 uses 'related
 | 
			
		||||
    related = getattr(Film.details, 'rel', None) or \
 | 
			
		||||
        getattr(Film.details, 'related')
 | 
			
		||||
    related = getattr(Film.details, "rel", None) or getattr(Film.details, "related")
 | 
			
		||||
 | 
			
		||||
    class A(DjangoObjectType):
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = FilmDetails
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -269,41 +269,41 @@ def test_should_onetoone_reverse_convert_model():
 | 
			
		|||
    assert dynamic_field.type == A
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.skipif(ArrayField is MissingType,
 | 
			
		||||
                    reason="ArrayField should exist")
 | 
			
		||||
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
 | 
			
		||||
def test_should_postgres_array_convert_list():
 | 
			
		||||
    field = assert_conversion(ArrayField, graphene.List, models.CharField(max_length=100))
 | 
			
		||||
    field = assert_conversion(
 | 
			
		||||
        ArrayField, graphene.List, models.CharField(max_length=100)
 | 
			
		||||
    )
 | 
			
		||||
    assert isinstance(field.type, graphene.NonNull)
 | 
			
		||||
    assert isinstance(field.type.of_type, graphene.List)
 | 
			
		||||
    assert field.type.of_type.of_type == graphene.String
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.skipif(ArrayField is MissingType,
 | 
			
		||||
                    reason="ArrayField should exist")
 | 
			
		||||
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
 | 
			
		||||
def test_should_postgres_array_multiple_convert_list():
 | 
			
		||||
    field = assert_conversion(ArrayField, graphene.List, ArrayField(models.CharField(max_length=100)))
 | 
			
		||||
    field = assert_conversion(
 | 
			
		||||
        ArrayField, graphene.List, ArrayField(models.CharField(max_length=100))
 | 
			
		||||
    )
 | 
			
		||||
    assert isinstance(field.type, graphene.NonNull)
 | 
			
		||||
    assert isinstance(field.type.of_type, graphene.List)
 | 
			
		||||
    assert isinstance(field.type.of_type.of_type, graphene.List)
 | 
			
		||||
    assert field.type.of_type.of_type.of_type == graphene.String
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.skipif(HStoreField is MissingType,
 | 
			
		||||
                    reason="HStoreField should exist")
 | 
			
		||||
@pytest.mark.skipif(HStoreField is MissingType, reason="HStoreField should exist")
 | 
			
		||||
def test_should_postgres_hstore_convert_string():
 | 
			
		||||
    assert_conversion(HStoreField, JSONString)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.skipif(JSONField is MissingType,
 | 
			
		||||
                    reason="JSONField should exist")
 | 
			
		||||
@pytest.mark.skipif(JSONField is MissingType, reason="JSONField should exist")
 | 
			
		||||
def test_should_postgres_json_convert_string():
 | 
			
		||||
    assert_conversion(JSONField, JSONString)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.skipif(RangeField is MissingType,
 | 
			
		||||
                    reason="RangeField should exist")
 | 
			
		||||
@pytest.mark.skipif(RangeField is MissingType, reason="RangeField should exist")
 | 
			
		||||
def test_should_postgres_range_convert_list():
 | 
			
		||||
    from django.contrib.postgres.fields import IntegerRangeField
 | 
			
		||||
 | 
			
		||||
    field = assert_conversion(IntegerRangeField, graphene.List)
 | 
			
		||||
    assert isinstance(field.type, graphene.NonNull)
 | 
			
		||||
    assert isinstance(field.type.of_type, graphene.List)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,7 +1,7 @@
 | 
			
		|||
from django.core.exceptions import ValidationError
 | 
			
		||||
from py.test import raises
 | 
			
		||||
 | 
			
		||||
from ..forms import GlobalIDFormField,GlobalIDMultipleChoiceField
 | 
			
		||||
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 'TXlUeXBlOmFiYw==' -> 'MyType', 'abc'
 | 
			
		||||
| 
						 | 
				
			
			@ -9,24 +9,24 @@ from ..forms import GlobalIDFormField,GlobalIDMultipleChoiceField
 | 
			
		|||
 | 
			
		||||
def test_global_id_valid():
 | 
			
		||||
    field = GlobalIDFormField()
 | 
			
		||||
    field.clean('TXlUeXBlOmFiYw==')
 | 
			
		||||
    field.clean("TXlUeXBlOmFiYw==")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_global_id_invalid():
 | 
			
		||||
    field = GlobalIDFormField()
 | 
			
		||||
    with raises(ValidationError):
 | 
			
		||||
        field.clean('badvalue')
 | 
			
		||||
        field.clean("badvalue")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_global_id_multiple_valid():
 | 
			
		||||
    field = GlobalIDMultipleChoiceField()
 | 
			
		||||
    field.clean(['TXlUeXBlOmFiYw==', 'TXlUeXBlOmFiYw=='])
 | 
			
		||||
    field.clean(["TXlUeXBlOmFiYw==", "TXlUeXBlOmFiYw=="])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_global_id_multiple_invalid():
 | 
			
		||||
    field = GlobalIDMultipleChoiceField()
 | 
			
		||||
    with raises(ValidationError):
 | 
			
		||||
        field.clean(['badvalue', 'another bad avue'])
 | 
			
		||||
        field.clean(["badvalue", "another bad avue"])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_global_id_none():
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| 
						 | 
				
			
			@ -7,48 +7,47 @@ from .models import Reporter
 | 
			
		|||
 | 
			
		||||
def test_should_raise_if_no_model():
 | 
			
		||||
    with raises(Exception) as excinfo:
 | 
			
		||||
 | 
			
		||||
        class Character1(DjangoObjectType):
 | 
			
		||||
            pass
 | 
			
		||||
    assert 'valid Django Model' in str(excinfo.value)
 | 
			
		||||
 | 
			
		||||
    assert "valid Django Model" in str(excinfo.value)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_should_raise_if_model_is_invalid():
 | 
			
		||||
    with raises(Exception) as excinfo:
 | 
			
		||||
        class Character2(DjangoObjectType):
 | 
			
		||||
 | 
			
		||||
        class Character2(DjangoObjectType):
 | 
			
		||||
            class Meta:
 | 
			
		||||
                model = 1
 | 
			
		||||
    assert 'valid Django Model' in str(excinfo.value)
 | 
			
		||||
 | 
			
		||||
    assert "valid Django Model" in str(excinfo.value)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_should_map_fields_correctly():
 | 
			
		||||
    class ReporterType2(DjangoObjectType):
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Reporter
 | 
			
		||||
            registry = Registry()
 | 
			
		||||
 | 
			
		||||
    fields = list(ReporterType2._meta.fields.keys())
 | 
			
		||||
    assert fields[:-2] == [
 | 
			
		||||
        'id',
 | 
			
		||||
        'first_name',
 | 
			
		||||
        'last_name',
 | 
			
		||||
        'email',
 | 
			
		||||
        'pets',
 | 
			
		||||
        'a_choice',
 | 
			
		||||
        'reporter_type'
 | 
			
		||||
        "id",
 | 
			
		||||
        "first_name",
 | 
			
		||||
        "last_name",
 | 
			
		||||
        "email",
 | 
			
		||||
        "pets",
 | 
			
		||||
        "a_choice",
 | 
			
		||||
        "reporter_type",
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    assert sorted(fields[-2:]) == [
 | 
			
		||||
        'articles',
 | 
			
		||||
        'films',
 | 
			
		||||
    ]
 | 
			
		||||
    assert sorted(fields[-2:]) == ["articles", "films"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_should_map_only_few_fields():
 | 
			
		||||
    class Reporter2(DjangoObjectType):
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Reporter
 | 
			
		||||
            only_fields = ('id', 'email')
 | 
			
		||||
            only_fields = ("id", "email")
 | 
			
		||||
 | 
			
		||||
    assert list(Reporter2._meta.fields.keys()) == ['id', 'email']
 | 
			
		||||
    assert list(Reporter2._meta.fields.keys()) == ["id", "email"]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,27 +12,30 @@ registry.reset_global_registry()
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class Reporter(DjangoObjectType):
 | 
			
		||||
    '''Reporter description'''
 | 
			
		||||
    """Reporter description"""
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        model = ReporterModel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ArticleConnection(Connection):
 | 
			
		||||
    '''Article Connection'''
 | 
			
		||||
    """Article Connection"""
 | 
			
		||||
 | 
			
		||||
    test = String()
 | 
			
		||||
 | 
			
		||||
    def resolve_test():
 | 
			
		||||
        return 'test'
 | 
			
		||||
        return "test"
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        abstract = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Article(DjangoObjectType):
 | 
			
		||||
    '''Article description'''
 | 
			
		||||
    """Article description"""
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        model = ArticleModel
 | 
			
		||||
        interfaces = (Node, )
 | 
			
		||||
        interfaces = (Node,)
 | 
			
		||||
        connection_class = ArticleConnection
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -48,7 +51,7 @@ def test_django_interface():
 | 
			
		|||
    assert issubclass(Node, Node)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@patch('graphene_django.tests.models.Article.objects.get', return_value=Article(id=1))
 | 
			
		||||
@patch("graphene_django.tests.models.Article.objects.get", return_value=Article(id=1))
 | 
			
		||||
def test_django_get_node(get):
 | 
			
		||||
    article = Article.get_node(None, 1)
 | 
			
		||||
    get.assert_called_with(pk=1)
 | 
			
		||||
| 
						 | 
				
			
			@ -58,18 +61,35 @@ def test_django_get_node(get):
 | 
			
		|||
def test_django_objecttype_map_correct_fields():
 | 
			
		||||
    fields = Reporter._meta.fields
 | 
			
		||||
    fields = list(fields.keys())
 | 
			
		||||
    assert fields[:-2] == ['id', 'first_name', 'last_name', 'email', 'pets', 'a_choice', 'reporter_type']
 | 
			
		||||
    assert sorted(fields[-2:]) == ['articles', 'films']
 | 
			
		||||
    assert fields[:-2] == [
 | 
			
		||||
        "id",
 | 
			
		||||
        "first_name",
 | 
			
		||||
        "last_name",
 | 
			
		||||
        "email",
 | 
			
		||||
        "pets",
 | 
			
		||||
        "a_choice",
 | 
			
		||||
        "reporter_type",
 | 
			
		||||
    ]
 | 
			
		||||
    assert sorted(fields[-2:]) == ["articles", "films"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_django_objecttype_with_node_have_correct_fields():
 | 
			
		||||
    fields = Article._meta.fields
 | 
			
		||||
    assert list(fields.keys()) == ['id', 'headline', 'pub_date', 'pub_date_time', 'reporter', 'editor', 'lang', 'importance']
 | 
			
		||||
    assert list(fields.keys()) == [
 | 
			
		||||
        "id",
 | 
			
		||||
        "headline",
 | 
			
		||||
        "pub_date",
 | 
			
		||||
        "pub_date_time",
 | 
			
		||||
        "reporter",
 | 
			
		||||
        "editor",
 | 
			
		||||
        "lang",
 | 
			
		||||
        "importance",
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_django_objecttype_with_custom_meta():
 | 
			
		||||
    class ArticleTypeOptions(DjangoObjectTypeOptions):
 | 
			
		||||
        '''Article Type Options'''
 | 
			
		||||
        """Article Type Options"""
 | 
			
		||||
 | 
			
		||||
    class ArticleType(DjangoObjectType):
 | 
			
		||||
        class Meta:
 | 
			
		||||
| 
						 | 
				
			
			@ -77,7 +97,7 @@ def test_django_objecttype_with_custom_meta():
 | 
			
		|||
 | 
			
		||||
        @classmethod
 | 
			
		||||
        def __init_subclass_with_meta__(cls, **options):
 | 
			
		||||
            options.setdefault('_meta', ArticleTypeOptions(cls))
 | 
			
		||||
            options.setdefault("_meta", ArticleTypeOptions(cls))
 | 
			
		||||
            super(ArticleType, cls).__init_subclass_with_meta__(**options)
 | 
			
		||||
 | 
			
		||||
    class Article(ArticleType):
 | 
			
		||||
| 
						 | 
				
			
			@ -180,6 +200,7 @@ def with_local_registry(func):
 | 
			
		|||
        else:
 | 
			
		||||
            registry.registry = old
 | 
			
		||||
            return retval
 | 
			
		||||
 | 
			
		||||
    return inner
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -188,11 +209,10 @@ def test_django_objecttype_only_fields():
 | 
			
		|||
    class Reporter(DjangoObjectType):
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = ReporterModel
 | 
			
		||||
            only_fields = ('id', 'email', 'films')
 | 
			
		||||
 | 
			
		||||
            only_fields = ("id", "email", "films")
 | 
			
		||||
 | 
			
		||||
    fields = list(Reporter._meta.fields.keys())
 | 
			
		||||
    assert fields == ['id', 'email', 'films']
 | 
			
		||||
    assert fields == ["id", "email", "films"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@with_local_registry
 | 
			
		||||
| 
						 | 
				
			
			@ -200,8 +220,7 @@ def test_django_objecttype_exclude_fields():
 | 
			
		|||
    class Reporter(DjangoObjectType):
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = ReporterModel
 | 
			
		||||
            exclude_fields = ('email')
 | 
			
		||||
 | 
			
		||||
            exclude_fields = "email"
 | 
			
		||||
 | 
			
		||||
    fields = list(Reporter._meta.fields.keys())
 | 
			
		||||
    assert 'email' not in fields
 | 
			
		||||
    assert "email" not in fields
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -8,15 +8,15 @@ except ImportError:
 | 
			
		|||
    from urllib.parse import urlencode
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def url_string(string='/graphql', **url_params):
 | 
			
		||||
def url_string(string="/graphql", **url_params):
 | 
			
		||||
    if url_params:
 | 
			
		||||
        string += '?' + urlencode(url_params)
 | 
			
		||||
        string += "?" + urlencode(url_params)
 | 
			
		||||
 | 
			
		||||
    return string
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def batch_url_string(**url_params):
 | 
			
		||||
    return url_string('/graphql/batch', **url_params)
 | 
			
		||||
    return url_string("/graphql/batch", **url_params)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def response_json(response):
 | 
			
		||||
| 
						 | 
				
			
			@ -28,441 +28,446 @@ jl = lambda **kwargs: json.dumps([kwargs])
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def test_graphiql_is_enabled(client):
 | 
			
		||||
    response = client.get(url_string(), HTTP_ACCEPT='text/html')
 | 
			
		||||
    response = client.get(url_string(), HTTP_ACCEPT="text/html")
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response['Content-Type'].split(';')[0] == 'text/html'
 | 
			
		||||
    assert response["Content-Type"].split(";")[0] == "text/html"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_qfactor_graphiql(client):
 | 
			
		||||
    response = client.get(url_string(query='{test}'), HTTP_ACCEPT='application/json;q=0.8, text/html;q=0.9')
 | 
			
		||||
    response = client.get(
 | 
			
		||||
        url_string(query="{test}"),
 | 
			
		||||
        HTTP_ACCEPT="application/json;q=0.8, text/html;q=0.9",
 | 
			
		||||
    )
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response['Content-Type'].split(';')[0] == 'text/html'
 | 
			
		||||
    assert response["Content-Type"].split(";")[0] == "text/html"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_qfactor_json(client):
 | 
			
		||||
    response = client.get(url_string(query='{test}'), HTTP_ACCEPT='text/html;q=0.8, application/json;q=0.9')
 | 
			
		||||
    response = client.get(
 | 
			
		||||
        url_string(query="{test}"),
 | 
			
		||||
        HTTP_ACCEPT="text/html;q=0.8, application/json;q=0.9",
 | 
			
		||||
    )
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response['Content-Type'].split(';')[0] == 'application/json'
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'data': {'test': "Hello World"}
 | 
			
		||||
    }
 | 
			
		||||
    assert response["Content-Type"].split(";")[0] == "application/json"
 | 
			
		||||
    assert response_json(response) == {"data": {"test": "Hello World"}}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_allows_get_with_query_param(client):
 | 
			
		||||
    response = client.get(url_string(query='{test}'))
 | 
			
		||||
    response = client.get(url_string(query="{test}"))
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'data': {'test': "Hello World"}
 | 
			
		||||
    }
 | 
			
		||||
    assert response_json(response) == {"data": {"test": "Hello World"}}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_allows_get_with_variable_values(client):
 | 
			
		||||
    response = client.get(url_string(
 | 
			
		||||
        query='query helloWho($who: String){ test(who: $who) }',
 | 
			
		||||
        variables=json.dumps({'who': "Dolly"})
 | 
			
		||||
    ))
 | 
			
		||||
    response = client.get(
 | 
			
		||||
        url_string(
 | 
			
		||||
            query="query helloWho($who: String){ test(who: $who) }",
 | 
			
		||||
            variables=json.dumps({"who": "Dolly"}),
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'data': {'test': "Hello Dolly"}
 | 
			
		||||
    }
 | 
			
		||||
    assert response_json(response) == {"data": {"test": "Hello Dolly"}}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_allows_get_with_operation_name(client):
 | 
			
		||||
    response = client.get(url_string(
 | 
			
		||||
        query='''
 | 
			
		||||
    response = client.get(
 | 
			
		||||
        url_string(
 | 
			
		||||
            query="""
 | 
			
		||||
        query helloYou { test(who: "You"), ...shared }
 | 
			
		||||
        query helloWorld { test(who: "World"), ...shared }
 | 
			
		||||
        query helloDolly { test(who: "Dolly"), ...shared }
 | 
			
		||||
        fragment shared on QueryRoot {
 | 
			
		||||
          shared: test(who: "Everyone")
 | 
			
		||||
        }
 | 
			
		||||
        ''',
 | 
			
		||||
        operationName='helloWorld'
 | 
			
		||||
    ))
 | 
			
		||||
        """,
 | 
			
		||||
            operationName="helloWorld",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'data': {
 | 
			
		||||
            'test': 'Hello World',
 | 
			
		||||
            'shared': 'Hello Everyone'
 | 
			
		||||
        }
 | 
			
		||||
        "data": {"test": "Hello World", "shared": "Hello Everyone"}
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_reports_validation_errors(client):
 | 
			
		||||
    response = client.get(url_string(
 | 
			
		||||
        query='{ test, unknownOne, unknownTwo }'
 | 
			
		||||
    ))
 | 
			
		||||
    response = client.get(url_string(query="{ test, unknownOne, unknownTwo }"))
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 400
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'errors': [
 | 
			
		||||
        "errors": [
 | 
			
		||||
            {
 | 
			
		||||
                'message': 'Cannot query field "unknownOne" on type "QueryRoot".',
 | 
			
		||||
                'locations': [{'line': 1, 'column': 9}]
 | 
			
		||||
                "message": 'Cannot query field "unknownOne" on type "QueryRoot".',
 | 
			
		||||
                "locations": [{"line": 1, "column": 9}],
 | 
			
		||||
            },
 | 
			
		||||
            {
 | 
			
		||||
                'message': 'Cannot query field "unknownTwo" on type "QueryRoot".',
 | 
			
		||||
                'locations': [{'line': 1, 'column': 21}]
 | 
			
		||||
            }
 | 
			
		||||
                "message": 'Cannot query field "unknownTwo" on type "QueryRoot".',
 | 
			
		||||
                "locations": [{"line": 1, "column": 21}],
 | 
			
		||||
            },
 | 
			
		||||
        ]
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_errors_when_missing_operation_name(client):
 | 
			
		||||
    response = client.get(url_string(
 | 
			
		||||
        query='''
 | 
			
		||||
    response = client.get(
 | 
			
		||||
        url_string(
 | 
			
		||||
            query="""
 | 
			
		||||
        query TestQuery { test }
 | 
			
		||||
        mutation TestMutation { writeTest { test } }
 | 
			
		||||
        '''
 | 
			
		||||
    ))
 | 
			
		||||
        """
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 400
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'errors': [
 | 
			
		||||
        "errors": [
 | 
			
		||||
            {
 | 
			
		||||
                'message': 'Must provide operation name if query contains multiple operations.'
 | 
			
		||||
                "message": "Must provide operation name if query contains multiple operations."
 | 
			
		||||
            }
 | 
			
		||||
        ]
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_errors_when_sending_a_mutation_via_get(client):
 | 
			
		||||
    response = client.get(url_string(
 | 
			
		||||
        query='''
 | 
			
		||||
    response = client.get(
 | 
			
		||||
        url_string(
 | 
			
		||||
            query="""
 | 
			
		||||
        mutation TestMutation { writeTest { test } }
 | 
			
		||||
        '''
 | 
			
		||||
    ))
 | 
			
		||||
        """
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    assert response.status_code == 405
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'errors': [
 | 
			
		||||
            {
 | 
			
		||||
                'message': 'Can only perform a mutation operation from a POST request.'
 | 
			
		||||
            }
 | 
			
		||||
        "errors": [
 | 
			
		||||
            {"message": "Can only perform a mutation operation from a POST request."}
 | 
			
		||||
        ]
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_errors_when_selecting_a_mutation_within_a_get(client):
 | 
			
		||||
    response = client.get(url_string(
 | 
			
		||||
        query='''
 | 
			
		||||
    response = client.get(
 | 
			
		||||
        url_string(
 | 
			
		||||
            query="""
 | 
			
		||||
        query TestQuery { test }
 | 
			
		||||
        mutation TestMutation { writeTest { test } }
 | 
			
		||||
        ''',
 | 
			
		||||
        operationName='TestMutation'
 | 
			
		||||
    ))
 | 
			
		||||
        """,
 | 
			
		||||
            operationName="TestMutation",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 405
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'errors': [
 | 
			
		||||
            {
 | 
			
		||||
                'message': 'Can only perform a mutation operation from a POST request.'
 | 
			
		||||
            }
 | 
			
		||||
        "errors": [
 | 
			
		||||
            {"message": "Can only perform a mutation operation from a POST request."}
 | 
			
		||||
        ]
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_allows_mutation_to_exist_within_a_get(client):
 | 
			
		||||
    response = client.get(url_string(
 | 
			
		||||
        query='''
 | 
			
		||||
    response = client.get(
 | 
			
		||||
        url_string(
 | 
			
		||||
            query="""
 | 
			
		||||
        query TestQuery { test }
 | 
			
		||||
        mutation TestMutation { writeTest { test } }
 | 
			
		||||
        ''',
 | 
			
		||||
        operationName='TestQuery'
 | 
			
		||||
    ))
 | 
			
		||||
        """,
 | 
			
		||||
            operationName="TestQuery",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'data': {'test': "Hello World"}
 | 
			
		||||
    }
 | 
			
		||||
    assert response_json(response) == {"data": {"test": "Hello World"}}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_allows_post_with_json_encoding(client):
 | 
			
		||||
    response = client.post(url_string(), j(query='{test}'), 'application/json')
 | 
			
		||||
    response = client.post(url_string(), j(query="{test}"), "application/json")
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'data': {'test': "Hello World"}
 | 
			
		||||
    }
 | 
			
		||||
    assert response_json(response) == {"data": {"test": "Hello World"}}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_batch_allows_post_with_json_encoding(client):
 | 
			
		||||
    response = client.post(batch_url_string(), jl(id=1, query='{test}'), 'application/json')
 | 
			
		||||
    response = client.post(
 | 
			
		||||
        batch_url_string(), jl(id=1, query="{test}"), "application/json"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response_json(response) == [{
 | 
			
		||||
        'id': 1,
 | 
			
		||||
        'data': {'test': "Hello World"},
 | 
			
		||||
        'status': 200,
 | 
			
		||||
    }]
 | 
			
		||||
    assert response_json(response) == [
 | 
			
		||||
        {"id": 1, "data": {"test": "Hello World"}, "status": 200}
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_batch_fails_if_is_empty(client):
 | 
			
		||||
    response = client.post(batch_url_string(), '[]', 'application/json')
 | 
			
		||||
    response = client.post(batch_url_string(), "[]", "application/json")
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 400
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'errors': [{'message': 'Received an empty list in the batch request.'}]
 | 
			
		||||
        "errors": [{"message": "Received an empty list in the batch request."}]
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_allows_sending_a_mutation_via_post(client):
 | 
			
		||||
    response = client.post(url_string(), j(query='mutation TestMutation { writeTest { test } }'), 'application/json')
 | 
			
		||||
    response = client.post(
 | 
			
		||||
        url_string(),
 | 
			
		||||
        j(query="mutation TestMutation { writeTest { test } }"),
 | 
			
		||||
        "application/json",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'data': {'writeTest': {'test': 'Hello World'}}
 | 
			
		||||
    }
 | 
			
		||||
    assert response_json(response) == {"data": {"writeTest": {"test": "Hello World"}}}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_allows_post_with_url_encoding(client):
 | 
			
		||||
    response = client.post(url_string(), urlencode(dict(query='{test}')), 'application/x-www-form-urlencoded')
 | 
			
		||||
    response = client.post(
 | 
			
		||||
        url_string(),
 | 
			
		||||
        urlencode(dict(query="{test}")),
 | 
			
		||||
        "application/x-www-form-urlencoded",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'data': {'test': "Hello World"}
 | 
			
		||||
    }
 | 
			
		||||
    assert response_json(response) == {"data": {"test": "Hello World"}}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_supports_post_json_query_with_string_variables(client):
 | 
			
		||||
    response = client.post(url_string(), j(
 | 
			
		||||
        query='query helloWho($who: String){ test(who: $who) }',
 | 
			
		||||
        variables=json.dumps({'who': "Dolly"})
 | 
			
		||||
    ), 'application/json')
 | 
			
		||||
    response = client.post(
 | 
			
		||||
        url_string(),
 | 
			
		||||
        j(
 | 
			
		||||
            query="query helloWho($who: String){ test(who: $who) }",
 | 
			
		||||
            variables=json.dumps({"who": "Dolly"}),
 | 
			
		||||
        ),
 | 
			
		||||
        "application/json",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'data': {'test': "Hello Dolly"}
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    assert response_json(response) == {"data": {"test": "Hello Dolly"}}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_batch_supports_post_json_query_with_string_variables(client):
 | 
			
		||||
    response = client.post(batch_url_string(), jl(
 | 
			
		||||
        id=1,
 | 
			
		||||
        query='query helloWho($who: String){ test(who: $who) }',
 | 
			
		||||
        variables=json.dumps({'who': "Dolly"})
 | 
			
		||||
    ), 'application/json')
 | 
			
		||||
    response = client.post(
 | 
			
		||||
        batch_url_string(),
 | 
			
		||||
        jl(
 | 
			
		||||
            id=1,
 | 
			
		||||
            query="query helloWho($who: String){ test(who: $who) }",
 | 
			
		||||
            variables=json.dumps({"who": "Dolly"}),
 | 
			
		||||
        ),
 | 
			
		||||
        "application/json",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response_json(response) == [{
 | 
			
		||||
        'id': 1,
 | 
			
		||||
        'data': {'test': "Hello Dolly"},
 | 
			
		||||
        'status': 200,
 | 
			
		||||
    }]
 | 
			
		||||
    assert response_json(response) == [
 | 
			
		||||
        {"id": 1, "data": {"test": "Hello Dolly"}, "status": 200}
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_supports_post_json_query_with_json_variables(client):
 | 
			
		||||
    response = client.post(url_string(), j(
 | 
			
		||||
        query='query helloWho($who: String){ test(who: $who) }',
 | 
			
		||||
        variables={'who': "Dolly"}
 | 
			
		||||
    ), 'application/json')
 | 
			
		||||
    response = client.post(
 | 
			
		||||
        url_string(),
 | 
			
		||||
        j(
 | 
			
		||||
            query="query helloWho($who: String){ test(who: $who) }",
 | 
			
		||||
            variables={"who": "Dolly"},
 | 
			
		||||
        ),
 | 
			
		||||
        "application/json",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'data': {'test': "Hello Dolly"}
 | 
			
		||||
    }
 | 
			
		||||
    assert response_json(response) == {"data": {"test": "Hello Dolly"}}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_batch_supports_post_json_query_with_json_variables(client):
 | 
			
		||||
    response = client.post(batch_url_string(), jl(
 | 
			
		||||
        id=1,
 | 
			
		||||
        query='query helloWho($who: String){ test(who: $who) }',
 | 
			
		||||
        variables={'who': "Dolly"}
 | 
			
		||||
    ), 'application/json')
 | 
			
		||||
    response = client.post(
 | 
			
		||||
        batch_url_string(),
 | 
			
		||||
        jl(
 | 
			
		||||
            id=1,
 | 
			
		||||
            query="query helloWho($who: String){ test(who: $who) }",
 | 
			
		||||
            variables={"who": "Dolly"},
 | 
			
		||||
        ),
 | 
			
		||||
        "application/json",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response_json(response) == [{
 | 
			
		||||
        'id': 1,
 | 
			
		||||
        'data': {'test': "Hello Dolly"},
 | 
			
		||||
        'status': 200,
 | 
			
		||||
    }]
 | 
			
		||||
    assert response_json(response) == [
 | 
			
		||||
        {"id": 1, "data": {"test": "Hello Dolly"}, "status": 200}
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_supports_post_url_encoded_query_with_string_variables(client):
 | 
			
		||||
    response = client.post(url_string(), urlencode(dict(
 | 
			
		||||
        query='query helloWho($who: String){ test(who: $who) }',
 | 
			
		||||
        variables=json.dumps({'who': "Dolly"})
 | 
			
		||||
    )), 'application/x-www-form-urlencoded')
 | 
			
		||||
    response = client.post(
 | 
			
		||||
        url_string(),
 | 
			
		||||
        urlencode(
 | 
			
		||||
            dict(
 | 
			
		||||
                query="query helloWho($who: String){ test(who: $who) }",
 | 
			
		||||
                variables=json.dumps({"who": "Dolly"}),
 | 
			
		||||
            )
 | 
			
		||||
        ),
 | 
			
		||||
        "application/x-www-form-urlencoded",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'data': {'test': "Hello Dolly"}
 | 
			
		||||
    }
 | 
			
		||||
    assert response_json(response) == {"data": {"test": "Hello Dolly"}}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_supports_post_json_quey_with_get_variable_values(client):
 | 
			
		||||
    response = client.post(url_string(
 | 
			
		||||
        variables=json.dumps({'who': "Dolly"})
 | 
			
		||||
    ), j(
 | 
			
		||||
        query='query helloWho($who: String){ test(who: $who) }',
 | 
			
		||||
    ), 'application/json')
 | 
			
		||||
    response = client.post(
 | 
			
		||||
        url_string(variables=json.dumps({"who": "Dolly"})),
 | 
			
		||||
        j(query="query helloWho($who: String){ test(who: $who) }"),
 | 
			
		||||
        "application/json",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'data': {'test': "Hello Dolly"}
 | 
			
		||||
    }
 | 
			
		||||
    assert response_json(response) == {"data": {"test": "Hello Dolly"}}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_post_url_encoded_query_with_get_variable_values(client):
 | 
			
		||||
    response = client.post(url_string(
 | 
			
		||||
        variables=json.dumps({'who': "Dolly"})
 | 
			
		||||
    ), urlencode(dict(
 | 
			
		||||
        query='query helloWho($who: String){ test(who: $who) }',
 | 
			
		||||
    )), 'application/x-www-form-urlencoded')
 | 
			
		||||
    response = client.post(
 | 
			
		||||
        url_string(variables=json.dumps({"who": "Dolly"})),
 | 
			
		||||
        urlencode(dict(query="query helloWho($who: String){ test(who: $who) }")),
 | 
			
		||||
        "application/x-www-form-urlencoded",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'data': {'test': "Hello Dolly"}
 | 
			
		||||
    }
 | 
			
		||||
    assert response_json(response) == {"data": {"test": "Hello Dolly"}}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_supports_post_raw_text_query_with_get_variable_values(client):
 | 
			
		||||
    response = client.post(url_string(
 | 
			
		||||
        variables=json.dumps({'who': "Dolly"})
 | 
			
		||||
    ),
 | 
			
		||||
        'query helloWho($who: String){ test(who: $who) }',
 | 
			
		||||
        'application/graphql'
 | 
			
		||||
    response = client.post(
 | 
			
		||||
        url_string(variables=json.dumps({"who": "Dolly"})),
 | 
			
		||||
        "query helloWho($who: String){ test(who: $who) }",
 | 
			
		||||
        "application/graphql",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response_json(response) == {"data": {"test": "Hello Dolly"}}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_allows_post_with_operation_name(client):
 | 
			
		||||
    response = client.post(
 | 
			
		||||
        url_string(),
 | 
			
		||||
        j(
 | 
			
		||||
            query="""
 | 
			
		||||
        query helloYou { test(who: "You"), ...shared }
 | 
			
		||||
        query helloWorld { test(who: "World"), ...shared }
 | 
			
		||||
        query helloDolly { test(who: "Dolly"), ...shared }
 | 
			
		||||
        fragment shared on QueryRoot {
 | 
			
		||||
          shared: test(who: "Everyone")
 | 
			
		||||
        }
 | 
			
		||||
        """,
 | 
			
		||||
            operationName="helloWorld",
 | 
			
		||||
        ),
 | 
			
		||||
        "application/json",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'data': {'test': "Hello Dolly"}
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_allows_post_with_operation_name(client):
 | 
			
		||||
    response = client.post(url_string(), j(
 | 
			
		||||
        query='''
 | 
			
		||||
        query helloYou { test(who: "You"), ...shared }
 | 
			
		||||
        query helloWorld { test(who: "World"), ...shared }
 | 
			
		||||
        query helloDolly { test(who: "Dolly"), ...shared }
 | 
			
		||||
        fragment shared on QueryRoot {
 | 
			
		||||
          shared: test(who: "Everyone")
 | 
			
		||||
        }
 | 
			
		||||
        ''',
 | 
			
		||||
        operationName='helloWorld'
 | 
			
		||||
    ), 'application/json')
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'data': {
 | 
			
		||||
            'test': 'Hello World',
 | 
			
		||||
            'shared': 'Hello Everyone'
 | 
			
		||||
        }
 | 
			
		||||
        "data": {"test": "Hello World", "shared": "Hello Everyone"}
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_batch_allows_post_with_operation_name(client):
 | 
			
		||||
    response = client.post(batch_url_string(), jl(
 | 
			
		||||
        id=1,
 | 
			
		||||
        query='''
 | 
			
		||||
    response = client.post(
 | 
			
		||||
        batch_url_string(),
 | 
			
		||||
        jl(
 | 
			
		||||
            id=1,
 | 
			
		||||
            query="""
 | 
			
		||||
        query helloYou { test(who: "You"), ...shared }
 | 
			
		||||
        query helloWorld { test(who: "World"), ...shared }
 | 
			
		||||
        query helloDolly { test(who: "Dolly"), ...shared }
 | 
			
		||||
        fragment shared on QueryRoot {
 | 
			
		||||
          shared: test(who: "Everyone")
 | 
			
		||||
        }
 | 
			
		||||
        ''',
 | 
			
		||||
        operationName='helloWorld'
 | 
			
		||||
    ), 'application/json')
 | 
			
		||||
        """,
 | 
			
		||||
            operationName="helloWorld",
 | 
			
		||||
        ),
 | 
			
		||||
        "application/json",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response_json(response) == [{
 | 
			
		||||
        'id': 1,
 | 
			
		||||
        'data': {
 | 
			
		||||
            'test': 'Hello World',
 | 
			
		||||
            'shared': 'Hello Everyone'
 | 
			
		||||
        },
 | 
			
		||||
        'status': 200,
 | 
			
		||||
    }]
 | 
			
		||||
    assert response_json(response) == [
 | 
			
		||||
        {
 | 
			
		||||
            "id": 1,
 | 
			
		||||
            "data": {"test": "Hello World", "shared": "Hello Everyone"},
 | 
			
		||||
            "status": 200,
 | 
			
		||||
        }
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_allows_post_with_get_operation_name(client):
 | 
			
		||||
    response = client.post(url_string(
 | 
			
		||||
        operationName='helloWorld'
 | 
			
		||||
    ), '''
 | 
			
		||||
    response = client.post(
 | 
			
		||||
        url_string(operationName="helloWorld"),
 | 
			
		||||
        """
 | 
			
		||||
    query helloYou { test(who: "You"), ...shared }
 | 
			
		||||
    query helloWorld { test(who: "World"), ...shared }
 | 
			
		||||
    query helloDolly { test(who: "Dolly"), ...shared }
 | 
			
		||||
    fragment shared on QueryRoot {
 | 
			
		||||
      shared: test(who: "Everyone")
 | 
			
		||||
    }
 | 
			
		||||
    ''',
 | 
			
		||||
        'application/graphql')
 | 
			
		||||
    """,
 | 
			
		||||
        "application/graphql",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'data': {
 | 
			
		||||
            'test': 'Hello World',
 | 
			
		||||
            'shared': 'Hello Everyone'
 | 
			
		||||
        }
 | 
			
		||||
        "data": {"test": "Hello World", "shared": "Hello Everyone"}
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.urls('graphene_django.tests.urls_inherited')
 | 
			
		||||
@pytest.mark.urls("graphene_django.tests.urls_inherited")
 | 
			
		||||
def test_inherited_class_with_attributes_works(client):
 | 
			
		||||
    inherited_url = '/graphql/inherited/'
 | 
			
		||||
    inherited_url = "/graphql/inherited/"
 | 
			
		||||
    # Check schema and pretty attributes work
 | 
			
		||||
    response = client.post(url_string(inherited_url, query='{test}'))
 | 
			
		||||
    response = client.post(url_string(inherited_url, query="{test}"))
 | 
			
		||||
    assert response.content.decode() == (
 | 
			
		||||
        '{\n'
 | 
			
		||||
        '  "data": {\n'
 | 
			
		||||
        '    "test": "Hello World"\n'
 | 
			
		||||
        '  }\n'
 | 
			
		||||
        '}'
 | 
			
		||||
        "{\n" '  "data": {\n' '    "test": "Hello World"\n' "  }\n" "}"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Check graphiql works
 | 
			
		||||
    response = client.get(url_string(inherited_url), HTTP_ACCEPT='text/html')
 | 
			
		||||
    response = client.get(url_string(inherited_url), HTTP_ACCEPT="text/html")
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.urls('graphene_django.tests.urls_pretty')
 | 
			
		||||
@pytest.mark.urls("graphene_django.tests.urls_pretty")
 | 
			
		||||
def test_supports_pretty_printing(client):
 | 
			
		||||
    response = client.get(url_string(query='{test}'))
 | 
			
		||||
    response = client.get(url_string(query="{test}"))
 | 
			
		||||
 | 
			
		||||
    assert response.content.decode() == (
 | 
			
		||||
        '{\n'
 | 
			
		||||
        '  "data": {\n'
 | 
			
		||||
        '    "test": "Hello World"\n'
 | 
			
		||||
        '  }\n'
 | 
			
		||||
        '}'
 | 
			
		||||
        "{\n" '  "data": {\n' '    "test": "Hello World"\n' "  }\n" "}"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_supports_pretty_printing_by_request(client):
 | 
			
		||||
    response = client.get(url_string(query='{test}', pretty='1'))
 | 
			
		||||
    response = client.get(url_string(query="{test}", pretty="1"))
 | 
			
		||||
 | 
			
		||||
    assert response.content.decode() == (
 | 
			
		||||
        '{\n'
 | 
			
		||||
        '  "data": {\n'
 | 
			
		||||
        '    "test": "Hello World"\n'
 | 
			
		||||
        '  }\n'
 | 
			
		||||
        '}'
 | 
			
		||||
        "{\n" '  "data": {\n' '    "test": "Hello World"\n' "  }\n" "}"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_handles_field_errors_caught_by_graphql(client):
 | 
			
		||||
    response = client.get(url_string(query='{thrower}'))
 | 
			
		||||
    response = client.get(url_string(query="{thrower}"))
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'data': None,
 | 
			
		||||
        'errors': [{
 | 
			
		||||
            'locations': [{'column': 2, 'line': 1}],
 | 
			
		||||
            'path': ['thrower'],
 | 
			
		||||
            'message': 'Throws!',
 | 
			
		||||
        }]
 | 
			
		||||
        "data": None,
 | 
			
		||||
        "errors": [
 | 
			
		||||
            {
 | 
			
		||||
                "locations": [{"column": 2, "line": 1}],
 | 
			
		||||
                "path": ["thrower"],
 | 
			
		||||
                "message": "Throws!",
 | 
			
		||||
            }
 | 
			
		||||
        ],
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_handles_syntax_errors_caught_by_graphql(client):
 | 
			
		||||
    response = client.get(url_string(query='syntaxerror'))
 | 
			
		||||
    response = client.get(url_string(query="syntaxerror"))
 | 
			
		||||
    assert response.status_code == 400
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'errors': [{'locations': [{'column': 1, 'line': 1}],
 | 
			
		||||
                    'message': 'Syntax Error GraphQL (1:1) '
 | 
			
		||||
                               'Unexpected Name "syntaxerror"\n\n1: syntaxerror\n   ^\n'}]
 | 
			
		||||
        "errors": [
 | 
			
		||||
            {
 | 
			
		||||
                "locations": [{"column": 1, "line": 1}],
 | 
			
		||||
                "message": "Syntax Error GraphQL (1:1) "
 | 
			
		||||
                'Unexpected Name "syntaxerror"\n\n1: syntaxerror\n   ^\n',
 | 
			
		||||
            }
 | 
			
		||||
        ]
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -471,25 +476,25 @@ def test_handles_errors_caused_by_a_lack_of_query(client):
 | 
			
		|||
 | 
			
		||||
    assert response.status_code == 400
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'errors': [{'message': 'Must provide query string.'}]
 | 
			
		||||
        "errors": [{"message": "Must provide query string."}]
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_handles_not_expected_json_bodies(client):
 | 
			
		||||
    response = client.post(url_string(), '[]', 'application/json')
 | 
			
		||||
    response = client.post(url_string(), "[]", "application/json")
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 400
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'errors': [{'message': 'The received data is not a valid JSON query.'}]
 | 
			
		||||
        "errors": [{"message": "The received data is not a valid JSON query."}]
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_handles_invalid_json_bodies(client):
 | 
			
		||||
    response = client.post(url_string(), '[oh}', 'application/json')
 | 
			
		||||
    response = client.post(url_string(), "[oh}", "application/json")
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 400
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'errors': [{'message': 'POST body sent invalid JSON.'}]
 | 
			
		||||
        "errors": [{"message": "POST body sent invalid JSON."}]
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -499,63 +504,57 @@ def test_handles_django_request_error(client, monkeypatch):
 | 
			
		|||
 | 
			
		||||
    monkeypatch.setattr("django.http.request.HttpRequest.read", mocked_read)
 | 
			
		||||
 | 
			
		||||
    valid_json = json.dumps(dict(foo='bar'))
 | 
			
		||||
    response = client.post(url_string(), valid_json, 'application/json')
 | 
			
		||||
    valid_json = json.dumps(dict(foo="bar"))
 | 
			
		||||
    response = client.post(url_string(), valid_json, "application/json")
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 400
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'errors': [{'message': 'foo-bar'}]
 | 
			
		||||
    }
 | 
			
		||||
    assert response_json(response) == {"errors": [{"message": "foo-bar"}]}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_handles_incomplete_json_bodies(client):
 | 
			
		||||
    response = client.post(url_string(), '{"query":', 'application/json')
 | 
			
		||||
    response = client.post(url_string(), '{"query":', "application/json")
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 400
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'errors': [{'message': 'POST body sent invalid JSON.'}]
 | 
			
		||||
        "errors": [{"message": "POST body sent invalid JSON."}]
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_handles_plain_post_text(client):
 | 
			
		||||
    response = client.post(url_string(
 | 
			
		||||
        variables=json.dumps({'who': "Dolly"})
 | 
			
		||||
    ),
 | 
			
		||||
        'query helloWho($who: String){ test(who: $who) }',
 | 
			
		||||
        'text/plain'
 | 
			
		||||
    response = client.post(
 | 
			
		||||
        url_string(variables=json.dumps({"who": "Dolly"})),
 | 
			
		||||
        "query helloWho($who: String){ test(who: $who) }",
 | 
			
		||||
        "text/plain",
 | 
			
		||||
    )
 | 
			
		||||
    assert response.status_code == 400
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'errors': [{'message': 'Must provide query string.'}]
 | 
			
		||||
        "errors": [{"message": "Must provide query string."}]
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_handles_poorly_formed_variables(client):
 | 
			
		||||
    response = client.get(url_string(
 | 
			
		||||
        query='query helloWho($who: String){ test(who: $who) }',
 | 
			
		||||
        variables='who:You'
 | 
			
		||||
    ))
 | 
			
		||||
    response = client.get(
 | 
			
		||||
        url_string(
 | 
			
		||||
            query="query helloWho($who: String){ test(who: $who) }", variables="who:You"
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    assert response.status_code == 400
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'errors': [{'message': 'Variables are invalid JSON.'}]
 | 
			
		||||
        "errors": [{"message": "Variables are invalid JSON."}]
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_handles_unsupported_http_methods(client):
 | 
			
		||||
    response = client.put(url_string(query='{test}'))
 | 
			
		||||
    response = client.put(url_string(query="{test}"))
 | 
			
		||||
    assert response.status_code == 405
 | 
			
		||||
    assert response['Allow'] == 'GET, POST'
 | 
			
		||||
    assert response["Allow"] == "GET, POST"
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'errors': [{'message': 'GraphQL only supports GET and POST requests.'}]
 | 
			
		||||
        "errors": [{"message": "GraphQL only supports GET and POST requests."}]
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_passes_request_into_context_request(client):
 | 
			
		||||
    response = client.get(url_string(query='{request}', q='testing'))
 | 
			
		||||
    response = client.get(url_string(query="{request}", q="testing"))
 | 
			
		||||
 | 
			
		||||
    assert response.status_code == 200
 | 
			
		||||
    assert response_json(response) == {
 | 
			
		||||
        'data': {
 | 
			
		||||
            'request': 'testing'
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    assert response_json(response) == {"data": {"request": "testing"}}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,6 +3,6 @@ from django.conf.urls import url
 | 
			
		|||
from ..views import GraphQLView
 | 
			
		||||
 | 
			
		||||
urlpatterns = [
 | 
			
		||||
    url(r'^graphql/batch', GraphQLView.as_view(batch=True)),
 | 
			
		||||
    url(r'^graphql', GraphQLView.as_view(graphiql=True)),
 | 
			
		||||
    url(r"^graphql/batch", GraphQLView.as_view(batch=True)),
 | 
			
		||||
    url(r"^graphql", GraphQLView.as_view(graphiql=True)),
 | 
			
		||||
]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,12 +3,11 @@ from django.conf.urls import url
 | 
			
		|||
from ..views import GraphQLView
 | 
			
		||||
from .schema_view import schema
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CustomGraphQLView(GraphQLView):
 | 
			
		||||
    schema = schema
 | 
			
		||||
    graphiql = True
 | 
			
		||||
    pretty = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
urlpatterns = [
 | 
			
		||||
    url(r'^graphql/inherited/$', CustomGraphQLView.as_view()),
 | 
			
		||||
]
 | 
			
		||||
urlpatterns = [url(r"^graphql/inherited/$", CustomGraphQLView.as_view())]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,6 +3,4 @@ from django.conf.urls import url
 | 
			
		|||
from ..views import GraphQLView
 | 
			
		||||
from .schema_view import schema
 | 
			
		||||
 | 
			
		||||
urlpatterns = [
 | 
			
		||||
    url(r'^graphql', GraphQLView.as_view(schema=schema, pretty=True)),
 | 
			
		||||
]
 | 
			
		||||
urlpatterns = [url(r"^graphql", GraphQLView.as_view(schema=schema, pretty=True))]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -8,8 +8,7 @@ from graphene.types.utils import yank_fields_from_attrs
 | 
			
		|||
 | 
			
		||||
from .converter import convert_django_field_with_choices
 | 
			
		||||
from .registry import Registry, get_global_registry
 | 
			
		||||
from .utils import (DJANGO_FILTER_INSTALLED, get_model_fields,
 | 
			
		||||
                    is_valid_django_model)
 | 
			
		||||
from .utils import DJANGO_FILTER_INSTALLED, get_model_fields, is_valid_django_model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def construct_fields(model, registry, only_fields, exclude_fields):
 | 
			
		||||
| 
						 | 
				
			
			@ -21,7 +20,7 @@ def construct_fields(model, registry, only_fields, exclude_fields):
 | 
			
		|||
        # is_already_created = name in options.fields
 | 
			
		||||
        is_excluded = name in exclude_fields  # or is_already_created
 | 
			
		||||
        # https://docs.djangoproject.com/en/1.10/ref/models/fields/#django.db.models.ForeignKey.related_query_name
 | 
			
		||||
        is_no_backref = str(name).endswith('+')
 | 
			
		||||
        is_no_backref = str(name).endswith("+")
 | 
			
		||||
        if is_not_in_only or is_excluded or is_no_backref:
 | 
			
		||||
            # We skip this field if we specify only_fields and is not
 | 
			
		||||
            # in there. Or when we exclude this field in exclude_fields.
 | 
			
		||||
| 
						 | 
				
			
			@ -43,9 +42,21 @@ class DjangoObjectTypeOptions(ObjectTypeOptions):
 | 
			
		|||
 | 
			
		||||
class DjangoObjectType(ObjectType):
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=False,
 | 
			
		||||
                                    only_fields=(), exclude_fields=(), filter_fields=None, connection=None,
 | 
			
		||||
                                    connection_class=None, use_connection=None, interfaces=(), _meta=None, **options):
 | 
			
		||||
    def __init_subclass_with_meta__(
 | 
			
		||||
        cls,
 | 
			
		||||
        model=None,
 | 
			
		||||
        registry=None,
 | 
			
		||||
        skip_registry=False,
 | 
			
		||||
        only_fields=(),
 | 
			
		||||
        exclude_fields=(),
 | 
			
		||||
        filter_fields=None,
 | 
			
		||||
        connection=None,
 | 
			
		||||
        connection_class=None,
 | 
			
		||||
        use_connection=None,
 | 
			
		||||
        interfaces=(),
 | 
			
		||||
        _meta=None,
 | 
			
		||||
        **options
 | 
			
		||||
    ):
 | 
			
		||||
        assert is_valid_django_model(model), (
 | 
			
		||||
            'You need to pass a valid Django Model in {}.Meta, received "{}".'
 | 
			
		||||
        ).format(cls.__name__, model)
 | 
			
		||||
| 
						 | 
				
			
			@ -54,7 +65,7 @@ class DjangoObjectType(ObjectType):
 | 
			
		|||
            registry = get_global_registry()
 | 
			
		||||
 | 
			
		||||
        assert isinstance(registry, Registry), (
 | 
			
		||||
            'The attribute registry in {} needs to be an instance of '
 | 
			
		||||
            "The attribute registry in {} needs to be an instance of "
 | 
			
		||||
            'Registry, received "{}".'
 | 
			
		||||
        ).format(cls.__name__, registry)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -62,12 +73,13 @@ class DjangoObjectType(ObjectType):
 | 
			
		|||
            raise Exception("Can only set filter_fields if Django-Filter is installed")
 | 
			
		||||
 | 
			
		||||
        django_fields = yank_fields_from_attrs(
 | 
			
		||||
            construct_fields(model, registry, only_fields, exclude_fields),
 | 
			
		||||
            _as=Field,
 | 
			
		||||
            construct_fields(model, registry, only_fields, exclude_fields), _as=Field
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if use_connection is None and interfaces:
 | 
			
		||||
            use_connection = any((issubclass(interface, Node) for interface in interfaces))
 | 
			
		||||
            use_connection = any(
 | 
			
		||||
                (issubclass(interface, Node) for interface in interfaces)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if use_connection and not connection:
 | 
			
		||||
            # We create the connection automatically
 | 
			
		||||
| 
						 | 
				
			
			@ -75,7 +87,8 @@ class DjangoObjectType(ObjectType):
 | 
			
		|||
                connection_class = Connection
 | 
			
		||||
 | 
			
		||||
            connection = connection_class.create_type(
 | 
			
		||||
                '{}Connection'.format(cls.__name__), node=cls)
 | 
			
		||||
                "{}Connection".format(cls.__name__), node=cls
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if connection is not None:
 | 
			
		||||
            assert issubclass(connection, Connection), (
 | 
			
		||||
| 
						 | 
				
			
			@ -91,7 +104,9 @@ class DjangoObjectType(ObjectType):
 | 
			
		|||
        _meta.fields = django_fields
 | 
			
		||||
        _meta.connection = connection
 | 
			
		||||
 | 
			
		||||
        super(DjangoObjectType, cls).__init_subclass_with_meta__(_meta=_meta, interfaces=interfaces, **options)
 | 
			
		||||
        super(DjangoObjectType, cls).__init_subclass_with_meta__(
 | 
			
		||||
            _meta=_meta, interfaces=interfaces, **options
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if not skip_registry:
 | 
			
		||||
            registry.register(cls)
 | 
			
		||||
| 
						 | 
				
			
			@ -107,9 +122,7 @@ class DjangoObjectType(ObjectType):
 | 
			
		|||
        if isinstance(root, cls):
 | 
			
		||||
            return True
 | 
			
		||||
        if not is_valid_django_model(type(root)):
 | 
			
		||||
            raise Exception((
 | 
			
		||||
                'Received incompatible instance "{}".'
 | 
			
		||||
            ).format(root))
 | 
			
		||||
            raise Exception(('Received incompatible instance "{}".').format(root))
 | 
			
		||||
 | 
			
		||||
        model = root._meta.model._meta.concrete_model
 | 
			
		||||
        return model == cls._meta.model
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,6 +13,7 @@ class LazyList(object):
 | 
			
		|||
 | 
			
		||||
try:
 | 
			
		||||
    import django_filters  # noqa
 | 
			
		||||
 | 
			
		||||
    DJANGO_FILTER_INSTALLED = True
 | 
			
		||||
except ImportError:
 | 
			
		||||
    DJANGO_FILTER_INSTALLED = False
 | 
			
		||||
| 
						 | 
				
			
			@ -25,8 +26,7 @@ def get_reverse_fields(model, local_field_names):
 | 
			
		|||
            continue
 | 
			
		||||
 | 
			
		||||
        # Django =>1.9 uses 'rel', django <1.9 uses 'related'
 | 
			
		||||
        related = getattr(attr, 'rel', None) or \
 | 
			
		||||
            getattr(attr, 'related', None)
 | 
			
		||||
        related = getattr(attr, "rel", None) or getattr(attr, "related", None)
 | 
			
		||||
        if isinstance(related, models.ManyToOneRel):
 | 
			
		||||
            yield (name, related)
 | 
			
		||||
        elif isinstance(related, models.ManyToManyRel) and not related.symmetrical:
 | 
			
		||||
| 
						 | 
				
			
			@ -42,9 +42,9 @@ def maybe_queryset(value):
 | 
			
		|||
def get_model_fields(model):
 | 
			
		||||
    local_fields = [
 | 
			
		||||
        (field.name, field)
 | 
			
		||||
        for field
 | 
			
		||||
        in sorted(list(model._meta.fields) +
 | 
			
		||||
                  list(model._meta.local_many_to_many))
 | 
			
		||||
        for field in sorted(
 | 
			
		||||
            list(model._meta.fields) + list(model._meta.local_many_to_many)
 | 
			
		||||
        )
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    # Make sure we don't duplicate local fields with "reverse" version
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -20,7 +20,6 @@ from .settings import graphene_settings
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class HttpError(Exception):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, response, message=None, *args, **kwargs):
 | 
			
		||||
        self.response = response
 | 
			
		||||
        self.message = message = message or response.content.decode()
 | 
			
		||||
| 
						 | 
				
			
			@ -29,18 +28,18 @@ class HttpError(Exception):
 | 
			
		|||
 | 
			
		||||
def get_accepted_content_types(request):
 | 
			
		||||
    def qualify(x):
 | 
			
		||||
        parts = x.split(';', 1)
 | 
			
		||||
        parts = x.split(";", 1)
 | 
			
		||||
        if len(parts) == 2:
 | 
			
		||||
            match = re.match(r'(^|;)q=(0(\.\d{,3})?|1(\.0{,3})?)(;|$)',
 | 
			
		||||
                             parts[1])
 | 
			
		||||
            match = re.match(r"(^|;)q=(0(\.\d{,3})?|1(\.0{,3})?)(;|$)", parts[1])
 | 
			
		||||
            if match:
 | 
			
		||||
                return parts[0].strip(), float(match.group(2))
 | 
			
		||||
        return parts[0].strip(), 1
 | 
			
		||||
 | 
			
		||||
    raw_content_types = request.META.get('HTTP_ACCEPT', '*/*').split(',')
 | 
			
		||||
    raw_content_types = request.META.get("HTTP_ACCEPT", "*/*").split(",")
 | 
			
		||||
    qualified_content_types = map(qualify, raw_content_types)
 | 
			
		||||
    return list(x[0] for x in sorted(qualified_content_types,
 | 
			
		||||
                                     key=lambda x: x[1], reverse=True))
 | 
			
		||||
    return list(
 | 
			
		||||
        x[0] for x in sorted(qualified_content_types, key=lambda x: x[1], reverse=True)
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def instantiate_middleware(middlewares):
 | 
			
		||||
| 
						 | 
				
			
			@ -52,8 +51,8 @@ def instantiate_middleware(middlewares):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class GraphQLView(View):
 | 
			
		||||
    graphiql_version = '0.11.10'
 | 
			
		||||
    graphiql_template = 'graphene/graphiql.html'
 | 
			
		||||
    graphiql_version = "0.11.10"
 | 
			
		||||
    graphiql_template = "graphene/graphiql.html"
 | 
			
		||||
 | 
			
		||||
    schema = None
 | 
			
		||||
    graphiql = False
 | 
			
		||||
| 
						 | 
				
			
			@ -64,8 +63,17 @@ class GraphQLView(View):
 | 
			
		|||
    pretty = False
 | 
			
		||||
    batch = False
 | 
			
		||||
 | 
			
		||||
    def __init__(self, schema=None, executor=None, middleware=None, root_value=None, graphiql=False, pretty=False,
 | 
			
		||||
                 batch=False, backend=None):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        schema=None,
 | 
			
		||||
        executor=None,
 | 
			
		||||
        middleware=None,
 | 
			
		||||
        root_value=None,
 | 
			
		||||
        graphiql=False,
 | 
			
		||||
        pretty=False,
 | 
			
		||||
        batch=False,
 | 
			
		||||
        backend=None,
 | 
			
		||||
    ):
 | 
			
		||||
        if not schema:
 | 
			
		||||
            schema = graphene_settings.SCHEMA
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -86,9 +94,9 @@ class GraphQLView(View):
 | 
			
		|||
        self.backend = backend
 | 
			
		||||
 | 
			
		||||
        assert isinstance(
 | 
			
		||||
            self.schema, GraphQLSchema), 'A Schema is required to be provided to GraphQLView.'
 | 
			
		||||
        assert not all((graphiql, batch)
 | 
			
		||||
                       ), 'Use either graphiql or batch processing'
 | 
			
		||||
            self.schema, GraphQLSchema
 | 
			
		||||
        ), "A Schema is required to be provided to GraphQLView."
 | 
			
		||||
        assert not all((graphiql, batch)), "Use either graphiql or batch processing"
 | 
			
		||||
 | 
			
		||||
    # noinspection PyUnusedLocal
 | 
			
		||||
    def get_root_value(self, request):
 | 
			
		||||
| 
						 | 
				
			
			@ -106,59 +114,59 @@ class GraphQLView(View):
 | 
			
		|||
    @method_decorator(ensure_csrf_cookie)
 | 
			
		||||
    def dispatch(self, request, *args, **kwargs):
 | 
			
		||||
        try:
 | 
			
		||||
            if request.method.lower() not in ('get', 'post'):
 | 
			
		||||
                raise HttpError(HttpResponseNotAllowed(
 | 
			
		||||
                    ['GET', 'POST'], 'GraphQL only supports GET and POST requests.'))
 | 
			
		||||
            if request.method.lower() not in ("get", "post"):
 | 
			
		||||
                raise HttpError(
 | 
			
		||||
                    HttpResponseNotAllowed(
 | 
			
		||||
                        ["GET", "POST"], "GraphQL only supports GET and POST requests."
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            data = self.parse_body(request)
 | 
			
		||||
            show_graphiql = self.graphiql and self.can_display_graphiql(
 | 
			
		||||
                request, data)
 | 
			
		||||
            show_graphiql = self.graphiql and self.can_display_graphiql(request, data)
 | 
			
		||||
 | 
			
		||||
            if self.batch:
 | 
			
		||||
                responses = [self.get_response(request, entry) for entry in data]
 | 
			
		||||
                result = '[{}]'.format(','.join([response[0] for response in responses]))
 | 
			
		||||
                status_code = responses and max(responses, key=lambda response: response[1])[1] or 200
 | 
			
		||||
                result = "[{}]".format(
 | 
			
		||||
                    ",".join([response[0] for response in responses])
 | 
			
		||||
                )
 | 
			
		||||
                status_code = (
 | 
			
		||||
                    responses
 | 
			
		||||
                    and max(responses, key=lambda response: response[1])[1]
 | 
			
		||||
                    or 200
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                result, status_code = self.get_response(
 | 
			
		||||
                    request, data, show_graphiql)
 | 
			
		||||
                result, status_code = self.get_response(request, data, show_graphiql)
 | 
			
		||||
 | 
			
		||||
            if show_graphiql:
 | 
			
		||||
                query, variables, operation_name, id = self.get_graphql_params(
 | 
			
		||||
                    request, data)
 | 
			
		||||
                    request, data
 | 
			
		||||
                )
 | 
			
		||||
                return self.render_graphiql(
 | 
			
		||||
                    request,
 | 
			
		||||
                    graphiql_version=self.graphiql_version,
 | 
			
		||||
                    query=query or '',
 | 
			
		||||
                    variables=json.dumps(variables) or '',
 | 
			
		||||
                    operation_name=operation_name or '',
 | 
			
		||||
                    result=result or ''
 | 
			
		||||
                    query=query or "",
 | 
			
		||||
                    variables=json.dumps(variables) or "",
 | 
			
		||||
                    operation_name=operation_name or "",
 | 
			
		||||
                    result=result or "",
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            return HttpResponse(
 | 
			
		||||
                status=status_code,
 | 
			
		||||
                content=result,
 | 
			
		||||
                content_type='application/json'
 | 
			
		||||
                status=status_code, content=result, content_type="application/json"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        except HttpError as e:
 | 
			
		||||
            response = e.response
 | 
			
		||||
            response['Content-Type'] = 'application/json'
 | 
			
		||||
            response.content = self.json_encode(request, {
 | 
			
		||||
                'errors': [self.format_error(e)]
 | 
			
		||||
            })
 | 
			
		||||
            response["Content-Type"] = "application/json"
 | 
			
		||||
            response.content = self.json_encode(
 | 
			
		||||
                request, {"errors": [self.format_error(e)]}
 | 
			
		||||
            )
 | 
			
		||||
            return response
 | 
			
		||||
 | 
			
		||||
    def get_response(self, request, data, show_graphiql=False):
 | 
			
		||||
        query, variables, operation_name, id = self.get_graphql_params(
 | 
			
		||||
            request, data)
 | 
			
		||||
        query, variables, operation_name, id = self.get_graphql_params(request, data)
 | 
			
		||||
 | 
			
		||||
        execution_result = self.execute_graphql_request(
 | 
			
		||||
            request,
 | 
			
		||||
            data,
 | 
			
		||||
            query,
 | 
			
		||||
            variables,
 | 
			
		||||
            operation_name,
 | 
			
		||||
            show_graphiql
 | 
			
		||||
            request, data, query, variables, operation_name, show_graphiql
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        status_code = 200
 | 
			
		||||
| 
						 | 
				
			
			@ -166,17 +174,18 @@ class GraphQLView(View):
 | 
			
		|||
            response = {}
 | 
			
		||||
 | 
			
		||||
            if execution_result.errors:
 | 
			
		||||
                response['errors'] = [self.format_error(
 | 
			
		||||
                    e) for e in execution_result.errors]
 | 
			
		||||
                response["errors"] = [
 | 
			
		||||
                    self.format_error(e) for e in execution_result.errors
 | 
			
		||||
                ]
 | 
			
		||||
 | 
			
		||||
            if execution_result.invalid:
 | 
			
		||||
                status_code = 400
 | 
			
		||||
            else:
 | 
			
		||||
                response['data'] = execution_result.data
 | 
			
		||||
                response["data"] = execution_result.data
 | 
			
		||||
 | 
			
		||||
            if self.batch:
 | 
			
		||||
                response['id'] = id
 | 
			
		||||
                response['status'] = status_code
 | 
			
		||||
                response["id"] = id
 | 
			
		||||
                response["status"] = status_code
 | 
			
		||||
 | 
			
		||||
            result = self.json_encode(request, response, pretty=show_graphiql)
 | 
			
		||||
        else:
 | 
			
		||||
| 
						 | 
				
			
			@ -188,22 +197,21 @@ class GraphQLView(View):
 | 
			
		|||
        return render(request, self.graphiql_template, data)
 | 
			
		||||
 | 
			
		||||
    def json_encode(self, request, d, pretty=False):
 | 
			
		||||
        if not (self.pretty or pretty) and not request.GET.get('pretty'):
 | 
			
		||||
            return json.dumps(d, separators=(',', ':'))
 | 
			
		||||
        if not (self.pretty or pretty) and not request.GET.get("pretty"):
 | 
			
		||||
            return json.dumps(d, separators=(",", ":"))
 | 
			
		||||
 | 
			
		||||
        return json.dumps(d, sort_keys=True,
 | 
			
		||||
                          indent=2, separators=(',', ': '))
 | 
			
		||||
        return json.dumps(d, sort_keys=True, indent=2, separators=(",", ": "))
 | 
			
		||||
 | 
			
		||||
    def parse_body(self, request):
 | 
			
		||||
        content_type = self.get_content_type(request)
 | 
			
		||||
 | 
			
		||||
        if content_type == 'application/graphql':
 | 
			
		||||
            return {'query': request.body.decode()}
 | 
			
		||||
        if content_type == "application/graphql":
 | 
			
		||||
            return {"query": request.body.decode()}
 | 
			
		||||
 | 
			
		||||
        elif content_type == 'application/json':
 | 
			
		||||
        elif content_type == "application/json":
 | 
			
		||||
            # noinspection PyBroadException
 | 
			
		||||
            try:
 | 
			
		||||
                body = request.body.decode('utf-8')
 | 
			
		||||
                body = request.body.decode("utf-8")
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                raise HttpError(HttpResponseBadRequest(str(e)))
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -211,33 +219,36 @@ class GraphQLView(View):
 | 
			
		|||
                request_json = json.loads(body)
 | 
			
		||||
                if self.batch:
 | 
			
		||||
                    assert isinstance(request_json, list), (
 | 
			
		||||
                        'Batch requests should receive a list, but received {}.'
 | 
			
		||||
                        "Batch requests should receive a list, but received {}."
 | 
			
		||||
                    ).format(repr(request_json))
 | 
			
		||||
                    assert len(request_json) > 0, (
 | 
			
		||||
                        'Received an empty list in the batch request.'
 | 
			
		||||
                    )
 | 
			
		||||
                    assert (
 | 
			
		||||
                        len(request_json) > 0
 | 
			
		||||
                    ), "Received an empty list in the batch request."
 | 
			
		||||
                else:
 | 
			
		||||
                    assert isinstance(request_json, dict), (
 | 
			
		||||
                        'The received data is not a valid JSON query.'
 | 
			
		||||
                    )
 | 
			
		||||
                    assert isinstance(
 | 
			
		||||
                        request_json, dict
 | 
			
		||||
                    ), "The received data is not a valid JSON query."
 | 
			
		||||
                return request_json
 | 
			
		||||
            except AssertionError as e:
 | 
			
		||||
                raise HttpError(HttpResponseBadRequest(str(e)))
 | 
			
		||||
            except (TypeError, ValueError):
 | 
			
		||||
                raise HttpError(HttpResponseBadRequest(
 | 
			
		||||
                    'POST body sent invalid JSON.'))
 | 
			
		||||
                raise HttpError(HttpResponseBadRequest("POST body sent invalid JSON."))
 | 
			
		||||
 | 
			
		||||
        elif content_type in ['application/x-www-form-urlencoded', 'multipart/form-data']:
 | 
			
		||||
        elif content_type in [
 | 
			
		||||
            "application/x-www-form-urlencoded",
 | 
			
		||||
            "multipart/form-data",
 | 
			
		||||
        ]:
 | 
			
		||||
            return request.POST
 | 
			
		||||
 | 
			
		||||
        return {}
 | 
			
		||||
 | 
			
		||||
    def execute_graphql_request(self, request, data, query, variables, operation_name, show_graphiql=False):
 | 
			
		||||
    def execute_graphql_request(
 | 
			
		||||
        self, request, data, query, variables, operation_name, show_graphiql=False
 | 
			
		||||
    ):
 | 
			
		||||
        if not query:
 | 
			
		||||
            if show_graphiql:
 | 
			
		||||
                return None
 | 
			
		||||
            raise HttpError(HttpResponseBadRequest(
 | 
			
		||||
                'Must provide query string.'))
 | 
			
		||||
            raise HttpError(HttpResponseBadRequest("Must provide query string."))
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            backend = self.get_backend(request)
 | 
			
		||||
| 
						 | 
				
			
			@ -245,23 +256,27 @@ class GraphQLView(View):
 | 
			
		|||
        except Exception as e:
 | 
			
		||||
            return ExecutionResult(errors=[e], invalid=True)
 | 
			
		||||
 | 
			
		||||
        if request.method.lower() == 'get':
 | 
			
		||||
        if request.method.lower() == "get":
 | 
			
		||||
            operation_type = document.get_operation_type(operation_name)
 | 
			
		||||
            if operation_type and operation_type != 'query':
 | 
			
		||||
            if operation_type and operation_type != "query":
 | 
			
		||||
                if show_graphiql:
 | 
			
		||||
                    return None
 | 
			
		||||
 | 
			
		||||
                raise HttpError(HttpResponseNotAllowed(
 | 
			
		||||
                    ['POST'], 'Can only perform a {} operation from a POST request.'.format(
 | 
			
		||||
                        operation_type)
 | 
			
		||||
                ))
 | 
			
		||||
                raise HttpError(
 | 
			
		||||
                    HttpResponseNotAllowed(
 | 
			
		||||
                        ["POST"],
 | 
			
		||||
                        "Can only perform a {} operation from a POST request.".format(
 | 
			
		||||
                            operation_type
 | 
			
		||||
                        ),
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            extra_options = {}
 | 
			
		||||
            if self.executor:
 | 
			
		||||
                # We only include it optionally since
 | 
			
		||||
                # executor is not a valid argument in all backends
 | 
			
		||||
                extra_options['executor'] = self.executor
 | 
			
		||||
                extra_options["executor"] = self.executor
 | 
			
		||||
 | 
			
		||||
            return document.execute(
 | 
			
		||||
                root=self.get_root_value(request),
 | 
			
		||||
| 
						 | 
				
			
			@ -276,7 +291,7 @@ class GraphQLView(View):
 | 
			
		|||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def can_display_graphiql(cls, request, data):
 | 
			
		||||
        raw = 'raw' in request.GET or 'raw' in data
 | 
			
		||||
        raw = "raw" in request.GET or "raw" in data
 | 
			
		||||
        return not raw and cls.request_wants_html(request)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
| 
						 | 
				
			
			@ -285,26 +300,32 @@ class GraphQLView(View):
 | 
			
		|||
        accepted_length = len(accepted)
 | 
			
		||||
        # the list will be ordered in preferred first - so we have to make
 | 
			
		||||
        # sure the most preferred gets the highest number
 | 
			
		||||
        html_priority = accepted_length - accepted.index('text/html') if 'text/html' in accepted else 0
 | 
			
		||||
        json_priority = accepted_length - accepted.index('application/json') if 'application/json' in accepted else 0
 | 
			
		||||
        html_priority = (
 | 
			
		||||
            accepted_length - accepted.index("text/html")
 | 
			
		||||
            if "text/html" in accepted
 | 
			
		||||
            else 0
 | 
			
		||||
        )
 | 
			
		||||
        json_priority = (
 | 
			
		||||
            accepted_length - accepted.index("application/json")
 | 
			
		||||
            if "application/json" in accepted
 | 
			
		||||
            else 0
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return html_priority > json_priority
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_graphql_params(request, data):
 | 
			
		||||
        query = request.GET.get('query') or data.get('query')
 | 
			
		||||
        variables = request.GET.get('variables') or data.get('variables')
 | 
			
		||||
        id = request.GET.get('id') or data.get('id')
 | 
			
		||||
        query = request.GET.get("query") or data.get("query")
 | 
			
		||||
        variables = request.GET.get("variables") or data.get("variables")
 | 
			
		||||
        id = request.GET.get("id") or data.get("id")
 | 
			
		||||
 | 
			
		||||
        if variables and isinstance(variables, six.text_type):
 | 
			
		||||
            try:
 | 
			
		||||
                variables = json.loads(variables)
 | 
			
		||||
            except Exception:
 | 
			
		||||
                raise HttpError(HttpResponseBadRequest(
 | 
			
		||||
                    'Variables are invalid JSON.'))
 | 
			
		||||
                raise HttpError(HttpResponseBadRequest("Variables are invalid JSON."))
 | 
			
		||||
 | 
			
		||||
        operation_name = request.GET.get(
 | 
			
		||||
            'operationName') or data.get('operationName')
 | 
			
		||||
        operation_name = request.GET.get("operationName") or data.get("operationName")
 | 
			
		||||
        if operation_name == "null":
 | 
			
		||||
            operation_name = None
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -315,11 +336,10 @@ class GraphQLView(View):
 | 
			
		|||
        if isinstance(error, GraphQLError):
 | 
			
		||||
            return format_graphql_error(error)
 | 
			
		||||
 | 
			
		||||
        return {'message': six.text_type(error)}
 | 
			
		||||
        return {"message": six.text_type(error)}
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_content_type(request):
 | 
			
		||||
        meta = request.META
 | 
			
		||||
        content_type = meta.get(
 | 
			
		||||
            'CONTENT_TYPE', meta.get('HTTP_CONTENT_TYPE', ''))
 | 
			
		||||
        return content_type.split(';', 1)[0].lower()
 | 
			
		||||
        content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", ""))
 | 
			
		||||
        return content_type.split(";", 1)[0].lower()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user