Merge pull request #37 from mcallistersean/ticket_10

use twisted endpoint description strings to bind to ports and sockets
This commit is contained in:
Andrew Godwin 2016-11-14 09:29:32 -08:00 committed by GitHub
commit fd83678276
7 changed files with 323 additions and 30 deletions

View File

@ -37,6 +37,17 @@ To achieve this you can use the --fd flag::
daphne --fd 5 django_project.asgi:channel_layer daphne --fd 5 django_project.asgi:channel_layer
If you want more control over the port/socket bindings you can fall back to
using `twisted's endpoint description strings
<http://twistedmatrix.com/documents/current/api/twisted.internet.endpoints.html#serverFromString>`_
by using the `--endpoint (-e)` flag, which can be used multiple times.
This line would start a SSL server on port 443, assuming that `key.pem` and `crt.pem`
exist in the current directory (requires pyopenssl to be installed)::
daphne -e ssl:443:privateKey=key.pem:certKey=crt.pem django_project.asgi:channel_layer
To see all available command line options run daphne with the *-h* flag. To see all available command line options run daphne with the *-h* flag.
Root Path (SCRIPT_NAME) Root Path (SCRIPT_NAME)

View File

@ -2,12 +2,14 @@ import sys
import argparse import argparse
import logging import logging
import importlib import importlib
from .server import Server from .server import Server, build_endpoint_description_strings
from .access import AccessLogGenerator from .access import AccessLogGenerator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_HOST = '127.0.0.1'
DEFAULT_PORT = 8000
class CommandLineInterface(object): class CommandLineInterface(object):
""" """
@ -25,14 +27,14 @@ class CommandLineInterface(object):
'--port', '--port',
type=int, type=int,
help='Port number to listen on', help='Port number to listen on',
default=8000, default=None,
) )
self.parser.add_argument( self.parser.add_argument(
'-b', '-b',
'--bind', '--bind',
dest='host', dest='host',
help='The host/address to bind to', help='The host/address to bind to',
default="127.0.0.1", default=None,
) )
self.parser.add_argument( self.parser.add_argument(
'-u', '-u',
@ -48,6 +50,14 @@ class CommandLineInterface(object):
help='Bind to a file descriptor rather than a TCP host/port or named unix socket', help='Bind to a file descriptor rather than a TCP host/port or named unix socket',
default=None, default=None,
) )
self.parser.add_argument(
'-e',
'--endpoint',
dest='socket_strings',
action='append',
help='Use raw server strings passed directly to twisted',
default=[],
)
self.parser.add_argument( self.parser.add_argument(
'-v', '-v',
'--verbosity', '--verbosity',
@ -105,6 +115,8 @@ class CommandLineInterface(object):
action='store_true', action='store_true',
) )
self.server = None
@classmethod @classmethod
def entrypoint(cls): def entrypoint(cls):
""" """
@ -143,18 +155,34 @@ class CommandLineInterface(object):
channel_layer = importlib.import_module(module_path) channel_layer = importlib.import_module(module_path)
for bit in object_path.split("."): for bit in object_path.split("."):
channel_layer = getattr(channel_layer, bit) channel_layer = getattr(channel_layer, bit)
# Run server
logger.info( if not any([args.host, args.port, args.unix_socket, args.file_descriptor, args.socket_strings]):
"Starting server at %s, channel layer %s", # no advanced binding options passed, patch in defaults
(args.unix_socket if args.unix_socket else "%s:%s" % (args.host, args.port)), args.host = DEFAULT_HOST
args.channel_layer, args.port = DEFAULT_PORT
) elif args.host and not args.port:
Server( args.port = DEFAULT_PORT
channel_layer=channel_layer, elif args.port and not args.host:
args.host = DEFAULT_HOST
# build endpoint description strings from (optional) cli arguments
endpoints = build_endpoint_description_strings(
host=args.host, host=args.host,
port=args.port, port=args.port,
unix_socket=args.unix_socket, unix_socket=args.unix_socket,
file_descriptor=args.file_descriptor, file_descriptor=args.file_descriptor
)
endpoints = sorted(
args.socket_strings + endpoints
)
logger.info(
'Starting server at %s, channel layer %s.' %
(', '.join(endpoints), args.channel_layer)
)
self.server = Server(
channel_layer=channel_layer,
endpoints=endpoints,
http_timeout=args.http_timeout, http_timeout=args.http_timeout,
ping_interval=args.ping_interval, ping_interval=args.ping_interval,
ping_timeout=args.ping_timeout, ping_timeout=args.ping_timeout,
@ -164,4 +192,5 @@ class CommandLineInterface(object):
verbosity=args.verbosity, verbosity=args.verbosity,
proxy_forwarded_address_header='X-Forwarded-For' if args.proxy_headers else None, proxy_forwarded_address_header='X-Forwarded-For' if args.proxy_headers else None,
proxy_forwarded_port_header='X-Forwarded-Port' if args.proxy_headers else None, proxy_forwarded_port_header='X-Forwarded-Port' if args.proxy_headers else None,
).run() )
self.server.run()

