daphne/channels/backends/redis_py.py
2015-11-09 13:01:02 +01:00

192 lines
6.7 KiB
Python

import binascii
import json
import math
import random
import time
import uuid
import redis
from django.utils import six
from .base import BaseChannelBackend
class RedisChannelBackend(BaseChannelBackend):
"""
ORM-backed channel environment. For development use only; it will span
multiple processes fine, but it's going to be pretty bad at throughput.
"""
def __init__(self, routing, expiry=60, hosts=None, prefix="django-channels:"):
super(RedisChannelBackend, self).__init__(routing=routing, expiry=expiry)
# Make sure they provided some hosts, or provide a default
if not hosts:
hosts = [("localhost", 6379)]
for host, port in hosts:
assert isinstance(host, six.string_types)
assert int(port)
self.hosts = hosts
self.prefix = prefix
# Precalculate some values for ring selection
self.ring_size = len(self.hosts)
self.ring_divisor = int(math.ceil(4096 / float(self.ring_size)))
def consistent_hash(self, value):
"""
Maps the value to a node value between 0 and 4095
using MD5, then down to one of the ring nodes.
"""
if isinstance(value, six.text_type):
value = value.encode("utf8")
bigval = binascii.crc32(value) & 0xffffffff
return (bigval // 0x100000) // self.ring_divisor
def random_index(self):
return random.randint(0, len(self.hosts) - 1)
def connection(self, index):
"""
Returns the correct connection for the current thread.
Pass key to use a server based on consistent hashing of the key value;
pass None to use a random server instead.
"""
# If index is explicitly None, pick a random server
if index is None:
index = self.random_index()
# Catch bad indexes
if not (0 <= index < self.ring_size):
raise ValueError("There are only %s hosts - you asked for %s!" % (self.ring_size, index))
host, port = self.hosts[index]
return redis.Redis(host=host, port=port)
def send(self, channel, message):
# if channel is no str (=> bytes) convert it
if not isinstance(channel, str):
channel = channel.decode("utf-8")
# Write out message into expiring key (avoids big items in list)
# TODO: Use extended set, drop support for older redis?
key = self.prefix + uuid.uuid4().hex
# Pick a connection to the right server - consistent for response
# channels, random for normal channels
if channel.startswith("!"):
index = self.consistent_hash(key)
connection = self.connection(index)
else:
connection = self.connection(None)
connection.set(
key,
json.dumps(message),
)
connection.expire(
key,
self.expiry + 10,
)
# Add key to list
connection.rpush(
self.prefix + channel,
key,
)
# Set list to expire when message does (any later messages will bump this)
connection.expire(
self.prefix + channel,
self.expiry + 10,
)
# TODO: Prune expired messages from same list (in case nobody consumes)
def receive_many(self, channels):
if not channels:
raise ValueError("Cannot receive on empty channel list!")
# Work out what servers to listen on for the given channels
indexes = {}
random_index = self.random_index()
for channel in channels:
if channel.startswith("!"):
indexes.setdefault(self.consistent_hash(channel), []).append(channel)
else:
indexes.setdefault(random_index, []).append(channel)
# Get a message from one of our channels
while True:
# Select a random connection to use
# TODO: Would we be better trying to do this truly async?
index = random.choice(list(indexes.keys()))
connection = self.connection(index)
channels = indexes[index]
# Shuffle channels to avoid the first ones starving others of workers
random.shuffle(channels)
# Pop off any waiting message
result = connection.blpop([self.prefix + channel for channel in channels], timeout=1)
if result:
content = connection.get(result[1])
if content is None:
continue
return result[0][len(self.prefix):].decode("utf-8"), json.loads(content.decode("utf-8"))
else:
return None, None
def group_add(self, group, channel, expiry=None):
"""
Adds the channel to the named group for at least 'expiry'
seconds (expiry defaults to message expiry if not provided).
"""
key = "%s:group:%s" % (self.prefix, group)
key = key.encode("utf8")
self.connection(self.consistent_hash(group)).zadd(
key,
**{channel: time.time() + (expiry or self.expiry)}
)
def group_discard(self, group, channel):
"""
Removes the channel from the named group if it is in the group;
does nothing otherwise (does not error)
"""
key = "%s:group:%s" % (self.prefix, group)
key = key.encode("utf8")
self.connection(self.consistent_hash(group)).zrem(
key,
channel,
)
def group_channels(self, group):
"""
Returns an iterable of all channels in the group.
"""
key = "%s:group:%s" % (self.prefix, group)
key = key.encode("utf8")
connection = self.connection(self.consistent_hash(group))
# Discard old channels
connection.zremrangebyscore(key, 0, int(time.time()) - 10)
# Return current lot
return [x.decode("utf8") for x in connection.zrange(
key,
0,
-1,
)]
# TODO: send_group efficient implementation using Lua
def lock_channel(self, channel, expiry=None):
"""
Attempts to get a lock on the named channel. Returns True if lock
obtained, False if lock not obtained.
"""
key = "%s:lock:%s" % (self.prefix, channel)
return bool(self.connection(self.consistent_hash(channel)).setnx(key, "1"))
def unlock_channel(self, channel):
"""
Unlocks the named channel. Always succeeds.
"""
key = "%s:lock:%s" % (self.prefix, channel)
self.connection(self.consistent_hash(channel)).delete(key)
def __str__(self):
return "%s(hosts=%s)" % (self.__class__.__name__, self.hosts)
def flush(self):
for i in range(self.ring_size):
self.connection(i).flushdb()