diff --git a/daphne/endpoints.py b/daphne/endpoints.py index b97364f..6b89953 100644 --- a/daphne/endpoints.py +++ b/daphne/endpoints.py @@ -1,22 +1,42 @@ -def build_endpoint_description_strings( - host=None, port=None, unix_socket=None, file_descriptor=None -): +from abc import ABC, abstractmethod + +class Endpoint(ABC): + @abstractmethod + def parse(self, options): + pass + +class TCPEndpoint(Endpoint): + def parse(self, options): + if options.get("port") and options.get("host"): + host = options["host"].strip("[]").replace(":", r"\:") + return f"tcp:port={int(options['port'])}:interface={host}" + elif options.get("port") or options.get("host"): + raise ValueError("TCP binding requires both port and host kwargs.") + return None + +class UNIXEndpoint(Endpoint): + def parse(self, options): + if options.get("unix_socket"): + return f"unix:{options['unix_socket']}" + return None + +class FileDescriptorEndpoint(Endpoint): + def parse(self, options): + if options.get("file_descriptor") is not None: + return f"fd:fileno={int(options['file_descriptor'])}" + return None + +endpoint_parsers = [TCPEndpoint(), UNIXEndpoint(), FileDescriptorEndpoint()] + +def build_endpoint_description_strings(**kwargs): """ Build a list of twisted endpoint description strings that the server will listen on. This is to streamline the generation of twisted endpoint description strings from easier to use command line args such as host, port, unix sockets etc. """ socket_descriptions = [] - if host and port is not None: - host = host.strip("[]").replace(":", r"\:") - socket_descriptions.append("tcp:port=%d:interface=%s" % (int(port), host)) - elif any([host, port]): - raise ValueError("TCP binding requires both port and host kwargs.") - - if unix_socket: - socket_descriptions.append("unix:%s" % unix_socket) - - if file_descriptor is not None: - socket_descriptions.append("fd:fileno=%d" % int(file_descriptor)) - + for parser in endpoint_parsers: + description = parser.parse(kwargs) + if description: + socket_descriptions.append(description) return socket_descriptions diff --git a/tests/test_cli.py b/tests/test_cli.py index 59f44e7..aee7f8b 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -4,7 +4,11 @@ from argparse import ArgumentError from unittest import TestCase, skipUnless from daphne.cli import CommandLineInterface -from daphne.endpoints import build_endpoint_description_strings as build +from daphne.endpoints import ( + build_endpoint_description_strings as build, + endpoint_parsers, + Endpoint, +) class TestEndpointDescriptions(TestCase): @@ -64,6 +68,23 @@ class TestEndpointDescriptions(TestCase): ), ) + def test_custom_endpoint(self): + class CustomEndpoint(Endpoint): + def parse(self, options): + if options.get("custom"): + return f"custom:{options['custom']}" + return None + + endpoint_parsers.append(CustomEndpoint()) + + self.assertEqual( + build(custom="myprotocol"), + ["custom:myprotocol"], + ) + + # Cleanup custom endpoint parser + endpoint_parsers.pop() + class TestCLIInterface(TestCase): """