View File

@ -3,6 +3,7 @@ import socket
from twisted.internet import reactor, defer from twisted.internet import reactor, defer
from twisted.logger import globalLogBeginner, STDLibLogObserver from twisted.logger import globalLogBeginner, STDLibLogObserver
from twisted.internet.endpoints import serverFromString
from .http_protocol import HTTPFactory from .http_protocol import HTTPFactory
@ -14,8 +15,9 @@ class Server(object):
def __init__( def __init__(
self, self,
channel_layer, channel_layer,
host="127.0.0.1", host=None,
port=8000, port=None,
endpoints=[],
unix_socket=None, unix_socket=None,
file_descriptor=None, file_descriptor=None,
signal_handlers=True, signal_handlers=True,
@ -31,10 +33,23 @@ class Server(object):
verbosity=1 verbosity=1
): ):
self.channel_layer = channel_layer self.channel_layer = channel_layer
self.host = host self.endpoints = endpoints
self.port = port
self.unix_socket = unix_socket if any([host, port, unix_socket, file_descriptor]):
self.file_descriptor = file_descriptor raise DeprecationWarning('''
The host/port/unix_socket/file_descriptor keyword arguments to %s are deprecated.
''' % self.__class__.__name__)
# build endpoint description strings from deprecated kwargs
self.endpoints = sorted(self.endpoints + build_endpoint_description_strings(
host=host,
port=port,
unix_socket=unix_socket,
file_descriptor=file_descriptor
))
if len(self.endpoints) == 0:
raise UserWarning("No endpoints. This server will not listen on anything.")
self.signal_handlers = signal_handlers self.signal_handlers = signal_handlers
self.action_logger = action_logger self.action_logger = action_logger
self.http_timeout = http_timeout self.http_timeout = http_timeout
@ -67,15 +82,6 @@ class Server(object):
globalLogBeginner.beginLoggingTo([lambda _: None], redirectStandardIO=False, discardBuffer=True) globalLogBeginner.beginLoggingTo([lambda _: None], redirectStandardIO=False, discardBuffer=True)
else: else:
globalLogBeginner.beginLoggingTo([STDLibLogObserver(__name__)]) globalLogBeginner.beginLoggingTo([STDLibLogObserver(__name__)])
# Listen on a socket
if self.unix_socket:
reactor.listenUNIX(self.unix_socket, self.factory)
elif self.file_descriptor:
# socket returns the same socket if supplied with a fileno
sock = socket.socket(fileno=self.file_descriptor)
reactor.adoptStreamPort(self.file_descriptor, sock.family, self.factory)
else:
reactor.listenTCP(self.port, self.factory, interface=self.host)
if "twisted" in self.channel_layer.extensions and False: if "twisted" in self.channel_layer.extensions and False:
logger.info("Using native Twisted mode on channel layer") logger.info("Using native Twisted mode on channel layer")
@ -84,6 +90,12 @@ class Server(object):
logger.info("Using busy-loop synchronous mode on channel layer") logger.info("Using busy-loop synchronous mode on channel layer")
reactor.callLater(0, self.backend_reader_sync) reactor.callLater(0, self.backend_reader_sync)
reactor.callLater(2, self.timeout_checker) reactor.callLater(2, self.timeout_checker)
for socket_description in self.endpoints:
logger.info("Listening on endpoint %s" % socket_description)
ep = serverFromString(reactor, socket_description)
ep.listen(self.factory)
reactor.run(installSignalHandlers=self.signal_handlers) reactor.run(installSignalHandlers=self.signal_handlers)
def backend_reader_sync(self): def backend_reader_sync(self):
@ -156,3 +168,35 @@ class Server(object):
""" """
self.factory.check_timeouts() self.factory.check_timeouts()
reactor.callLater(2, self.timeout_checker) reactor.callLater(2, self.timeout_checker)
def build_endpoint_description_strings(
host=None,
port=None,
unix_socket=None,
file_descriptor=None
):
"""
Build a list of twisted endpoint description strings that the server will listen on.
This is to streamline the generation of twisted endpoint description strings from easier
to use command line args such as host, port, unix sockets etc.
"""
socket_descriptions = []
if host and port:
socket_descriptions.append('tcp:port=%d:interface=%s' % (int(port), host))
elif any([host, port]):
raise ValueError('TCP binding requires both port and host kwargs.')
if unix_socket:
socket_descriptions.append('unix:%s' % unix_socket)
if file_descriptor:
socket_descriptions.append('fd:domain=INET:fileno=%d' % int(file_descriptor))
return socket_descriptions

