Switch to settings-based backend list, start ORM backend

This commit is contained in:
Andrew Godwin 2015-06-10 09:40:57 -07:00
parent 5a7af1e3af
commit c9eb683ed8
8 changed files with 127 additions and 22 deletions

View File

@ -1,12 +1,18 @@
from .channel import Channel # Load backends
DEFAULT_CHANNEL_BACKEND = "default"
# Load a backend from .backends import BackendManager
from .backends.memory import InMemoryChannelBackend from django.conf import settings
DEFAULT_CHANNEL_LAYER = "default" channel_backends = BackendManager(
channel_layers = { getattr(settings, "CHANNEL_BACKENDS", {
DEFAULT_CHANNEL_LAYER: InMemoryChannelBackend(), DEFAULT_CHANNEL_BACKEND: {
} "BACKEND": "channels.backends.memory.InMemoryChannelBackend",
}
})
)
# Ensure monkeypatching # Ensure monkeypatching
from .hacks import monkeypatch_django from .hacks import monkeypatch_django
monkeypatch_django() monkeypatch_django()
# Promote channel to top-level (down here to avoid circular import errs)
from .channel import Channel

View File

@ -3,7 +3,7 @@ import functools
from django.core.handlers.base import BaseHandler from django.core.handlers.base import BaseHandler
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from channels import Channel, channel_layers, DEFAULT_CHANNEL_LAYER from channels import Channel, channel_backends, DEFAULT_CHANNEL_BACKEND
class UrlConsumer(object): class UrlConsumer(object):
@ -35,7 +35,7 @@ def view_producer(channel_name):
return producing_view return producing_view
def view_consumer(channel_name, alias=None): def view_consumer(channel_name, alias=DEFAULT_CHANNEL_BACKEND):
""" """
Decorates a normal Django view to be a channel consumer. Decorates a normal Django view to be a channel consumer.
Does not run any middleware Does not run any middleware
@ -47,7 +47,7 @@ def view_consumer(channel_name, alias=None):
response = func(request) response = func(request)
Channel(request.response_channel).send(**response.channel_encode()) Channel(request.response_channel).send(**response.channel_encode())
# Get the channel layer and register # Get the channel layer and register
channel_layer = channel_layers[alias or DEFAULT_CHANNEL_LAYER] channel_layer = channel_backends[DEFAULT_CHANNEL_BACKEND]
channel_layer.registry.add_consumer(consumer, [channel_name]) channel_layer.registry.add_consumer(consumer, [channel_name])
return func return func
return inner return inner

View File

@ -0,0 +1,27 @@
from django.utils.module_loading import import_string
class InvalidChannelBackendError(ValueError):
pass
class BackendManager(object):
"""
Takes a settings dictionary of backends and initialises them.
"""
def __init__(self, backend_configs):
self.backends = {}
for name, config in backend_configs.items():
# Load the backend class
try:
backend_class = import_string(config['BACKEND'])
except KeyError:
raise InvalidChannelBackendError("No BACKEND specified for %s" % name)
except ImportError:
raise InvalidChannelBackendError("Cannot import BACKEND %s specified for %s" % (config['BACKEND'], name))
# Initialise and pass config
self.backends[name] = backend_class(**{k.lower(): v for k, v in config.items() if k != "BACKEND"})
def __getitem__(self, key):
return self.backends[key]

View File

@ -1,6 +1,5 @@
import time import time
import string import json
import random
from collections import deque from collections import deque
from .base import BaseChannelBackend from .base import BaseChannelBackend
@ -13,6 +12,8 @@ class InMemoryChannelBackend(BaseChannelBackend):
""" """
def send(self, channel, message): def send(self, channel, message):
# Try JSON encoding it to make sure it would, but store the native version
json.dumps(message)
# Add to the deque, making it if needs be # Add to the deque, making it if needs be
queues.setdefault(channel, deque()).append(message) queues.setdefault(channel, deque()).append(message)

68
channels/backends/orm.py Normal file
View File

@ -0,0 +1,68 @@
import time
import datetime
from django.apps.registry import Apps
from django.db import models, connections, DEFAULT_DB_ALIAS
from .base import BaseChannelBackend
queues = {}
class ORMChannelBackend(BaseChannelBackend):
"""
ORM-backed channel environment. For development use only; it will span
multiple processes fine, but it's going to be pretty bad at throughput.
"""
def __init__(self, expiry, db_alias=DEFAULT_DB_ALIAS):
super(ORMChannelBackend, self).__init__(expiry)
self.connection = connections[db_alias]
self.model = self.make_model()
self.ensure_schema()
def make_model(self):
"""
Initialises a new model to store messages; not done as part of a
models.py as we don't want to make it for most installs.
"""
class Message(models.Model):
# We assume an autoincrementing PK for message order
channel = models.CharField(max_length=200, db_index=True)
content = models.TextField()
expiry = models.DateTimeField(db_index=True)
class Meta:
apps = Apps()
app_label = "channels"
db_table = "django_channels"
return Message
def ensure_schema(self):
"""
Ensures the table exists and has the correct schema.
"""
# If the table's there, that's fine - we've never changed its schema
# in the codebase.
if self.model._meta.db_table in self.connection.introspection.table_names(self.connection.cursor()):
return
# Make the table
with self.connection.schema_editor() as editor:
editor.create_model(self.model)
def send(self, channel, message):
self.model.objects.create(
channel = channel,
message = json.dumps(message),
expiry = datetime.datetime.utcnow() + datetime.timedelta(seconds=self.expiry)
)
def receive_many(self, channels):
while True:
# Delete all expired messages (add 10 second grace period for clock sync)
self.model.objects.filter(expiry__lt=datetime.datetime.utcnow() - datetime.timedelta(seconds=10)).delete()
# Get a message from one of our channels
message = self.model.objects.filter(channel__in=channels).order_by("id").first()
if message:
return message.channel, json.loads(message.content)
# If all empty, sleep for a little bit
time.sleep(0.2)

