use twisted endpoint description strings to bind to ports and sockets

This commit is contained in:
Sean Mc Allister 2016-08-11 17:52:27 +02:00
parent fca52d4850
commit 95351ffebb
6 changed files with 273 additions and 27 deletions

View File

@ -8,6 +8,8 @@ from .access import AccessLogGenerator
logger = logging.getLogger(__name__)
DEFAULT_HOST = '127.0.0.1'
DEFAULT_PORT = 8000
class CommandLineInterface(object):
"""
@ -25,14 +27,14 @@ class CommandLineInterface(object):
'--port',
type=int,
help='Port number to listen on',
default=8000,
default=None,
)
self.parser.add_argument(
'-b',
'--bind',
dest='host',
help='The host/address to bind to',
default="127.0.0.1",
default=None,
)
self.parser.add_argument(
'-u',
@ -48,6 +50,14 @@ class CommandLineInterface(object):
help='Bind to a file descriptor rather than a TCP host/port or named unix socket',
default=None,
)
self.parser.add_argument(
'-e',
'--endpoint',
dest='socket_strings',
action='append',
help='Use raw server strings passed directly to twisted',
default=[],
)
self.parser.add_argument(
'-v',
'--verbosity',
@ -135,22 +145,30 @@ class CommandLineInterface(object):
channel_layer = importlib.import_module(module_path)
for bit in object_path.split("."):
channel_layer = getattr(channel_layer, bit)
if not any([args.host, args.port, args.unix_socket, args.file_descriptor, args.socket_strings]):
# no advanced binding options passed, patch in defaults
args.host = DEFAULT_HOST
args.port = DEFAULT_PORT
elif args.host and not args.port:
args.port = DEFAULT_PORT
elif args.port and not args.host:
args.host = DEFAULT_HOST
# Run server
logger.info(
"Starting server at %s, channel layer %s",
(args.unix_socket if args.unix_socket else "%s:%s" % (args.host, args.port)),
args.channel_layer,
)
Server(
self.server = Server(
channel_layer=channel_layer,
host=args.host,
port=args.port,
unix_socket=args.unix_socket,
file_descriptor=args.file_descriptor,
endpoints=args.socket_strings,
http_timeout=args.http_timeout,
ping_interval=args.ping_interval,
ping_timeout=args.ping_timeout,
action_logger=AccessLogGenerator(access_log_stream) if access_log_stream else None,
ws_protocols=args.ws_protocols,
root_path=args.root_path,
).run()
)
self.server.run()

View File

@ -1,11 +1,12 @@
import logging
import socket
from twisted.internet import reactor, defer
from twisted.logger import globalLogBeginner
from twisted.internet.endpoints import serverFromString
from .http_protocol import HTTPFactory
logger = logging.getLogger(__name__)
@ -14,8 +15,9 @@ class Server(object):
def __init__(
self,
channel_layer,
host="127.0.0.1",
port=8000,
host=None,
port=None,
endpoints=[],
unix_socket=None,
file_descriptor=None,
signal_handlers=True,
@ -28,15 +30,23 @@ class Server(object):
root_path="",
):
self.channel_layer = channel_layer
self.host = host
self.port = port
self.unix_socket = unix_socket
self.file_descriptor = file_descriptor
self.endpoints = sorted(endpoints + self.build_endpoint_description_strings(
host=host,
port=port,
unix_socket=unix_socket,
file_descriptor=file_descriptor
))
if len(self.endpoints) == 0:
raise UserWarning("No endpoints. This server will not listen on anything.")
self.signal_handlers = signal_handlers
self.action_logger = action_logger
self.http_timeout = http_timeout
self.ping_interval = ping_interval
self.ping_timeout = ping_timeout
# If they did not provide a websocket timeout, default it to the
# channel layer's group_expiry value if present, or one day if not.
self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400)
@ -54,17 +64,9 @@ class Server(object):
ws_protocols=self.ws_protocols,
root_path=self.root_path,
)
# Redirect the Twisted log to nowhere
globalLogBeginner.beginLoggingTo([lambda _: None], redirectStandardIO=False, discardBuffer=True)
# Listen on a socket
if self.unix_socket:
reactor.listenUNIX(self.unix_socket, self.factory)
elif self.file_descriptor:
# socket returns the same socket if supplied with a fileno
sock = socket.socket(fileno=self.file_descriptor)
reactor.adoptStreamPort(self.file_descriptor, sock.family, self.factory)
else:
reactor.listenTCP(self.port, self.factory, interface=self.host)
if "twisted" in self.channel_layer.extensions and False:
logger.info("Using native Twisted mode on channel layer")
@ -73,6 +75,12 @@ class Server(object):
logger.info("Using busy-loop synchronous mode on channel layer")
reactor.callLater(0, self.backend_reader_sync)
reactor.callLater(2, self.timeout_checker)
for socket_description in self.endpoints:
logger.info("Listening on endpoint %s" % socket_description)
ep = serverFromString(reactor, socket_description)
ep.listen(self.factory)
reactor.run(installSignalHandlers=self.signal_handlers)
def backend_reader_sync(self):
@ -135,3 +143,34 @@ class Server(object):
"""
self.factory.check_timeouts()
reactor.callLater(2, self.timeout_checker)
@staticmethod
def build_endpoint_description_strings(
host=None,
port=None,
unix_socket=None,
file_descriptor=None
):
"""
Build a list of twisted endpoint description strings that the server will listen on
"""
socket_descriptions = []
if host and port:
socket_descriptions.append('tcp:port=%d:interface=%s' % (int(port), host))
elif any([host, port]):
raise ValueError('TCP binding requires both port and host kwargs.')
if unix_socket:
socket_descriptions.append('unix:%s' % unix_socket)
if file_descriptor:
socket_descriptions.append('fd:domain=INET:fileno=%d' % int(file_descriptor))
return socket_descriptions

