diff --git a/graphene/relay/connection.py b/graphene/relay/connection.py index 46dbba98..b4639ef9 100644 --- a/graphene/relay/connection.py +++ b/graphene/relay/connection.py @@ -90,17 +90,39 @@ class ConnectionMeta(ObjectTypeMeta): class Connection(six.with_metaclass(ConnectionMeta, ObjectType)): - pass + + @classmethod + def Field(cls, *args, **kwargs): # noqa: N802 + return ConnectionField(cls, *args, **kwargs) + + @classmethod + def connection_resolver(cls, resolved, args, context, info): + if isinstance(resolved, cls): + return resolved + + assert isinstance(resolved, Iterable), ( + 'Resolved value from the connection field have to be iterable or instance of {}. ' + 'Received "{}"' + ).format(cls, resolved) + connection = connection_from_list( + resolved, + args, + connection_type=cls, + edge_type=cls.Edge, + pageinfo_type=PageInfo + ) + connection.iterable = resolved + return connection -class IterableConnectionField(Field): +class ConnectionField(Field): def __init__(self, type, *args, **kwargs): kwargs.setdefault('before', String()) kwargs.setdefault('after', String()) kwargs.setdefault('first', Int()) kwargs.setdefault('last', Int()) - super(IterableConnectionField, self).__init__( + super(ConnectionField, self).__init__( type, *args, **kwargs @@ -108,7 +130,7 @@ class IterableConnectionField(Field): @property def type(self): - type = super(IterableConnectionField, self).type + type = super(ConnectionField, self).type if is_node(type): connection_type = type.Connection else: @@ -118,37 +140,16 @@ class IterableConnectionField(Field): ).format(str(self), connection_type) return connection_type - @classmethod - def resolve_connection(cls, connection_type, args, resolved): - if isinstance(resolved, connection_type): - return resolved - - assert isinstance(resolved, Iterable), ( - 'Resolved value from the connection field have to be iterable or instance of {}. ' - 'Received "{}"' - ).format(connection_type, resolved) - connection = connection_from_list( - resolved, - args, - connection_type=connection_type, - edge_type=connection_type.Edge, - pageinfo_type=PageInfo - ) - connection.iterable = resolved - return connection - @classmethod def connection_resolver(cls, resolver, connection_type, root, args, context, info): resolved = resolver(root, args, context, info) - on_resolve = partial(cls.resolve_connection, connection_type, args) + on_resolve = partial(connection_type.connection_resolver, args=args, context=context, info=info) if isinstance(resolved, Promise): return resolved.then(on_resolve) return on_resolve(resolved) def get_resolver(self, parent_resolver): - resolver = super(IterableConnectionField, self).get_resolver(parent_resolver) + resolver = super(ConnectionField, self).get_resolver(parent_resolver) return partial(self.connection_resolver, resolver, self.type) - -ConnectionField = IterableConnectionField