mirror of
https://github.com/django/daphne.git
synced 2025-06-07 14:43:14 +03:00
191 lines
6.7 KiB
Python
191 lines
6.7 KiB
Python
import time
|
|
import json
|
|
import datetime
|
|
import math
|
|
import redis
|
|
import random
|
|
import binascii
|
|
import uuid
|
|
|
|
from django.utils import six
|
|
|
|
from .base import BaseChannelBackend
|
|
|
|
|
|
class RedisChannelBackend(BaseChannelBackend):
|
|
"""
|
|
ORM-backed channel environment. For development use only; it will span
|
|
multiple processes fine, but it's going to be pretty bad at throughput.
|
|
"""
|
|
|
|
def __init__(self, routing, expiry=60, hosts=None, prefix="django-channels:"):
|
|
super(RedisChannelBackend, self).__init__(routing=routing, expiry=expiry)
|
|
# Make sure they provided some hosts, or provide a default
|
|
if not hosts:
|
|
hosts = [("localhost", 6379)]
|
|
for host, port in hosts:
|
|
assert isinstance(host, six.string_types)
|
|
assert int(port)
|
|
self.hosts = hosts
|
|
self.prefix = prefix
|
|
# Precalculate some values for ring selection
|
|
self.ring_size = len(self.hosts)
|
|
self.ring_divisor = int(math.ceil(4096 / float(self.ring_size)))
|
|
|
|
def consistent_hash(self, value):
|
|
"""
|
|
Maps the value to a node value between 0 and 4095
|
|
using MD5, then down to one of the ring nodes.
|
|
"""
|
|
if isinstance(value, six.text_type):
|
|
value = value.encode("utf8")
|
|
bigval = binascii.crc32(value) & 0xffffffff
|
|
return (bigval // 0x100000) // self.ring_divisor
|
|
|
|
def random_index(self):
|
|
return random.randint(0, len(self.hosts) - 1)
|
|
|
|
def connection(self, index):
|
|
"""
|
|
Returns the correct connection for the current thread.
|
|
|
|
Pass key to use a server based on consistent hashing of the key value;
|
|
pass None to use a random server instead.
|
|
"""
|
|
# If index is explicitly None, pick a random server
|
|
if index is None:
|
|
index = self.random_index()
|
|
# Catch bad indexes
|
|
if not (0 <= index < self.ring_size):
|
|
raise ValueError("There are only %s hosts - you asked for %s!" % (self.ring_size, index))
|
|
host, port = self.hosts[index]
|
|
return redis.Redis(host=host, port=port)
|
|
|
|
def send(self, channel, message):
|
|
# if channel is no str (=> bytes) convert it
|
|
if not isinstance(channel, str):
|
|
channel = channel.decode('utf-8')
|
|
# Pick a connection to the right server - consistent for response
|
|
# channels, random for normal channels
|
|
if channel.startswith("!"):
|
|
index = self.consistent_hash(key)
|
|
connection = self.connection(index)
|
|
else:
|
|
connection = self.connection(None)
|
|
# Write out message into expiring key (avoids big items in list)
|
|
# TODO: Use extended set, drop support for older redis?
|
|
key = self.prefix + uuid.uuid4().hex
|
|
connection.set(
|
|
key,
|
|
json.dumps(message),
|
|
)
|
|
connection.expire(
|
|
key,
|
|
self.expiry + 10,
|
|
)
|
|
# Add key to list
|
|
connection.rpush(
|
|
self.prefix + channel,
|
|
key,
|
|
)
|
|
# Set list to expire when message does (any later messages will bump this)
|
|
connection.expire(
|
|
self.prefix + channel,
|
|
self.expiry + 10,
|
|
)
|
|
# TODO: Prune expired messages from same list (in case nobody consumes)
|
|
|
|
def receive_many(self, channels):
|
|
if not channels:
|
|
raise ValueError("Cannot receive on empty channel list!")
|
|
# Work out what servers to listen on for the given channels
|
|
indexes = {}
|
|
random_index = self.random_index()
|
|
for channel in channels:
|
|
if channel.startswith("!"):
|
|
indexes.setdefault(self.consistent_hash(channel), []).append(channel)
|
|
else:
|
|
indexes.setdefault(random_index, []).append(channel)
|
|
# Get a message from one of our channels
|
|
while True:
|
|
# Select a random connection to use
|
|
# TODO: Would we be better trying to do this truly async?
|
|
index = random.choice(list(indexes.keys()))
|
|
connection = self.connection(index)
|
|
channels = indexes[index]
|
|
# Shuffle channels to avoid the first ones starving others of workers
|
|
random.shuffle(channels)
|
|
# Pop off any waiting message
|
|
result = connection.blpop([self.prefix + channel for channel in channels], timeout=1)
|
|
if result:
|
|
content = connection.get(result[1])
|
|
if content is None:
|
|
continue
|
|
return result[0][len(self.prefix):].decode("utf-8"), json.loads(content.decode("utf-8"))
|
|
else:
|
|
return None, None
|
|
|
|
def group_add(self, group, channel, expiry=None):
|
|
"""
|
|
Adds the channel to the named group for at least 'expiry'
|
|
seconds (expiry defaults to message expiry if not provided).
|
|
"""
|
|
key = "%s:group:%s" % (self.prefix, group)
|
|
key = key.encode("utf8")
|
|
self.connection(self.consistent_hash(group)).zadd(
|
|
key,
|
|
**{channel: time.time() + (expiry or self.expiry)}
|
|
)
|
|
|
|
def group_discard(self, group, channel):
|
|
"""
|
|
Removes the channel from the named group if it is in the group;
|
|
does nothing otherwise (does not error)
|
|
"""
|
|
key = "%s:group:%s" % (self.prefix, group)
|
|
key = key.encode("utf8")
|
|
self.connection(self.consistent_hash(group)).zrem(
|
|
key,
|
|
channel,
|
|
)
|
|
|
|
def group_channels(self, group):
|
|
"""
|
|
Returns an iterable of all channels in the group.
|
|
"""
|
|
key = "%s:group:%s" % (self.prefix, group)
|
|
key = key.encode("utf8")
|
|
connection = self.connection(self.consistent_hash(group))
|
|
# Discard old channels
|
|
connection.zremrangebyscore(key, 0, int(time.time()) - 10)
|
|
# Return current lot
|
|
return [x.decode("utf8") for x in connection.zrange(
|
|
key,
|
|
0,
|
|
-1,
|
|
)]
|
|
|
|
# TODO: send_group efficient implementation using Lua
|
|
|
|
def lock_channel(self, channel, expiry=None):
|
|
"""
|
|
Attempts to get a lock on the named channel. Returns True if lock
|
|
obtained, False if lock not obtained.
|
|
"""
|
|
key = "%s:lock:%s" % (self.prefix, channel)
|
|
return bool(self.connection(self.consistent_hash(channel)).setnx(key, "1"))
|
|
|
|
def unlock_channel(self, channel):
|
|
"""
|
|
Unlocks the named channel. Always succeeds.
|
|
"""
|
|
key = "%s:lock:%s" % (self.prefix, channel)
|
|
self.connection(self.consistent_hash(channel)).delete(key)
|
|
|
|
def __str__(self):
|
|
return "%s(hosts=%s)" % (self.__class__.__name__, self.hosts)
|
|
|
|
def flush(self):
|
|
for i in range(self.ring_size):
|
|
self.connection(i).flushdb()
|