mirror of
https://github.com/graphql-python/graphene.git
synced 2024-11-22 09:36:44 +03:00
Vendor DataLoader
from aiodataloader
and move get_event_loop()
out of __init__
function. (#1459)
* Vendor DataLoader from aiodataloader and also move get_event_loop behavior from `__init__` to a property which only gets resolved when actually needed (this will solve PyTest-related to early get_event_loop() issues) * Added DataLoader's specific tests * plug `loop` parameter into `self._loop`, so that we still have the ability to pass in a custom event loop, if needed. Co-authored-by: Erik Wrede <erikwrede2@gmail.com>
This commit is contained in:
parent
20219fdc1b
commit
694c1db21e
281
graphene/utils/dataloader.py
Normal file
281
graphene/utils/dataloader.py
Normal file
|
@ -0,0 +1,281 @@
|
||||||
|
from asyncio import (
|
||||||
|
gather,
|
||||||
|
ensure_future,
|
||||||
|
get_event_loop,
|
||||||
|
iscoroutine,
|
||||||
|
iscoroutinefunction,
|
||||||
|
)
|
||||||
|
from collections import namedtuple
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from typing import List # flake8: noqa
|
||||||
|
|
||||||
|
Loader = namedtuple("Loader", "key,future")
|
||||||
|
|
||||||
|
|
||||||
|
def iscoroutinefunctionorpartial(fn):
|
||||||
|
return iscoroutinefunction(fn.func if isinstance(fn, partial) else fn)
|
||||||
|
|
||||||
|
|
||||||
|
class DataLoader(object):
|
||||||
|
batch = True
|
||||||
|
max_batch_size = None # type: int
|
||||||
|
cache = True
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
batch_load_fn=None,
|
||||||
|
batch=None,
|
||||||
|
max_batch_size=None,
|
||||||
|
cache=None,
|
||||||
|
get_cache_key=None,
|
||||||
|
cache_map=None,
|
||||||
|
loop=None,
|
||||||
|
):
|
||||||
|
|
||||||
|
self._loop = loop
|
||||||
|
|
||||||
|
if batch_load_fn is not None:
|
||||||
|
self.batch_load_fn = batch_load_fn
|
||||||
|
|
||||||
|
assert iscoroutinefunctionorpartial(
|
||||||
|
self.batch_load_fn
|
||||||
|
), "batch_load_fn must be coroutine. Received: {}".format(self.batch_load_fn)
|
||||||
|
|
||||||
|
if not callable(self.batch_load_fn):
|
||||||
|
raise TypeError( # pragma: no cover
|
||||||
|
(
|
||||||
|
"DataLoader must be have a batch_load_fn which accepts "
|
||||||
|
"Iterable<key> and returns Future<Iterable<value>>, but got: {}."
|
||||||
|
).format(batch_load_fn)
|
||||||
|
)
|
||||||
|
|
||||||
|
if batch is not None:
|
||||||
|
self.batch = batch # pragma: no cover
|
||||||
|
|
||||||
|
if max_batch_size is not None:
|
||||||
|
self.max_batch_size = max_batch_size
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
self.cache = cache # pragma: no cover
|
||||||
|
|
||||||
|
self.get_cache_key = get_cache_key or (lambda x: x)
|
||||||
|
|
||||||
|
self._cache = cache_map if cache_map is not None else {}
|
||||||
|
self._queue = [] # type: List[Loader]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loop(self):
|
||||||
|
if not self._loop:
|
||||||
|
self._loop = get_event_loop()
|
||||||
|
|
||||||
|
return self._loop
|
||||||
|
|
||||||
|
def load(self, key=None):
|
||||||
|
"""
|
||||||
|
Loads a key, returning a `Future` for the value represented by that key.
|
||||||
|
"""
|
||||||
|
if key is None:
|
||||||
|
raise TypeError( # pragma: no cover
|
||||||
|
(
|
||||||
|
"The loader.load() function must be called with a value, "
|
||||||
|
"but got: {}."
|
||||||
|
).format(key)
|
||||||
|
)
|
||||||
|
|
||||||
|
cache_key = self.get_cache_key(key)
|
||||||
|
|
||||||
|
# If caching and there is a cache-hit, return cached Future.
|
||||||
|
if self.cache:
|
||||||
|
cached_result = self._cache.get(cache_key)
|
||||||
|
if cached_result:
|
||||||
|
return cached_result
|
||||||
|
|
||||||
|
# Otherwise, produce a new Future for this value.
|
||||||
|
future = self.loop.create_future()
|
||||||
|
# If caching, cache this Future.
|
||||||
|
if self.cache:
|
||||||
|
self._cache[cache_key] = future
|
||||||
|
|
||||||
|
self.do_resolve_reject(key, future)
|
||||||
|
return future
|
||||||
|
|
||||||
|
def do_resolve_reject(self, key, future):
|
||||||
|
# Enqueue this Future to be dispatched.
|
||||||
|
self._queue.append(Loader(key=key, future=future))
|
||||||
|
# Determine if a dispatch of this queue should be scheduled.
|
||||||
|
# A single dispatch should be scheduled per queue at the time when the
|
||||||
|
# queue changes from "empty" to "full".
|
||||||
|
if len(self._queue) == 1:
|
||||||
|
if self.batch:
|
||||||
|
# If batching, schedule a task to dispatch the queue.
|
||||||
|
enqueue_post_future_job(self.loop, self)
|
||||||
|
else:
|
||||||
|
# Otherwise dispatch the (queue of one) immediately.
|
||||||
|
dispatch_queue(self) # pragma: no cover
|
||||||
|
|
||||||
|
def load_many(self, keys):
|
||||||
|
"""
|
||||||
|
Loads multiple keys, returning a list of values
|
||||||
|
|
||||||
|
>>> a, b = await my_loader.load_many([ 'a', 'b' ])
|
||||||
|
|
||||||
|
This is equivalent to the more verbose:
|
||||||
|
|
||||||
|
>>> a, b = await gather(
|
||||||
|
>>> my_loader.load('a'),
|
||||||
|
>>> my_loader.load('b')
|
||||||
|
>>> )
|
||||||
|
"""
|
||||||
|
if not isinstance(keys, Iterable):
|
||||||
|
raise TypeError( # pragma: no cover
|
||||||
|
(
|
||||||
|
"The loader.load_many() function must be called with Iterable<key> "
|
||||||
|
"but got: {}."
|
||||||
|
).format(keys)
|
||||||
|
)
|
||||||
|
|
||||||
|
return gather(*[self.load(key) for key in keys])
|
||||||
|
|
||||||
|
def clear(self, key):
|
||||||
|
"""
|
||||||
|
Clears the value at `key` from the cache, if it exists. Returns itself for
|
||||||
|
method chaining.
|
||||||
|
"""
|
||||||
|
cache_key = self.get_cache_key(key)
|
||||||
|
self._cache.pop(cache_key, None)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def clear_all(self):
|
||||||
|
"""
|
||||||
|
Clears the entire cache. To be used when some event results in unknown
|
||||||
|
invalidations across this particular `DataLoader`. Returns itself for
|
||||||
|
method chaining.
|
||||||
|
"""
|
||||||
|
self._cache.clear()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def prime(self, key, value):
|
||||||
|
"""
|
||||||
|
Adds the provied key and value to the cache. If the key already exists, no
|
||||||
|
change is made. Returns itself for method chaining.
|
||||||
|
"""
|
||||||
|
cache_key = self.get_cache_key(key)
|
||||||
|
|
||||||
|
# Only add the key if it does not already exist.
|
||||||
|
if cache_key not in self._cache:
|
||||||
|
# Cache a rejected future if the value is an Error, in order to match
|
||||||
|
# the behavior of load(key).
|
||||||
|
future = self.loop.create_future()
|
||||||
|
if isinstance(value, Exception):
|
||||||
|
future.set_exception(value)
|
||||||
|
else:
|
||||||
|
future.set_result(value)
|
||||||
|
|
||||||
|
self._cache[cache_key] = future
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def enqueue_post_future_job(loop, loader):
|
||||||
|
async def dispatch():
|
||||||
|
dispatch_queue(loader)
|
||||||
|
|
||||||
|
loop.call_soon(ensure_future, dispatch())
|
||||||
|
|
||||||
|
|
||||||
|
def get_chunks(iterable_obj, chunk_size=1):
|
||||||
|
chunk_size = max(1, chunk_size)
|
||||||
|
return (
|
||||||
|
iterable_obj[i : i + chunk_size]
|
||||||
|
for i in range(0, len(iterable_obj), chunk_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def dispatch_queue(loader):
|
||||||
|
"""
|
||||||
|
Given the current state of a Loader instance, perform a batch load
|
||||||
|
from its current queue.
|
||||||
|
"""
|
||||||
|
# Take the current loader queue, replacing it with an empty queue.
|
||||||
|
queue = loader._queue
|
||||||
|
loader._queue = []
|
||||||
|
|
||||||
|
# If a max_batch_size was provided and the queue is longer, then segment the
|
||||||
|
# queue into multiple batches, otherwise treat the queue as a single batch.
|
||||||
|
max_batch_size = loader.max_batch_size
|
||||||
|
|
||||||
|
if max_batch_size and max_batch_size < len(queue):
|
||||||
|
chunks = get_chunks(queue, max_batch_size)
|
||||||
|
for chunk in chunks:
|
||||||
|
ensure_future(dispatch_queue_batch(loader, chunk))
|
||||||
|
else:
|
||||||
|
ensure_future(dispatch_queue_batch(loader, queue))
|
||||||
|
|
||||||
|
|
||||||
|
async def dispatch_queue_batch(loader, queue):
|
||||||
|
# Collect all keys to be loaded in this dispatch
|
||||||
|
keys = [loaded.key for loaded in queue]
|
||||||
|
|
||||||
|
# Call the provided batch_load_fn for this loader with the loader queue's keys.
|
||||||
|
batch_future = loader.batch_load_fn(keys)
|
||||||
|
|
||||||
|
# Assert the expected response from batch_load_fn
|
||||||
|
if not batch_future or not iscoroutine(batch_future):
|
||||||
|
return failed_dispatch( # pragma: no cover
|
||||||
|
loader,
|
||||||
|
queue,
|
||||||
|
TypeError(
|
||||||
|
(
|
||||||
|
"DataLoader must be constructed with a function which accepts "
|
||||||
|
"Iterable<key> and returns Future<Iterable<value>>, but the function did "
|
||||||
|
"not return a Coroutine: {}."
|
||||||
|
).format(batch_future)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
values = await batch_future
|
||||||
|
if not isinstance(values, Iterable):
|
||||||
|
raise TypeError( # pragma: no cover
|
||||||
|
(
|
||||||
|
"DataLoader must be constructed with a function which accepts "
|
||||||
|
"Iterable<key> and returns Future<Iterable<value>>, but the function did "
|
||||||
|
"not return a Future of a Iterable: {}."
|
||||||
|
).format(values)
|
||||||
|
)
|
||||||
|
|
||||||
|
values = list(values)
|
||||||
|
if len(values) != len(keys):
|
||||||
|
raise TypeError( # pragma: no cover
|
||||||
|
(
|
||||||
|
"DataLoader must be constructed with a function which accepts "
|
||||||
|
"Iterable<key> and returns Future<Iterable<value>>, but the function did "
|
||||||
|
"not return a Future of a Iterable with the same length as the Iterable "
|
||||||
|
"of keys."
|
||||||
|
"\n\nKeys:\n{}"
|
||||||
|
"\n\nValues:\n{}"
|
||||||
|
).format(keys, values)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step through the values, resolving or rejecting each Future in the
|
||||||
|
# loaded queue.
|
||||||
|
for loaded, value in zip(queue, values):
|
||||||
|
if isinstance(value, Exception):
|
||||||
|
loaded.future.set_exception(value)
|
||||||
|
else:
|
||||||
|
loaded.future.set_result(value)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return failed_dispatch(loader, queue, e)
|
||||||
|
|
||||||
|
|
||||||
|
def failed_dispatch(loader, queue, error):
|
||||||
|
"""
|
||||||
|
Do not cache individual loads if the entire batch dispatch fails,
|
||||||
|
but still reject each request so they do not hang.
|
||||||
|
"""
|
||||||
|
for loaded in queue:
|
||||||
|
loader.clear(loaded.key)
|
||||||
|
loaded.future.set_exception(error)
|
452
graphene/utils/tests/test_dataloader.py
Normal file
452
graphene/utils/tests/test_dataloader.py
Normal file
|
@ -0,0 +1,452 @@
|
||||||
|
from asyncio import gather
|
||||||
|
from collections import namedtuple
|
||||||
|
from functools import partial
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from graphene.utils.dataloader import DataLoader
|
||||||
|
from pytest import mark, raises
|
||||||
|
|
||||||
|
from graphene import ObjectType, String, Schema, Field, List
|
||||||
|
|
||||||
|
CHARACTERS = {
|
||||||
|
"1": {"name": "Luke Skywalker", "sibling": "3"},
|
||||||
|
"2": {"name": "Darth Vader", "sibling": None},
|
||||||
|
"3": {"name": "Leia Organa", "sibling": "1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
get_character = Mock(side_effect=lambda character_id: CHARACTERS[character_id])
|
||||||
|
|
||||||
|
|
||||||
|
class CharacterType(ObjectType):
|
||||||
|
name = String()
|
||||||
|
sibling = Field(lambda: CharacterType)
|
||||||
|
|
||||||
|
async def resolve_sibling(character, info):
|
||||||
|
if character["sibling"]:
|
||||||
|
return await info.context.character_loader.load(character["sibling"])
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class Query(ObjectType):
|
||||||
|
skywalker_family = List(CharacterType)
|
||||||
|
|
||||||
|
async def resolve_skywalker_family(_, info):
|
||||||
|
return await info.context.character_loader.load_many(["1", "2", "3"])
|
||||||
|
|
||||||
|
|
||||||
|
mock_batch_load_fn = Mock(
|
||||||
|
side_effect=lambda character_ids: [get_character(id) for id in character_ids]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CharacterLoader(DataLoader):
|
||||||
|
async def batch_load_fn(self, character_ids):
|
||||||
|
return mock_batch_load_fn(character_ids)
|
||||||
|
|
||||||
|
|
||||||
|
Context = namedtuple("Context", "character_loader")
|
||||||
|
|
||||||
|
|
||||||
|
@mark.asyncio
|
||||||
|
async def test_basic_dataloader():
|
||||||
|
schema = Schema(query=Query)
|
||||||
|
|
||||||
|
character_loader = CharacterLoader()
|
||||||
|
context = Context(character_loader=character_loader)
|
||||||
|
|
||||||
|
query = """
|
||||||
|
{
|
||||||
|
skywalkerFamily {
|
||||||
|
name
|
||||||
|
sibling {
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = await schema.execute_async(query, context=context)
|
||||||
|
|
||||||
|
assert not result.errors
|
||||||
|
assert result.data == {
|
||||||
|
"skywalkerFamily": [
|
||||||
|
{"name": "Luke Skywalker", "sibling": {"name": "Leia Organa"}},
|
||||||
|
{"name": "Darth Vader", "sibling": None},
|
||||||
|
{"name": "Leia Organa", "sibling": {"name": "Luke Skywalker"}},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
assert mock_batch_load_fn.call_count == 1
|
||||||
|
assert get_character.call_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
def id_loader(**options):
|
||||||
|
load_calls = []
|
||||||
|
|
||||||
|
async def default_resolve(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
resolve = options.pop("resolve", default_resolve)
|
||||||
|
|
||||||
|
async def fn(keys):
|
||||||
|
load_calls.append(keys)
|
||||||
|
return await resolve(keys)
|
||||||
|
# return keys
|
||||||
|
|
||||||
|
identity_loader = DataLoader(fn, **options)
|
||||||
|
return identity_loader, load_calls
|
||||||
|
|
||||||
|
|
||||||
|
@mark.asyncio
|
||||||
|
async def test_build_a_simple_data_loader():
|
||||||
|
async def call_fn(keys):
|
||||||
|
return keys
|
||||||
|
|
||||||
|
identity_loader = DataLoader(call_fn)
|
||||||
|
|
||||||
|
promise1 = identity_loader.load(1)
|
||||||
|
|
||||||
|
value1 = await promise1
|
||||||
|
assert value1 == 1
|
||||||
|
|
||||||
|
|
||||||
|
@mark.asyncio
|
||||||
|
async def test_can_build_a_data_loader_from_a_partial():
|
||||||
|
value_map = {1: "one"}
|
||||||
|
|
||||||
|
async def call_fn(context, keys):
|
||||||
|
return [context.get(key) for key in keys]
|
||||||
|
|
||||||
|
partial_fn = partial(call_fn, value_map)
|
||||||
|
identity_loader = DataLoader(partial_fn)
|
||||||
|
|
||||||
|
promise1 = identity_loader.load(1)
|
||||||
|
|
||||||
|
value1 = await promise1
|
||||||
|
assert value1 == "one"
|
||||||
|
|
||||||
|
|
||||||
|
@mark.asyncio
|
||||||
|
async def test_supports_loading_multiple_keys_in_one_call():
|
||||||
|
async def call_fn(keys):
|
||||||
|
return keys
|
||||||
|
|
||||||
|
identity_loader = DataLoader(call_fn)
|
||||||
|
|
||||||
|
promise_all = identity_loader.load_many([1, 2])
|
||||||
|
|
||||||
|
values = await promise_all
|
||||||
|
assert values == [1, 2]
|
||||||
|
|
||||||
|
promise_all = identity_loader.load_many([])
|
||||||
|
|
||||||
|
values = await promise_all
|
||||||
|
assert values == []
|
||||||
|
|
||||||
|
|
||||||
|
@mark.asyncio
|
||||||
|
async def test_batches_multiple_requests():
|
||||||
|
identity_loader, load_calls = id_loader()
|
||||||
|
|
||||||
|
promise1 = identity_loader.load(1)
|
||||||
|
promise2 = identity_loader.load(2)
|
||||||
|
|
||||||
|
p = gather(promise1, promise2)
|
||||||
|
|
||||||
|
value1, value2 = await p
|
||||||
|
|
||||||
|
assert value1 == 1
|
||||||
|
assert value2 == 2
|
||||||
|
|
||||||
|
assert load_calls == [[1, 2]]
|
||||||
|
|
||||||
|
|
||||||
|
@mark.asyncio
|
||||||
|
async def test_batches_multiple_requests_with_max_batch_sizes():
|
||||||
|
identity_loader, load_calls = id_loader(max_batch_size=2)
|
||||||
|
|
||||||
|
promise1 = identity_loader.load(1)
|
||||||
|
promise2 = identity_loader.load(2)
|
||||||
|
promise3 = identity_loader.load(3)
|
||||||
|
|
||||||
|
p = gather(promise1, promise2, promise3)
|
||||||
|
|
||||||
|
value1, value2, value3 = await p
|
||||||
|
|
||||||
|
assert value1 == 1
|
||||||
|
assert value2 == 2
|
||||||
|
assert value3 == 3
|
||||||
|
|
||||||
|
assert load_calls == [[1, 2], [3]]
|
||||||
|
|
||||||
|
|
||||||
|
@mark.asyncio
|
||||||
|
async def test_coalesces_identical_requests():
|
||||||
|
identity_loader, load_calls = id_loader()
|
||||||
|
|
||||||
|
promise1 = identity_loader.load(1)
|
||||||
|
promise2 = identity_loader.load(1)
|
||||||
|
|
||||||
|
assert promise1 == promise2
|
||||||
|
p = gather(promise1, promise2)
|
||||||
|
|
||||||
|
value1, value2 = await p
|
||||||
|
|
||||||
|
assert value1 == 1
|
||||||
|
assert value2 == 1
|
||||||
|
|
||||||
|
assert load_calls == [[1]]
|
||||||
|
|
||||||
|
|
||||||
|
@mark.asyncio
|
||||||
|
async def test_caches_repeated_requests():
|
||||||
|
identity_loader, load_calls = id_loader()
|
||||||
|
|
||||||
|
a, b = await gather(identity_loader.load("A"), identity_loader.load("B"))
|
||||||
|
|
||||||
|
assert a == "A"
|
||||||
|
assert b == "B"
|
||||||
|
|
||||||
|
assert load_calls == [["A", "B"]]
|
||||||
|
|
||||||
|
a2, c = await gather(identity_loader.load("A"), identity_loader.load("C"))
|
||||||
|
|
||||||
|
assert a2 == "A"
|
||||||
|
assert c == "C"
|
||||||
|
|
||||||
|
assert load_calls == [["A", "B"], ["C"]]
|
||||||
|
|
||||||
|
a3, b2, c2 = await gather(
|
||||||
|
identity_loader.load("A"), identity_loader.load("B"), identity_loader.load("C")
|
||||||
|
)
|
||||||
|
|
||||||
|
assert a3 == "A"
|
||||||
|
assert b2 == "B"
|
||||||
|
assert c2 == "C"
|
||||||
|
|
||||||
|
assert load_calls == [["A", "B"], ["C"]]
|
||||||
|
|
||||||
|
|
||||||
|
@mark.asyncio
|
||||||
|
async def test_clears_single_value_in_loader():
|
||||||
|
identity_loader, load_calls = id_loader()
|
||||||
|
|
||||||
|
a, b = await gather(identity_loader.load("A"), identity_loader.load("B"))
|
||||||
|
|
||||||
|
assert a == "A"
|
||||||
|
assert b == "B"
|
||||||
|
|
||||||
|
assert load_calls == [["A", "B"]]
|
||||||
|
|
||||||
|
identity_loader.clear("A")
|
||||||
|
|
||||||
|
a2, b2 = await gather(identity_loader.load("A"), identity_loader.load("B"))
|
||||||
|
|
||||||
|
assert a2 == "A"
|
||||||
|
assert b2 == "B"
|
||||||
|
|
||||||
|
assert load_calls == [["A", "B"], ["A"]]
|
||||||
|
|
||||||
|
|
||||||
|
@mark.asyncio
|
||||||
|
async def test_clears_all_values_in_loader():
|
||||||
|
identity_loader, load_calls = id_loader()
|
||||||
|
|
||||||
|
a, b = await gather(identity_loader.load("A"), identity_loader.load("B"))
|
||||||
|
|
||||||
|
assert a == "A"
|
||||||
|
assert b == "B"
|
||||||
|
|
||||||
|
assert load_calls == [["A", "B"]]
|
||||||
|
|
||||||
|
identity_loader.clear_all()
|
||||||
|
|
||||||
|
a2, b2 = await gather(identity_loader.load("A"), identity_loader.load("B"))
|
||||||
|
|
||||||
|
assert a2 == "A"
|
||||||
|
assert b2 == "B"
|
||||||
|
|
||||||
|
assert load_calls == [["A", "B"], ["A", "B"]]
|
||||||
|
|
||||||
|
|
||||||
|
@mark.asyncio
|
||||||
|
async def test_allows_priming_the_cache():
|
||||||
|
identity_loader, load_calls = id_loader()
|
||||||
|
|
||||||
|
identity_loader.prime("A", "A")
|
||||||
|
|
||||||
|
a, b = await gather(identity_loader.load("A"), identity_loader.load("B"))
|
||||||
|
|
||||||
|
assert a == "A"
|
||||||
|
assert b == "B"
|
||||||
|
|
||||||
|
assert load_calls == [["B"]]
|
||||||
|
|
||||||
|
|
||||||
|
@mark.asyncio
|
||||||
|
async def test_does_not_prime_keys_that_already_exist():
|
||||||
|
identity_loader, load_calls = id_loader()
|
||||||
|
|
||||||
|
identity_loader.prime("A", "X")
|
||||||
|
|
||||||
|
a1 = await identity_loader.load("A")
|
||||||
|
b1 = await identity_loader.load("B")
|
||||||
|
|
||||||
|
assert a1 == "X"
|
||||||
|
assert b1 == "B"
|
||||||
|
|
||||||
|
identity_loader.prime("A", "Y")
|
||||||
|
identity_loader.prime("B", "Y")
|
||||||
|
|
||||||
|
a2 = await identity_loader.load("A")
|
||||||
|
b2 = await identity_loader.load("B")
|
||||||
|
|
||||||
|
assert a2 == "X"
|
||||||
|
assert b2 == "B"
|
||||||
|
|
||||||
|
assert load_calls == [["B"]]
|
||||||
|
|
||||||
|
|
||||||
|
# # Represents Errors
|
||||||
|
@mark.asyncio
|
||||||
|
async def test_resolves_to_error_to_indicate_failure():
|
||||||
|
async def resolve(keys):
|
||||||
|
mapped_keys = [
|
||||||
|
key if key % 2 == 0 else Exception("Odd: {}".format(key)) for key in keys
|
||||||
|
]
|
||||||
|
return mapped_keys
|
||||||
|
|
||||||
|
even_loader, load_calls = id_loader(resolve=resolve)
|
||||||
|
|
||||||
|
with raises(Exception) as exc_info:
|
||||||
|
await even_loader.load(1)
|
||||||
|
|
||||||
|
assert str(exc_info.value) == "Odd: 1"
|
||||||
|
|
||||||
|
value2 = await even_loader.load(2)
|
||||||
|
assert value2 == 2
|
||||||
|
assert load_calls == [[1], [2]]
|
||||||
|
|
||||||
|
|
||||||
|
@mark.asyncio
|
||||||
|
async def test_can_represent_failures_and_successes_simultaneously():
|
||||||
|
async def resolve(keys):
|
||||||
|
mapped_keys = [
|
||||||
|
key if key % 2 == 0 else Exception("Odd: {}".format(key)) for key in keys
|
||||||
|
]
|
||||||
|
return mapped_keys
|
||||||
|
|
||||||
|
even_loader, load_calls = id_loader(resolve=resolve)
|
||||||
|
|
||||||
|
promise1 = even_loader.load(1)
|
||||||
|
promise2 = even_loader.load(2)
|
||||||
|
|
||||||
|
with raises(Exception) as exc_info:
|
||||||
|
await promise1
|
||||||
|
|
||||||
|
assert str(exc_info.value) == "Odd: 1"
|
||||||
|
value2 = await promise2
|
||||||
|
assert value2 == 2
|
||||||
|
assert load_calls == [[1, 2]]
|
||||||
|
|
||||||
|
|
||||||
|
@mark.asyncio
|
||||||
|
async def test_caches_failed_fetches():
|
||||||
|
async def resolve(keys):
|
||||||
|
mapped_keys = [Exception("Error: {}".format(key)) for key in keys]
|
||||||
|
return mapped_keys
|
||||||
|
|
||||||
|
error_loader, load_calls = id_loader(resolve=resolve)
|
||||||
|
|
||||||
|
with raises(Exception) as exc_info:
|
||||||
|
await error_loader.load(1)
|
||||||
|
|
||||||
|
assert str(exc_info.value) == "Error: 1"
|
||||||
|
|
||||||
|
with raises(Exception) as exc_info:
|
||||||
|
await error_loader.load(1)
|
||||||
|
|
||||||
|
assert str(exc_info.value) == "Error: 1"
|
||||||
|
|
||||||
|
assert load_calls == [[1]]
|
||||||
|
|
||||||
|
|
||||||
|
@mark.asyncio
|
||||||
|
async def test_caches_failed_fetches_2():
|
||||||
|
identity_loader, load_calls = id_loader()
|
||||||
|
|
||||||
|
identity_loader.prime(1, Exception("Error: 1"))
|
||||||
|
|
||||||
|
with raises(Exception) as _:
|
||||||
|
await identity_loader.load(1)
|
||||||
|
|
||||||
|
assert load_calls == []
|
||||||
|
|
||||||
|
|
||||||
|
# It is resilient to job queue ordering
|
||||||
|
@mark.asyncio
|
||||||
|
async def test_batches_loads_occuring_within_promises():
|
||||||
|
identity_loader, load_calls = id_loader()
|
||||||
|
|
||||||
|
async def load_b_1():
|
||||||
|
return await load_b_2()
|
||||||
|
|
||||||
|
async def load_b_2():
|
||||||
|
return await identity_loader.load("B")
|
||||||
|
|
||||||
|
values = await gather(identity_loader.load("A"), load_b_1())
|
||||||
|
|
||||||
|
assert values == ["A", "B"]
|
||||||
|
|
||||||
|
assert load_calls == [["A", "B"]]
|
||||||
|
|
||||||
|
|
||||||
|
@mark.asyncio
|
||||||
|
async def test_catches_error_if_loader_resolver_fails():
|
||||||
|
exc = Exception("AOH!")
|
||||||
|
|
||||||
|
def do_resolve(x):
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
a_loader, a_load_calls = id_loader(resolve=do_resolve)
|
||||||
|
|
||||||
|
with raises(Exception) as exc_info:
|
||||||
|
await a_loader.load("A1")
|
||||||
|
|
||||||
|
assert exc_info.value == exc
|
||||||
|
|
||||||
|
|
||||||
|
@mark.asyncio
|
||||||
|
async def test_can_call_a_loader_from_a_loader():
|
||||||
|
deep_loader, deep_load_calls = id_loader()
|
||||||
|
a_loader, a_load_calls = id_loader(
|
||||||
|
resolve=lambda keys: deep_loader.load(tuple(keys))
|
||||||
|
)
|
||||||
|
b_loader, b_load_calls = id_loader(
|
||||||
|
resolve=lambda keys: deep_loader.load(tuple(keys))
|
||||||
|
)
|
||||||
|
|
||||||
|
a1, b1, a2, b2 = await gather(
|
||||||
|
a_loader.load("A1"),
|
||||||
|
b_loader.load("B1"),
|
||||||
|
a_loader.load("A2"),
|
||||||
|
b_loader.load("B2"),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert a1 == "A1"
|
||||||
|
assert b1 == "B1"
|
||||||
|
assert a2 == "A2"
|
||||||
|
assert b2 == "B2"
|
||||||
|
|
||||||
|
assert a_load_calls == [["A1", "A2"]]
|
||||||
|
assert b_load_calls == [["B1", "B2"]]
|
||||||
|
assert deep_load_calls == [[("A1", "A2"), ("B1", "B2")]]
|
||||||
|
|
||||||
|
|
||||||
|
@mark.asyncio
|
||||||
|
async def test_dataloader_clear_with_missing_key_works():
|
||||||
|
async def do_resolve(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
a_loader, a_load_calls = id_loader(resolve=do_resolve)
|
||||||
|
assert a_loader.clear("A1") == a_loader
|
|
@ -2,6 +2,10 @@
|
||||||
exclude = setup.py,docs/*,*/examples/*,graphene/pyutils/*,tests
|
exclude = setup.py,docs/*,*/examples/*,graphene/pyutils/*,tests
|
||||||
max-line-length = 120
|
max-line-length = 120
|
||||||
|
|
||||||
|
# This is a specific ignore for Black+Flake8
|
||||||
|
# source: https://black.readthedocs.io/en/stable/guides/using_black_with_other_tools.html#id1
|
||||||
|
extend-ignore = E203
|
||||||
|
|
||||||
[coverage:run]
|
[coverage:run]
|
||||||
omit = graphene/pyutils/*,*/tests/*,graphene/types/scalars.py
|
omit = graphene/pyutils/*,*/tests/*,graphene/types/scalars.py
|
||||||
|
|
||||||
|
|
1
setup.py
1
setup.py
|
@ -53,7 +53,6 @@ tests_require = [
|
||||||
"snapshottest>=0.6,<1",
|
"snapshottest>=0.6,<1",
|
||||||
"coveralls>=3.3,<4",
|
"coveralls>=3.3,<4",
|
||||||
"promise>=2.3,<3",
|
"promise>=2.3,<3",
|
||||||
"aiodataloader<1",
|
|
||||||
"mock>=4,<5",
|
"mock>=4,<5",
|
||||||
"pytz==2022.1",
|
"pytz==2022.1",
|
||||||
"iso8601>=1,<2",
|
"iso8601>=1,<2",
|
||||||
|
|
|
@ -1,79 +0,0 @@
|
||||||
from collections import namedtuple
|
|
||||||
from unittest.mock import Mock
|
|
||||||
from pytest import mark
|
|
||||||
from aiodataloader import DataLoader
|
|
||||||
|
|
||||||
from graphene import ObjectType, String, Schema, Field, List
|
|
||||||
|
|
||||||
|
|
||||||
CHARACTERS = {
|
|
||||||
"1": {"name": "Luke Skywalker", "sibling": "3"},
|
|
||||||
"2": {"name": "Darth Vader", "sibling": None},
|
|
||||||
"3": {"name": "Leia Organa", "sibling": "1"},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
get_character = Mock(side_effect=lambda character_id: CHARACTERS[character_id])
|
|
||||||
|
|
||||||
|
|
||||||
class CharacterType(ObjectType):
|
|
||||||
name = String()
|
|
||||||
sibling = Field(lambda: CharacterType)
|
|
||||||
|
|
||||||
async def resolve_sibling(character, info):
|
|
||||||
if character["sibling"]:
|
|
||||||
return await info.context.character_loader.load(character["sibling"])
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class Query(ObjectType):
|
|
||||||
skywalker_family = List(CharacterType)
|
|
||||||
|
|
||||||
async def resolve_skywalker_family(_, info):
|
|
||||||
return await info.context.character_loader.load_many(["1", "2", "3"])
|
|
||||||
|
|
||||||
|
|
||||||
mock_batch_load_fn = Mock(
|
|
||||||
side_effect=lambda character_ids: [get_character(id) for id in character_ids]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CharacterLoader(DataLoader):
|
|
||||||
async def batch_load_fn(self, character_ids):
|
|
||||||
return mock_batch_load_fn(character_ids)
|
|
||||||
|
|
||||||
|
|
||||||
Context = namedtuple("Context", "character_loader")
|
|
||||||
|
|
||||||
|
|
||||||
@mark.asyncio
|
|
||||||
async def test_basic_dataloader():
|
|
||||||
schema = Schema(query=Query)
|
|
||||||
|
|
||||||
character_loader = CharacterLoader()
|
|
||||||
context = Context(character_loader=character_loader)
|
|
||||||
|
|
||||||
query = """
|
|
||||||
{
|
|
||||||
skywalkerFamily {
|
|
||||||
name
|
|
||||||
sibling {
|
|
||||||
name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
result = await schema.execute_async(query, context=context)
|
|
||||||
|
|
||||||
assert not result.errors
|
|
||||||
assert result.data == {
|
|
||||||
"skywalkerFamily": [
|
|
||||||
{"name": "Luke Skywalker", "sibling": {"name": "Leia Organa"}},
|
|
||||||
{"name": "Darth Vader", "sibling": None},
|
|
||||||
{"name": "Leia Organa", "sibling": {"name": "Luke Skywalker"}},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
assert mock_batch_load_fn.call_count == 1
|
|
||||||
assert get_character.call_count == 3
|
|
Loading…
Reference in New Issue
Block a user