Updated to use ASGI v3 applications internally. (#275)

Used guarantee_single_callable().
Removed unneeded --asgi-protocol CLI option.
Updated tests.

Co-authored-by: Carlton Gibson <carlton.gibson@noumenal.es>
This commit is contained in:
Avinash Raj 2020-10-28 00:20:50 +05:30 committed by GitHub
parent e1b77e930b
commit 15ba5c6776
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 26 additions and 44 deletions

View File

@ -1,10 +1,9 @@
import argparse import argparse
import functools
import logging import logging
import sys import sys
from argparse import ArgumentError, Namespace from argparse import ArgumentError, Namespace
from asgiref.compatibility import is_double_callable from asgiref.compatibility import guarantee_single_callable
from .access import AccessLogGenerator from .access import AccessLogGenerator
from .endpoints import build_endpoint_description_strings from .endpoints import build_endpoint_description_strings
@ -17,19 +16,6 @@ DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 8000 DEFAULT_PORT = 8000
class ASGI3Middleware:
def __init__(self, app):
self.app = app
def __call__(self, scope):
scope.setdefault("asgi", {})
scope["asgi"]["version"] = "3.0"
return functools.partial(self.asgi, scope=scope)
async def asgi(self, receive, send, scope):
await self.app(scope, receive, send)
class CommandLineInterface(object): class CommandLineInterface(object):
""" """
Acts as the main CLI entry point for running the server. Acts as the main CLI entry point for running the server.
@ -129,13 +115,6 @@ class CommandLineInterface(object):
help="The WebSocket protocols you wish to support", help="The WebSocket protocols you wish to support",
default=None, default=None,
) )
self.parser.add_argument(
"--asgi-protocol",
dest="asgi_protocol",
help="The version of the ASGI protocol to use",
default="auto",
choices=["asgi2", "asgi3", "auto"],
)
self.parser.add_argument( self.parser.add_argument(
"--root-path", "--root-path",
dest="root_path", dest="root_path",
@ -247,16 +226,11 @@ class CommandLineInterface(object):
access_log_stream = open(args.access_log, "a", 1) access_log_stream = open(args.access_log, "a", 1)
elif args.verbosity >= 1: elif args.verbosity >= 1:
access_log_stream = sys.stdout access_log_stream = sys.stdout
# Import application # Import application
sys.path.insert(0, ".") sys.path.insert(0, ".")
application = import_by_path(args.application) application = import_by_path(args.application)
application = guarantee_single_callable(application)
asgi_protocol = args.asgi_protocol
if asgi_protocol == "auto":
asgi_protocol = "asgi2" if is_double_callable(application) else "asgi3"
if asgi_protocol == "asgi3":
application = ASGI3Middleware(application)
# Set up port/host bindings # Set up port/host bindings
if not any( if not any(

View File

@ -199,15 +199,17 @@ class Server(object):
assert "application_instance" not in self.connections[protocol] assert "application_instance" not in self.connections[protocol]
# Make an instance of the application # Make an instance of the application
input_queue = asyncio.Queue() input_queue = asyncio.Queue()
application_instance = self.application(scope=scope) scope.setdefault("asgi", {"version": "3.0"})
application_instance = self.application(
scope=scope,
receive=input_queue.get,
send=lambda message: self.handle_reply(protocol, message),
)
# Run it, and stash the future for later checking # Run it, and stash the future for later checking
if protocol not in self.connections: if protocol not in self.connections:
return None return None
self.connections[protocol]["application_instance"] = asyncio.ensure_future( self.connections[protocol]["application_instance"] = asyncio.ensure_future(
application_instance( application_instance,
receive=input_queue.get,
send=lambda message: self.handle_reply(protocol, message),
),
loop=asyncio.get_event_loop(), loop=asyncio.get_event_loop(),
) )
return input_queue return input_queue

View File

@ -5,7 +5,6 @@ import pickle
import tempfile import tempfile
import traceback import traceback
from concurrent.futures import CancelledError from concurrent.futures import CancelledError
from functools import partial
class DaphneTestingInstance: class DaphneTestingInstance:
@ -43,7 +42,7 @@ class DaphneTestingInstance:
# Start up process # Start up process
self.process = DaphneProcess( self.process = DaphneProcess(
host=self.host, host=self.host,
application=partial(TestApplication, lock=self.lock), application=TestApplication(lock=self.lock),
kwargs=kwargs, kwargs=kwargs,
setup=self.process_setup, setup=self.process_setup,
teardown=self.process_teardown, teardown=self.process_teardown,
@ -173,12 +172,12 @@ class TestApplication:
setup_storage = os.path.join(tempfile.gettempdir(), "setup.testio") setup_storage = os.path.join(tempfile.gettempdir(), "setup.testio")
result_storage = os.path.join(tempfile.gettempdir(), "result.testio") result_storage = os.path.join(tempfile.gettempdir(), "result.testio")
def __init__(self, scope, lock): def __init__(self, lock):
self.scope = scope
self.lock = lock self.lock = lock
self.messages = [] self.messages = []
async def __call__(self, send, receive): async def __call__(self, scope, receive, send):
self.scope = scope
# Receive input and send output # Receive input and send output
logging.debug("test app coroutine alive") logging.debug("test app coroutine alive")
try: try:

View File

@ -24,6 +24,7 @@ class TestHTTPRequest(DaphneTestCase):
# Check overall keys # Check overall keys
self.assert_key_sets( self.assert_key_sets(
required_keys={ required_keys={
"asgi",
"type", "type",
"http_version", "http_version",
"method", "method",
@ -35,6 +36,7 @@ class TestHTTPRequest(DaphneTestCase):
optional_keys={"scheme", "root_path", "client", "server"}, optional_keys={"scheme", "root_path", "client", "server"},
actual_keys=scope.keys(), actual_keys=scope.keys(),
) )
self.assertEqual(scope["asgi"]["version"], "3.0")
# Check that it is the right type # Check that it is the right type
self.assertEqual(scope["type"], "http") self.assertEqual(scope["type"], "http")
# Method (uppercased unicode string) # Method (uppercased unicode string)
@ -120,10 +122,7 @@ class TestHTTPRequest(DaphneTestCase):
self.assert_valid_http_scope(scope, "GET", request_path, params=request_params) self.assert_valid_http_scope(scope, "GET", request_path, params=request_params)
self.assert_valid_http_request_message(messages[0], body=b"") self.assert_valid_http_request_message(messages[0], body=b"")
@given( @given(request_path=http_strategies.http_path(), chunk_size=integers(min_value=1))
request_path=http_strategies.http_path(),
chunk_size=integers(min_value=1),
)
@settings(max_examples=5, deadline=5000) @settings(max_examples=5, deadline=5000)
def test_request_body_chunking(self, request_path, chunk_size): def test_request_body_chunking(self, request_path, chunk_size):
""" """

View File

@ -23,10 +23,18 @@ class TestWebsocket(DaphneTestCase):
""" """
# Check overall keys # Check overall keys
self.assert_key_sets( self.assert_key_sets(
required_keys={"type", "path", "raw_path", "query_string", "headers"}, required_keys={
"asgi",
"type",
"path",
"raw_path",
"query_string",
"headers",
},
optional_keys={"scheme", "root_path", "client", "server", "subprotocols"}, optional_keys={"scheme", "root_path", "client", "server", "subprotocols"},
actual_keys=scope.keys(), actual_keys=scope.keys(),
) )
self.assertEqual(scope["asgi"]["version"], "3.0")
# Check that it is the right type # Check that it is the right type
self.assertEqual(scope["type"], "websocket") self.assertEqual(scope["type"], "websocket")
# Path # Path