mirror of
https://github.com/django/daphne.git
synced 2025-04-21 17:22:03 +03:00
Fixed #6: Linearise decorator and better user session stuff
This commit is contained in:
parent
eed6e5e607
commit
638bf260f8
66
channels/auth.py
Normal file
66
channels/auth.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
import functools
|
||||
|
||||
from django.contrib import auth
|
||||
from .decorators import channel_session, http_session
|
||||
|
||||
|
||||
def transfer_user(from_session, to_session):
|
||||
"""
|
||||
Transfers user from HTTP session to channel session.
|
||||
"""
|
||||
to_session[auth.BACKEND_SESSION_KEY] = from_session[auth.BACKEND_SESSION_KEY]
|
||||
to_session[auth.SESSION_KEY] = from_session[auth.SESSION_KEY]
|
||||
to_session[auth.HASH_SESSION_KEY] = from_session[auth.HASH_SESSION_KEY]
|
||||
|
||||
|
||||
def channel_session_user(func):
|
||||
"""
|
||||
Presents a message.user attribute obtained from a user ID in the channel
|
||||
session, rather than in the http_session. Turns on channel session implicitly.
|
||||
"""
|
||||
@channel_session
|
||||
@functools.wraps(func)
|
||||
def inner(message, *args, **kwargs):
|
||||
# If we didn't get a session, then we don't get a user
|
||||
if not hasattr(message, "channel_session"):
|
||||
raise ValueError("Did not see a channel session to get auth from")
|
||||
if message.channel_session is None:
|
||||
message.user = None
|
||||
# Otherwise, be a bit naughty and make a fake Request with just
|
||||
# a "session" attribute (later on, perhaps refactor contrib.auth to
|
||||
# pass around session rather than request)
|
||||
else:
|
||||
fake_request = type("FakeRequest", (object, ), {"session": message.channel_session})
|
||||
message.user = auth.get_user(fake_request)
|
||||
# Run the consumer
|
||||
return func(message, *args, **kwargs)
|
||||
return inner
|
||||
|
||||
|
||||
def http_session_user(func):
|
||||
"""
|
||||
Wraps a HTTP or WebSocket consumer (or any consumer of messages
|
||||
that provides a "COOKIES" attribute) to provide both a "session"
|
||||
attribute and a "user" attibute, like AuthMiddleware does.
|
||||
|
||||
This runs http_session() to get a session to hook auth off of.
|
||||
If the user does not have a session cookie set, both "session"
|
||||
and "user" will be None.
|
||||
"""
|
||||
@http_session
|
||||
@functools.wraps(func)
|
||||
def inner(message, *args, **kwargs):
|
||||
# If we didn't get a session, then we don't get a user
|
||||
if not hasattr(message, "http_session"):
|
||||
raise ValueError("Did not see a http session to get auth from")
|
||||
if message.http_session is None:
|
||||
message.user = None
|
||||
# Otherwise, be a bit naughty and make a fake Request with just
|
||||
# a "session" attribute (later on, perhaps refactor contrib.auth to
|
||||
# pass around session rather than request)
|
||||
else:
|
||||
fake_request = type("FakeRequest", (object, ), {"session": message.http_session})
|
||||
message.user = auth.get_user(fake_request)
|
||||
# Run the consumer
|
||||
return func(message, *args, **kwargs)
|
||||
return inner
|
|
@ -93,3 +93,16 @@ class BaseChannelBackend(object):
|
|||
|
||||
def __str__(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
def lock_channel(self, channel):
|
||||
"""
|
||||
Attempts to get a lock on the named channel. Returns True if lock
|
||||
obtained, False if lock not obtained.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def unlock_channel(self, channel):
|
||||
"""
|
||||
Unlocks the named channel. Always succeeds.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
import datetime
|
||||
|
||||
from django.apps.registry import Apps
|
||||
from django.db import models, connections, DEFAULT_DB_ALIAS
|
||||
from django.db import models, connections, DEFAULT_DB_ALIAS, IntegrityError
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.timezone import now
|
||||
|
||||
|
@ -71,6 +71,26 @@ class DatabaseChannelBackend(BaseChannelBackend):
|
|||
editor.create_model(Group)
|
||||
return Group
|
||||
|
||||
@cached_property
|
||||
def lock_model(self):
|
||||
"""
|
||||
Initialises a new model to store groups; not done as part of a
|
||||
models.py as we don't want to make it for most installs.
|
||||
"""
|
||||
# Make the model class
|
||||
class Lock(models.Model):
|
||||
channel = models.CharField(max_length=200, unique=True)
|
||||
expiry = models.DateTimeField(db_index=True)
|
||||
class Meta:
|
||||
apps = Apps()
|
||||
app_label = "channels"
|
||||
db_table = "django_channel_locks"
|
||||
# Ensure its table exists
|
||||
if Lock._meta.db_table not in self.connection.introspection.table_names(self.connection.cursor()):
|
||||
with self.connection.schema_editor() as editor:
|
||||
editor.create_model(Lock)
|
||||
return Lock
|
||||
|
||||
def send(self, channel, message):
|
||||
self.channel_model.objects.create(
|
||||
channel = channel,
|
||||
|
@ -97,6 +117,7 @@ class DatabaseChannelBackend(BaseChannelBackend):
|
|||
# Include a 10-second grace period because that solves some clock sync
|
||||
self.channel_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete()
|
||||
self.group_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete()
|
||||
self.lock_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete()
|
||||
|
||||
def group_add(self, group, channel, expiry=None):
|
||||
"""
|
||||
|
@ -123,5 +144,27 @@ class DatabaseChannelBackend(BaseChannelBackend):
|
|||
self._clean_expired()
|
||||
return list(self.group_model.objects.filter(group=group).values_list("channel", flat=True))
|
||||
|
||||
def lock_channel(self, channel, expiry=None):
|
||||
"""
|
||||
Attempts to get a lock on the named channel. Returns True if lock
|
||||
obtained, False if lock not obtained.
|
||||
"""
|
||||
# We rely on the UNIQUE constraint for only-one-thread-wins on locks
|
||||
try:
|
||||
self.lock_model.objects.create(
|
||||
channel = channel,
|
||||
expiry = now() + datetime.timedelta(seconds=expiry or self.expiry),
|
||||
)
|
||||
except IntegrityError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def unlock_channel(self, channel):
|
||||
"""
|
||||
Unlocks the named channel. Always succeeds.
|
||||
"""
|
||||
self.lock_model.objects.filter(channel=channel).delete()
|
||||
|
||||
def __str__(self):
|
||||
return "%s(alias=%s)" % (self.__class__.__name__, self.connection.alias)
|
||||
|
|
|
@ -5,6 +5,7 @@ from .base import BaseChannelBackend
|
|||
|
||||
queues = {}
|
||||
groups = {}
|
||||
locks = set()
|
||||
|
||||
class InMemoryChannelBackend(BaseChannelBackend):
|
||||
"""
|
||||
|
@ -72,3 +73,22 @@ class InMemoryChannelBackend(BaseChannelBackend):
|
|||
"""
|
||||
self._clean_expired()
|
||||
return groups.get(group, {}).keys()
|
||||
|
||||
def lock_channel(self, channel):
|
||||
"""
|
||||
Attempts to get a lock on the named channel. Returns True if lock
|
||||
obtained, False if lock not obtained.
|
||||
"""
|
||||
# Probably not perfect for race conditions, but close enough considering
|
||||
# it shouldn't be used.
|
||||
if channel not in locks:
|
||||
locks.add(channel)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def unlock_channel(self, channel):
|
||||
"""
|
||||
Unlocks the named channel. Always succeeds.
|
||||
"""
|
||||
locks.discard(channel)
|
||||
|
|
|
@ -101,5 +101,20 @@ class RedisChannelBackend(BaseChannelBackend):
|
|||
|
||||
# TODO: send_group efficient implementation using Lua
|
||||
|
||||
def lock_channel(self, channel, expiry=None):
|
||||
"""
|
||||
Attempts to get a lock on the named channel. Returns True if lock
|
||||
obtained, False if lock not obtained.
|
||||
"""
|
||||
key = "%s:lock:%s" % (self.prefix, channel)
|
||||
return bool(self.connection.setnx(key, "1"))
|
||||
|
||||
def unlock_channel(self, channel):
|
||||
"""
|
||||
Unlocks the named channel. Always succeeds.
|
||||
"""
|
||||
key = "%s:lock:%s" % (self.prefix, channel)
|
||||
self.connection.delete(key)
|
||||
|
||||
def __str__(self):
|
||||
return "%s(host=%s, port=%s)" % (self.__class__.__name__, self.host, self.port)
|
||||
|
|
|
@ -3,16 +3,77 @@ import hashlib
|
|||
from importlib import import_module
|
||||
|
||||
from django.conf import settings
|
||||
from django.utils import six
|
||||
from django.contrib import auth
|
||||
|
||||
from channels import channel_backends, DEFAULT_CHANNEL_BACKEND
|
||||
|
||||
def linearize(func):
|
||||
"""
|
||||
Makes sure the contained consumer does not run at the same time other
|
||||
consumers are running on messages with the same reply_channel.
|
||||
|
||||
Required if you don't want weird things like a second consumer starting
|
||||
up before the first has exited and saved its session. Doesn't guarantee
|
||||
ordering, just linearity.
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
def inner(message, *args, **kwargs):
|
||||
# Make sure there's a reply channel
|
||||
if not message.reply_channel:
|
||||
raise ValueError("No reply_channel sent to consumer; @no_overlap can only be used on messages containing it.")
|
||||
# Get the lock, or re-queue
|
||||
locked = message.channel_backend.lock_channel(message.reply_channel)
|
||||
if not locked:
|
||||
raise message.Requeue()
|
||||
# OK, keep going
|
||||
try:
|
||||
return func(message, *args, **kwargs)
|
||||
finally:
|
||||
message.channel_backend.unlock_channel(message.reply_channel)
|
||||
return inner
|
||||
|
||||
|
||||
def channel_session(func):
|
||||
"""
|
||||
Provides a session-like object called "channel_session" to consumers
|
||||
as a message attribute that will auto-persist across consumers with
|
||||
the same incoming "reply_channel" value.
|
||||
|
||||
Use this to persist data across the lifetime of a connection.
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
def inner(message, *args, **kwargs):
|
||||
# Make sure there's a reply_channel
|
||||
if not message.reply_channel:
|
||||
raise ValueError("No reply_channel sent to consumer; @channel_session can only be used on messages containing it.")
|
||||
# Make sure there's NOT a channel_session already
|
||||
if hasattr(message, "channel_session"):
|
||||
raise ValueError("channel_session decorator wrapped inside another channel_session decorator")
|
||||
# Turn the reply_channel into a valid session key length thing.
|
||||
# We take the last 24 bytes verbatim, as these are the random section,
|
||||
# and then hash the remaining ones onto the start, and add a prefix
|
||||
reply_name = message.reply_channel.name
|
||||
session_key = "skt" + hashlib.md5(reply_name[:-24]).hexdigest()[:8] + reply_name[-24:]
|
||||
# Make a session storage
|
||||
session_engine = import_module(settings.SESSION_ENGINE)
|
||||
session = session_engine.SessionStore(session_key=session_key)
|
||||
# If the session does not already exist, save to force our
|
||||
# session key to be valid.
|
||||
if not session.exists(session.session_key):
|
||||
session.save(must_create=True)
|
||||
message.channel_session = session
|
||||
# Run the consumer
|
||||
try:
|
||||
return func(message, *args, **kwargs)
|
||||
finally:
|
||||
# Persist session if needed
|
||||
if session.modified:
|
||||
session.save()
|
||||
return inner
|
||||
|
||||
|
||||
def http_session(func):
|
||||
"""
|
||||
Wraps a HTTP or WebSocket consumer (or any consumer of messages
|
||||
that provides a "COOKIES" or "GET" attribute) to provide a "session"
|
||||
Wraps a HTTP or WebSocket connect consumer (or any consumer of messages
|
||||
that provides a "cooikies" or "get" attribute) to provide a "http_session"
|
||||
attribute that behaves like request.session; that is, it's hung off of
|
||||
a per-user session key that is saved in a cookie or passed as the
|
||||
"session_key" GET parameter.
|
||||
|
@ -21,13 +82,16 @@ def http_session(func):
|
|||
don't have one - that's what SessionMiddleware is for, this is a simpler
|
||||
read-only version for more low-level code.
|
||||
|
||||
If a user does not have a session we can inflate, the "session" attribute will
|
||||
be None, rather than an empty session you can write to.
|
||||
If a message does not have a session we can inflate, the "session" attribute
|
||||
will be None, rather than an empty session you can write to.
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
def inner(message, *args, **kwargs):
|
||||
if "cookies" not in message.content and "get" not in message.content:
|
||||
raise ValueError("No cookies or get sent to consumer; this decorator can only be used on messages containing at least one.")
|
||||
raise ValueError("No cookies or get sent to consumer - cannot initialise http_session")
|
||||
# Make sure there's NOT a http_session already
|
||||
if hasattr(message, "http_session"):
|
||||
raise ValueError("http_session decorator wrapped inside another http_session decorator")
|
||||
# Make sure there's a session key
|
||||
session_key = None
|
||||
if "get" in message.content:
|
||||
|
@ -43,7 +107,7 @@ def http_session(func):
|
|||
session = session_engine.SessionStore(session_key=session_key)
|
||||
else:
|
||||
session = None
|
||||
message.session = session
|
||||
message.http_session = session
|
||||
# Run the consumer
|
||||
result = func(message, *args, **kwargs)
|
||||
# Persist session if needed (won't be saved if error happens)
|
||||
|
@ -51,65 +115,3 @@ def http_session(func):
|
|||
session.save()
|
||||
return result
|
||||
return inner
|
||||
|
||||
|
||||
def http_django_auth(func):
|
||||
"""
|
||||
Wraps a HTTP or WebSocket consumer (or any consumer of messages
|
||||
that provides a "COOKIES" attribute) to provide both a "session"
|
||||
attribute and a "user" attibute, like AuthMiddleware does.
|
||||
|
||||
This runs http_session() to get a session to hook auth off of.
|
||||
If the user does not have a session cookie set, both "session"
|
||||
and "user" will be None.
|
||||
"""
|
||||
@http_session
|
||||
@functools.wraps(func)
|
||||
def inner(message, *args, **kwargs):
|
||||
# If we didn't get a session, then we don't get a user
|
||||
if not hasattr(message, "session"):
|
||||
raise ValueError("Did not see a session to get auth from")
|
||||
if message.session is None:
|
||||
message.user = None
|
||||
# Otherwise, be a bit naughty and make a fake Request with just
|
||||
# a "session" attribute (later on, perhaps refactor contrib.auth to
|
||||
# pass around session rather than request)
|
||||
else:
|
||||
fake_request = type("FakeRequest", (object, ), {"session": message.session})
|
||||
message.user = auth.get_user(fake_request)
|
||||
# Run the consumer
|
||||
return func(message, *args, **kwargs)
|
||||
return inner
|
||||
|
||||
|
||||
def channel_session(func):
|
||||
"""
|
||||
Provides a session-like object called "channel_session" to consumers
|
||||
as a message attribute that will auto-persist across consumers with
|
||||
the same incoming "reply_channel" value.
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
def inner(message, *args, **kwargs):
|
||||
# Make sure there's a reply_channel in kwargs
|
||||
if not message.reply_channel:
|
||||
raise ValueError("No reply_channel sent to consumer; this decorator can only be used on messages containing it.")
|
||||
# Turn the reply_channel into a valid session key length thing.
|
||||
# We take the last 24 bytes verbatim, as these are the random section,
|
||||
# and then hash the remaining ones onto the start, and add a prefix
|
||||
# TODO: See if there's a better way of doing this
|
||||
reply_name = message.reply_channel.name
|
||||
session_key = "skt" + hashlib.md5(reply_name[:-24]).hexdigest()[:8] + reply_name[-24:]
|
||||
# Make a session storage
|
||||
session_engine = import_module(settings.SESSION_ENGINE)
|
||||
session = session_engine.SessionStore(session_key=session_key)
|
||||
# If the session does not already exist, save to force our session key to be valid
|
||||
if not session.exists(session.session_key):
|
||||
session.save()
|
||||
message.channel_session = session
|
||||
# Run the consumer
|
||||
result = func(message, *args, **kwargs)
|
||||
# Persist session if needed (won't be saved if error happens)
|
||||
if session.modified:
|
||||
session.save()
|
||||
return result
|
||||
return inner
|
||||
|
|
|
@ -18,7 +18,7 @@ class InterfaceProtocol(WebSocketServerProtocol):
|
|||
self.channel_backend = channel_backends[DEFAULT_CHANNEL_BACKEND]
|
||||
self.request_info = {
|
||||
"path": request.path,
|
||||
"GET": request.params,
|
||||
"get": request.params,
|
||||
}
|
||||
|
||||
def onOpen(self):
|
||||
|
|
|
@ -10,6 +10,13 @@ class Message(object):
|
|||
to use to reply to this message's end user, if that makes sense.
|
||||
"""
|
||||
|
||||
class Requeue(Exception):
|
||||
"""
|
||||
Raise this while processing a message to requeue it back onto the
|
||||
channel. Useful if you're manually ensuring partial ordering, etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __init__(self, content, channel, channel_backend, reply_channel=None):
|
||||
self.content = content
|
||||
self.channel = channel
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import traceback
|
||||
from .message import Message
|
||||
from .utils import name_that_thing
|
||||
|
||||
|
||||
class Worker(object):
|
||||
|
@ -31,5 +32,8 @@ class Worker(object):
|
|||
self.callback(channel, message)
|
||||
try:
|
||||
consumer(message)
|
||||
except Message.Requeue:
|
||||
self.channel_backend.send(channel, content)
|
||||
except:
|
||||
print "Error processing message with consumer %s:" % name_that_thing(consumer)
|
||||
traceback.print_exc()
|
||||
|
|
Loading…
Reference in New Issue
Block a user