diff --git a/graphene/utils/dataloader.py b/graphene/utils/dataloader.py new file mode 100644 index 00000000..143558aa --- /dev/null +++ b/graphene/utils/dataloader.py @@ -0,0 +1,281 @@ +from asyncio import ( + gather, + ensure_future, + get_event_loop, + iscoroutine, + iscoroutinefunction, +) +from collections import namedtuple +from collections.abc import Iterable +from functools import partial + +from typing import List # flake8: noqa + +Loader = namedtuple("Loader", "key,future") + + +def iscoroutinefunctionorpartial(fn): + return iscoroutinefunction(fn.func if isinstance(fn, partial) else fn) + + +class DataLoader(object): + batch = True + max_batch_size = None # type: int + cache = True + + def __init__( + self, + batch_load_fn=None, + batch=None, + max_batch_size=None, + cache=None, + get_cache_key=None, + cache_map=None, + loop=None, + ): + + self._loop = loop + + if batch_load_fn is not None: + self.batch_load_fn = batch_load_fn + + assert iscoroutinefunctionorpartial( + self.batch_load_fn + ), "batch_load_fn must be coroutine. Received: {}".format(self.batch_load_fn) + + if not callable(self.batch_load_fn): + raise TypeError( # pragma: no cover + ( + "DataLoader must be have a batch_load_fn which accepts " + "Iterable and returns Future>, but got: {}." + ).format(batch_load_fn) + ) + + if batch is not None: + self.batch = batch # pragma: no cover + + if max_batch_size is not None: + self.max_batch_size = max_batch_size + + if cache is not None: + self.cache = cache # pragma: no cover + + self.get_cache_key = get_cache_key or (lambda x: x) + + self._cache = cache_map if cache_map is not None else {} + self._queue = [] # type: List[Loader] + + @property + def loop(self): + if not self._loop: + self._loop = get_event_loop() + + return self._loop + + def load(self, key=None): + """ + Loads a key, returning a `Future` for the value represented by that key. + """ + if key is None: + raise TypeError( # pragma: no cover + ( + "The loader.load() function must be called with a value, " + "but got: {}." + ).format(key) + ) + + cache_key = self.get_cache_key(key) + + # If caching and there is a cache-hit, return cached Future. + if self.cache: + cached_result = self._cache.get(cache_key) + if cached_result: + return cached_result + + # Otherwise, produce a new Future for this value. + future = self.loop.create_future() + # If caching, cache this Future. + if self.cache: + self._cache[cache_key] = future + + self.do_resolve_reject(key, future) + return future + + def do_resolve_reject(self, key, future): + # Enqueue this Future to be dispatched. + self._queue.append(Loader(key=key, future=future)) + # Determine if a dispatch of this queue should be scheduled. + # A single dispatch should be scheduled per queue at the time when the + # queue changes from "empty" to "full". + if len(self._queue) == 1: + if self.batch: + # If batching, schedule a task to dispatch the queue. + enqueue_post_future_job(self.loop, self) + else: + # Otherwise dispatch the (queue of one) immediately. + dispatch_queue(self) # pragma: no cover + + def load_many(self, keys): + """ + Loads multiple keys, returning a list of values + + >>> a, b = await my_loader.load_many([ 'a', 'b' ]) + + This is equivalent to the more verbose: + + >>> a, b = await gather( + >>> my_loader.load('a'), + >>> my_loader.load('b') + >>> ) + """ + if not isinstance(keys, Iterable): + raise TypeError( # pragma: no cover + ( + "The loader.load_many() function must be called with Iterable " + "but got: {}." + ).format(keys) + ) + + return gather(*[self.load(key) for key in keys]) + + def clear(self, key): + """ + Clears the value at `key` from the cache, if it exists. Returns itself for + method chaining. + """ + cache_key = self.get_cache_key(key) + self._cache.pop(cache_key, None) + return self + + def clear_all(self): + """ + Clears the entire cache. To be used when some event results in unknown + invalidations across this particular `DataLoader`. Returns itself for + method chaining. + """ + self._cache.clear() + return self + + def prime(self, key, value): + """ + Adds the provied key and value to the cache. If the key already exists, no + change is made. Returns itself for method chaining. + """ + cache_key = self.get_cache_key(key) + + # Only add the key if it does not already exist. + if cache_key not in self._cache: + # Cache a rejected future if the value is an Error, in order to match + # the behavior of load(key). + future = self.loop.create_future() + if isinstance(value, Exception): + future.set_exception(value) + else: + future.set_result(value) + + self._cache[cache_key] = future + + return self + + +def enqueue_post_future_job(loop, loader): + async def dispatch(): + dispatch_queue(loader) + + loop.call_soon(ensure_future, dispatch()) + + +def get_chunks(iterable_obj, chunk_size=1): + chunk_size = max(1, chunk_size) + return ( + iterable_obj[i : i + chunk_size] + for i in range(0, len(iterable_obj), chunk_size) + ) + + +def dispatch_queue(loader): + """ + Given the current state of a Loader instance, perform a batch load + from its current queue. + """ + # Take the current loader queue, replacing it with an empty queue. + queue = loader._queue + loader._queue = [] + + # If a max_batch_size was provided and the queue is longer, then segment the + # queue into multiple batches, otherwise treat the queue as a single batch. + max_batch_size = loader.max_batch_size + + if max_batch_size and max_batch_size < len(queue): + chunks = get_chunks(queue, max_batch_size) + for chunk in chunks: + ensure_future(dispatch_queue_batch(loader, chunk)) + else: + ensure_future(dispatch_queue_batch(loader, queue)) + + +async def dispatch_queue_batch(loader, queue): + # Collect all keys to be loaded in this dispatch + keys = [loaded.key for loaded in queue] + + # Call the provided batch_load_fn for this loader with the loader queue's keys. + batch_future = loader.batch_load_fn(keys) + + # Assert the expected response from batch_load_fn + if not batch_future or not iscoroutine(batch_future): + return failed_dispatch( # pragma: no cover + loader, + queue, + TypeError( + ( + "DataLoader must be constructed with a function which accepts " + "Iterable and returns Future>, but the function did " + "not return a Coroutine: {}." + ).format(batch_future) + ), + ) + + try: + values = await batch_future + if not isinstance(values, Iterable): + raise TypeError( # pragma: no cover + ( + "DataLoader must be constructed with a function which accepts " + "Iterable and returns Future>, but the function did " + "not return a Future of a Iterable: {}." + ).format(values) + ) + + values = list(values) + if len(values) != len(keys): + raise TypeError( # pragma: no cover + ( + "DataLoader must be constructed with a function which accepts " + "Iterable and returns Future>, but the function did " + "not return a Future of a Iterable with the same length as the Iterable " + "of keys." + "\n\nKeys:\n{}" + "\n\nValues:\n{}" + ).format(keys, values) + ) + + # Step through the values, resolving or rejecting each Future in the + # loaded queue. + for loaded, value in zip(queue, values): + if isinstance(value, Exception): + loaded.future.set_exception(value) + else: + loaded.future.set_result(value) + + except Exception as e: + return failed_dispatch(loader, queue, e) + + +def failed_dispatch(loader, queue, error): + """ + Do not cache individual loads if the entire batch dispatch fails, + but still reject each request so they do not hang. + """ + for loaded in queue: + loader.clear(loaded.key) + loaded.future.set_exception(error) diff --git a/graphene/utils/tests/test_dataloader.py b/graphene/utils/tests/test_dataloader.py new file mode 100644 index 00000000..257f6b4d --- /dev/null +++ b/graphene/utils/tests/test_dataloader.py @@ -0,0 +1,452 @@ +from asyncio import gather +from collections import namedtuple +from functools import partial +from unittest.mock import Mock + +from graphene.utils.dataloader import DataLoader +from pytest import mark, raises + +from graphene import ObjectType, String, Schema, Field, List + +CHARACTERS = { + "1": {"name": "Luke Skywalker", "sibling": "3"}, + "2": {"name": "Darth Vader", "sibling": None}, + "3": {"name": "Leia Organa", "sibling": "1"}, +} + +get_character = Mock(side_effect=lambda character_id: CHARACTERS[character_id]) + + +class CharacterType(ObjectType): + name = String() + sibling = Field(lambda: CharacterType) + + async def resolve_sibling(character, info): + if character["sibling"]: + return await info.context.character_loader.load(character["sibling"]) + return None + + +class Query(ObjectType): + skywalker_family = List(CharacterType) + + async def resolve_skywalker_family(_, info): + return await info.context.character_loader.load_many(["1", "2", "3"]) + + +mock_batch_load_fn = Mock( + side_effect=lambda character_ids: [get_character(id) for id in character_ids] +) + + +class CharacterLoader(DataLoader): + async def batch_load_fn(self, character_ids): + return mock_batch_load_fn(character_ids) + + +Context = namedtuple("Context", "character_loader") + + +@mark.asyncio +async def test_basic_dataloader(): + schema = Schema(query=Query) + + character_loader = CharacterLoader() + context = Context(character_loader=character_loader) + + query = """ + { + skywalkerFamily { + name + sibling { + name + } + } + } + """ + + result = await schema.execute_async(query, context=context) + + assert not result.errors + assert result.data == { + "skywalkerFamily": [ + {"name": "Luke Skywalker", "sibling": {"name": "Leia Organa"}}, + {"name": "Darth Vader", "sibling": None}, + {"name": "Leia Organa", "sibling": {"name": "Luke Skywalker"}}, + ] + } + + assert mock_batch_load_fn.call_count == 1 + assert get_character.call_count == 3 + + +def id_loader(**options): + load_calls = [] + + async def default_resolve(x): + return x + + resolve = options.pop("resolve", default_resolve) + + async def fn(keys): + load_calls.append(keys) + return await resolve(keys) + # return keys + + identity_loader = DataLoader(fn, **options) + return identity_loader, load_calls + + +@mark.asyncio +async def test_build_a_simple_data_loader(): + async def call_fn(keys): + return keys + + identity_loader = DataLoader(call_fn) + + promise1 = identity_loader.load(1) + + value1 = await promise1 + assert value1 == 1 + + +@mark.asyncio +async def test_can_build_a_data_loader_from_a_partial(): + value_map = {1: "one"} + + async def call_fn(context, keys): + return [context.get(key) for key in keys] + + partial_fn = partial(call_fn, value_map) + identity_loader = DataLoader(partial_fn) + + promise1 = identity_loader.load(1) + + value1 = await promise1 + assert value1 == "one" + + +@mark.asyncio +async def test_supports_loading_multiple_keys_in_one_call(): + async def call_fn(keys): + return keys + + identity_loader = DataLoader(call_fn) + + promise_all = identity_loader.load_many([1, 2]) + + values = await promise_all + assert values == [1, 2] + + promise_all = identity_loader.load_many([]) + + values = await promise_all + assert values == [] + + +@mark.asyncio +async def test_batches_multiple_requests(): + identity_loader, load_calls = id_loader() + + promise1 = identity_loader.load(1) + promise2 = identity_loader.load(2) + + p = gather(promise1, promise2) + + value1, value2 = await p + + assert value1 == 1 + assert value2 == 2 + + assert load_calls == [[1, 2]] + + +@mark.asyncio +async def test_batches_multiple_requests_with_max_batch_sizes(): + identity_loader, load_calls = id_loader(max_batch_size=2) + + promise1 = identity_loader.load(1) + promise2 = identity_loader.load(2) + promise3 = identity_loader.load(3) + + p = gather(promise1, promise2, promise3) + + value1, value2, value3 = await p + + assert value1 == 1 + assert value2 == 2 + assert value3 == 3 + + assert load_calls == [[1, 2], [3]] + + +@mark.asyncio +async def test_coalesces_identical_requests(): + identity_loader, load_calls = id_loader() + + promise1 = identity_loader.load(1) + promise2 = identity_loader.load(1) + + assert promise1 == promise2 + p = gather(promise1, promise2) + + value1, value2 = await p + + assert value1 == 1 + assert value2 == 1 + + assert load_calls == [[1]] + + +@mark.asyncio +async def test_caches_repeated_requests(): + identity_loader, load_calls = id_loader() + + a, b = await gather(identity_loader.load("A"), identity_loader.load("B")) + + assert a == "A" + assert b == "B" + + assert load_calls == [["A", "B"]] + + a2, c = await gather(identity_loader.load("A"), identity_loader.load("C")) + + assert a2 == "A" + assert c == "C" + + assert load_calls == [["A", "B"], ["C"]] + + a3, b2, c2 = await gather( + identity_loader.load("A"), identity_loader.load("B"), identity_loader.load("C") + ) + + assert a3 == "A" + assert b2 == "B" + assert c2 == "C" + + assert load_calls == [["A", "B"], ["C"]] + + +@mark.asyncio +async def test_clears_single_value_in_loader(): + identity_loader, load_calls = id_loader() + + a, b = await gather(identity_loader.load("A"), identity_loader.load("B")) + + assert a == "A" + assert b == "B" + + assert load_calls == [["A", "B"]] + + identity_loader.clear("A") + + a2, b2 = await gather(identity_loader.load("A"), identity_loader.load("B")) + + assert a2 == "A" + assert b2 == "B" + + assert load_calls == [["A", "B"], ["A"]] + + +@mark.asyncio +async def test_clears_all_values_in_loader(): + identity_loader, load_calls = id_loader() + + a, b = await gather(identity_loader.load("A"), identity_loader.load("B")) + + assert a == "A" + assert b == "B" + + assert load_calls == [["A", "B"]] + + identity_loader.clear_all() + + a2, b2 = await gather(identity_loader.load("A"), identity_loader.load("B")) + + assert a2 == "A" + assert b2 == "B" + + assert load_calls == [["A", "B"], ["A", "B"]] + + +@mark.asyncio +async def test_allows_priming_the_cache(): + identity_loader, load_calls = id_loader() + + identity_loader.prime("A", "A") + + a, b = await gather(identity_loader.load("A"), identity_loader.load("B")) + + assert a == "A" + assert b == "B" + + assert load_calls == [["B"]] + + +@mark.asyncio +async def test_does_not_prime_keys_that_already_exist(): + identity_loader, load_calls = id_loader() + + identity_loader.prime("A", "X") + + a1 = await identity_loader.load("A") + b1 = await identity_loader.load("B") + + assert a1 == "X" + assert b1 == "B" + + identity_loader.prime("A", "Y") + identity_loader.prime("B", "Y") + + a2 = await identity_loader.load("A") + b2 = await identity_loader.load("B") + + assert a2 == "X" + assert b2 == "B" + + assert load_calls == [["B"]] + + +# # Represents Errors +@mark.asyncio +async def test_resolves_to_error_to_indicate_failure(): + async def resolve(keys): + mapped_keys = [ + key if key % 2 == 0 else Exception("Odd: {}".format(key)) for key in keys + ] + return mapped_keys + + even_loader, load_calls = id_loader(resolve=resolve) + + with raises(Exception) as exc_info: + await even_loader.load(1) + + assert str(exc_info.value) == "Odd: 1" + + value2 = await even_loader.load(2) + assert value2 == 2 + assert load_calls == [[1], [2]] + + +@mark.asyncio +async def test_can_represent_failures_and_successes_simultaneously(): + async def resolve(keys): + mapped_keys = [ + key if key % 2 == 0 else Exception("Odd: {}".format(key)) for key in keys + ] + return mapped_keys + + even_loader, load_calls = id_loader(resolve=resolve) + + promise1 = even_loader.load(1) + promise2 = even_loader.load(2) + + with raises(Exception) as exc_info: + await promise1 + + assert str(exc_info.value) == "Odd: 1" + value2 = await promise2 + assert value2 == 2 + assert load_calls == [[1, 2]] + + +@mark.asyncio +async def test_caches_failed_fetches(): + async def resolve(keys): + mapped_keys = [Exception("Error: {}".format(key)) for key in keys] + return mapped_keys + + error_loader, load_calls = id_loader(resolve=resolve) + + with raises(Exception) as exc_info: + await error_loader.load(1) + + assert str(exc_info.value) == "Error: 1" + + with raises(Exception) as exc_info: + await error_loader.load(1) + + assert str(exc_info.value) == "Error: 1" + + assert load_calls == [[1]] + + +@mark.asyncio +async def test_caches_failed_fetches_2(): + identity_loader, load_calls = id_loader() + + identity_loader.prime(1, Exception("Error: 1")) + + with raises(Exception) as _: + await identity_loader.load(1) + + assert load_calls == [] + + +# It is resilient to job queue ordering +@mark.asyncio +async def test_batches_loads_occuring_within_promises(): + identity_loader, load_calls = id_loader() + + async def load_b_1(): + return await load_b_2() + + async def load_b_2(): + return await identity_loader.load("B") + + values = await gather(identity_loader.load("A"), load_b_1()) + + assert values == ["A", "B"] + + assert load_calls == [["A", "B"]] + + +@mark.asyncio +async def test_catches_error_if_loader_resolver_fails(): + exc = Exception("AOH!") + + def do_resolve(x): + raise exc + + a_loader, a_load_calls = id_loader(resolve=do_resolve) + + with raises(Exception) as exc_info: + await a_loader.load("A1") + + assert exc_info.value == exc + + +@mark.asyncio +async def test_can_call_a_loader_from_a_loader(): + deep_loader, deep_load_calls = id_loader() + a_loader, a_load_calls = id_loader( + resolve=lambda keys: deep_loader.load(tuple(keys)) + ) + b_loader, b_load_calls = id_loader( + resolve=lambda keys: deep_loader.load(tuple(keys)) + ) + + a1, b1, a2, b2 = await gather( + a_loader.load("A1"), + b_loader.load("B1"), + a_loader.load("A2"), + b_loader.load("B2"), + ) + + assert a1 == "A1" + assert b1 == "B1" + assert a2 == "A2" + assert b2 == "B2" + + assert a_load_calls == [["A1", "A2"]] + assert b_load_calls == [["B1", "B2"]] + assert deep_load_calls == [[("A1", "A2"), ("B1", "B2")]] + + +@mark.asyncio +async def test_dataloader_clear_with_missing_key_works(): + async def do_resolve(x): + return x + + a_loader, a_load_calls = id_loader(resolve=do_resolve) + assert a_loader.clear("A1") == a_loader diff --git a/setup.cfg b/setup.cfg index 2037bc1b..db1ff134 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,6 +2,10 @@ exclude = setup.py,docs/*,*/examples/*,graphene/pyutils/*,tests max-line-length = 120 +# This is a specific ignore for Black+Flake8 +# source: https://black.readthedocs.io/en/stable/guides/using_black_with_other_tools.html#id1 +extend-ignore = E203 + [coverage:run] omit = graphene/pyutils/*,*/tests/*,graphene/types/scalars.py diff --git a/setup.py b/setup.py index b87f56cc..dce6aa6c 100644 --- a/setup.py +++ b/setup.py @@ -53,7 +53,6 @@ tests_require = [ "snapshottest>=0.6,<1", "coveralls>=3.3,<4", "promise>=2.3,<3", - "aiodataloader<1", "mock>=4,<5", "pytz==2022.1", "iso8601>=1,<2", diff --git a/tests_asyncio/test_dataloader.py b/tests_asyncio/test_dataloader.py deleted file mode 100644 index fb8d1630..00000000 --- a/tests_asyncio/test_dataloader.py +++ /dev/null @@ -1,79 +0,0 @@ -from collections import namedtuple -from unittest.mock import Mock -from pytest import mark -from aiodataloader import DataLoader - -from graphene import ObjectType, String, Schema, Field, List - - -CHARACTERS = { - "1": {"name": "Luke Skywalker", "sibling": "3"}, - "2": {"name": "Darth Vader", "sibling": None}, - "3": {"name": "Leia Organa", "sibling": "1"}, -} - - -get_character = Mock(side_effect=lambda character_id: CHARACTERS[character_id]) - - -class CharacterType(ObjectType): - name = String() - sibling = Field(lambda: CharacterType) - - async def resolve_sibling(character, info): - if character["sibling"]: - return await info.context.character_loader.load(character["sibling"]) - return None - - -class Query(ObjectType): - skywalker_family = List(CharacterType) - - async def resolve_skywalker_family(_, info): - return await info.context.character_loader.load_many(["1", "2", "3"]) - - -mock_batch_load_fn = Mock( - side_effect=lambda character_ids: [get_character(id) for id in character_ids] -) - - -class CharacterLoader(DataLoader): - async def batch_load_fn(self, character_ids): - return mock_batch_load_fn(character_ids) - - -Context = namedtuple("Context", "character_loader") - - -@mark.asyncio -async def test_basic_dataloader(): - schema = Schema(query=Query) - - character_loader = CharacterLoader() - context = Context(character_loader=character_loader) - - query = """ - { - skywalkerFamily { - name - sibling { - name - } - } - } - """ - - result = await schema.execute_async(query, context=context) - - assert not result.errors - assert result.data == { - "skywalkerFamily": [ - {"name": "Luke Skywalker", "sibling": {"name": "Leia Organa"}}, - {"name": "Darth Vader", "sibling": None}, - {"name": "Leia Organa", "sibling": {"name": "Luke Skywalker"}}, - ] - } - - assert mock_batch_load_fn.call_count == 1 - assert get_character.call_count == 3