diff --git a/daphne/cli.py b/daphne/cli.py index 60daed7..68a8cb8 100755 --- a/daphne/cli.py +++ b/daphne/cli.py @@ -1,4 +1,7 @@ +import asyncio import argparse +import functools +import inspect import logging import sys from argparse import ArgumentError, Namespace @@ -14,6 +17,17 @@ DEFAULT_HOST = "127.0.0.1" DEFAULT_PORT = 8000 +class ASGI3Middleware: + def __init__(self, app): + self.app = app + + def __call__(self, scope): + return functools.partial(self.asgi, scope=scope) + + async def asgi(self, receive, send, scope): + await self.app(scope, receive, send) + + class CommandLineInterface(object): """ Acts as the main CLI entry point for running the server. @@ -113,6 +127,13 @@ class CommandLineInterface(object): help="The WebSocket protocols you wish to support", default=None, ) + self.parser.add_argument( + "--asgi-protocol", + dest="asgi_protocol", + help="The version of the ASGI protocol to use", + default="auto", + choices=["asgi2", "asgi3", "auto"] + ) self.parser.add_argument( "--root-path", dest="root_path", @@ -198,6 +219,23 @@ class CommandLineInterface(object): if args.proxy_headers: return "X-Forwarded-Port" + def _guess_asgi_protocol(self, application): + if getattr(application, "_asgi_single_callable", False): + return "asgi3" + if getattr(application, "_asgi_double_callable", False): + return "asgi2" + # Uninstanted classes are double-callable + if inspect.isclass(application): + return "asgi2" + # Instanted classes depend on their __call__ + if hasattr(application, "__call__"): + # We only check to see if its __call__ is a coroutine function - + # if it's not, it still might be a coroutine function itself. + if asyncio.iscoroutinefunction(application.__call__): + return "asgi3" + # Non-classes we just check directly + return "asgi3" if asyncio.iscoroutinefunction(application) else "asgi2" + def run(self, args): """ Pass in raw argument list and it will decode them @@ -227,6 +265,14 @@ class CommandLineInterface(object): # Import application sys.path.insert(0, ".") application = import_by_path(args.application) + + asgi_protocol = args.asgi_protocol + if asgi_protocol == "auto": + asgi_protocol = self._guess_asgi_protocol(application) + + if asgi_protocol == "asgi3": + application = ASGI3Middleware(application) + # Set up port/host bindings if not any( [