mirror of
https://github.com/graphql-python/graphene.git
synced 2025-07-18 03:52:24 +03:00
Vendor DataLoader from aiodataloader and also move get_event_loop behavior from __init__
to a property which only gets resolved when actually needed (this will solve PyTest-related early get_event_loop() issues)
This commit is contained in:
parent
20219fdc1b
commit
2467efc989
282
graphene/utils/dataloader.py
Normal file
282
graphene/utils/dataloader.py
Normal file
|
@ -0,0 +1,282 @@
|
|||
from asyncio import (
|
||||
gather,
|
||||
ensure_future,
|
||||
get_event_loop,
|
||||
iscoroutine,
|
||||
iscoroutinefunction,
|
||||
)
|
||||
from collections import namedtuple
|
||||
from collections.abc import Iterable
|
||||
from functools import partial
|
||||
|
||||
from typing import List # flake8: noqa
|
||||
|
||||
Loader = namedtuple("Loader", "key,future")
|
||||
|
||||
|
||||
def iscoroutinefunctionorpartial(fn):
|
||||
return iscoroutinefunction(fn.func if isinstance(fn, partial) else fn)
|
||||
|
||||
|
||||
class DataLoader(object):
|
||||
batch = True
|
||||
max_batch_size = None # type: int
|
||||
cache = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_load_fn=None,
|
||||
batch=None,
|
||||
max_batch_size=None,
|
||||
cache=None,
|
||||
get_cache_key=None,
|
||||
cache_map=None,
|
||||
loop=None,
|
||||
):
|
||||
|
||||
# Create empty _loop which will be populated with asyncio's event loop as soon as it's needed.
|
||||
self._loop = None
|
||||
|
||||
if batch_load_fn is not None:
|
||||
self.batch_load_fn = batch_load_fn
|
||||
|
||||
assert iscoroutinefunctionorpartial(
|
||||
self.batch_load_fn
|
||||
), "batch_load_fn must be coroutine. Received: {}".format(self.batch_load_fn)
|
||||
|
||||
if not callable(self.batch_load_fn):
|
||||
raise TypeError(
|
||||
(
|
||||
"DataLoader must be have a batch_load_fn which accepts "
|
||||
"Iterable<key> and returns Future<Iterable<value>>, but got: {}."
|
||||
).format(batch_load_fn)
|
||||
)
|
||||
|
||||
if batch is not None:
|
||||
self.batch = batch
|
||||
|
||||
if max_batch_size is not None:
|
||||
self.max_batch_size = max_batch_size
|
||||
|
||||
if cache is not None:
|
||||
self.cache = cache
|
||||
|
||||
self.get_cache_key = get_cache_key or (lambda x: x)
|
||||
|
||||
self._cache = cache_map if cache_map is not None else {}
|
||||
self._queue = [] # type: List[Loader]
|
||||
|
||||
@property
|
||||
def loop(self):
|
||||
if not self._loop:
|
||||
self._loop = get_event_loop()
|
||||
|
||||
return self._loop
|
||||
|
||||
def load(self, key=None):
|
||||
"""
|
||||
Loads a key, returning a `Future` for the value represented by that key.
|
||||
"""
|
||||
if key is None:
|
||||
raise TypeError(
|
||||
(
|
||||
"The loader.load() function must be called with a value, "
|
||||
"but got: {}."
|
||||
).format(key)
|
||||
)
|
||||
|
||||
cache_key = self.get_cache_key(key)
|
||||
|
||||
# If caching and there is a cache-hit, return cached Future.
|
||||
if self.cache:
|
||||
cached_result = self._cache.get(cache_key)
|
||||
if cached_result:
|
||||
return cached_result
|
||||
|
||||
# Otherwise, produce a new Future for this value.
|
||||
future = self.loop.create_future()
|
||||
# If caching, cache this Future.
|
||||
if self.cache:
|
||||
self._cache[cache_key] = future
|
||||
|
||||
self.do_resolve_reject(key, future)
|
||||
return future
|
||||
|
||||
def do_resolve_reject(self, key, future):
|
||||
# Enqueue this Future to be dispatched.
|
||||
self._queue.append(Loader(key=key, future=future))
|
||||
# Determine if a dispatch of this queue should be scheduled.
|
||||
# A single dispatch should be scheduled per queue at the time when the
|
||||
# queue changes from "empty" to "full".
|
||||
if len(self._queue) == 1:
|
||||
if self.batch:
|
||||
# If batching, schedule a task to dispatch the queue.
|
||||
enqueue_post_future_job(self.loop, self)
|
||||
else:
|
||||
# Otherwise dispatch the (queue of one) immediately.
|
||||
dispatch_queue(self)
|
||||
|
||||
def load_many(self, keys):
|
||||
"""
|
||||
Loads multiple keys, returning a list of values
|
||||
|
||||
>>> a, b = await my_loader.load_many([ 'a', 'b' ])
|
||||
|
||||
This is equivalent to the more verbose:
|
||||
|
||||
>>> a, b = await gather(
|
||||
>>> my_loader.load('a'),
|
||||
>>> my_loader.load('b')
|
||||
>>> )
|
||||
"""
|
||||
if not isinstance(keys, Iterable):
|
||||
raise TypeError(
|
||||
(
|
||||
"The loader.load_many() function must be called with Iterable<key> "
|
||||
"but got: {}."
|
||||
).format(keys)
|
||||
)
|
||||
|
||||
return gather(*[self.load(key) for key in keys])
|
||||
|
||||
def clear(self, key):
|
||||
"""
|
||||
Clears the value at `key` from the cache, if it exists. Returns itself for
|
||||
method chaining.
|
||||
"""
|
||||
cache_key = self.get_cache_key(key)
|
||||
self._cache.pop(cache_key, None)
|
||||
return self
|
||||
|
||||
def clear_all(self):
|
||||
"""
|
||||
Clears the entire cache. To be used when some event results in unknown
|
||||
invalidations across this particular `DataLoader`. Returns itself for
|
||||
method chaining.
|
||||
"""
|
||||
self._cache.clear()
|
||||
return self
|
||||
|
||||
def prime(self, key, value):
|
||||
"""
|
||||
Adds the provied key and value to the cache. If the key already exists, no
|
||||
change is made. Returns itself for method chaining.
|
||||
"""
|
||||
cache_key = self.get_cache_key(key)
|
||||
|
||||
# Only add the key if it does not already exist.
|
||||
if cache_key not in self._cache:
|
||||
# Cache a rejected future if the value is an Error, in order to match
|
||||
# the behavior of load(key).
|
||||
future = self.loop.create_future()
|
||||
if isinstance(value, Exception):
|
||||
future.set_exception(value)
|
||||
else:
|
||||
future.set_result(value)
|
||||
|
||||
self._cache[cache_key] = future
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def enqueue_post_future_job(loop, loader):
|
||||
async def dispatch():
|
||||
dispatch_queue(loader)
|
||||
|
||||
loop.call_soon(ensure_future, dispatch())
|
||||
|
||||
|
||||
def get_chunks(iterable_obj, chunk_size=1):
|
||||
chunk_size = max(1, chunk_size)
|
||||
return (
|
||||
iterable_obj[i : i + chunk_size]
|
||||
for i in range(0, len(iterable_obj), chunk_size)
|
||||
)
|
||||
|
||||
|
||||
def dispatch_queue(loader):
|
||||
"""
|
||||
Given the current state of a Loader instance, perform a batch load
|
||||
from its current queue.
|
||||
"""
|
||||
# Take the current loader queue, replacing it with an empty queue.
|
||||
queue = loader._queue
|
||||
loader._queue = []
|
||||
|
||||
# If a max_batch_size was provided and the queue is longer, then segment the
|
||||
# queue into multiple batches, otherwise treat the queue as a single batch.
|
||||
max_batch_size = loader.max_batch_size
|
||||
|
||||
if max_batch_size and max_batch_size < len(queue):
|
||||
chunks = get_chunks(queue, max_batch_size)
|
||||
for chunk in chunks:
|
||||
ensure_future(dispatch_queue_batch(loader, chunk))
|
||||
else:
|
||||
ensure_future(dispatch_queue_batch(loader, queue))
|
||||
|
||||
|
||||
async def dispatch_queue_batch(loader, queue):
|
||||
# Collect all keys to be loaded in this dispatch
|
||||
keys = [loaded.key for loaded in queue]
|
||||
|
||||
# Call the provided batch_load_fn for this loader with the loader queue's keys.
|
||||
batch_future = loader.batch_load_fn(keys)
|
||||
|
||||
# Assert the expected response from batch_load_fn
|
||||
if not batch_future or not iscoroutine(batch_future):
|
||||
return failed_dispatch(
|
||||
loader,
|
||||
queue,
|
||||
TypeError(
|
||||
(
|
||||
"DataLoader must be constructed with a function which accepts "
|
||||
"Iterable<key> and returns Future<Iterable<value>>, but the function did "
|
||||
"not return a Coroutine: {}."
|
||||
).format(batch_future)
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
values = await batch_future
|
||||
if not isinstance(values, Iterable):
|
||||
raise TypeError(
|
||||
(
|
||||
"DataLoader must be constructed with a function which accepts "
|
||||
"Iterable<key> and returns Future<Iterable<value>>, but the function did "
|
||||
"not return a Future of a Iterable: {}."
|
||||
).format(values)
|
||||
)
|
||||
|
||||
values = list(values)
|
||||
if len(values) != len(keys):
|
||||
raise TypeError(
|
||||
(
|
||||
"DataLoader must be constructed with a function which accepts "
|
||||
"Iterable<key> and returns Future<Iterable<value>>, but the function did "
|
||||
"not return a Future of a Iterable with the same length as the Iterable "
|
||||
"of keys."
|
||||
"\n\nKeys:\n{}"
|
||||
"\n\nValues:\n{}"
|
||||
).format(keys, values)
|
||||
)
|
||||
|
||||
# Step through the values, resolving or rejecting each Future in the
|
||||
# loaded queue.
|
||||
for loaded, value in zip(queue, values):
|
||||
if isinstance(value, Exception):
|
||||
loaded.future.set_exception(value)
|
||||
else:
|
||||
loaded.future.set_result(value)
|
||||
|
||||
except Exception as e:
|
||||
return failed_dispatch(loader, queue, e)
|
||||
|
||||
|
||||
def failed_dispatch(loader, queue, error):
|
||||
"""
|
||||
Do not cache individual loads if the entire batch dispatch fails,
|
||||
but still reject each request so they do not hang.
|
||||
"""
|
||||
for loaded in queue:
|
||||
loader.clear(loaded.key)
|
||||
loaded.future.set_exception(error)
|
1
setup.py
1
setup.py
|
@ -53,7 +53,6 @@ tests_require = [
|
|||
"snapshottest>=0.6,<1",
|
||||
"coveralls>=3.3,<4",
|
||||
"promise>=2.3,<3",
|
||||
"aiodataloader<1",
|
||||
"mock>=4,<5",
|
||||
"pytz==2022.1",
|
||||
"iso8601>=1,<2",
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from collections import namedtuple
|
||||
from unittest.mock import Mock
|
||||
|
||||
from graphene.utils.dataloader import DataLoader
|
||||
from pytest import mark
|
||||
from aiodataloader import DataLoader
|
||||
|
||||
from graphene import ObjectType, String, Schema, Field, List
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user