mirror of
				https://github.com/graphql-python/graphene.git
				synced 2025-10-30 23:47:55 +03:00 
			
		
		
		
	Vendor DataLoader from aiodataloader and move get_event_loop() out of __init__ function. (#1459)
				
					
				
			* Vendor DataLoader from aiodataloader and also move get_event_loop behavior from `__init__` to a property which only gets resolved when actually needed (this will solve PyTest-related to early get_event_loop() issues) * Added DataLoader's specific tests * plug `loop` parameter into `self._loop`, so that we still have the ability to pass in a custom event loop, if needed. Co-authored-by: Erik Wrede <erikwrede2@gmail.com>
This commit is contained in:
		
							parent
							
								
									20219fdc1b
								
							
						
					
					
						commit
						694c1db21e
					
				
							
								
								
									
										281
									
								
								graphene/utils/dataloader.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										281
									
								
								graphene/utils/dataloader.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,281 @@ | ||||||
|  | from asyncio import ( | ||||||
|  |     gather, | ||||||
|  |     ensure_future, | ||||||
|  |     get_event_loop, | ||||||
|  |     iscoroutine, | ||||||
|  |     iscoroutinefunction, | ||||||
|  | ) | ||||||
|  | from collections import namedtuple | ||||||
|  | from collections.abc import Iterable | ||||||
|  | from functools import partial | ||||||
|  | 
 | ||||||
|  | from typing import List  # flake8: noqa | ||||||
|  | 
 | ||||||
|  | Loader = namedtuple("Loader", "key,future") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def iscoroutinefunctionorpartial(fn): | ||||||
|  |     return iscoroutinefunction(fn.func if isinstance(fn, partial) else fn) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class DataLoader(object): | ||||||
|  |     batch = True | ||||||
|  |     max_batch_size = None  # type: int | ||||||
|  |     cache = True | ||||||
|  | 
 | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         batch_load_fn=None, | ||||||
|  |         batch=None, | ||||||
|  |         max_batch_size=None, | ||||||
|  |         cache=None, | ||||||
|  |         get_cache_key=None, | ||||||
|  |         cache_map=None, | ||||||
|  |         loop=None, | ||||||
|  |     ): | ||||||
|  | 
 | ||||||
|  |         self._loop = loop | ||||||
|  | 
 | ||||||
|  |         if batch_load_fn is not None: | ||||||
|  |             self.batch_load_fn = batch_load_fn | ||||||
|  | 
 | ||||||
|  |         assert iscoroutinefunctionorpartial( | ||||||
|  |             self.batch_load_fn | ||||||
|  |         ), "batch_load_fn must be coroutine. Received: {}".format(self.batch_load_fn) | ||||||
|  | 
 | ||||||
