Fixed #6: Linearise decorator and better user session stuff

This commit is contained in:
Andrew Godwin 2015-09-10 11:52:49 -05:00
parent eed6e5e607
commit 638bf260f8
9 changed files with 243 additions and 73 deletions

66
channels/auth.py Normal file
View File

@ -0,0 +1,66 @@
import functools
from django.contrib import auth
from .decorators import channel_session, http_session
def transfer_user(from_session, to_session):
"""
Transfers user from HTTP session to channel session.
"""
to_session[auth.BACKEND_SESSION_KEY] = from_session[auth.BACKEND_SESSION_KEY]
to_session[auth.SESSION_KEY] = from_session[auth.SESSION_KEY]
to_session[auth.HASH_SESSION_KEY] = from_session[auth.HASH_SESSION_KEY]
def channel_session_user(func):
"""
Presents a message.user attribute obtained from a user ID in the channel
session, rather than in the http_session. Turns on channel session implicitly.
"""
@channel_session
@functools.wraps(func)
def inner(message, *args, **kwargs):
# If we didn't get a session, then we don't get a user
if not hasattr(message, "channel_session"):
raise ValueError("Did not see a channel session to get auth from")
if message.channel_session is None:
message.user = None
# Otherwise, be a bit naughty and make a fake Request with just
# a "session" attribute (later on, perhaps refactor contrib.auth to
# pass around session rather than request)
else:
fake_request = type("FakeRequest", (object, ), {"session": message.channel_session})
message.user = auth.get_user(fake_request)
# Run the consumer
return func(message, *args, **kwargs)
return inner
def http_session_user(func):
"""
Wraps a HTTP or WebSocket consumer (or any consumer of messages
that provides a "COOKIES" attribute) to provide both a "session"
attribute and a "user" attibute, like AuthMiddleware does.
This runs http_session() to get a session to hook auth off of.
If the user does not have a session cookie set, both "session"
and "user" will be None.
"""
@http_session
@functools.wraps(func)
def inner(message, *args, **kwargs):
# If we didn't get a session, then we don't get a user
if not hasattr(message, "http_session"):
raise ValueError("Did not see a http session to get auth from")
if message.http_session is None:
message.user = None
# Otherwise, be a bit naughty and make a fake Request with just
# a "session" attribute (later on, perhaps refactor contrib.auth to
# pass around session rather than request)
else:
fake_request = type("FakeRequest", (object, ), {"session": message.http_session})
message.user = auth.get_user(fake_request)
# Run the consumer
return func(message, *args, **kwargs)
return inner

View File

@ -93,3 +93,16 @@ class BaseChannelBackend(object):
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
def lock_channel(self, channel):
"""
Attempts to get a lock on the named channel. Returns True if lock
obtained, False if lock not obtained.
"""
raise NotImplementedError()
def unlock_channel(self, channel):
"""
Unlocks the named channel. Always succeeds.
"""
raise NotImplementedError()

View File

