diff --git a/channels/auth.py b/channels/auth.py index d0c786a..1ae7e1d 100644 --- a/channels/auth.py +++ b/channels/auth.py @@ -1,8 +1,9 @@ import functools from django.contrib import auth +from django.contrib.auth.models import AnonymousUser -from .decorators import channel_session, http_session +from .sessions import channel_session, http_session def transfer_user(from_session, to_session): @@ -26,7 +27,7 @@ def channel_session_user(func): if not hasattr(message, "channel_session"): raise ValueError("Did not see a channel session to get auth from") if message.channel_session is None: - message.user = None + message.user = AnonymousUser() # Otherwise, be a bit naughty and make a fake Request with just # a "session" attribute (later on, perhaps refactor contrib.auth to # pass around session rather than request) @@ -55,7 +56,7 @@ def http_session_user(func): if not hasattr(message, "http_session"): raise ValueError("Did not see a http session to get auth from") if message.http_session is None: - message.user = None + message.user = AnonymousUser() # Otherwise, be a bit naughty and make a fake Request with just # a "session" attribute (later on, perhaps refactor contrib.auth to # pass around session rather than request) @@ -65,3 +66,18 @@ def http_session_user(func): # Run the consumer return func(message, *args, **kwargs) return inner + + +def channel_session_user_from_http(func): + """ + Decorator that automatically transfers the user from HTTP sessions to + channel-based sessions, and returns the user as message.user as well. + Useful for things that consume e.g. websocket.connect + """ + @http_session_user + @channel_session + def inner(message, *args, **kwargs): + if message.http_session is not None: + transfer_user(message.http_session, message.channel_session) + return func(message, *args, **kwargs) + return inner diff --git a/channels/decorators.py b/channels/decorators.py index 063447a..279a7bd 100644 --- a/channels/decorators.py +++ b/channels/decorators.py @@ -1,8 +1,4 @@ import functools -import hashlib -from importlib import import_module - -from django.conf import settings def linearize(func): @@ -32,95 +28,3 @@ def linearize(func): # TODO: Release lock here pass return inner - - -def channel_session(func): - """ - Provides a session-like object called "channel_session" to consumers - as a message attribute that will auto-persist across consumers with - the same incoming "reply_channel" value. - - Use this to persist data across the lifetime of a connection. - """ - @functools.wraps(func) - def inner(message, *args, **kwargs): - # Make sure there's a reply_channel - if not message.reply_channel: - raise ValueError( - "No reply_channel sent to consumer; @channel_session " + - "can only be used on messages containing it." - ) - - # Make sure there's NOT a channel_session already - if hasattr(message, "channel_session"): - raise ValueError("channel_session decorator wrapped inside another channel_session decorator") - - # Turn the reply_channel into a valid session key length thing. - # We take the last 24 bytes verbatim, as these are the random section, - # and then hash the remaining ones onto the start, and add a prefix - reply_name = message.reply_channel.name - hashed = hashlib.md5(reply_name[:-24].encode()).hexdigest()[:8] - session_key = "skt" + hashed + reply_name[-24:] - # Make a session storage - session_engine = import_module(settings.SESSION_ENGINE) - session = session_engine.SessionStore(session_key=session_key) - # If the session does not already exist, save to force our - # session key to be valid. - if not session.exists(session.session_key): - session.save(must_create=True) - message.channel_session = session - # Run the consumer - try: - return func(message, *args, **kwargs) - finally: - # Persist session if needed - if session.modified: - session.save() - return inner - - -def http_session(func): - """ - Wraps a HTTP or WebSocket connect consumer (or any consumer of messages - that provides a "cookies" or "get" attribute) to provide a "http_session" - attribute that behaves like request.session; that is, it's hung off of - a per-user session key that is saved in a cookie or passed as the - "session_key" GET parameter. - - It won't automatically create and set a session cookie for users who - don't have one - that's what SessionMiddleware is for, this is a simpler - read-only version for more low-level code. - - If a message does not have a session we can inflate, the "session" attribute - will be None, rather than an empty session you can write to. - """ - @functools.wraps(func) - def inner(message, *args, **kwargs): - if "cookies" not in message.content and "get" not in message.content: - raise ValueError("No cookies or get sent to consumer - cannot initialise http_session") - # Make sure there's NOT a http_session already - if hasattr(message, "http_session"): - raise ValueError("http_session decorator wrapped inside another http_session decorator") - # Make sure there's a session key - session_key = None - if "get" in message.content: - try: - session_key = message.content['get'].get("session_key", [])[0] - except IndexError: - pass - if "cookies" in message.content and session_key is None: - session_key = message.content['cookies'].get(settings.SESSION_COOKIE_NAME) - # Make a session storage - if session_key: - session_engine = import_module(settings.SESSION_ENGINE) - session = session_engine.SessionStore(session_key=session_key) - else: - session = None - message.http_session = session - # Run the consumer - result = func(message, *args, **kwargs) - # Persist session if needed (won't be saved if error happens) - if session is not None and session.modified: - session.save() - return result - return inner diff --git a/channels/message.py b/channels/message.py index 53a36cc..5830fdd 100644 --- a/channels/message.py +++ b/channels/message.py @@ -12,9 +12,12 @@ class Message(object): to use to reply to this message's end user, if that makes sense. """ - def __init__(self, content, channel, channel_layer): + def __init__(self, content, channel_name, channel_layer): self.content = content - self.channel = channel + self.channel = Channel( + channel_name, + channel_layer=channel_layer, + ) self.channel_layer = channel_layer if content.get("reply_channel", None): self.reply_channel = Channel( diff --git a/channels/sessions.py b/channels/sessions.py new file mode 100644 index 0000000..598b79e --- /dev/null +++ b/channels/sessions.py @@ -0,0 +1,101 @@ +import functools +import hashlib +from importlib import import_module + +from django.conf import settings + +from .handler import AsgiRequest + + +def channel_session(func): + """ + Provides a session-like object called "channel_session" to consumers + as a message attribute that will auto-persist across consumers with + the same incoming "reply_channel" value. + + Use this to persist data across the lifetime of a connection. + """ + @functools.wraps(func) + def inner(message, *args, **kwargs): + # Make sure there's a reply_channel + if not message.reply_channel: + raise ValueError( + "No reply_channel sent to consumer; @channel_session " + + "can only be used on messages containing it." + ) + + # Make sure there's NOT a channel_session already + if hasattr(message, "channel_session"): + raise ValueError("channel_session decorator wrapped inside another channel_session decorator") + + # Turn the reply_channel into a valid session key length thing. + # We take the last 24 bytes verbatim, as these are the random section, + # and then hash the remaining ones onto the start, and add a prefix + reply_name = message.reply_channel.name + hashed = hashlib.md5(reply_name[:-24].encode()).hexdigest()[:8] + session_key = "skt" + hashed + reply_name[-24:] + # Make a session storage + session_engine = import_module(settings.SESSION_ENGINE) + session = session_engine.SessionStore(session_key=session_key) + # If the session does not already exist, save to force our + # session key to be valid. + if not session.exists(session.session_key): + session.save(must_create=True) + message.channel_session = session + # Run the consumer + try: + return func(message, *args, **kwargs) + finally: + # Persist session if needed + if session.modified: + session.save() + return inner + + +def http_session(func): + """ + Wraps a HTTP or WebSocket connect consumer (or any consumer of messages + that provides a "cookies" or "get" attribute) to provide a "http_session" + attribute that behaves like request.session; that is, it's hung off of + a per-user session key that is saved in a cookie or passed as the + "session_key" GET parameter. + + It won't automatically create and set a session cookie for users who + don't have one - that's what SessionMiddleware is for, this is a simpler + read-only version for more low-level code. + + If a message does not have a session we can inflate, the "session" attribute + will be None, rather than an empty session you can write to. + """ + @functools.wraps(func) + def inner(message, *args, **kwargs): + try: + # We want to parse the WebSocket (or similar HTTP-lite) message + # to get cookies and GET, but we need to add in a few things that + # might not have been there. + if "method" not in message.content: + message.content['method'] = "FAKE" + request = AsgiRequest(message) + except Exception as e: + raise ValueError("Cannot parse HTTP message - are you sure this is a HTTP consumer? %s" % e) + # Make sure there's NOT a http_session already + if hasattr(message, "http_session"): + raise ValueError("http_session decorator wrapped inside another http_session decorator") + # Make sure there's a session key + session_key = request.GET.get("session_key", None) + if session_key is None: + session_key = request.COOKIES.get(settings.SESSION_COOKIE_NAME, None) + # Make a session storage + if session_key: + session_engine = import_module(settings.SESSION_ENGINE) + session = session_engine.SessionStore(session_key=session_key) + else: + session = None + message.http_session = session + # Run the consumer + result = func(message, *args, **kwargs) + # Persist session if needed (won't be saved if error happens) + if session is not None and session.modified: + session.save() + return result + return inner diff --git a/channels/worker.py b/channels/worker.py index 8e61d80..ba0ae35 100644 --- a/channels/worker.py +++ b/channels/worker.py @@ -34,7 +34,7 @@ class Worker(object): # Create message wrapper message = Message( content=content, - channel=channel, + channel_name=channel, channel_layer=self.channel_layer, ) # Handle the message