|  |         if not callable(self.batch_load_fn): | ||||||
|  |             raise TypeError(  # pragma: no cover | ||||||
|  |                 ( | ||||||
|  |                     "DataLoader must be have a batch_load_fn which accepts " | ||||||
|  |                     "Iterable<key> and returns Future<Iterable<value>>, but got: {}." | ||||||
|  |                 ).format(batch_load_fn) | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         if batch is not None: | ||||||
|  |             self.batch = batch  # pragma: no cover | ||||||
|  | 
 | ||||||
|  |         if max_batch_size is not None: | ||||||
|  |             self.max_batch_size = max_batch_size | ||||||
|  | 
 | ||||||
|  |         if cache is not None: | ||||||
|  |             self.cache = cache  # pragma: no cover | ||||||
|  | 
 | ||||||
|  |         self.get_cache_key = get_cache_key or (lambda x: x) | ||||||
|  | 
 | ||||||
|  |         self._cache = cache_map if cache_map is not None else {} | ||||||
|  |         self._queue = []  # type: List[Loader] | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def loop(self): | ||||||
|  |         if not self._loop: | ||||||
|  |             self._loop = get_event_loop() | ||||||
|  | 
 | ||||||
|  |         return self._loop | ||||||
|  | 
 | ||||||
|  |     def load(self, key=None): | ||||||
|  |         """ | ||||||
|  |         Loads a key, returning a `Future` for the value represented by that key. | ||||||
|  |         """ | ||||||
|  |         if key is None: | ||||||
|  |             raise TypeError(  # pragma: no cover | ||||||
|  |                 ( | ||||||
|  |                     "The loader.load() function must be called with a value, " | ||||||
|  |                     "but got: {}." | ||||||
|  |                 ).format(key) | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         cache_key = self.get_cache_key(key) | ||||||
|  | 
 | ||||||
|  |         # If caching and there is a cache-hit, return cached Future. | ||||||
|  |         if self.cache: | ||||||
|  |             cached_result = self._cache.get(cache_key) | ||||||
|  |             if cached_result: | ||||||
|  |                 return cached_result | ||||||
|  | 
 | ||||||
|  |         # Otherwise, produce a new Future for this value. | ||||||
|  |         future = self.loop.create_future() | ||||||
|  |         # If caching, cache this Future. | ||||||
|  |         if self.cache: | ||||||
|  |             self._cache[cache_key] = future | ||||||
|  | 
 | ||||||
|  |         self.do_resolve_reject(key, future) | ||||||
|  |         return future | ||||||
|  | 
 | ||||||
|  |     def do_resolve_reject(self, key, future): | ||||||
|  |         # Enqueue this Future to be dispatched. | ||||||
|  |         self._queue.append(Loader(key=key, future=future)) | ||||||
|  |         # Determine if a dispatch of this queue should be scheduled. | ||||||
|  |         # A single dispatch should be scheduled per queue at the time when the | ||||||
|  |         # queue changes from "empty" to "full". | ||||||
|  |         if len(self._queue) == 1: | ||||||
|  |             if self.batch: | ||||||
|  |                 # If batching, schedule a task to dispatch the queue. | ||||||
|  |                 enqueue_post_future_job(self.loop, self) | ||||||
|  |             else: | ||||||
|  |                 # Otherwise dispatch the (queue of one) immediately. | ||||||
|  |                 dispatch_queue(self)  # pragma: no cover | ||||||
|  | 
 | ||||||
|  |     def load_many(self, keys): | ||||||
|  |         """ | ||||||
|  |         Loads multiple keys, returning a list of values | ||||||
|  | 
 | ||||||
|  |         >>> a, b = await my_loader.load_many([ 'a', 'b' ]) | ||||||
|  | 
 | ||||||
|  |         This is equivalent to the more verbose: | ||||||
|  | 
 | ||||||
|  |         >>> a, b = await gather( | ||||||
|  |         >>>    my_loader.load('a'), | ||||||
|  |         >>>    my_loader.load('b') | ||||||
|  |         >>> ) | ||||||
|  |         """ | ||||||
|  |         if not isinstance(keys, Iterable): | ||||||
|  |             raise TypeError(  # pragma: no cover | ||||||
|  |                 ( | ||||||
|  |                     "The loader.load_many() function must be called with Iterable<key> " | ||||||
|  |                     "but got: {}." | ||||||
|  |                 ).format(keys) | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         return gather(*[self.load(key) for key in keys]) | ||||||
|  | 
 | ||||||
|  |     def clear(self, key): | ||||||
|  |         """ | ||||||
|  |         Clears the value at `key` from the cache, if it exists. Returns itself for | ||||||
|  |         method chaining. | ||||||
|  |         """ | ||||||
|  |         cache_key = self.get_cache_key(key) | ||||||
|  |         self._cache.pop(cache_key, None) | ||||||
|  |         return self | ||||||
|  | 
 | ||||||
|  |     def clear_all(self): | ||||||
|  |         """ | ||||||
|  |         Clears the entire cache. To be used when some event results in unknown | ||||||
|  |         invalidations across this particular `DataLoader`. Returns itself for | ||||||
|  |         method chaining. | ||||||
|  |         """ | ||||||
|  |         self._cache.clear() | ||||||
|  |         return self | ||||||
|  | 
 | ||||||
|  |     def prime(self, key, value): | ||||||
|  |         """ | ||||||
|  |         Adds the provied key and value to the cache. If the key already exists, no | ||||||
|  |         change is made. Returns itself for method chaining. | ||||||
|  |         """ | ||||||
|  |         cache_key = self.get_cache_key(key) | ||||||
|  | 
 | ||||||
|  |         # Only add the key if it does not already exist. | ||||||
|  |         if cache_key not in self._cache: | ||||||
|  |             # Cache a rejected future if the value is an Error, in order to match | ||||||
|  |             # the behavior of load(key). | ||||||
|  |             future = self.loop.create_future() | ||||||
|  |             if isinstance(value, Exception): | ||||||
|  |                 future.set_exception(value) | ||||||
|  |             else: | ||||||
|  |                 future.set_result(value) | ||||||
|  | 
 | ||||||
|  |             self._cache[cache_key] = future | ||||||
|  | 
 | ||||||
|  |         return self | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def enqueue_post_future_job(loop, loader): | ||||||
|  |     async def dispatch(): | ||||||
|  |         dispatch_queue(loader) | ||||||
|  | 
 | ||||||
|  |     loop.call_soon(ensure_future, dispatch()) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_chunks(iterable_obj, chunk_size=1): | ||||||
|  |     chunk_size = max(1, chunk_size) | ||||||
|  |     return ( | ||||||
|  |         iterable_obj[i : i + chunk_size] | ||||||
|  |         for i in range(0, len(iterable_obj), chunk_size) | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def dispatch_queue(loader): | ||||||
|  |     """ | ||||||
|  |     Given the current state of a Loader instance, perform a batch load | ||||||
|  |     from its current queue. | ||||||
|  |     """ | ||||||
|  |     # Take the current loader queue, replacing it with an empty queue. | ||||||
|  |     queue = loader._queue | ||||||
|  |     loader._queue = [] | ||||||
|  | 
 | ||||||
|  |     # If a max_batch_size was provided and the queue is longer, then segment the | ||||||
|  |     # queue into multiple batches, otherwise treat the queue as a single batch. | ||||||
|  |     max_batch_size = loader.max_batch_size | ||||||
|  | 
 | ||||||
|  |     if max_batch_size and max_batch_size < len(queue): | ||||||
|  |         chunks = get_chunks(queue, max_batch_size) | ||||||
|  |         for chunk in chunks: | ||||||
|  |             ensure_future(dispatch_queue_batch(loader, chunk)) | ||||||
|  |     else: | ||||||
|  |         ensure_future(dispatch_queue_batch(loader, queue)) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | async def dispatch_queue_batch(loader, queue): | ||||||
|  |     # Collect all keys to be loaded in this dispatch | ||||||
|  |     keys = [loaded.key for loaded in queue] | ||||||
|  | 
 | ||||||
|  |     # Call the provided batch_load_fn for this loader with the loader queue's keys. | ||||||
|  |     batch_future = loader.batch_load_fn(keys) | ||||||
|  | 
 | ||||||
|  |     # Assert the expected response from batch_load_fn | ||||||
|  |     if not batch_future or not iscoroutine(batch_future): | ||||||
|  |         return failed_dispatch(  # pragma: no cover | ||||||
|  |             loader, | ||||||
|  |             queue, | ||||||
|  |             TypeError( | ||||||
|  |                 ( | ||||||
|  |                     "DataLoader must be constructed with a function which accepts " | ||||||
|  |                     "Iterable<key> and returns Future<Iterable<value>>, but the function did " | ||||||
|  |                     "not return a Coroutine: {}." | ||||||
|  |                 ).format(batch_future) | ||||||
|  |             ), | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     try: | ||||||
|  |         values = await batch_future | ||||||
|  |         if not isinstance(values, Iterable): | ||||||
|  |             raise TypeError(  # pragma: no cover | ||||||
|  |                 ( | ||||||
|  |                     "DataLoader must be constructed with a function which accepts " | ||||||
|  |                     "Iterable<key> and returns Future<Iterable<value>>, but the function did " | ||||||
|  |                     "not return a Future of a Iterable: {}." | ||||||
|  |                 ).format(values) | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         values = list(values) | ||||||
|  |         if len(values) != len(keys): | ||||||
|  |             raise TypeError(  # pragma: no cover | ||||||
|  |                 ( | ||||||
|  |                     "DataLoader must be constructed with a function which accepts " | ||||||
|  |                     "Iterable<key> and returns Future<Iterable<value>>, but the function did " | ||||||
|  |                     "not return a Future of a Iterable with the same length as the Iterable " | ||||||
|  |                     "of keys." | ||||||
|  |                     "\n\nKeys:\n{}" | ||||||
|  |                     "\n\nValues:\n{}" | ||||||
|  |                 ).format(keys, values) | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         # Step through the values, resolving or rejecting each Future in the | ||||||
|  |         # loaded queue. | ||||||
|  |         for loaded, value in zip(queue, values): | ||||||
|  |             if isinstance(value, Exception): | ||||||
|  |                 loaded.future.set_exception(value) | ||||||
|  |             else: | ||||||
|  |                 loaded.future.set_result(value) | ||||||
|  | 
 | ||||||
|  |     except Exception as e: | ||||||
|  |         return failed_dispatch(loader, queue, e) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def failed_dispatch(loader, queue, error): | ||||||
|  |     """ | ||||||
|  |     Do not cache individual loads if the entire batch dispatch fails, | ||||||
|  |     but still reject each request so they do not hang. | ||||||
|  |     """ | ||||||
|  |     for loaded in queue: | ||||||
|  |         loader.clear(loaded.key) | ||||||
|  |         loaded.future.set_exception(error) | ||||||
							
								
								
									
										452
									
								
								graphene/utils/tests/test_dataloader.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										452
									
								
								graphene/utils/tests/test_dataloader.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,452 @@ | ||||||
|  | from asyncio import gather | ||||||
|  | from collections import namedtuple | ||||||
|  | from functools import partial | ||||||
|  | from unittest.mock import Mock | ||||||
|  | 
 | ||||||
|  | from graphene.utils.dataloader import DataLoader | ||||||
|  | from pytest import mark, raises | ||||||
|  | 
 | ||||||
|  | from graphene import ObjectType, String, Schema, Field, List | ||||||
|  | 
 | ||||||
|  | CHARACTERS = { | ||||||
|  |     "1": {"name": "Luke Skywalker", "sibling": "3"}, | ||||||
|  |     "2": {"name": "Darth Vader", "sibling": None}, | ||||||
|  |     "3": {"name": "Leia Organa", "sibling": "1"}, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | get_character = Mock(side_effect=lambda character_id: CHARACTERS[character_id]) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class CharacterType(ObjectType): | ||||||
|  |     name = String() | ||||||
|  |     sibling = Field(lambda: CharacterType) | ||||||
|  | 
 | ||||||
|  |     async def resolve_sibling(character, info): | ||||||
|  |         if character["sibling"]: | ||||||
|  |             return await info.context.character_loader.load(character["sibling"]) | ||||||
|  |         return None | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Query(ObjectType): | ||||||
|  |     skywalker_family = List(CharacterType) | ||||||
|  | 
 | ||||||
|  |     async def resolve_skywalker_family(_, info): | ||||||
|  |         return await info.context.character_loader.load_many(["1", "2", "3"]) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | mock_batch_load_fn = Mock( | ||||||
|  |     side_effect=lambda character_ids: [get_character(id) for id in character_ids] | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class CharacterLoader(DataLoader): | ||||||
|  |     async def batch_load_fn(self, character_ids): | ||||||
|  |         return mock_batch_load_fn(character_ids) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | Context = namedtuple("Context", "character_loader") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @mark.asyncio | ||||||
|  | async def test_basic_dataloader(): | ||||||
|  |     schema = Schema(query=Query) | ||||||
|  | 
 | ||||||
|  |     character_loader = CharacterLoader() | ||||||
|  |     context = Context(character_loader=character_loader) | ||||||
|  | 
 | ||||||
|  |     query = """ | ||||||
|  |         { | ||||||
|  |             skywalkerFamily { | ||||||
|  |                 name | ||||||
|  |                 sibling { | ||||||
|  |                     name | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     result = await schema.execute_async(query, context=context) | ||||||
|  | 
 | ||||||
|  |     assert not result.errors | ||||||
|  |     assert result.data == { | ||||||
|  |         "skywalkerFamily": [ | ||||||
|  |             {"name": "Luke Skywalker", "sibling": {"name": "Leia Organa"}}, | ||||||
|  |             {"name": "Darth Vader", "sibling": None}, | ||||||
|  |             {"name": "Leia Organa", "sibling": {"name": "Luke Skywalker"}}, | ||||||
|  |         ] | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     assert mock_batch_load_fn.call_count == 1 | ||||||
|  |     assert get_character.call_count == 3 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def id_loader(**options): | ||||||
|  |     load_calls = [] | ||||||
|  | 
 | ||||||
|  |     async def default_resolve(x): | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  |     resolve = options.pop("resolve", default_resolve) | ||||||
|  | 
 | ||||||
|  |     async def fn(keys): | ||||||
|  |         load_calls.append(keys) | ||||||
|  |         return await resolve(keys) | ||||||
|  |         # return keys | ||||||
|  | 
 | ||||||
|  |     identity_loader = DataLoader(fn, **options) | ||||||
|  |     return identity_loader, load_calls | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @mark.asyncio | ||||||
|  | async def test_build_a_simple_data_loader(): | ||||||
|  |     async def call_fn(keys): | ||||||
|  |         return keys | ||||||
|  | 
 | ||||||
|  |     identity_loader = DataLoader(call_fn) | ||||||
|  | 
 | ||||||
|  |     promise1 = identity_loader.load(1) | ||||||
|  | 
 | ||||||
|  |     value1 = await promise1 | ||||||
|  |     assert value1 == 1 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @mark.asyncio | ||||||
|  | async def test_can_build_a_data_loader_from_a_partial(): | ||||||
|  |     value_map = {1: "one"} | ||||||
|  | 
 | ||||||
|  |     async def call_fn(context, keys): | ||||||
|  |         return [context.get(key) for key in keys] | ||||||
|  | 
 | ||||||
|  |     partial_fn = partial(call_fn, value_map) | ||||||
|  |     identity_loader = DataLoader(partial_fn) | ||||||
|  | 
 | ||||||
|  |     promise1 = identity_loader.load(1) | ||||||
|  | 
 | ||||||
|  |     value1 = await promise1 | ||||||
|  |     assert value1 == "one" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @mark.asyncio | ||||||
|  | async def test_supports_loading_multiple_keys_in_one_call(): | ||||||
|  |     async def call_fn(keys): | ||||||
|  |         return keys | ||||||
|  | 
 | ||||||
|  |     identity_loader = DataLoader(call_fn) | ||||||
|  | 
 | ||||||
|  |     promise_all = identity_loader.load_many([1, 2]) | ||||||
|  | 
 | ||||||
|  |     values = await promise_all | ||||||
|  |     assert values == [1, 2] | ||||||
|  | 
 | ||||||
|  |     promise_all = identity_loader.load_many([]) | ||||||
|  | 
 | ||||||
|  |     values = await promise_all | ||||||
|  |     assert values == [] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @mark.asyncio | ||||||
|  | async def test_batches_multiple_requests(): | ||||||
|  |     identity_loader, load_calls = id_loader() | ||||||
|  | 
 | ||||||
|  |     promise1 = identity_loader.load(1) | ||||||
|  |     promise2 = identity_loader.load(2) | ||||||
|  | 
 | ||||||
|  |     p = gather(promise1, promise2) | ||||||
|  | 
 | ||||||
|  |     value1, value2 = await p | ||||||
|  | 
 | ||||||
|  |     assert value1 == 1 | ||||||
|  |     assert value2 == 2 | ||||||
|  | 
 | ||||||
|  |     assert load_calls == [[1, 2]] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @mark.asyncio | ||||||
|  | async def test_batches_multiple_requests_with_max_batch_sizes(): | ||||||
|  |     identity_loader, load_calls = id_loader(max_batch_size=2) | ||||||
|  | 
 | ||||||
|  |     promise1 = identity_loader.load(1) | ||||||
|  |     promise2 = identity_loader.load(2) | ||||||
|  |     promise3 = identity_loader.load(3) | ||||||
|  | 
 | ||||||
|  |     p = gather(promise1, promise2, promise3) | ||||||
|  | 
 | ||||||
|  |     value1, value2, value3 = await p | ||||||
|  | 
 | ||||||
|  |     assert value1 == 1 | ||||||
|  |     assert value2 == 2 | ||||||
|  |     assert value3 == 3 | ||||||
|  | 
 | ||||||
|  |     assert load_calls == [[1, 2], [3]] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @mark.asyncio | ||||||
|  | async def test_coalesces_identical_requests(): | ||||||
|  |     identity_loader, load_calls = id_loader() | ||||||
|  | 
 | ||||||
|  |     promise1 = identity_loader.load(1) | ||||||
|  |     promise2 = identity_loader.load(1) | ||||||
|  | 
 | ||||||
|  |     assert promise1 == promise2 | ||||||
|  |     p = gather(promise1, promise2) | ||||||
|  | 
 | ||||||
|  |     value1, value2 = await p | ||||||
|  | 
 | ||||||
|  |     assert value1 == 1 | ||||||
|  |     assert value2 == 1 | ||||||
|  | 
 | ||||||
|  |     assert load_calls == [[1]] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @mark.asyncio | ||||||
|  | async def test_caches_repeated_requests(): | ||||||
|  |     identity_loader, load_calls = id_loader() | ||||||
|  | 
 | ||||||
|  |     a, b = await gather(identity_loader.load("A"), identity_loader.load("B")) | ||||||
|  | 
 | ||||||
|  |     assert a == "A" | ||||||
|  |     assert b == "B" | ||||||
|  | 
 | ||||||
|  |     assert load_calls == [["A", "B"]] | ||||||
|  | 
 | ||||||
|  |     a2, c = await gather(identity_loader.load("A"), identity_loader.load("C")) | ||||||
|  | 
 | ||||||
|  |     assert a2 == "A" | ||||||
|  |     assert c == "C" | ||||||
|  | 
 | ||||||
|  |     assert load_calls == [["A", "B"], ["C"]] | ||||||
|  | 
 | ||||||
|  |     a3, b2, c2 = await gather( | ||||||
|  |         identity_loader.load("A"), identity_loader.load("B"), identity_loader.load("C") | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     assert a3 == "A" | ||||||
|  |     assert b2 == "B" | ||||||
|  |     assert c2 == "C" | ||||||
|  | 
 | ||||||
|  |     assert load_calls == [["A", "B"], ["C"]] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @mark.asyncio | ||||||
|  | async def test_clears_single_value_in_loader(): | ||||||
|  |     identity_loader, load_calls = id_loader() | ||||||
|  | 
 | ||||||
|  |     a, b = await gather(identity_loader.load("A"), identity_loader.load("B")) | ||||||
|  | 
 | ||||||
|  |     assert a == "A" | ||||||
|  |     assert b == "B" | ||||||
|  | 
 | ||||||
|  |     assert load_calls == [["A", "B"]] | ||||||
|  | 
 | ||||||
|  |     identity_loader.clear("A") | ||||||
|  | 
 | ||||||
|  |     a2, b2 = await gather(identity_loader.load("A"), identity_loader.load("B")) | ||||||
|  | 
 | ||||||
|  |     assert a2 == "A" | ||||||
|  |     assert b2 == "B" | ||||||
|  | 
 | ||||||
|  |     assert load_calls == [["A", "B"], ["A"]] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @mark.asyncio | ||||||
|  | async def test_clears_all_values_in_loader(): | ||||||
|  |     identity_loader, load_calls = id_loader() | ||||||
|  | 
 | ||||||
|  |     a, b = await gather(identity_loader.load("A"), identity_loader.load("B")) | ||||||
|  | 
 | ||||||
|  |     assert a == "A" | ||||||
|  |     assert b == "B" | ||||||
|  | 
 | ||||||
|  |     assert load_calls == [["A", "B"]] | ||||||
|  | 
 | ||||||
|  |     identity_loader.clear_all() | ||||||
|  | 
 | ||||||
|  |     a2, b2 = await gather(identity_loader.load("A"), identity_loader.load("B")) | ||||||
|  | 
 | ||||||
|  |     assert a2 == "A" | ||||||
|  |     assert b2 == "B" | ||||||
|  | 
 | ||||||
|  |     assert load_calls == [["A", "B"], ["A", "B"]] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @mark.asyncio | ||||||
|  | async def test_allows_priming_the_cache(): | ||||||
|  |     identity_loader, load_calls = id_loader() | ||||||
|  | 
 | ||||||
|  |     identity_loader.prime("A", "A") | ||||||
|  | 
 | ||||||
|  |     a, b = await gather(identity_loader.load("A"), identity_loader.load("B")) | ||||||
|  | 
 | ||||||
|  |     assert a == "A" | ||||||
|  |     assert b == "B" | ||||||
|  | 
 | ||||||
|  |     assert load_calls == [["B"]] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @mark.asyncio | ||||||
|  | async def test_does_not_prime_keys_that_already_exist(): | ||||||
|  |     identity_loader, load_calls = id_loader() | ||||||
|  | 
 | ||||||
|  |     identity_loader.prime("A", "X") | ||||||
|  | 
 | ||||||
|  |     a1 = await identity_loader.load("A") | ||||||
|  |     b1 = await identity_loader.load("B") | ||||||
|  | 
 | ||||||
|  |     assert a1 == "X" | ||||||
|  |     assert b1 == "B" | ||||||
|  | 
 | ||||||
|  |     identity_loader.prime("A", "Y") | ||||||
|  |     identity_loader.prime("B", "Y") | ||||||
|  | 
 | ||||||
|  |     a2 = await identity_loader.load("A") | ||||||
|  |     b2 = await identity_loader.load("B") | ||||||
|  | 
 | ||||||
|  |     assert a2 == "X" | ||||||
|  |     assert b2 == "B" | ||||||
|  | 
 | ||||||
|  |     assert load_calls == [["B"]] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # # Represents Errors | ||||||
|  | @mark.asyncio | ||||||
|  | async def test_resolves_to_error_to_indicate_failure(): | ||||||
|  |     async def resolve(keys): | ||||||
|  |         mapped_keys = [ | ||||||
|  |             key if key % 2 == 0 else Exception("Odd: {}".format(key)) for key in keys | ||||||
|  |         ] | ||||||
|  |         return mapped_keys | ||||||
|  | 
 | ||||||
|  |     even_loader, load_calls = id_loader(resolve=resolve) | ||||||
|  | 
 | ||||||
|  |     with raises(Exception) as exc_info: | ||||||
|  |         await even_loader.load(1) | ||||||
|  | 
 | ||||||
|  |     assert str(exc_info.value) == "Odd: 1" | ||||||
|  | 
 | ||||||
|  |     value2 = await even_loader.load(2) | ||||||
|  |     assert value2 == 2 | ||||||
|  |     assert load_calls == [[1], [2]] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @mark.asyncio | ||||||
|  | async def test_can_represent_failures_and_successes_simultaneously(): | ||||||
|  |     async def resolve(keys): | ||||||
|  |         mapped_keys = [ | ||||||
|  |             key if key % 2 == 0 else Exception("Odd: {}".format(key)) for key in keys | ||||||
|  |         ] | ||||||
|  |         return mapped_keys | ||||||
|  | 
 | ||||||
|  |     even_loader, load_calls = id_loader(resolve=resolve) | ||||||
|  | 
 | ||||||
|  |     promise1 = even_loader.load(1) | ||||||
|  |     promise2 = even_loader.load(2) | ||||||
|  | 
 | ||||||
|  |     with raises(Exception) as exc_info: | ||||||
|  |         await promise1 | ||||||
|  | 
 | ||||||
|  |     assert str(exc_info.value) == "Odd: 1" | ||||||
|  |     value2 = await promise2 | ||||||
|  |     assert value2 == 2 | ||||||
|  |     assert load_calls == [[1, 2]] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @mark.asyncio | ||||||
|  | async def test_caches_failed_fetches(): | ||||||
|  |     async def resolve(keys): | ||||||
|  |         mapped_keys = [Exception("Error: {}".format(key)) for key in keys] | ||||||
|  |         return mapped_keys | ||||||
|  | 
 | ||||||
|  |     error_loader, load_calls = id_loader(resolve=resolve) | ||||||
|  | 
 | ||||||
|  |     with raises(Exception) as exc_info: | ||||||
|  |         await error_loader.load(1) | ||||||
|  | 
 | ||||||
|  |     assert str(exc_info.value) == "Error: 1" | ||||||
|  | 
 | ||||||
|  |     with raises(Exception) as exc_info: | ||||||
|  |         await error_loader.load(1) | ||||||
|  | 
 | ||||||
|  |     assert str(exc_info.value) == "Error: 1" | ||||||
|  | 
 | ||||||
|  |     assert load_calls == [[1]] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @mark.asyncio | ||||||
|  | async def test_caches_failed_fetches_2(): | ||||||
|  |     identity_loader, load_calls = id_loader() | ||||||
|  | 
 | ||||||
|  |     identity_loader.prime(1, Exception("Error: 1")) | ||||||
|  | 
 | ||||||
|  |     with raises(Exception) as _: | ||||||
|  |         await identity_loader.load(1) | ||||||
|  | 
 | ||||||
|  |     assert load_calls == [] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # It is resilient to job queue ordering | ||||||
|  | @mark.asyncio | ||||||
|  | async def test_batches_loads_occuring_within_promises(): | ||||||
|  |     identity_loader, load_calls = id_loader() | ||||||
|  | 
 | ||||||
|  |     async def load_b_1(): | ||||||
|  |         return await load_b_2() | ||||||
|  | 
 | ||||||
|  |     async def load_b_2(): | ||||||
|  |         return await identity_loader.load("B") | ||||||
|  | 
 | ||||||
|  |     values = await gather(identity_loader.load("A"), load_b_1()) | ||||||
|  | 
 | ||||||
|  |     assert values == ["A", "B"] | ||||||
|  | 
 | ||||||
|  |     assert load_calls == [["A", "B"]] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @mark.asyncio | ||||||
|  | async def test_catches_error_if_loader_resolver_fails(): | ||||||
|  |     exc = Exception("AOH!") | ||||||
|  | 
 | ||||||
|  |     def do_resolve(x): | ||||||
|  |         raise exc | ||||||
|  | 
 | ||||||
|  |     a_loader, a_load_calls = id_loader(resolve=do_resolve) | ||||||
|  | 
 | ||||||
|  |     with raises(Exception) as exc_info: | ||||||
|  |         await a_loader.load("A1") | ||||||
|  | 
 | ||||||
|  |     assert exc_info.value == exc | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @mark.asyncio | ||||||
|  | async def test_can_call_a_loader_from_a_loader(): | ||||||
|  |     deep_loader, deep_load_calls = id_loader() | ||||||
|  |     a_loader, a_load_calls = id_loader( | ||||||
|  |         resolve=lambda keys: deep_loader.load(tuple(keys)) | ||||||
|  |     ) | ||||||
|  |     b_loader, b_load_calls = id_loader( | ||||||
|  |         resolve=lambda keys: deep_loader.load(tuple(keys)) | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     a1, b1, a2, b2 = await gather( | ||||||
|  |         a_loader.load("A1"), | ||||||
|  |         b_loader.load("B1"), | ||||||
|  |         a_loader.load("A2"), | ||||||
|  |         b_loader.load("B2"), | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     assert a1 == "A1" | ||||||
|  |     assert b1 == "B1" | ||||||
|  |     assert a2 == "A2" | ||||||
|  |     assert b2 == "B2" | ||||||
|  | 
 | ||||||
|  |     assert a_load_calls == [["A1", "A2"]] | ||||||
|  |     assert b_load_calls == [["B1", "B2"]] | ||||||
|  |     assert deep_load_calls == [[("A1", "A2"), ("B1", "B2")]] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @mark.asyncio | ||||||
|  | async def test_dataloader_clear_with_missing_key_works(): | ||||||
|  |     async def do_resolve(x): | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  |     a_loader, a_load_calls = id_loader(resolve=do_resolve) | ||||||
|  |     assert a_loader.clear("A1") == a_loader | ||||||
|  | @ -2,6 +2,10 @@ | ||||||
| exclude = setup.py,docs/*,*/examples/*,graphene/pyutils/*,tests | exclude = setup.py,docs/*,*/examples/*,graphene/pyutils/*,tests | ||||||
| max-line-length = 120 | max-line-length = 120 | ||||||
| 
 | 
 | ||||||
|  | # This is a specific ignore for Black+Flake8 | ||||||
|  | # source: https://black.readthedocs.io/en/stable/guides/using_black_with_other_tools.html#id1 | ||||||
|  | extend-ignore = E203 | ||||||
|  | 
 | ||||||
| [coverage:run] | [coverage:run] | ||||||
| omit = graphene/pyutils/*,*/tests/*,graphene/types/scalars.py | omit = graphene/pyutils/*,*/tests/*,graphene/types/scalars.py | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
							
								
								
									
										1
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								setup.py
									
									
									
									
									
								
							|  | @ -53,7 +53,6 @@ tests_require = [ | ||||||
|     "snapshottest>=0.6,<1", |     "snapshottest>=0.6,<1", | ||||||
|     "coveralls>=3.3,<4", |     "coveralls>=3.3,<4", | ||||||
|     "promise>=2.3,<3", |     "promise>=2.3,<3", | ||||||
|     "aiodataloader<1", |  | ||||||
|     "mock>=4,<5", |     "mock>=4,<5", | ||||||
|     "pytz==2022.1", |     "pytz==2022.1", | ||||||
|     "iso8601>=1,<2", |     "iso8601>=1,<2", | ||||||
|  |  | ||||||
|  | @ -1,79 +0,0 @@ | ||||||
| from collections import namedtuple |  | ||||||
| from unittest.mock import Mock |  | ||||||
| from pytest import mark |  | ||||||
| from aiodataloader import DataLoader |  | ||||||
| 
 |  | ||||||
| from graphene import ObjectType, String, Schema, Field, List |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| CHARACTERS = { |  | ||||||
|     "1": {"name": "Luke Skywalker", "sibling": "3"}, |  | ||||||
|     "2": {"name": "Darth Vader", "sibling": None}, |  | ||||||
|     "3": {"name": "Leia Organa", "sibling": "1"}, |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| get_character = Mock(side_effect=lambda character_id: CHARACTERS[character_id]) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class CharacterType(ObjectType): |  | ||||||
|     name = String() |  | ||||||
|     sibling = Field(lambda: CharacterType) |  | ||||||
| 
 |  | ||||||
|     async def resolve_sibling(character, info): |  | ||||||
|         if character["sibling"]: |  | ||||||
|             return await info.context.character_loader.load(character["sibling"]) |  | ||||||
|         return None |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class Query(ObjectType): |  | ||||||
|     skywalker_family = List(CharacterType) |  | ||||||
| 
 |  | ||||||
|     async def resolve_skywalker_family(_, info): |  | ||||||
|         return await info.context.character_loader.load_many(["1", "2", "3"]) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| mock_batch_load_fn = Mock( |  | ||||||
|     side_effect=lambda character_ids: [get_character(id) for id in character_ids] |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class CharacterLoader(DataLoader): |  | ||||||
|     async def batch_load_fn(self, character_ids): |  | ||||||
|         return mock_batch_load_fn(character_ids) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| Context = namedtuple("Context", "character_loader") |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| @mark.asyncio |  | ||||||
| async def test_basic_dataloader(): |  | ||||||
|     schema = Schema(query=Query) |  | ||||||
| 
 |  | ||||||
|     character_loader = CharacterLoader() |  | ||||||
|     context = Context(character_loader=character_loader) |  | ||||||
| 
 |  | ||||||
|     query = """ |  | ||||||
|         { |  | ||||||
|             skywalkerFamily { |  | ||||||
|                 name |  | ||||||
|                 sibling { |  | ||||||
|                     name |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     """ |  | ||||||
| 
 |  | ||||||
|     result = await schema.execute_async(query, context=context) |  | ||||||
| 
 |  | ||||||
|     assert not result.errors |  | ||||||
|     assert result.data == { |  | ||||||
|         "skywalkerFamily": [ |  | ||||||
|             {"name": "Luke Skywalker", "sibling": {"name": "Leia Organa"}}, |  | ||||||
|             {"name": "Darth Vader", "sibling": None}, |  | ||||||
|             {"name": "Leia Organa", "sibling": {"name": "Luke Skywalker"}}, |  | ||||||
|         ] |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     assert mock_batch_load_fn.call_count == 1 |  | ||||||
|     assert get_character.call_count == 3 |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user