Change to scope-based code

This commit is contained in:
Andrew Godwin 2017-11-12 16:32:30 -08:00
parent 01f174bf26
commit 017797c05b
4 changed files with 51 additions and 52 deletions

View File

@ -42,7 +42,7 @@ class CommandLineInterface(object):
'--websocket_timeout', '--websocket_timeout',
type=int, type=int,
help='Maximum time to allow a websocket to be connected. -1 for infinite.', help='Maximum time to allow a websocket to be connected. -1 for infinite.',
default=None, default=86400,
) )
self.parser.add_argument( self.parser.add_argument(
'--websocket_connect_timeout', '--websocket_connect_timeout',

View File

@ -30,8 +30,6 @@ class WebRequest(http.Request):
GET and POST out. GET and POST out.
""" """
application_type = "http"
error_template = """ error_template = """
<html> <html>
<head> <head>
@ -138,8 +136,6 @@ class WebRequest(http.Request):
# Boring old HTTP. # Boring old HTTP.
else: else:
# Create application to handle this connection
self.application_queue = self.server.create_application(self)
# Sanitize and decode headers, potentially extracting root path # Sanitize and decode headers, potentially extracting root path
self.clean_headers = [] self.clean_headers = []
self.root_path = self.server.root_path self.root_path = self.server.root_path
@ -154,10 +150,9 @@ class WebRequest(http.Request):
self.clean_headers.append((name.lower(), value)) self.clean_headers.append((name.lower(), value))
logger.debug("HTTP %s request for %s", self.method, self.client_addr) logger.debug("HTTP %s request for %s", self.method, self.client_addr)
self.content.seek(0, 0) self.content.seek(0, 0)
# Run application against request # Work out the application scope and create application
self.application_queue.put_nowait( self.application_queue = self.server.create_application(self, {
{ "type": "http",
"type": "http.request",
# TODO: Correctly say if it's 1.1 or 1.0 # TODO: Correctly say if it's 1.1 or 1.0
"http_version": self.clientproto.split(b"/")[-1].decode("ascii"), "http_version": self.clientproto.split(b"/")[-1].decode("ascii"),
"method": self.method.decode("ascii"), "method": self.method.decode("ascii"),
@ -166,9 +161,14 @@ class WebRequest(http.Request):
"scheme": "https" if self.isSecure() else "http", "scheme": "https" if self.isSecure() else "http",
"query_string": self.query_string, "query_string": self.query_string,
"headers": self.clean_headers, "headers": self.clean_headers,
"body": self.content.read(),
"client": self.client_addr, "client": self.client_addr,
"server": self.server_addr, "server": self.server_addr,
})
# Run application against request
self.application_queue.put_nowait(
{
"type": "http.request",
"body": self.content.read(),
}, },
) )
except Exception: except Exception:

View File

@ -91,7 +91,7 @@ class Server(object):
reactor.callLater(2, self.timeout_checker) reactor.callLater(2, self.timeout_checker)
for socket_description in self.endpoints: for socket_description in self.endpoints:
logger.info("Listening on endpoint %s" % socket_description) logger.info("Listening on endpoint %s", socket_description)
ep = serverFromString(reactor, str(socket_description)) ep = serverFromString(reactor, str(socket_description))
self.listeners.append(ep.listen(self.http_factory)) self.listeners.append(ep.listen(self.http_factory))
@ -123,25 +123,23 @@ class Server(object):
### Internal event/message handling ### Internal event/message handling
def create_application(self, protocol): def create_application(self, protocol, scope):
""" """
Creates a new application instance that fronts a Protocol instance Creates a new application instance that fronts a Protocol instance
for one of our supported protocols. Pass it the protocol, for one of our supported protocols. Pass it the protocol,
and it will work out the type, supply appropriate callables, and and it will work out the type, supply appropriate callables, and
return you the application's input queue return you the application's input queue
""" """
# Make sure the protocol defines a application type
assert protocol.application_type is not None
# Make sure the protocol has not had another application made for it # Make sure the protocol has not had another application made for it
assert protocol not in self.application_instances assert protocol not in self.application_instances
# Make an instance of the application # Make an instance of the application
input_queue = asyncio.Queue() input_queue = asyncio.Queue()
application_instance = asyncio.ensure_future(self.application( application_instance = self.application(scope=scope)
type=protocol.application_type, # Run it, and stash the future for later checking
next=input_queue.get, self.application_instances[protocol] = asyncio.ensure_future(application_instance(
reply=lambda message: self.handle_reply(protocol, message), receive=input_queue.get,
send=lambda message: self.handle_reply(protocol, message),
), loop=asyncio.get_event_loop()) ), loop=asyncio.get_event_loop())
self.application_instances[protocol] = application_instance
return input_queue return input_queue
async def handle_reply(self, protocol, message): async def handle_reply(self, protocol, message):
@ -188,7 +186,7 @@ class Server(object):
if not application_instance.done(): if not application_instance.done():
application_instance.cancel() application_instance.cancel()
wait_for.append(application_instance) wait_for.append(application_instance)
logging.info("Killed %i pending application instances" % len(wait_for)) logging.info("Killed %i pending application instances", len(wait_for))
# Make Twisted wait until they're all dead # Make Twisted wait until they're all dead
wait_deferred = defer.Deferred.fromFuture(asyncio.gather(*wait_for)) wait_deferred = defer.Deferred.fromFuture(asyncio.gather(*wait_for))
wait_deferred.addErrback(lambda x: None) wait_deferred.addErrback(lambda x: None)
@ -199,7 +197,7 @@ class Server(object):
Called periodically to enforce timeout rules on all connections. Called periodically to enforce timeout rules on all connections.
Also checks pings at the same time. Also checks pings at the same time.
""" """
for protocol in self.protocols: for protocol in list(self.protocols):
protocol.check_timeouts() protocol.check_timeouts()
reactor.callLater(2, self.timeout_checker) reactor.callLater(2, self.timeout_checker)

View File

@ -33,8 +33,6 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.socket_opened = time.time() self.socket_opened = time.time()
self.last_data = time.time() self.last_data = time.time()
try: try:
# Make new application instance
self.application_queue = self.server.create_application(self)
# Sanitize and decode headers # Sanitize and decode headers
self.clean_headers = [] self.clean_headers = []
for name, value in request.headers.items(): for name, value in request.headers.items():
@ -60,41 +58,31 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.server.proxy_forwarded_port_header, self.server.proxy_forwarded_port_header,
self.client_addr self.client_addr
) )
# Decode websocket subprotocol options
# Make initial request info dict from request (we only have it here) subprotocols = []
for header, value in self.clean_headers:
if header == b'sec-websocket-protocol':
subprotocols = [x.strip() for x in self.unquote(value).split(",")]
# Make new application instance with scope
self.path = request.path.encode("ascii") self.path = request.path.encode("ascii")
self.connect_message = { self.application_queue = self.server.create_application(self, {
"type": "websocket.connect", "type": "websocket",
"path": self.unquote(self.path), "path": self.unquote(self.path),
"headers": self.clean_headers, "headers": self.clean_headers,
"query_string": self._raw_query_string, # Passed by HTTP protocol "query_string": self._raw_query_string, # Passed by HTTP protocol
"client": self.client_addr, "client": self.client_addr,
"server": self.server_addr, "server": self.server_addr,
"subprotocols": subprotocols,
"order": 0, "order": 0,
} })
except: except:
# Exceptions here are not displayed right, just 500. # Exceptions here are not displayed right, just 500.
# Turn them into an ERROR log. # Turn them into an ERROR log.
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
raise raise
ws_protocol = None
for header, value in self.clean_headers:
if header == b'sec-websocket-protocol':
protocols = [x.strip() for x in self.unquote(value).split(",")]
for protocol in protocols:
if protocol in self.server.websocket_protocols:
ws_protocol = protocol
break
# Work out what subprotocol we will accept, if any
if ws_protocol and ws_protocol in self.server.websocket_protocols:
self.protocol_to_accept = ws_protocol
else:
self.protocol_to_accept = None
# Send over the connect message # Send over the connect message
self.application_queue.put_nowait(self.connect_message) self.application_queue.put_nowait({"type": "websocket.connect"})
self.server.log_action("websocket", "connecting", { self.server.log_action("websocket", "connecting", {
"path": self.request.path, "path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
@ -154,7 +142,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
if "type" not in message: if "type" not in message:
raise ValueError("Message has no type defined") raise ValueError("Message has no type defined")
if message["type"] == "websocket.accept": if message["type"] == "websocket.accept":
self.serverAccept() self.serverAccept(message.get("subprotocol", None))
elif message["type"] == "websocket.close": elif message["type"] == "websocket.close":
if self.state == self.STATE_CONNECTING: if self.state == self.STATE_CONNECTING:
self.serverReject() self.serverReject()
@ -174,11 +162,23 @@ class WebSocketProtocol(WebSocketServerProtocol):
if message.get("text", None): if message.get("text", None):
self.serverSend(message["text"], False) self.serverSend(message["text"], False)
def serverAccept(self): def handle_exception(self, exception):
"""
Called by the server when our application tracebacks
"""
if hasattr(self, "handshake_deferred"):
# If the handshake is still ongoing, we need to emit a HTTP error
# code rather than a WebSocket one.
self.handshake_deferred.errback(ConnectionDeny(code=500, reason="Internal server error"))
else:
self.sendCloseFrame(code=1011)
def serverAccept(self, subprotocol=None):
""" """
Called when we get a message saying to accept the connection. Called when we get a message saying to accept the connection.
""" """
self.handshake_deferred.callback(self.protocol_to_accept) self.handshake_deferred.callback(subprotocol)
del self.handshake_deferred
logger.debug("WebSocket %s accepted by application", self.client_addr) logger.debug("WebSocket %s accepted by application", self.client_addr)
def serverReject(self): def serverReject(self):
@ -186,6 +186,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
Called when we get a message saying to reject the connection. Called when we get a message saying to reject the connection.
""" """
self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied")) self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied"))
del self.handshake_deferred
self.server.discard_protocol(self) self.server.discard_protocol(self)
logger.debug("WebSocket %s rejected by application", self.client_addr) logger.debug("WebSocket %s rejected by application", self.client_addr)
self.server.log_action("websocket", "rejected", { self.server.log_action("websocket", "rejected", {