diff --git a/channels/auth.py b/channels/auth.py new file mode 100644 index 0000000..261010c --- /dev/null +++ b/channels/auth.py @@ -0,0 +1,66 @@ +import functools + +from django.contrib import auth +from .decorators import channel_session, http_session + + +def transfer_user(from_session, to_session): + """ + Transfers user from HTTP session to channel session. + """ + to_session[auth.BACKEND_SESSION_KEY] = from_session[auth.BACKEND_SESSION_KEY] + to_session[auth.SESSION_KEY] = from_session[auth.SESSION_KEY] + to_session[auth.HASH_SESSION_KEY] = from_session[auth.HASH_SESSION_KEY] + + +def channel_session_user(func): + """ + Presents a message.user attribute obtained from a user ID in the channel + session, rather than in the http_session. Turns on channel session implicitly. + """ + @channel_session + @functools.wraps(func) + def inner(message, *args, **kwargs): + # If we didn't get a session, then we don't get a user + if not hasattr(message, "channel_session"): + raise ValueError("Did not see a channel session to get auth from") + if message.channel_session is None: + message.user = None + # Otherwise, be a bit naughty and make a fake Request with just + # a "session" attribute (later on, perhaps refactor contrib.auth to + # pass around session rather than request) + else: + fake_request = type("FakeRequest", (object, ), {"session": message.channel_session}) + message.user = auth.get_user(fake_request) + # Run the consumer + return func(message, *args, **kwargs) + return inner + + +def http_session_user(func): + """ + Wraps a HTTP or WebSocket consumer (or any consumer of messages + that provides a "COOKIES" attribute) to provide both a "session" + attribute and a "user" attibute, like AuthMiddleware does. + + This runs http_session() to get a session to hook auth off of. + If the user does not have a session cookie set, both "session" + and "user" will be None. + """ + @http_session + @functools.wraps(func) + def inner(message, *args, **kwargs): + # If we didn't get a session, then we don't get a user + if not hasattr(message, "http_session"): + raise ValueError("Did not see a http session to get auth from") + if message.http_session is None: + message.user = None + # Otherwise, be a bit naughty and make a fake Request with just + # a "session" attribute (later on, perhaps refactor contrib.auth to + # pass around session rather than request) + else: + fake_request = type("FakeRequest", (object, ), {"session": message.http_session}) + message.user = auth.get_user(fake_request) + # Run the consumer + return func(message, *args, **kwargs) + return inner diff --git a/channels/backends/base.py b/channels/backends/base.py index 84e874e..9baaf6c 100644 --- a/channels/backends/base.py +++ b/channels/backends/base.py @@ -93,3 +93,16 @@ class BaseChannelBackend(object): def __str__(self): return self.__class__.__name__ + + def lock_channel(self, channel): + """ + Attempts to get a lock on the named channel. Returns True if lock + obtained, False if lock not obtained. + """ + raise NotImplementedError() + + def unlock_channel(self, channel): + """ + Unlocks the named channel. Always succeeds. + """ + raise NotImplementedError() diff --git a/channels/backends/database.py b/channels/backends/database.py index 504f3ad..fdc866a 100644 --- a/channels/backends/database.py +++ b/channels/backends/database.py @@ -3,7 +3,7 @@ import json import datetime from django.apps.registry import Apps -from django.db import models, connections, DEFAULT_DB_ALIAS +from django.db import models, connections, DEFAULT_DB_ALIAS, IntegrityError from django.utils.functional import cached_property from django.utils.timezone import now @@ -71,6 +71,26 @@ class DatabaseChannelBackend(BaseChannelBackend): editor.create_model(Group) return Group + @cached_property + def lock_model(self): + """ + Initialises a new model to store groups; not done as part of a + models.py as we don't want to make it for most installs. + """ + # Make the model class + class Lock(models.Model): + channel = models.CharField(max_length=200, unique=True) + expiry = models.DateTimeField(db_index=True) + class Meta: + apps = Apps() + app_label = "channels" + db_table = "django_channel_locks" + # Ensure its table exists + if Lock._meta.db_table not in self.connection.introspection.table_names(self.connection.cursor()): + with self.connection.schema_editor() as editor: + editor.create_model(Lock) + return Lock + def send(self, channel, message): self.channel_model.objects.create( channel = channel, @@ -97,6 +117,7 @@ class DatabaseChannelBackend(BaseChannelBackend): # Include a 10-second grace period because that solves some clock sync self.channel_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete() self.group_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete() + self.lock_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete() def group_add(self, group, channel, expiry=None): """ @@ -123,5 +144,27 @@ class DatabaseChannelBackend(BaseChannelBackend): self._clean_expired() return list(self.group_model.objects.filter(group=group).values_list("channel", flat=True)) + def lock_channel(self, channel, expiry=None): + """ + Attempts to get a lock on the named channel. Returns True if lock + obtained, False if lock not obtained. + """ + # We rely on the UNIQUE constraint for only-one-thread-wins on locks + try: + self.lock_model.objects.create( + channel = channel, + expiry = now() + datetime.timedelta(seconds=expiry or self.expiry), + ) + except IntegrityError: + return False + else: + return True + + def unlock_channel(self, channel): + """ + Unlocks the named channel. Always succeeds. + """ + self.lock_model.objects.filter(channel=channel).delete() + def __str__(self): return "%s(alias=%s)" % (self.__class__.__name__, self.connection.alias) diff --git a/channels/backends/memory.py b/channels/backends/memory.py index c398653..1d78363 100644 --- a/channels/backends/memory.py +++ b/channels/backends/memory.py @@ -5,6 +5,7 @@ from .base import BaseChannelBackend queues = {} groups = {} +locks = set() class InMemoryChannelBackend(BaseChannelBackend): """ @@ -72,3 +73,22 @@ class InMemoryChannelBackend(BaseChannelBackend): """ self._clean_expired() return groups.get(group, {}).keys() + + def lock_channel(self, channel): + """ + Attempts to get a lock on the named channel. Returns True if lock + obtained, False if lock not obtained. + """ + # Probably not perfect for race conditions, but close enough considering + # it shouldn't be used. + if channel not in locks: + locks.add(channel) + return True + else: + return False + + def unlock_channel(self, channel): + """ + Unlocks the named channel. Always succeeds. + """ + locks.discard(channel) diff --git a/channels/backends/redis_py.py b/channels/backends/redis_py.py index 5c8671c..6af662f 100644 --- a/channels/backends/redis_py.py +++ b/channels/backends/redis_py.py @@ -101,5 +101,20 @@ class RedisChannelBackend(BaseChannelBackend): # TODO: send_group efficient implementation using Lua + def lock_channel(self, channel, expiry=None): + """ + Attempts to get a lock on the named channel. Returns True if lock + obtained, False if lock not obtained. + """ + key = "%s:lock:%s" % (self.prefix, channel) + return bool(self.connection.setnx(key, "1")) + + def unlock_channel(self, channel): + """ + Unlocks the named channel. Always succeeds. + """ + key = "%s:lock:%s" % (self.prefix, channel) + self.connection.delete(key) + def __str__(self): return "%s(host=%s, port=%s)" % (self.__class__.__name__, self.host, self.port) diff --git a/channels/decorators.py b/channels/decorators.py index 03a4ecc..280b493 100644 --- a/channels/decorators.py +++ b/channels/decorators.py @@ -3,16 +3,77 @@ import hashlib from importlib import import_module from django.conf import settings -from django.utils import six -from django.contrib import auth -from channels import channel_backends, DEFAULT_CHANNEL_BACKEND + +def linearize(func): + """ + Makes sure the contained consumer does not run at the same time other + consumers are running on messages with the same reply_channel. + + Required if you don't want weird things like a second consumer starting + up before the first has exited and saved its session. Doesn't guarantee + ordering, just linearity. + """ + @functools.wraps(func) + def inner(message, *args, **kwargs): + # Make sure there's a reply channel + if not message.reply_channel: + raise ValueError("No reply_channel sent to consumer; @no_overlap can only be used on messages containing it.") + # Get the lock, or re-queue + locked = message.channel_backend.lock_channel(message.reply_channel) + if not locked: + raise message.Requeue() + # OK, keep going + try: + return func(message, *args, **kwargs) + finally: + message.channel_backend.unlock_channel(message.reply_channel) + return inner + + +def channel_session(func): + """ + Provides a session-like object called "channel_session" to consumers + as a message attribute that will auto-persist across consumers with + the same incoming "reply_channel" value. + + Use this to persist data across the lifetime of a connection. + """ + @functools.wraps(func) + def inner(message, *args, **kwargs): + # Make sure there's a reply_channel + if not message.reply_channel: + raise ValueError("No reply_channel sent to consumer; @channel_session can only be used on messages containing it.") + # Make sure there's NOT a channel_session already + if hasattr(message, "channel_session"): + raise ValueError("channel_session decorator wrapped inside another channel_session decorator") + # Turn the reply_channel into a valid session key length thing. + # We take the last 24 bytes verbatim, as these are the random section, + # and then hash the remaining ones onto the start, and add a prefix + reply_name = message.reply_channel.name + session_key = "skt" + hashlib.md5(reply_name[:-24]).hexdigest()[:8] + reply_name[-24:] + # Make a session storage + session_engine = import_module(settings.SESSION_ENGINE) + session = session_engine.SessionStore(session_key=session_key) + # If the session does not already exist, save to force our + # session key to be valid. + if not session.exists(session.session_key): + session.save(must_create=True) + message.channel_session = session + # Run the consumer + try: + return func(message, *args, **kwargs) + finally: + # Persist session if needed + if session.modified: + session.save() + return inner def http_session(func): """ - Wraps a HTTP or WebSocket consumer (or any consumer of messages - that provides a "COOKIES" or "GET" attribute) to provide a "session" + Wraps a HTTP or WebSocket connect consumer (or any consumer of messages + that provides a "cooikies" or "get" attribute) to provide a "http_session" attribute that behaves like request.session; that is, it's hung off of a per-user session key that is saved in a cookie or passed as the "session_key" GET parameter. @@ -21,13 +82,16 @@ def http_session(func): don't have one - that's what SessionMiddleware is for, this is a simpler read-only version for more low-level code. - If a user does not have a session we can inflate, the "session" attribute will - be None, rather than an empty session you can write to. + If a message does not have a session we can inflate, the "session" attribute + will be None, rather than an empty session you can write to. """ @functools.wraps(func) def inner(message, *args, **kwargs): if "cookies" not in message.content and "get" not in message.content: - raise ValueError("No cookies or get sent to consumer; this decorator can only be used on messages containing at least one.") + raise ValueError("No cookies or get sent to consumer - cannot initialise http_session") + # Make sure there's NOT a http_session already + if hasattr(message, "http_session"): + raise ValueError("http_session decorator wrapped inside another http_session decorator") # Make sure there's a session key session_key = None if "get" in message.content: @@ -43,7 +107,7 @@ def http_session(func): session = session_engine.SessionStore(session_key=session_key) else: session = None - message.session = session + message.http_session = session # Run the consumer result = func(message, *args, **kwargs) # Persist session if needed (won't be saved if error happens) @@ -51,65 +115,3 @@ def http_session(func): session.save() return result return inner - - -def http_django_auth(func): - """ - Wraps a HTTP or WebSocket consumer (or any consumer of messages - that provides a "COOKIES" attribute) to provide both a "session" - attribute and a "user" attibute, like AuthMiddleware does. - - This runs http_session() to get a session to hook auth off of. - If the user does not have a session cookie set, both "session" - and "user" will be None. - """ - @http_session - @functools.wraps(func) - def inner(message, *args, **kwargs): - # If we didn't get a session, then we don't get a user - if not hasattr(message, "session"): - raise ValueError("Did not see a session to get auth from") - if message.session is None: - message.user = None - # Otherwise, be a bit naughty and make a fake Request with just - # a "session" attribute (later on, perhaps refactor contrib.auth to - # pass around session rather than request) - else: - fake_request = type("FakeRequest", (object, ), {"session": message.session}) - message.user = auth.get_user(fake_request) - # Run the consumer - return func(message, *args, **kwargs) - return inner - - -def channel_session(func): - """ - Provides a session-like object called "channel_session" to consumers - as a message attribute that will auto-persist across consumers with - the same incoming "reply_channel" value. - """ - @functools.wraps(func) - def inner(message, *args, **kwargs): - # Make sure there's a reply_channel in kwargs - if not message.reply_channel: - raise ValueError("No reply_channel sent to consumer; this decorator can only be used on messages containing it.") - # Turn the reply_channel into a valid session key length thing. - # We take the last 24 bytes verbatim, as these are the random section, - # and then hash the remaining ones onto the start, and add a prefix - # TODO: See if there's a better way of doing this - reply_name = message.reply_channel.name - session_key = "skt" + hashlib.md5(reply_name[:-24]).hexdigest()[:8] + reply_name[-24:] - # Make a session storage - session_engine = import_module(settings.SESSION_ENGINE) - session = session_engine.SessionStore(session_key=session_key) - # If the session does not already exist, save to force our session key to be valid - if not session.exists(session.session_key): - session.save() - message.channel_session = session - # Run the consumer - result = func(message, *args, **kwargs) - # Persist session if needed (won't be saved if error happens) - if session.modified: - session.save() - return result - return inner diff --git a/channels/interfaces/websocket_twisted.py b/channels/interfaces/websocket_twisted.py index 478d03b..ff5e947 100644 --- a/channels/interfaces/websocket_twisted.py +++ b/channels/interfaces/websocket_twisted.py @@ -18,7 +18,7 @@ class InterfaceProtocol(WebSocketServerProtocol): self.channel_backend = channel_backends[DEFAULT_CHANNEL_BACKEND] self.request_info = { "path": request.path, - "GET": request.params, + "get": request.params, } def onOpen(self): diff --git a/channels/message.py b/channels/message.py index bc7eda5..c9a80bd 100644 --- a/channels/message.py +++ b/channels/message.py @@ -10,6 +10,13 @@ class Message(object): to use to reply to this message's end user, if that makes sense. """ + class Requeue(Exception): + """ + Raise this while processing a message to requeue it back onto the + channel. Useful if you're manually ensuring partial ordering, etc. + """ + pass + def __init__(self, content, channel, channel_backend, reply_channel=None): self.content = content self.channel = channel diff --git a/channels/worker.py b/channels/worker.py index db0a896..50de9bd 100644 --- a/channels/worker.py +++ b/channels/worker.py @@ -1,5 +1,6 @@ import traceback from .message import Message +from .utils import name_that_thing class Worker(object): @@ -31,5 +32,8 @@ class Worker(object): self.callback(channel, message) try: consumer(message) + except Message.Requeue: + self.channel_backend.send(channel, content) except: + print "Error processing message with consumer %s:" % name_that_thing(consumer) traceback.print_exc()