3
daphne/tests/asgi.py Normal file
View File

@ -0,0 +1,3 @@
# coding=utf-8
channel_layer = {}

View File

@ -0,0 +1,182 @@
# coding: utf8
from __future__ import unicode_literals
from unittest import TestCase
from six import string_types
import logging
from ..server import Server, build_endpoint_description_strings
from ..cli import CommandLineInterface
# this is the callable that will be tested here
build = build_endpoint_description_strings
class TestEndpointDescriptions(TestCase):
def testBasics(self):
self.assertEqual(build(), [], msg="Empty list returned when no kwargs given")
def testTcpPortBindings(self):
self.assertEqual(
build(port=1234, host='example.com'),
['tcp:port=1234:interface=example.com']
)
self.assertEqual(
build(port=8000, host='127.0.0.1'),
['tcp:port=8000:interface=127.0.0.1']
)
# incomplete port/host kwargs raise errors
self.assertRaises(
ValueError,
build, port=123
)
self.assertRaises(
ValueError,
build, host='example.com'
)
def testUnixSocketBinding(self):
self.assertEqual(
build(unix_socket='/tmp/daphne.sock'),
['unix:/tmp/daphne.sock']
)
def testFileDescriptorBinding(self):
self.assertEqual(
build(file_descriptor=5),
['fd:domain=INET:fileno=5']
)
def testMultipleEnpoints(self):
self.assertEqual(
sorted(
build(
file_descriptor=123,
unix_socket='/tmp/daphne.sock',
port=8080,
host='10.0.0.1'
)
),
sorted([
'tcp:port=8080:interface=10.0.0.1',
'unix:/tmp/daphne.sock',
'fd:domain=INET:fileno=123'
])
)
class TestCLIInterface(TestCase):
# construct a string that will be accepted as the channel_layer argument
_import_channel_layer_string = 'daphne.tests.asgi:channel_layer'
def setUp(self):
logging.disable(logging.CRITICAL)
# patch out the servers run method
self._default_server_run = Server.run
Server.run = lambda x: x
def tearDown(self):
logging.disable(logging.NOTSET)
# restore the original server run method
Server.run = self._default_server_run
def build_cli(self, cli_args=''):
# split the string and append the channel_layer positional argument
if isinstance(cli_args, string_types):
cli_args = cli_args.split()
args = cli_args + [self._import_channel_layer_string]
cli = CommandLineInterface()
cli.run(args)
return cli
def get_endpoints(self, cli_args=''):
cli = self.build_cli(cli_args=cli_args)
return cli.server.endpoints
def checkCLI(self, args='', endpoints=[], msg='Expected endpoints do not match.'):
cli = self.build_cli(cli_args=args)
generated_endpoints = sorted(cli.server.endpoints)
endpoints.sort()
self.assertEqual(
generated_endpoints,
endpoints,
msg=msg
)
def testCLIBasics(self):
self.checkCLI(
'',
['tcp:port=8000:interface=127.0.0.1']
)
self.checkCLI(
'-p 123',
['tcp:port=123:interface=127.0.0.1']
)
self.checkCLI(
'-b 10.0.0.1',
['tcp:port=8000:interface=10.0.0.1']
)
self.checkCLI(
'-p 8080 -b example.com',
['tcp:port=8080:interface=example.com']
)
def testCLIEndpointCreation(self):
self.checkCLI(
'-p 8080 -u /tmp/daphne.sock',
[
'tcp:port=8080:interface=127.0.0.1',
'unix:/tmp/daphne.sock',
],
'Default binding host patched in when only port given'
)
self.checkCLI(
'-b example.com -u /tmp/daphne.sock',
[
'tcp:port=8000:interface=example.com',
'unix:/tmp/daphne.sock',
],
'Default port patched in when missing.'
)
self.checkCLI(
'-u /tmp/daphne.sock --fd 5',
[
'fd:domain=INET:fileno=5',
'unix:/tmp/daphne.sock'
],
'File descriptor and unix socket bound, TCP ignored.'
)
def testMixedCLIEndpointCreation(self):
self.checkCLI(
'-p 8080 -e unix:/tmp/daphne.sock',
[
'tcp:port=8080:interface=127.0.0.1',
'unix:/tmp/daphne.sock'
],
'Mix host/port args with endpoint args'
)
self.checkCLI(
'-p 8080 -e tcp:port=8080:interface=127.0.0.1',
[
'tcp:port=8080:interface=127.0.0.1',
] * 2,
'Do not try to de-duplicate endpoint description strings.'
'This would fail when running the server.'
)
def testCustomEndpoints(self):
self.checkCLI(
'-e imap:',
['imap:']
)