@ -3,7 +3,7 @@ import json
import datetime import datetime
from django.apps.registry import Apps from django.apps.registry import Apps
from django.db import models, connections, DEFAULT_DB_ALIAS from django.db import models, connections, DEFAULT_DB_ALIAS, IntegrityError
from django.utils.functional import cached_property from django.utils.functional import cached_property
from django.utils.timezone import now from django.utils.timezone import now
@ -71,6 +71,26 @@ class DatabaseChannelBackend(BaseChannelBackend):
editor.create_model(Group) editor.create_model(Group)
return Group return Group
@cached_property
def lock_model(self):
"""
Initialises a new model to store groups; not done as part of a
models.py as we don't want to make it for most installs.
"""
# Make the model class
class Lock(models.Model):
channel = models.CharField(max_length=200, unique=True)
expiry = models.DateTimeField(db_index=True)
class Meta:
apps = Apps()
app_label = "channels"
db_table = "django_channel_locks"
# Ensure its table exists
if Lock._meta.db_table not in self.connection.introspection.table_names(self.connection.cursor()):
with self.connection.schema_editor() as editor:
editor.create_model(Lock)
return Lock
def send(self, channel, message): def send(self, channel, message):
self.channel_model.objects.create( self.channel_model.objects.create(
channel = channel, channel = channel,
@ -97,6 +117,7 @@ class DatabaseChannelBackend(BaseChannelBackend):
# Include a 10-second grace period because that solves some clock sync # Include a 10-second grace period because that solves some clock sync
self.channel_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete() self.channel_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete()
self.group_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete() self.group_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete()
self.lock_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete()
def group_add(self, group, channel, expiry=None): def group_add(self, group, channel, expiry=None):
""" """
@ -123,5 +144,27 @@ class DatabaseChannelBackend(BaseChannelBackend):
self._clean_expired() self._clean_expired()
return list(self.group_model.objects.filter(group=group).values_list("channel", flat=True)) return list(self.group_model.objects.filter(group=group).values_list("channel", flat=True))
def lock_channel(self, channel, expiry=None):
"""
Attempts to get a lock on the named channel. Returns True if lock
obtained, False if lock not obtained.
"""
# We rely on the UNIQUE constraint for only-one-thread-wins on locks
try:
self.lock_model.objects.create(
channel = channel,
expiry = now() + datetime.timedelta(seconds=expiry or self.expiry),
)
except IntegrityError:
return False
else:
return True
def unlock_channel(self, channel):
"""
Unlocks the named channel. Always succeeds.
"""
self.lock_model.objects.filter(channel=channel).delete()
def __str__(self): def __str__(self):
return "%s(alias=%s)" % (self.__class__.__name__, self.connection.alias) return "%s(alias=%s)" % (self.__class__.__name__, self.connection.alias)

View File

@ -5,6 +5,7 @@ from .base import BaseChannelBackend
queues = {} queues = {}
groups = {} groups = {}
locks = set()
class InMemoryChannelBackend(BaseChannelBackend): class InMemoryChannelBackend(BaseChannelBackend):
""" """
@ -72,3 +73,22 @@ class InMemoryChannelBackend(BaseChannelBackend):
""" """
self._clean_expired() self._clean_expired()
return groups.get(group, {}).keys() return groups.get(group, {}).keys()
def lock_channel(self, channel):
"""
Attempts to get a lock on the named channel. Returns True if lock
obtained, False if lock not obtained.
"""
# Probably not perfect for race conditions, but close enough considering
# it shouldn't be used.
if channel not in locks:
locks.add(channel)
return True
else:
return False
def unlock_channel(self, channel):
"""
Unlocks the named channel. Always succeeds.
"""
locks.discard(channel)

View File

@ -101,5 +101,20 @@ class RedisChannelBackend(BaseChannelBackend):
# TODO: send_group efficient implementation using Lua # TODO: send_group efficient implementation using Lua
def lock_channel(self, channel, expiry=None):
"""
Attempts to get a lock on the named channel. Returns True if lock
obtained, False if lock not obtained.
"""
key = "%s:lock:%s" % (self.prefix, channel)
return bool(self.connection.setnx(key, "1"))
def unlock_channel(self, channel):
"""
Unlocks the named channel. Always succeeds.
"""
key = "%s:lock:%s" % (self.prefix, channel)
self.connection.delete(key)
def __str__(self): def __str__(self):
return "%s(host=%s, port=%s)" % (self.__class__.__name__, self.host, self.port) return "%s(host=%s, port=%s)" % (self.__class__.__name__, self.host, self.port)

View File

@ -3,16 +3,77 @@ import hashlib
from importlib import import_module from importlib import import_module
from django.conf import settings from django.conf import settings
from django.utils import six
from django.contrib import auth
from channels import channel_backends, DEFAULT_CHANNEL_BACKEND
def linearize(func):
"""
Makes sure the contained consumer does not run at the same time other
consumers are running on messages with the same reply_channel.
Required if you don't want weird things like a second consumer starting
up before the first has exited and saved its session. Doesn't guarantee
ordering, just linearity.
"""
@functools.wraps(func)
def inner(message, *args, **kwargs):
# Make sure there's a reply channel
if not message.reply_channel:
raise ValueError("No reply_channel sent to consumer; @no_overlap can only be used on messages containing it.")
# Get the lock, or re-queue
locked = message.channel_backend.lock_channel(message.reply_channel)
if not locked:
raise message.Requeue()
# OK, keep going
try:
return func(message, *args, **kwargs)
finally:
message.channel_backend.unlock_channel(message.reply_channel)
return inner
def channel_session(func):
"""
Provides a session-like object called "channel_session" to consumers
as a message attribute that will auto-persist across consumers with
the same incoming "reply_channel" value.
Use this to persist data across the lifetime of a connection.
"""
@functools.wraps(func)
def inner(message, *args, **kwargs):
# Make sure there's a reply_channel
if not message.reply_channel:
raise ValueError("No reply_channel sent to consumer; @channel_session can only be used on messages containing it.")
# Make sure there's NOT a channel_session already
if hasattr(message, "channel_session"):
raise ValueError("channel_session decorator wrapped inside another channel_session decorator")
# Turn the reply_channel into a valid session key length thing.
# We take the last 24 bytes verbatim, as these are the random section,
# and then hash the remaining ones onto the start, and add a prefix
reply_name = message.reply_channel.name
session_key = "skt" + hashlib.md5(reply_name[:-24]).hexdigest()[:8] + reply_name[-24:]
# Make a session storage
session_engine = import_module(settings.SESSION_ENGINE)
session = session_engine.SessionStore(session_key=session_key)
# If the session does not already exist, save to force our
# session key to be valid.
if not session.exists(session.session_key):
session.save(must_create=True)
message.channel_session = session
# Run the consumer
try:
return func(message, *args, **kwargs)
finally:
# Persist session if needed
if session.modified:
session.save()
return inner
def http_session(func): def http_session(func):
""" """
Wraps a HTTP or WebSocket consumer (or any consumer of messages Wraps a HTTP or WebSocket connect consumer (or any consumer of messages
that provides a "COOKIES" or "GET" attribute) to provide a "session" that provides a "cooikies" or "get" attribute) to provide a "http_session"
attribute that behaves like request.session; that is, it's hung off of attribute that behaves like request.session; that is, it's hung off of
a per-user session key that is saved in a cookie or passed as the a per-user session key that is saved in a cookie or passed as the
"session_key" GET parameter. "session_key" GET parameter.
@ -21,13 +82,16 @@ def http_session(func):
don't have one - that's what SessionMiddleware is for, this is a simpler don't have one - that's what SessionMiddleware is for, this is a simpler
read-only version for more low-level code. read-only version for more low-level code.
If a user does not have a session we can inflate, the "session" attribute will If a message does not have a session we can inflate, the "session" attribute
be None, rather than an empty session you can write to. will be None, rather than an empty session you can write to.
""" """
@functools.wraps(func) @functools.wraps(func)
def inner(message, *args, **kwargs): def inner(message, *args, **kwargs):
if "cookies" not in message.content and "get" not in message.content: if "cookies" not in message.content and "get" not in message.content:
raise ValueError("No cookies or get sent to consumer; this decorator can only be used on messages containing at least one.") raise ValueError("No cookies or get sent to consumer - cannot initialise http_session")
# Make sure there's NOT a http_session already
if hasattr(message, "http_session"):
raise ValueError("http_session decorator wrapped inside another http_session decorator")
# Make sure there's a session key # Make sure there's a session key
session_key = None session_key = None
if "get" in message.content: if "get" in message.content:
@ -43,7 +107,7 @@ def http_session(func):
session = session_engine.SessionStore(session_key=session_key) session = session_engine.SessionStore(session_key=session_key)
else: else:
session = None session = None
message.session = session message.http_session = session
# Run the consumer # Run the consumer
result = func(message, *args, **kwargs) result = func(message, *args, **kwargs)
# Persist session if needed (won't be saved if error happens) # Persist session if needed (won't be saved if error happens)
@ -51,65 +115,3 @@ def http_session(func):
session.save() session.save()
return result return result
return inner return inner
def http_django_auth(func):
"""
Wraps a HTTP or WebSocket consumer (or any consumer of messages
that provides a "COOKIES" attribute) to provide both a "session"
attribute and a "user" attibute, like AuthMiddleware does.
This runs http_session() to get a session to hook auth off of.
If the user does not have a session cookie set, both "session"
and "user" will be None.
"""
@http_session
@functools.wraps(func)
def inner(message, *args, **kwargs):
# If we didn't get a session, then we don't get a user
if not hasattr(message, "session"):
raise ValueError("Did not see a session to get auth from")
if message.session is None:
message.user = None
# Otherwise, be a bit naughty and make a fake Request with just
# a "session" attribute (later on, perhaps refactor contrib.auth to
# pass around session rather than request)
else:
fake_request = type("FakeRequest", (object, ), {"session": message.session})
message.user = auth.get_user(fake_request)
# Run the consumer
return func(message, *args, **kwargs)
return inner
def channel_session(func):
"""
Provides a session-like object called "channel_session" to consumers
as a message attribute that will auto-persist across consumers with
the same incoming "reply_channel" value.
"""
@functools.wraps(func)
def inner(message, *args, **kwargs):
# Make sure there's a reply_channel in kwargs
if not message.reply_channel:
raise ValueError("No reply_channel sent to consumer; this decorator can only be used on messages containing it.")
# Turn the reply_channel into a valid session key length thing.
# We take the last 24 bytes verbatim, as these are the random section,
# and then hash the remaining ones onto the start, and add a prefix
# TODO: See if there's a better way of doing this
reply_name = message.reply_channel.name
session_key = "skt" + hashlib.md5(reply_name[:-24]).hexdigest()[:8] + reply_name[-24:]
# Make a session storage
session_engine = import_module(settings.SESSION_ENGINE)
session = session_engine.SessionStore(session_key=session_key)
# If the session does not already exist, save to force our session key to be valid
if not session.exists(session.session_key):
session.save()
message.channel_session = session
# Run the consumer
result = func(message, *args, **kwargs)
# Persist session if needed (won't be saved if error happens)
if session.modified:
session.save()
return result
return inner

View File

@ -18,7 +18,7 @@ class InterfaceProtocol(WebSocketServerProtocol):
self.channel_backend = channel_backends[DEFAULT_CHANNEL_BACKEND] self.channel_backend = channel_backends[DEFAULT_CHANNEL_BACKEND]
self.request_info = { self.request_info = {
"path": request.path, "path": request.path,
"GET": request.params, "get": request.params,
} }
def onOpen(self): def onOpen(self):

View File

@ -10,6 +10,13 @@ class Message(object):
to use to reply to this message's end user, if that makes sense. to use to reply to this message's end user, if that makes sense.
""" """
class Requeue(Exception):
"""
Raise this while processing a message to requeue it back onto the
channel. Useful if you're manually ensuring partial ordering, etc.
"""
pass
def __init__(self, content, channel, channel_backend, reply_channel=None): def __init__(self, content, channel, channel_backend, reply_channel=None):
self.content = content self.content = content
self.channel = channel self.channel = channel

View File

@ -1,5 +1,6 @@
import traceback import traceback
from .message import Message from .message import Message
from .utils import name_that_thing
class Worker(object): class Worker(object):
@ -31,5 +32,8 @@ class Worker(object):
self.callback(channel, message) self.callback(channel, message)
try: try:
consumer(message) consumer(message)
except Message.Requeue:
self.channel_backend.send(channel, content)
except: except:
print "Error processing message with consumer %s:" % name_that_thing(consumer)
traceback.print_exc() traceback.print_exc()