3
daphne/tests/asgi.py Normal file
View File

@ -0,0 +1,3 @@
# coding=utf-8
channel_layer = {}

View File

@ -0,0 +1,162 @@
# coding: utf8
from __future__ import unicode_literals
from unittest import TestCase
from six import string_types
from ..server import Server
from ..cli import CommandLineInterface
# this is the callable that will be tested here
build = Server.build_endpoint_description_strings
# patch out the servers run function
Server.run = lambda x: x
class TestEndpointDescriptions(TestCase):
def testBasics(self):
self.assertEqual(build(), [], msg="Empty list returned when no kwargs given")
def testTcpPortBindings(self):
self.assertEqual(
build(port=1234, host='example.com'),
['tcp:port=1234:interface=example.com']
)
self.assertEqual(
build(port=8000, host='127.0.0.1'),
['tcp:port=8000:interface=127.0.0.1']
)
# incomplete port/host kwargs raise errors
self.assertRaises(
ValueError,
build, port=123
)
self.assertRaises(
ValueError,
build, host='example.com'
)
def testUnixSocketBinding(self):
self.assertEqual(
build(unix_socket='/tmp/daphne.sock'),
['unix:/tmp/daphne.sock']
)
def testFileDescriptorBinding(self):
self.assertEqual(
build(file_descriptor=5),
['fd:domain=INET:fileno=5']
)
def testMultipleEnpoints(self):
self.assertEqual(
sorted(
build(
file_descriptor=123,
unix_socket='/tmp/daphne.sock',
port=8080,
host='10.0.0.1'
)
),
sorted([
'tcp:port=8080:interface=10.0.0.1',
'unix:/tmp/daphne.sock',
'fd:domain=INET:fileno=123'
])
)
class TestCLIInterface(TestCase):
# construct a string that will be accepted as the channel_layer argument
_import_channel_layer_string = '.'.join(
__loader__.name.split('.')[:-1] +
['asgi:channel_layer']
)
def build_cli(self, cli_args=''):
# split the string and append the channel_layer positional argument
if isinstance(cli_args, string_types):
cli_args = cli_args.split()
args = cli_args + [self._import_channel_layer_string]
cli = CommandLineInterface()
cli.run(args)
return cli
def get_endpoints(self, cli_args=''):
cli = self.build_cli(cli_args=cli_args)
return cli.server.endpoints
def checkCLI(self, args='', endpoints=[], msg='Expected endpoints do not match.'):
cli = self.build_cli(cli_args=args)
generated_endpoints = sorted(cli.server.endpoints)
endpoints.sort()
self.assertEqual(
generated_endpoints,
endpoints,
msg=msg
)
def testCLIBasics(self):
self.checkCLI(
'',
['tcp:port=8000:interface=127.0.0.1']
)
self.checkCLI(
'-p 123',
['tcp:port=123:interface=127.0.0.1']
)
self.checkCLI(
'-b 10.0.0.1',
['tcp:port=8000:interface=10.0.0.1']
)
self.checkCLI(
'-p 8080 -b example.com',
['tcp:port=8080:interface=example.com']
)
def testMixedCLIEndpoints(self):
self.checkCLI(
'-p 8080 -u /tmp/daphne.sock',
[
'tcp:port=8080:interface=127.0.0.1',
'unix:/tmp/daphne.sock',
],
'Default binding host patched in when only port given'
)
self.checkCLI(
'-b example.com -u /tmp/daphne.sock',
[
'tcp:port=8000:interface=example.com',
'unix:/tmp/daphne.sock',
],
'Default port patched in when missing.'
)
self.checkCLI(
'-u /tmp/daphne.sock --fd 5',
[
'fd:domain=INET:fileno=5',
'unix:/tmp/daphne.sock'
],
'File descriptor and unix socket bound, TCP ignored.'
)
def testCustomEndpoints(self):
self.checkCLI(
'-e imap:',
['imap:']
)

View File

@ -0,0 +1,24 @@
from twisted.plugin import IPlugin
from zope.interface import implementer
from twisted.internet.interfaces import IStreamServerEndpointStringParser
from twisted.internet import endpoints
import socket
@implementer(IPlugin, IStreamServerEndpointStringParser)
class _FDParser(object):
prefix = "fd"
def _parseServer(self, reactor, fileno, domain=socket.AF_INET):
fileno = int(fileno)
return endpoints.AdoptedStreamServerEndpoint(reactor, fileno, domain)
def parseStreamServer(self, reactor, *args, **kwargs):
# Delegate to another function with a sane signature. This function has
# an insane signature to trick zope.interface into believing the
# interface is correctly implemented.
return self._parseServer(reactor, *args, **kwargs)
parser = _FDParser()

View File

@ -3,7 +3,6 @@ import sys
from setuptools import find_packages, setup
from daphne import __version__
# We use the README as the long_description
readme_path = os.path.join(os.path.dirname(__file__), "README.rst")
@ -18,7 +17,8 @@ setup(
long_description=open(readme_path).read(),
license='BSD',
zip_safe=False,
packages=find_packages(),
package_dir={'twisted': 'daphne/twisted'},
packages=find_packages() + ['twisted.plugins'],
include_package_data=True,
install_requires=[
'asgiref>=0.13',