mirror of
				https://github.com/graphql-python/graphene-django.git
				synced 2025-11-04 09:57:53 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			233 lines
		
	
	
		
			7.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			233 lines
		
	
	
		
			7.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from functools import partial
 | 
						|
 | 
						|
import six
 | 
						|
from django.db.models.query import QuerySet
 | 
						|
from graphql_relay.connection.arrayconnection import (
 | 
						|
    connection_from_list_slice,
 | 
						|
    get_offset_with_default,
 | 
						|
)
 | 
						|
from promise import Promise
 | 
						|
 | 
						|
from graphene import NonNull
 | 
						|
from graphene.relay import ConnectionField, PageInfo
 | 
						|
from graphene.types import Field, List
 | 
						|
 | 
						|
from .settings import graphene_settings
 | 
						|
from .utils import maybe_queryset
 | 
						|
 | 
						|
 | 
						|
class DjangoListField(Field):
 | 
						|
    def __init__(self, _type, *args, **kwargs):
 | 
						|
        from .types import DjangoObjectType
 | 
						|
 | 
						|
        if isinstance(_type, NonNull):
 | 
						|
            _type = _type.of_type
 | 
						|
 | 
						|
        # Django would never return a Set of None  vvvvvvv
 | 
						|
        super(DjangoListField, self).__init__(List(NonNull(_type)), *args, **kwargs)
 | 
						|
 | 
						|
        assert issubclass(
 | 
						|
            self._underlying_type, DjangoObjectType
 | 
						|
        ), "DjangoListField only accepts DjangoObjectType types"
 | 
						|
 | 
						|
    @property
 | 
						|
    def _underlying_type(self):
 | 
						|
        _type = self._type
 | 
						|
        while hasattr(_type, "of_type"):
 | 
						|
            _type = _type.of_type
 | 
						|
        return _type
 | 
						|
 | 
						|
    @property
 | 
						|
    def model(self):
 | 
						|
        return self._underlying_type._meta.model
 | 
						|
 | 
						|
    def get_default_queryset(self):
 | 
						|
        return self.model._default_manager.get_queryset()
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def list_resolver(
 | 
						|
        django_object_type, resolver, default_queryset, root, info, **args
 | 
						|
    ):
 | 
						|
        queryset = maybe_queryset(resolver(root, info, **args))
 | 
						|
        if queryset is None:
 | 
						|
            queryset = default_queryset
 | 
						|
 | 
						|
        if isinstance(queryset, QuerySet):
 | 
						|
            # Pass queryset to the DjangoObjectType get_queryset method
 | 
						|
            queryset = maybe_queryset(django_object_type.get_queryset(queryset, info))
 | 
						|
 | 
						|
        return queryset
 | 
						|
 | 
						|
    def get_resolver(self, parent_resolver):
 | 
						|
        _type = self.type
 | 
						|
        if isinstance(_type, NonNull):
 | 
						|
            _type = _type.of_type
 | 
						|
        django_object_type = _type.of_type.of_type
 | 
						|
        return partial(
 | 
						|
            self.list_resolver,
 | 
						|
            django_object_type,
 | 
						|
            parent_resolver,
 | 
						|
            self.get_default_queryset(),
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
class DjangoConnectionField(ConnectionField):
 | 
						|
    def __init__(self, *args, **kwargs):
 | 
						|
        self.on = kwargs.pop("on", False)
 | 
						|
        self.max_limit = kwargs.pop(
 | 
						|
            "max_limit", graphene_settings.RELAY_CONNECTION_MAX_LIMIT
 | 
						|
        )
 | 
						|
        self.enforce_first_or_last = kwargs.pop(
 | 
						|
            "enforce_first_or_last",
 | 
						|
            graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST,
 | 
						|
        )
 | 
						|
        super(DjangoConnectionField, self).__init__(*args, **kwargs)
 | 
						|
 | 
						|
    @property
 | 
						|
    def type(self):
 | 
						|
        from .types import DjangoObjectType
 | 
						|
 | 
						|
        _type = super(ConnectionField, self).type
 | 
						|
        non_null = False
 | 
						|
        if isinstance(_type, NonNull):
 | 
						|
            _type = _type.of_type
 | 
						|
            non_null = True
 | 
						|
        assert issubclass(
 | 
						|
            _type, DjangoObjectType
 | 
						|
        ), "DjangoConnectionField only accepts DjangoObjectType types"
 | 
						|
        assert _type._meta.connection, "The type {} doesn't have a connection".format(
 | 
						|
            _type.__name__
 | 
						|
        )
 | 
						|
        connection_type = _type._meta.connection
 | 
						|
        if non_null:
 | 
						|
            return NonNull(connection_type)
 | 
						|
        return connection_type
 | 
						|
 | 
						|
    @property
 | 
						|
    def connection_type(self):
 | 
						|
        type = self.type
 | 
						|
        if isinstance(type, NonNull):
 | 
						|
            return type.of_type
 | 
						|
        return type
 | 
						|
 | 
						|
    @property
 | 
						|
    def node_type(self):
 | 
						|
        return self.connection_type._meta.node
 | 
						|
 | 
						|
    @property
 | 
						|
    def model(self):
 | 
						|
        return self.node_type._meta.model
 | 
						|
 | 
						|
    def get_manager(self):
 | 
						|
        if self.on:
 | 
						|
            return getattr(self.model, self.on)
 | 
						|
        else:
 | 
						|
            return self.model._default_manager
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def resolve_queryset(cls, connection, queryset, info, args):
 | 
						|
        # queryset is the resolved iterable from ObjectType
 | 
						|
        return connection._meta.node.get_queryset(queryset, info)
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def resolve_connection(cls, connection, args, iterable, max_limit=None):
 | 
						|
        iterable = maybe_queryset(iterable)
 | 
						|
 | 
						|
        if isinstance(iterable, QuerySet):
 | 
						|
            list_length = iterable.count()
 | 
						|
            list_slice_length = (
 | 
						|
                min(max_limit, list_length) if max_limit is not None else list_length
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            list_length = len(iterable)
 | 
						|
            list_slice_length = (
 | 
						|
                min(max_limit, list_length) if max_limit is not None else list_length
 | 
						|
            )
 | 
						|
 | 
						|
        # If after is higher than list_length, connection_from_list_slice
 | 
						|
        # would try to do a negative slicing which makes django throw an
 | 
						|
        # AssertionError
 | 
						|
        after = min(get_offset_with_default(args.get("after"), -1) + 1, list_length)
 | 
						|
 | 
						|
        if max_limit is not None and "first" not in args:
 | 
						|
            args["first"] = max_limit
 | 
						|
 | 
						|
        connection = connection_from_list_slice(
 | 
						|
            iterable[after:],
 | 
						|
            args,
 | 
						|
            slice_start=after,
 | 
						|
            list_length=list_length,
 | 
						|
            list_slice_length=list_slice_length,
 | 
						|
            connection_type=connection,
 | 
						|
            edge_type=connection.Edge,
 | 
						|
            pageinfo_type=PageInfo,
 | 
						|
        )
 | 
						|
        connection.iterable = iterable
 | 
						|
        connection.length = list_length
 | 
						|
        return connection
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def connection_resolver(
 | 
						|
        cls,
 | 
						|
        resolver,
 | 
						|
        connection,
 | 
						|
        default_manager,
 | 
						|
        queryset_resolver,
 | 
						|
        max_limit,
 | 
						|
        enforce_first_or_last,
 | 
						|
        root,
 | 
						|
        info,
 | 
						|
        **args
 | 
						|
    ):
 | 
						|
        first = args.get("first")
 | 
						|
        last = args.get("last")
 | 
						|
 | 
						|
        if enforce_first_or_last:
 | 
						|
            assert first or last, (
 | 
						|
                "You must provide a `first` or `last` value to properly paginate the `{}` connection."
 | 
						|
            ).format(info.field_name)
 | 
						|
 | 
						|
        if max_limit:
 | 
						|
            if first:
 | 
						|
                assert first <= max_limit, (
 | 
						|
                    "Requesting {} records on the `{}` connection exceeds the `first` limit of {} records."
 | 
						|
                ).format(first, info.field_name, max_limit)
 | 
						|
                args["first"] = min(first, max_limit)
 | 
						|
 | 
						|
            if last:
 | 
						|
                assert last <= max_limit, (
 | 
						|
                    "Requesting {} records on the `{}` connection exceeds the `last` limit of {} records."
 | 
						|
                ).format(last, info.field_name, max_limit)
 | 
						|
                args["last"] = min(last, max_limit)
 | 
						|
 | 
						|
        # eventually leads to DjangoObjectType's get_queryset (accepts queryset)
 | 
						|
        # or a resolve_foo (does not accept queryset)
 | 
						|
        iterable = resolver(root, info, **args)
 | 
						|
        if iterable is None:
 | 
						|
            iterable = default_manager
 | 
						|
        # thus the iterable gets refiltered by resolve_queryset
 | 
						|
        # but iterable might be promise
 | 
						|
        iterable = queryset_resolver(connection, iterable, info, args)
 | 
						|
        on_resolve = partial(
 | 
						|
            cls.resolve_connection, connection, args, max_limit=max_limit
 | 
						|
        )
 | 
						|
 | 
						|
        if Promise.is_thenable(iterable):
 | 
						|
            return Promise.resolve(iterable).then(on_resolve)
 | 
						|
 | 
						|
        return on_resolve(iterable)
 | 
						|
 | 
						|
    def get_resolver(self, parent_resolver):
 | 
						|
        return partial(
 | 
						|
            self.connection_resolver,
 | 
						|
            parent_resolver,
 | 
						|
            self.connection_type,
 | 
						|
            self.get_manager(),
 | 
						|
            self.get_queryset_resolver(),
 | 
						|
            self.max_limit,
 | 
						|
            self.enforce_first_or_last,
 | 
						|
        )
 | 
						|
 | 
						|
    def get_queryset_resolver(self):
 | 
						|
        return self.resolve_queryset
 |