View File

@ -3,6 +3,8 @@ import string
from django.utils import six from django.utils import six
from channels import channel_backends, DEFAULT_CHANNEL_BACKEND
class Channel(object): class Channel(object):
""" """
@ -16,13 +18,12 @@ class Channel(object):
"default" one by default. "default" one by default.
""" """
def __init__(self, name, alias=None): def __init__(self, name, alias=DEFAULT_CHANNEL_BACKEND):
""" """
Create an instance for the channel named "name" Create an instance for the channel named "name"
""" """
from channels import channel_layers, DEFAULT_CHANNEL_LAYER
self.name = name self.name = name
self.channel_layer = channel_layers[alias or DEFAULT_CHANNEL_LAYER] self.channel_layer = channel_backends[alias]
def send(self, **kwargs): def send(self, **kwargs):
""" """
@ -50,16 +51,15 @@ class Channel(object):
return view_producer(self.name) return view_producer(self.name)
@classmethod @classmethod
def consumer(self, channels, alias=None): def consumer(self, channels, alias=DEFAULT_CHANNEL_BACKEND):
""" """
Decorator that registers a function as a consumer. Decorator that registers a function as a consumer.
""" """
from channels import channel_layers, DEFAULT_CHANNEL_LAYER
# Upconvert if you just pass in a string # Upconvert if you just pass in a string
if isinstance(channels, six.string_types): if isinstance(channels, six.string_types):
channels = [channels] channels = [channels]
# Get the channel # Get the channel
channel_layer = channel_layers[alias or DEFAULT_CHANNEL_LAYER] channel_layer = channel_backends[alias]
# Return a function that'll register whatever it wraps # Return a function that'll register whatever it wraps
def inner(func): def inner(func):
channel_layer.registry.add_consumer(func, channels) channel_layer.registry.add_consumer(func, channels)

View File

@ -8,6 +8,9 @@ Note: All consumers also receive the channel name as the keyword argument
"channel", so there is no need for separate type information to let "channel", so there is no need for separate type information to let
multi-channel consumers distinguish. multi-channel consumers distinguish.
The length limit on channel names will be 200 characters.
HTTP Request HTTP Request
------------ ------------

View File

@ -3,7 +3,7 @@ import threading
from django.core.management.commands.runserver import Command as RunserverCommand from django.core.management.commands.runserver import Command as RunserverCommand
from django.core.handlers.wsgi import WSGIHandler from django.core.handlers.wsgi import WSGIHandler
from django.http import HttpResponse from django.http import HttpResponse
from channels import Channel, channel_layers, DEFAULT_CHANNEL_LAYER from channels import Channel, channel_backends, DEFAULT_CHANNEL_BACKEND
from channels.worker import Worker from channels.worker import Worker
from channels.utils import auto_import_consumers from channels.utils import auto_import_consumers
from channels.adapters import UrlConsumer from channels.adapters import UrlConsumer
@ -22,7 +22,7 @@ class Command(RunserverCommand):
# Force disable reloader for now # Force disable reloader for now
options['use_reloader'] = False options['use_reloader'] = False
# Check a handler is registered for http reqs # Check a handler is registered for http reqs
channel_layer = channel_layers[DEFAULT_CHANNEL_LAYER] channel_layer = channel_backends[DEFAULT_CHANNEL_BACKEND]
auto_import_consumers() auto_import_consumers()
if not channel_layer.registry.consumer_for_channel("django.wsgi.request"): if not channel_layer.registry.consumer_for_channel("django.wsgi.request"):
# Register the default one # Register the default one
@ -43,7 +43,7 @@ class WSGIInterfaceHandler(WSGIHandler):
def get_response(self, request): def get_response(self, request):
request.response_channel = Channel.new_name("django.wsgi.response") request.response_channel = Channel.new_name("django.wsgi.response")
Channel("django.wsgi.request").send(**request.channel_encode()) Channel("django.wsgi.request").send(**request.channel_encode())
channel, message = channel_layers[DEFAULT_CHANNEL_LAYER].receive_many([request.response_channel]) channel, message = channel_backends[DEFAULT_CHANNEL_BACKEND].receive_many([request.response_channel])
return HttpResponse.channel_decode(message) return HttpResponse.channel_decode(message)