From 15ba5c6776bcefc0f3e4a27b06636883ae5b7cf4 Mon Sep 17 00:00:00 2001 From: Avinash Raj Date: Wed, 28 Oct 2020 00:20:50 +0530 Subject: [PATCH] Updated to use ASGI v3 applications internally. (#275) Used guarantee_single_callable(). Removed unneeded --asgi-protocol CLI option. Updated tests. Co-authored-by: Carlton Gibson --- daphne/cli.py | 32 +++----------------------------- daphne/server.py | 12 +++++++----- daphne/testing.py | 9 ++++----- tests/test_http_request.py | 7 +++---- tests/test_websocket.py | 10 +++++++++- 5 files changed, 26 insertions(+), 44 deletions(-) diff --git a/daphne/cli.py b/daphne/cli.py index 2e65b12..8c42c43 100755 --- a/daphne/cli.py +++ b/daphne/cli.py @@ -1,10 +1,9 @@ import argparse -import functools import logging import sys from argparse import ArgumentError, Namespace -from asgiref.compatibility import is_double_callable +from asgiref.compatibility import guarantee_single_callable from .access import AccessLogGenerator from .endpoints import build_endpoint_description_strings @@ -17,19 +16,6 @@ DEFAULT_HOST = "127.0.0.1" DEFAULT_PORT = 8000 -class ASGI3Middleware: - def __init__(self, app): - self.app = app - - def __call__(self, scope): - scope.setdefault("asgi", {}) - scope["asgi"]["version"] = "3.0" - return functools.partial(self.asgi, scope=scope) - - async def asgi(self, receive, send, scope): - await self.app(scope, receive, send) - - class CommandLineInterface(object): """ Acts as the main CLI entry point for running the server. @@ -129,13 +115,6 @@ class CommandLineInterface(object): help="The WebSocket protocols you wish to support", default=None, ) - self.parser.add_argument( - "--asgi-protocol", - dest="asgi_protocol", - help="The version of the ASGI protocol to use", - default="auto", - choices=["asgi2", "asgi3", "auto"], - ) self.parser.add_argument( "--root-path", dest="root_path", @@ -247,16 +226,11 @@ class CommandLineInterface(object): access_log_stream = open(args.access_log, "a", 1) elif args.verbosity >= 1: access_log_stream = sys.stdout + # Import application sys.path.insert(0, ".") application = import_by_path(args.application) - - asgi_protocol = args.asgi_protocol - if asgi_protocol == "auto": - asgi_protocol = "asgi2" if is_double_callable(application) else "asgi3" - - if asgi_protocol == "asgi3": - application = ASGI3Middleware(application) + application = guarantee_single_callable(application) # Set up port/host bindings if not any( diff --git a/daphne/server.py b/daphne/server.py index 9ca01d4..f367e06 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -199,15 +199,17 @@ class Server(object): assert "application_instance" not in self.connections[protocol] # Make an instance of the application input_queue = asyncio.Queue() - application_instance = self.application(scope=scope) + scope.setdefault("asgi", {"version": "3.0"}) + application_instance = self.application( + scope=scope, + receive=input_queue.get, + send=lambda message: self.handle_reply(protocol, message), + ) # Run it, and stash the future for later checking if protocol not in self.connections: return None self.connections[protocol]["application_instance"] = asyncio.ensure_future( - application_instance( - receive=input_queue.get, - send=lambda message: self.handle_reply(protocol, message), - ), + application_instance, loop=asyncio.get_event_loop(), ) return input_queue diff --git a/daphne/testing.py b/daphne/testing.py index 3fd27ee..1632516 100644 --- a/daphne/testing.py +++ b/daphne/testing.py @@ -5,7 +5,6 @@ import pickle import tempfile import traceback from concurrent.futures import CancelledError -from functools import partial class DaphneTestingInstance: @@ -43,7 +42,7 @@ class DaphneTestingInstance: # Start up process self.process = DaphneProcess( host=self.host, - application=partial(TestApplication, lock=self.lock), + application=TestApplication(lock=self.lock), kwargs=kwargs, setup=self.process_setup, teardown=self.process_teardown, @@ -173,12 +172,12 @@ class TestApplication: setup_storage = os.path.join(tempfile.gettempdir(), "setup.testio") result_storage = os.path.join(tempfile.gettempdir(), "result.testio") - def __init__(self, scope, lock): - self.scope = scope + def __init__(self, lock): self.lock = lock self.messages = [] - async def __call__(self, send, receive): + async def __call__(self, scope, receive, send): + self.scope = scope # Receive input and send output logging.debug("test app coroutine alive") try: diff --git a/tests/test_http_request.py b/tests/test_http_request.py index fbe3f6c..7048326 100644 --- a/tests/test_http_request.py +++ b/tests/test_http_request.py @@ -24,6 +24,7 @@ class TestHTTPRequest(DaphneTestCase): # Check overall keys self.assert_key_sets( required_keys={ + "asgi", "type", "http_version", "method", @@ -35,6 +36,7 @@ class TestHTTPRequest(DaphneTestCase): optional_keys={"scheme", "root_path", "client", "server"}, actual_keys=scope.keys(), ) + self.assertEqual(scope["asgi"]["version"], "3.0") # Check that it is the right type self.assertEqual(scope["type"], "http") # Method (uppercased unicode string) @@ -120,10 +122,7 @@ class TestHTTPRequest(DaphneTestCase): self.assert_valid_http_scope(scope, "GET", request_path, params=request_params) self.assert_valid_http_request_message(messages[0], body=b"") - @given( - request_path=http_strategies.http_path(), - chunk_size=integers(min_value=1), - ) + @given(request_path=http_strategies.http_path(), chunk_size=integers(min_value=1)) @settings(max_examples=5, deadline=5000) def test_request_body_chunking(self, request_path, chunk_size): """ diff --git a/tests/test_websocket.py b/tests/test_websocket.py index c27e7a9..862e71c 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -23,10 +23,18 @@ class TestWebsocket(DaphneTestCase): """ # Check overall keys self.assert_key_sets( - required_keys={"type", "path", "raw_path", "query_string", "headers"}, + required_keys={ + "asgi", + "type", + "path", + "raw_path", + "query_string", + "headers", + }, optional_keys={"scheme", "root_path", "client", "server", "subprotocols"}, actual_keys=scope.keys(), ) + self.assertEqual(scope["asgi"]["version"], "3.0") # Check that it is the right type self.assertEqual(scope["type"], "websocket") # Path