diff --git a/hooks/post_gen_project.py b/hooks/post_gen_project.py index 13d0ff00a..deb392534 100644 --- a/hooks/post_gen_project.py +++ b/hooks/post_gen_project.py @@ -104,7 +104,9 @@ def remove_celery_files(): def remove_async_files(): file_names = [ os.path.join("config", "asgi.py"), - os.path.join("config", "websocket.py"), + os.path.join("{{cookiecutter.project_slug}}", "users", "websocket.py"), + os.path.join("{{cookiecutter.project_slug}}", "users", "tests", "async_server.py"), + os.path.join("{{cookiecutter.project_slug}}", "users", "tests", "test_socket.py"), ] for file_name in file_names: os.remove(file_name) diff --git a/{{cookiecutter.project_slug}}/__init__.py b/{{cookiecutter.project_slug}}/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/{{cookiecutter.project_slug}}/config/asgi.py b/{{cookiecutter.project_slug}}/config/asgi.py index 8c99bbf53..2ac024b82 100644 --- a/{{cookiecutter.project_slug}}/config/asgi.py +++ b/{{cookiecutter.project_slug}}/config/asgi.py @@ -28,7 +28,8 @@ django_application = get_asgi_application() # application = HelloWorldApplication(application) # Import websocket application here, so apps from django_application are loaded first -from config.websocket import websocket_application # noqa isort:skip +# This is necessary since you may need to develop with your apps, but they need to be loaded first. +from {{ cookiecutter.project_slug }}.users.websocket import websocket_application # noqa isort:skip async def application(scope, receive, send): diff --git a/{{cookiecutter.project_slug}}/config/settings/test.py b/{{cookiecutter.project_slug}}/config/settings/test.py index 667bb20d8..18d634f24 100644 --- a/{{cookiecutter.project_slug}}/config/settings/test.py +++ b/{{cookiecutter.project_slug}}/config/settings/test.py @@ -15,6 +15,11 @@ SECRET_KEY = env( # https://docs.djangoproject.com/en/dev/ref/settings/#test-runner TEST_RUNNER = "django.test.runner.DiscoverRunner" +{%- if cookiecutter.use_async == 'y' %} +# Needed for socket testing that also needs HTTP to go along with it +ALLOWED_HOSTS = ["localhost", "0.0.0.0", "127.0.0.1"] +{%- endif %} + # CACHES # ------------------------------------------------------------------------------ # https://docs.djangoproject.com/en/dev/ref/settings/#caches diff --git a/{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/users/tests/async_server.py b/{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/users/tests/async_server.py new file mode 100644 index 000000000..fcb5de6bc --- /dev/null +++ b/{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/users/tests/async_server.py @@ -0,0 +1,42 @@ +import asyncio +import functools +import threading +import time +from contextlib import contextmanager + +from uvicorn.config import Config +from uvicorn.main import ServerState +from uvicorn.protocols.http.h11_impl import H11Protocol +from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol + + +def run_loop(loop): + loop.run_forever() + loop.close() + + +@contextmanager +def run_server(path="/"): + asyncio.set_event_loop(None) + loop = asyncio.new_event_loop() + config = Config( + app="config.asgi:application", ws=WebSocketProtocol, http=H11Protocol + ) + server_state = ServerState() + protocol = functools.partial(H11Protocol, config=config, server_state=server_state) + create_server_task = loop.create_server(protocol, host="127.0.0.1") + server = loop.run_until_complete(create_server_task) + port = server.sockets[0].getsockname()[1] + url = f"ws://127.0.0.1:{port}{path}" + try: + # Run the event loop in a new thread. + thread = threading.Thread(target=run_loop, args=[loop]) + thread.start() + # Return the contextmanager state. + yield url + finally: + # Close the loop from our main thread. + while server_state.tasks: + time.sleep(0.01) + loop.call_soon_threadsafe(loop.stop) + thread.join() diff --git a/{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/users/tests/test_socket.py b/{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/users/tests/test_socket.py new file mode 100644 index 000000000..56132c571 --- /dev/null +++ b/{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/users/tests/test_socket.py @@ -0,0 +1,24 @@ +""" +Refer to Uvicorn's tests to know how to write your own. +https://github.com/encode/uvicorn/blob/master/tests/protocols/test_websocket.py + +""" +from asyncio import new_event_loop +from websockets import connect + +from {{ cookiecutter.project_slug }}.users.tests.async_server import run_server + + +def test_accept_connection(): + """ + If you want to communicate over HTTP, add live_server fixture + """ + async def open_connection(url): + async with connect(url) as websocket: + return websocket.open + + with run_server() as url: + loop = new_event_loop() + is_open = loop.run_until_complete(open_connection(url)) + assert is_open + loop.close() diff --git a/{{cookiecutter.project_slug}}/config/websocket.py b/{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/users/websocket.py similarity index 83% rename from {{cookiecutter.project_slug}}/config/websocket.py rename to {{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/users/websocket.py index 81adfbc66..4818e24cd 100644 --- a/{{cookiecutter.project_slug}}/config/websocket.py +++ b/{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/users/websocket.py @@ -3,6 +3,7 @@ async def websocket_application(scope, receive, send): event = await receive() if event["type"] == "websocket.connect": + # TODO Add authentication using DRF-SimpleJWT or other token methods await send({"type": "websocket.accept"}) if event["type"] == "websocket.disconnect": @@ -10,4 +11,4 @@ async def websocket_application(scope, receive, send): if event["type"] == "websocket.receive": if event["text"] == "ping": - await send({"type": "websocket.send", "text": "pong!"}) + await send({"type": "websocket.send", "text": "pong"})