View File

@ -0,0 +1,24 @@
from twisted.plugin import IPlugin
from zope.interface import implementer
from twisted.internet.interfaces import IStreamServerEndpointStringParser
from twisted.internet import endpoints
import socket
@implementer(IPlugin, IStreamServerEndpointStringParser)
class _FDParser(object):
prefix = "fd"
def _parseServer(self, reactor, fileno, domain=socket.AF_INET):
fileno = int(fileno)
return endpoints.AdoptedStreamServerEndpoint(reactor, fileno, domain)
def parseStreamServer(self, reactor, *args, **kwargs):
# Delegate to another function with a sane signature. This function has
# an insane signature to trick zope.interface into believing the
# interface is correctly implemented.
return self._parseServer(reactor, *args, **kwargs)
parser = _FDParser()

View File

@ -4,7 +4,6 @@ from setuptools import find_packages, setup
from daphne import __version__ from daphne import __version__
# We use the README as the long_description # We use the README as the long_description
readme_path = os.path.join(os.path.dirname(__file__), "README.rst") readme_path = os.path.join(os.path.dirname(__file__), "README.rst")
with open(readme_path) as fp: with open(readme_path) as fp:
@ -20,7 +19,8 @@ setup(
long_description=long_description, long_description=long_description,
license='BSD', license='BSD',
zip_safe=False, zip_safe=False,
packages=find_packages(), package_dir={'twisted': 'daphne/twisted'},
packages=find_packages() + ['twisted.plugins'],
include_package_data=True, include_package_data=True,
install_requires=[ install_requires=[
'asgiref>=0.13', 'asgiref>=0.13',