Support ASGI3

This commit is contained in:
Tom Christie 2019-03-20 10:01:42 +00:00
parent 67cfa98b00
commit 55f7e6713d

View File

@ -1,4 +1,7 @@
import asyncio
import argparse
import functools
import inspect
import logging
import sys
from argparse import ArgumentError, Namespace
@ -14,6 +17,17 @@ DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 8000
class ASGI3Middleware:
def __init__(self, app):
self.app = app
def __call__(self, scope):
return functools.partial(self.asgi, scope=scope)
async def asgi(self, receive, send, scope):
await self.app(scope, receive, send)
class CommandLineInterface(object):
"""
Acts as the main CLI entry point for running the server.
@ -113,6 +127,13 @@ class CommandLineInterface(object):
help="The WebSocket protocols you wish to support",
default=None,
)
self.parser.add_argument(
"--asgi-protocol",
dest="asgi_protocol",
help="The version of the ASGI protocol to use",
default="auto",
choices=["asgi2", "asgi3", "auto"]
)
self.parser.add_argument(
"--root-path",
dest="root_path",
@ -198,6 +219,23 @@ class CommandLineInterface(object):
if args.proxy_headers:
return "X-Forwarded-Port"
def _guess_asgi_protocol(self, application):
if getattr(application, "_asgi_single_callable", False):
return "asgi3"
if getattr(application, "_asgi_double_callable", False):
return "asgi2"
# Uninstanted classes are double-callable
if inspect.isclass(application):
return "asgi2"
# Instanted classes depend on their __call__
if hasattr(application, "__call__"):
# We only check to see if its __call__ is a coroutine function -
# if it's not, it still might be a coroutine function itself.
if asyncio.iscoroutinefunction(application.__call__):
return "asgi3"
# Non-classes we just check directly
return "asgi3" if asyncio.iscoroutinefunction(application) else "asgi2"
def run(self, args):
"""
Pass in raw argument list and it will decode them
@ -227,6 +265,14 @@ class CommandLineInterface(object):
# Import application
sys.path.insert(0, ".")
application = import_by_path(args.application)
asgi_protocol = args.asgi_protocol
if asgi_protocol == "auto":
asgi_protocol = self._guess_asgi_protocol(application)
if asgi_protocol == "asgi3":
application = ASGI3Middleware(application)
# Set up port/host bindings
if not any(
[