mirror of
https://github.com/graphql-python/graphene.git
synced 2024-11-25 11:03:58 +03:00
Vendor DataLoader
from aiodataloader
and move get_event_loop()
out of __init__
function. (#1459)
* Vendor DataLoader from aiodataloader and also move get_event_loop behavior from `__init__` to a property which only gets resolved when actually needed (this will solve PyTest-related to early get_event_loop() issues) * Added DataLoader's specific tests * plug `loop` parameter into `self._loop`, so that we still have the ability to pass in a custom event loop, if needed. Co-authored-by: Erik Wrede <erikwrede2@gmail.com>
This commit is contained in:
parent
20219fdc1b
commit
694c1db21e
281
graphene/utils/dataloader.py
Normal file
281
graphene/utils/dataloader.py
Normal file
|
@ -0,0 +1,281 @@
|
|||
from asyncio import (
|
||||
gather,
|
||||
ensure_future,
|
||||
get_event_loop,
|
||||
iscoroutine,
|
||||
iscoroutinefunction,
|
||||
)
|
||||
from collections import namedtuple
|
||||
from collections.abc import Iterable
|
||||
from functools import partial
|
||||
|
||||
from typing import List # flake8: noqa
|
||||
|
||||
Loader = namedtuple("Loader", "key,future")
|
||||
|
||||
|
||||
def iscoroutinefunctionorpartial(fn):
|
||||
return iscoroutinefunction(fn.func if isinstance(fn, partial) else fn)
|
||||
|
||||
|
||||
class DataLoader(object):
|
||||
batch = True
|
||||
max_batch_size = None # type: int
|
||||
cache = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_load_fn=None,
|
||||
batch=None,
|
||||
max_batch_size=None,
|
||||
cache=None,
|
||||
get_cache_key=None,
|
||||
cache_map=None,
|
||||
loop=None,
|
||||
):
|
||||
|
||||
self._loop = loop
|
||||
|
||||
if batch_load_fn is not None:
|
||||
self.batch_load_fn = batch_load_fn
|
||||
|
||||
assert iscoroutinefunctionorpartial(
|
||||
self.batch_load_fn
|
||||
), "batch_load_fn must be coroutine. Received: {}".format(self.batch_load_fn)
|
||||
|
||||
if not callable(self.batch_load_fn):
|
||||
raise TypeError( # pragma: no cover
|
||||
(
|
||||
"DataLoader must be have a batch_load_fn which accepts "
|
||||
"Iterable<key> and returns Future<Iterable<value>>, but got: {}."
|
||||
).format(batch_load_fn)
|
||||
)
|
||||
|
||||
if batch is not None:
|
||||
self.batch = batch # pragma: no cover
|
||||
|
||||
if max_batch_size is not None:
|
||||
self.max_batch_size = max_batch_size
|
||||
|
||||
if cache is not None:
|
||||
self.cache = cache # pragma: no cover
|
||||
|
||||
self.get_cache_key = get_cache_key or (lambda x: x)
|
||||
|
||||
self._cache = cache_map if cache_map is not None else {}
|
||||
self._queue = [] # type: List[Loader]
|
||||
|
||||
@property
|
||||
def loop(self):
|
||||
if not self._loop:
|
||||
self._loop = get_event_loop()
|
||||
|
||||
return self._loop
|
||||
|
||||
def load(self, key=None):
|
||||
"""
|
||||
Loads a key, returning a `Future` for the value represented by that key.
|
||||
"""
|
||||
if key is None:
|
||||
raise TypeError( # pragma: no cover
|
||||
(
|
||||
"The loader.load() function must be called with a value, "
|
||||
"but got: {}."
|
||||
).format(key)
|
||||
)
|
||||
|
||||
cache_key = self.get_cache_key(key)
|
||||
|
||||
# If caching and there is a cache-hit, return cached Future.
|
||||
if self.cache:
|
||||
cached_result = self._cache.get(cache_key)
|
||||
if cached_result:
|
||||
return cached_result
|
||||
|
||||
# Otherwise, produce a new Future for this value.
|
||||
future = self.loop.create_future()
|
||||
# If caching, cache this Future.
|
||||
if self.cache:
|
||||
self._cache[cache_key] = future
|
||||
|
||||
self.do_resolve_reject(key, future)
|
||||
return future
|
||||
|
||||
def do_resolve_reject(self, key, future):
|
||||
# Enqueue this Future to be dispatched.
|
||||
self._queue.append(Loader(key=key, future=future))
|
||||
# Determine if a dispatch of this queue should be scheduled.
|
||||
# A single dispatch should be scheduled per queue at the time when the
|
||||
# queue changes from "empty" to "full".
|
||||
if len(self._queue) == 1:
|
||||
if self.batch:
|
||||
# If batching, schedule a task to dispatch the queue.
|
||||
enqueue_post_future_job(self.loop, self)
|
||||
else:
|
||||
# Otherwise dispatch the (queue of one) immediately.
|
||||
dispatch_queue(self) # pragma: no cover
|
||||
|
||||
def load_many(self, keys):
|
||||
"""
|
||||
Loads multiple keys, returning a list of values
|
||||
|
||||
>>> a, b = await my_loader.load_many([ 'a', 'b' ])
|
||||
|
||||
This is equivalent to the more verbose:
|
||||
|
||||
>>> a, b = await gather(
|
||||
>>> my_loader.load('a'),
|
||||
>>> my_loader.load('b')
|
||||
>>> )
|
||||
"""
|
||||
if not isinstance(keys, Iterable):
|
||||
raise TypeError( # pragma: no cover
|
||||
(
|
||||
"The loader.load_many() function must be called with Iterable<key> "
|
||||
"but got: {}."
|
||||
).format(keys)
|
||||
)
|
||||
|
||||
return gather(*[self.load(key) for key in keys])
|
||||
|
||||
def clear(self, key):
|
||||
"""
|
||||
Clears the value at `key` from the cache, if it exists. Returns itself for
|
||||
method chaining.
|
||||
"""
|
||||
cache_key = self.get_cache_key(key)
|
||||
self._cache.pop(cache_key, None)
|
||||
return self
|
||||
|
||||
def clear_all(self):
|
||||
"""
|
||||
Clears the entire cache. To be used when some event results in unknown
|
||||
invalidations across this particular `DataLoader`. Returns itself for
|
||||
method chaining.
|
||||
"""
|
||||
self._cache.clear()
|
||||
return self
|
||||
|
||||
def prime(self, key, value):
|
||||
"""
|
||||
Adds the provied key and value to the cache. If the key already exists, no
|
||||
change is made. Returns itself for method chaining.
|
||||
"""
|
||||
cache_key = self.get_cache_key(key)
|
||||
|
||||
# Only add the key if it does not already exist.
|
||||
if cache_key not in self._cache:
|
||||
# Cache a rejected future if the value is an Error, in order to match
|
||||
# the behavior of load(key).
|
||||
future = self.loop.create_future()
|
||||
if isinstance(value, Exception):
|
||||
future.set_exception(value)
|
||||
else:
|
||||
future.set_result(value)
|
||||
|
||||
self._cache[cache_key] = future
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def enqueue_post_future_job(loop, loader):
|
||||
async def dispatch():
|
||||
dispatch_queue(loader)
|
||||
|
||||
loop.call_soon(ensure_future, dispatch())
|
||||
|
||||
|
||||
def get_chunks(iterable_obj, chunk_size=1):
|
||||
chunk_size = max(1, chunk_size)
|
||||
return (
|
||||
iterable_obj[i : i + chunk_size]
|
||||
for i in range(0, len(iterable_obj), chunk_size)
|
||||
)
|
||||
|
||||
|
||||
def dispatch_queue(loader):
|
||||
"""
|
||||
Given the current state of a Loader instance, perform a batch load
|
||||
from its current queue.
|
||||
"""
|
||||
# Take the current loader queue, replacing it with an empty queue.
|
||||
queue = loader._queue
|
||||
loader._queue = []
|
||||
|
||||
# If a max_batch_size was provided and the queue is longer, then segment the
|
||||
# queue into multiple batches, otherwise treat the queue as a single batch.
|
||||
max_batch_size = loader.max_batch_size
|
||||
|
||||
if max_batch_size and max_batch_size < len(queue):
|
||||
chunks = get_chunks(queue, max_batch_size)
|
||||
for chunk in chunks:
|
||||
ensure_future(dispatch_queue_batch(loader, chunk))
|
||||
else:
|
||||
ensure_future(dispatch_queue_batch(loader, queue))
|
||||
|
||||
|
||||
async def dispatch_queue_batch(loader, queue):
|
||||
# Collect all keys to be loaded in this dispatch
|
||||
keys = [loaded.key for loaded in queue]
|
||||
|
||||
# Call the provided batch_load_fn for this loader with the loader queue's keys.
|
||||
batch_future = loader.batch_load_fn(keys)
|
||||
|
||||
# Assert the expected response from batch_load_fn
|
||||
if not batch_future or not iscoroutine(batch_future):
|
||||
return failed_dispatch( # pragma: no cover
|
||||
loader,
|
||||
queue,
|
||||
TypeError(
|
||||
(
|
||||
"DataLoader must be constructed with a function which accepts "
|
||||
"Iterable<key> and returns Future<Iterable<value>>, but the function did "
|
||||
"not return a Coroutine: {}."
|
||||
).format(batch_future)
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
values = await batch_future
|
||||
if not isinstance(values, Iterable):
|
||||
raise TypeError( # pragma: no cover
|
||||
(
|
||||
"DataLoader must be constructed with a function which accepts "
|
||||
"Iterable<key> and returns Future<Iterable<value>>, but the function did "
|
||||
"not return a Future of a Iterable: {}."
|
||||
).format(values)
|
||||
)
|
||||
|
||||
values = list(values)
|
||||
if len(values) != len(keys):
|
||||
raise TypeError( # pragma: no cover
|
||||
(
|
||||
"DataLoader must be constructed with a function which accepts "
|
||||
"Iterable<key> and returns Future<Iterable<value>>, but the function did "
|
||||
"not return a Future of a Iterable with the same length as the Iterable "
|
||||
"of keys."
|
||||
"\n\nKeys:\n{}"
|
||||
"\n\nValues:\n{}"
|
||||
).format(keys, values)
|
||||
)
|
||||
|
||||
# Step through the values, resolving or rejecting each Future in the
|
||||
# loaded queue.
|
||||
for loaded, value in zip(queue, values):
|
||||
if isinstance(value, Exception):
|
||||
loaded.future.set_exception(value)
|
||||
else:
|
||||
loaded.future.set_result(value)
|
||||
|
||||
except Exception as e:
|
||||
return failed_dispatch(loader, queue, e)
|
||||
|
||||
|
||||
def failed_dispatch(loader, queue, error):
|
||||
"""
|
||||
Do not cache individual loads if the entire batch dispatch fails,
|
||||
but still reject each request so they do not hang.
|
||||
"""
|
||||
for loaded in queue:
|
||||
loader.clear(loaded.key)
|
||||
loaded.future.set_exception(error)
|
452
graphene/utils/tests/test_dataloader.py
Normal file
452
graphene/utils/tests/test_dataloader.py
Normal file
|
@ -0,0 +1,452 @@
|
|||
from asyncio import gather
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
from unittest.mock import Mock
|
||||
|
||||
from graphene.utils.dataloader import DataLoader
|
||||
from pytest import mark, raises
|
||||
|
||||
from graphene import ObjectType, String, Schema, Field, List
|
||||
|
||||
CHARACTERS = {
|
||||
"1": {"name": "Luke Skywalker", "sibling": "3"},
|
||||
"2": {"name": "Darth Vader", "sibling": None},
|
||||
"3": {"name": "Leia Organa", "sibling": "1"},
|
||||
}
|
||||
|
||||
get_character = Mock(side_effect=lambda character_id: CHARACTERS[character_id])
|
||||
|
||||
|
||||
class CharacterType(ObjectType):
|
||||
name = String()
|
||||
sibling = Field(lambda: CharacterType)
|
||||
|
||||
async def resolve_sibling(character, info):
|
||||
if character["sibling"]:
|
||||
return await info.context.character_loader.load(character["sibling"])
|
||||
return None
|
||||
|
||||
|
||||
class Query(ObjectType):
|
||||
skywalker_family = List(CharacterType)
|
||||
|
||||
async def resolve_skywalker_family(_, info):
|
||||
return await info.context.character_loader.load_many(["1", "2", "3"])
|
||||
|
||||
|
||||
mock_batch_load_fn = Mock(
|
||||
side_effect=lambda character_ids: [get_character(id) for id in character_ids]
|
||||
)
|
||||
|
||||
|
||||
class CharacterLoader(DataLoader):
|
||||
async def batch_load_fn(self, character_ids):
|
||||
return mock_batch_load_fn(character_ids)
|
||||
|
||||
|
||||
Context = namedtuple("Context", "character_loader")
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_basic_dataloader():
|
||||
schema = Schema(query=Query)
|
||||
|
||||
character_loader = CharacterLoader()
|
||||
context = Context(character_loader=character_loader)
|
||||
|
||||
query = """
|
||||
{
|
||||
skywalkerFamily {
|
||||
name
|
||||
sibling {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
result = await schema.execute_async(query, context=context)
|
||||
|
||||
assert not result.errors
|
||||
assert result.data == {
|
||||
"skywalkerFamily": [
|
||||
{"name": "Luke Skywalker", "sibling": {"name": "Leia Organa"}},
|
||||
{"name": "Darth Vader", "sibling": None},
|
||||
{"name": "Leia Organa", "sibling": {"name": "Luke Skywalker"}},
|
||||
]
|
||||
}
|
||||
|
||||
assert mock_batch_load_fn.call_count == 1
|
||||
assert get_character.call_count == 3
|
||||
|
||||
|
||||
def id_loader(**options):
|
||||
load_calls = []
|
||||
|
||||
async def default_resolve(x):
|
||||
return x
|
||||
|
||||
resolve = options.pop("resolve", default_resolve)
|
||||
|
||||
async def fn(keys):
|
||||
load_calls.append(keys)
|
||||
return await resolve(keys)
|
||||
# return keys
|
||||
|
||||
identity_loader = DataLoader(fn, **options)
|
||||
return identity_loader, load_calls
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_build_a_simple_data_loader():
|
||||
async def call_fn(keys):
|
||||
return keys
|
||||
|
||||
identity_loader = DataLoader(call_fn)
|
||||
|
||||
promise1 = identity_loader.load(1)
|
||||
|
||||
value1 = await promise1
|
||||
assert value1 == 1
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_can_build_a_data_loader_from_a_partial():
|
||||
value_map = {1: "one"}
|
||||
|
||||
async def call_fn(context, keys):
|
||||
return [context.get(key) for key in keys]
|
||||
|
||||
partial_fn = partial(call_fn, value_map)
|
||||
identity_loader = DataLoader(partial_fn)
|
||||
|
||||
promise1 = identity_loader.load(1)
|
||||
|
||||
value1 = await promise1
|
||||
assert value1 == "one"
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_supports_loading_multiple_keys_in_one_call():
|
||||
async def call_fn(keys):
|
||||
return keys
|
||||
|
||||
identity_loader = DataLoader(call_fn)
|
||||
|
||||
promise_all = identity_loader.load_many([1, 2])
|
||||
|
||||
values = await promise_all
|
||||
assert values == [1, 2]
|
||||
|
||||
promise_all = identity_loader.load_many([])
|
||||
|
||||
values = await promise_all
|
||||
assert values == []
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_batches_multiple_requests():
|
||||
identity_loader, load_calls = id_loader()
|
||||
|
||||
promise1 = identity_loader.load(1)
|
||||
promise2 = identity_loader.load(2)
|
||||
|
||||
p = gather(promise1, promise2)
|
||||
|
||||
value1, value2 = await p
|
||||
|
||||
assert value1 == 1
|
||||
assert value2 == 2
|
||||
|
||||
assert load_calls == [[1, 2]]
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_batches_multiple_requests_with_max_batch_sizes():
|
||||
identity_loader, load_calls = id_loader(max_batch_size=2)
|
||||
|
||||
promise1 = identity_loader.load(1)
|
||||
promise2 = identity_loader.load(2)
|
||||
promise3 = identity_loader.load(3)
|
||||
|
||||
p = gather(promise1, promise2, promise3)
|
||||
|
||||
value1, value2, value3 = await p
|
||||
|
||||
assert value1 == 1
|
||||
assert value2 == 2
|
||||
assert value3 == 3
|
||||
|
||||
assert load_calls == [[1, 2], [3]]
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_coalesces_identical_requests():
|
||||
identity_loader, load_calls = id_loader()
|
||||
|
||||
promise1 = identity_loader.load(1)
|
||||
promise2 = identity_loader.load(1)
|
||||
|
||||
assert promise1 == promise2
|
||||
p = gather(promise1, promise2)
|
||||
|
||||
value1, value2 = await p
|
||||
|
||||
assert value1 == 1
|
||||
assert value2 == 1
|
||||
|
||||
assert load_calls == [[1]]
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_caches_repeated_requests():
|
||||
identity_loader, load_calls = id_loader()
|
||||
|
||||
a, b = await gather(identity_loader.load("A"), identity_loader.load("B"))
|
||||
|
||||
assert a == "A"
|
||||
assert b == "B"
|
||||
|
||||
assert load_calls == [["A", "B"]]
|
||||
|
||||
a2, c = await gather(identity_loader.load("A"), identity_loader.load("C"))
|
||||
|
||||
assert a2 == "A"
|
||||
assert c == "C"
|
||||
|
||||
assert load_calls == [["A", "B"], ["C"]]
|
||||
|
||||
a3, b2, c2 = await gather(
|
||||
identity_loader.load("A"), identity_loader.load("B"), identity_loader.load("C")
|
||||
)
|
||||
|
||||
assert a3 == "A"
|
||||
assert b2 == "B"
|
||||
assert c2 == "C"
|
||||
|
||||
assert load_calls == [["A", "B"], ["C"]]
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_clears_single_value_in_loader():
|
||||
identity_loader, load_calls = id_loader()
|
||||
|
||||
a, b = await gather(identity_loader.load("A"), identity_loader.load("B"))
|
||||
|
||||
assert a == "A"
|
||||
assert b == "B"
|
||||
|
||||
assert load_calls == [["A", "B"]]
|
||||
|
||||
identity_loader.clear("A")
|
||||
|
||||
a2, b2 = await gather(identity_loader.load("A"), identity_loader.load("B"))
|
||||
|
||||
assert a2 == "A"
|
||||
assert b2 == "B"
|
||||
|
||||
assert load_calls == [["A", "B"], ["A"]]
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_clears_all_values_in_loader():
|
||||
identity_loader, load_calls = id_loader()
|
||||
|
||||
a, b = await gather(identity_loader.load("A"), identity_loader.load("B"))
|
||||
|
||||
assert a == "A"
|
||||
assert b == "B"
|
||||
|
||||
assert load_calls == [["A", "B"]]
|
||||
|
||||
identity_loader.clear_all()
|
||||
|
||||
a2, b2 = await gather(identity_loader.load("A"), identity_loader.load("B"))
|
||||
|
||||
assert a2 == "A"
|
||||
assert b2 == "B"
|
||||
|
||||
assert load_calls == [["A", "B"], ["A", "B"]]
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_allows_priming_the_cache():
|
||||
identity_loader, load_calls = id_loader()
|
||||
|
||||
identity_loader.prime("A", "A")
|
||||
|
||||
a, b = await gather(identity_loader.load("A"), identity_loader.load("B"))
|
||||
|
||||
assert a == "A"
|
||||
assert b == "B"
|
||||
|
||||
assert load_calls == [["B"]]
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_does_not_prime_keys_that_already_exist():
|
||||
identity_loader, load_calls = id_loader()
|
||||
|
||||
identity_loader.prime("A", "X")
|
||||
|
||||
a1 = await identity_loader.load("A")
|
||||
b1 = await identity_loader.load("B")
|
||||
|
||||
assert a1 == "X"
|
||||
assert b1 == "B"
|
||||
|
||||
identity_loader.prime("A", "Y")
|
||||
identity_loader.prime("B", "Y")
|
||||
|
||||
a2 = await identity_loader.load("A")
|
||||
b2 = await identity_loader.load("B")
|
||||
|
||||
assert a2 == "X"
|
||||
assert b2 == "B"
|
||||
|
||||
assert load_calls == [["B"]]
|
||||
|
||||
|
||||
# # Represents Errors
|
||||
@mark.asyncio
|
||||
async def test_resolves_to_error_to_indicate_failure():
|
||||
async def resolve(keys):
|
||||
mapped_keys = [
|
||||
key if key % 2 == 0 else Exception("Odd: {}".format(key)) for key in keys
|
||||
]
|
||||
return mapped_keys
|
||||
|
||||
even_loader, load_calls = id_loader(resolve=resolve)
|
||||
|
||||
with raises(Exception) as exc_info:
|
||||
await even_loader.load(1)
|
||||
|
||||
assert str(exc_info.value) == "Odd: 1"
|
||||
|
||||
value2 = await even_loader.load(2)
|
||||
assert value2 == 2
|
||||
assert load_calls == [[1], [2]]
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_can_represent_failures_and_successes_simultaneously():
|
||||
async def resolve(keys):
|
||||
mapped_keys = [
|
||||
key if key % 2 == 0 else Exception("Odd: {}".format(key)) for key in keys
|
||||
]
|
||||
return mapped_keys
|
||||
|
||||
even_loader, load_calls = id_loader(resolve=resolve)
|
||||
|
||||
promise1 = even_loader.load(1)
|
||||
promise2 = even_loader.load(2)
|
||||
|
||||
with raises(Exception) as exc_info:
|
||||
await promise1
|
||||
|
||||
assert str(exc_info.value) == "Odd: 1"
|
||||
value2 = await promise2
|
||||
assert value2 == 2
|
||||
assert load_calls == [[1, 2]]
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_caches_failed_fetches():
|
||||
async def resolve(keys):
|
||||
mapped_keys = [Exception("Error: {}".format(key)) for key in keys]
|
||||
return mapped_keys
|
||||
|
||||
error_loader, load_calls = id_loader(resolve=resolve)
|
||||
|
||||
with raises(Exception) as exc_info:
|
||||
await error_loader.load(1)
|
||||
|
||||
assert str(exc_info.value) == "Error: 1"
|
||||
|
||||
with raises(Exception) as exc_info:
|
||||
await error_loader.load(1)
|
||||
|
||||
assert str(exc_info.value) == "Error: 1"
|
||||
|
||||
assert load_calls == [[1]]
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_caches_failed_fetches_2():
|
||||
identity_loader, load_calls = id_loader()
|
||||
|
||||
identity_loader.prime(1, Exception("Error: 1"))
|
||||
|
||||
with raises(Exception) as _:
|
||||
await identity_loader.load(1)
|
||||
|
||||
assert load_calls == []
|
||||
|
||||
|
||||
# It is resilient to job queue ordering
|
||||
@mark.asyncio
|
||||
async def test_batches_loads_occuring_within_promises():
|
||||
identity_loader, load_calls = id_loader()
|
||||
|
||||
async def load_b_1():
|
||||
return await load_b_2()
|
||||
|
||||
async def load_b_2():
|
||||
return await identity_loader.load("B")
|
||||
|
||||
values = await gather(identity_loader.load("A"), load_b_1())
|
||||
|
||||
assert values == ["A", "B"]
|
||||
|
||||
assert load_calls == [["A", "B"]]
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_catches_error_if_loader_resolver_fails():
|
||||
exc = Exception("AOH!")
|
||||
|
||||
def do_resolve(x):
|
||||
raise exc
|
||||
|
||||
a_loader, a_load_calls = id_loader(resolve=do_resolve)
|
||||
|
||||
with raises(Exception) as exc_info:
|
||||
await a_loader.load("A1")
|
||||
|
||||
assert exc_info.value == exc
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_can_call_a_loader_from_a_loader():
|
||||
deep_loader, deep_load_calls = id_loader()
|
||||
a_loader, a_load_calls = id_loader(
|
||||
resolve=lambda keys: deep_loader.load(tuple(keys))
|
||||
)
|
||||
b_loader, b_load_calls = id_loader(
|
||||
resolve=lambda keys: deep_loader.load(tuple(keys))
|
||||
)
|
||||
|
||||
a1, b1, a2, b2 = await gather(
|
||||
a_loader.load("A1"),
|
||||
b_loader.load("B1"),
|
||||
a_loader.load("A2"),
|
||||
b_loader.load("B2"),
|
||||
)
|
||||
|
||||
assert a1 == "A1"
|
||||
assert b1 == "B1"
|
||||
assert a2 == "A2"
|
||||
assert b2 == "B2"
|
||||
|
||||
assert a_load_calls == [["A1", "A2"]]
|
||||
assert b_load_calls == [["B1", "B2"]]
|
||||
assert deep_load_calls == [[("A1", "A2"), ("B1", "B2")]]
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_dataloader_clear_with_missing_key_works():
|
||||
async def do_resolve(x):
|
||||
return x
|
||||
|
||||
a_loader, a_load_calls = id_loader(resolve=do_resolve)
|
||||
assert a_loader.clear("A1") == a_loader
|
|
@ -2,6 +2,10 @@
|
|||
exclude = setup.py,docs/*,*/examples/*,graphene/pyutils/*,tests
|
||||
max-line-length = 120
|
||||
|
||||
# This is a specific ignore for Black+Flake8
|
||||
# source: https://black.readthedocs.io/en/stable/guides/using_black_with_other_tools.html#id1
|
||||
extend-ignore = E203
|
||||
|
||||
[coverage:run]
|
||||
omit = graphene/pyutils/*,*/tests/*,graphene/types/scalars.py
|
||||
|
||||
|
|
1
setup.py
1
setup.py
|
@ -53,7 +53,6 @@ tests_require = [
|
|||
"snapshottest>=0.6,<1",
|
||||
"coveralls>=3.3,<4",
|
||||
"promise>=2.3,<3",
|
||||
"aiodataloader<1",
|
||||
"mock>=4,<5",
|
||||
"pytz==2022.1",
|
||||
"iso8601>=1,<2",
|
||||
|
|
|
@ -1,79 +0,0 @@
|
|||
from collections import namedtuple
|
||||
from unittest.mock import Mock
|
||||
from pytest import mark
|
||||
from aiodataloader import DataLoader
|
||||
|
||||
from graphene import ObjectType, String, Schema, Field, List
|
||||
|
||||
|
||||
CHARACTERS = {
|
||||
"1": {"name": "Luke Skywalker", "sibling": "3"},
|
||||
"2": {"name": "Darth Vader", "sibling": None},
|
||||
"3": {"name": "Leia Organa", "sibling": "1"},
|
||||
}
|
||||
|
||||
|
||||
get_character = Mock(side_effect=lambda character_id: CHARACTERS[character_id])
|
||||
|
||||
|
||||
class CharacterType(ObjectType):
|
||||
name = String()
|
||||
sibling = Field(lambda: CharacterType)
|
||||
|
||||
async def resolve_sibling(character, info):
|
||||
if character["sibling"]:
|
||||
return await info.context.character_loader.load(character["sibling"])
|
||||
return None
|
||||
|
||||
|
||||
class Query(ObjectType):
|
||||
skywalker_family = List(CharacterType)
|
||||
|
||||
async def resolve_skywalker_family(_, info):
|
||||
return await info.context.character_loader.load_many(["1", "2", "3"])
|
||||
|
||||
|
||||
mock_batch_load_fn = Mock(
|
||||
side_effect=lambda character_ids: [get_character(id) for id in character_ids]
|
||||
)
|
||||
|
||||
|
||||
class CharacterLoader(DataLoader):
|
||||
async def batch_load_fn(self, character_ids):
|
||||
return mock_batch_load_fn(character_ids)
|
||||
|
||||
|
||||
Context = namedtuple("Context", "character_loader")
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_basic_dataloader():
|
||||
schema = Schema(query=Query)
|
||||
|
||||
character_loader = CharacterLoader()
|
||||
context = Context(character_loader=character_loader)
|
||||
|
||||
query = """
|
||||
{
|
||||
skywalkerFamily {
|
||||
name
|
||||
sibling {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
result = await schema.execute_async(query, context=context)
|
||||
|
||||
assert not result.errors
|
||||
assert result.data == {
|
||||
"skywalkerFamily": [
|
||||
{"name": "Luke Skywalker", "sibling": {"name": "Leia Organa"}},
|
||||
{"name": "Darth Vader", "sibling": None},
|
||||
{"name": "Leia Organa", "sibling": {"name": "Luke Skywalker"}},
|
||||
]
|
||||
}
|
||||
|
||||
assert mock_batch_load_fn.call_count == 1
|
||||
assert get_character.call_count == 3
|
Loading…
Reference in New Issue
Block a user