Change to scope-based code

This commit is contained in:
Andrew Godwin 2017-11-12 16:32:30 -08:00
parent 01f174bf26
commit 017797c05b
4 changed files with 51 additions and 52 deletions

View File

@ -42,7 +42,7 @@ class CommandLineInterface(object):
'--websocket_timeout',
type=int,
help='Maximum time to allow a websocket to be connected. -1 for infinite.',
default=None,
default=86400,
)
self.parser.add_argument(
'--websocket_connect_timeout',

View File

@ -30,8 +30,6 @@ class WebRequest(http.Request):
GET and POST out.
"""
application_type = "http"
error_template = """
<html>
<head>
@ -138,8 +136,6 @@ class WebRequest(http.Request):
# Boring old HTTP.
else:
# Create application to handle this connection
self.application_queue = self.server.create_application(self)
# Sanitize and decode headers, potentially extracting root path
self.clean_headers = []
self.root_path = self.server.root_path
@ -154,21 +150,25 @@ class WebRequest(http.Request):
self.clean_headers.append((name.lower(), value))
logger.debug("HTTP %s request for %s", self.method, self.client_addr)
self.content.seek(0, 0)
# Work out the application scope and create application
self.application_queue = self.server.create_application(self, {
"type": "http",
# TODO: Correctly say if it's 1.1 or 1.0
"http_version": self.clientproto.split(b"/")[-1].decode("ascii"),
"method": self.method.decode("ascii"),
"path": self.unquote(self.path),
"root_path": self.root_path,
"scheme": "https" if self.isSecure() else "http",
"query_string": self.query_string,
"headers": self.clean_headers,
"client": self.client_addr,
"server": self.server_addr,
})
# Run application against request
self.application_queue.put_nowait(
{
"type": "http.request",
# TODO: Correctly say if it's 1.1 or 1.0
"http_version": self.clientproto.split(b"/")[-1].decode("ascii"),
"method": self.method.decode("ascii"),
"path": self.unquote(self.path),
"root_path": self.root_path,
"scheme": "https" if self.isSecure() else "http",
"query_string": self.query_string,
"headers": self.clean_headers,
"body": self.content.read(),
"client": self.client_addr,
"server": self.server_addr,
},
)
except Exception:

View File

@ -91,7 +91,7 @@ class Server(object):
reactor.callLater(2, self.timeout_checker)
for socket_description in self.endpoints:
logger.info("Listening on endpoint %s" % socket_description)
logger.info("Listening on endpoint %s", socket_description)
ep = serverFromString(reactor, str(socket_description))
self.listeners.append(ep.listen(self.http_factory))
@ -123,25 +123,23 @@ class Server(object):
### Internal event/message handling
def create_application(self, protocol):
def create_application(self, protocol, scope):
"""
Creates a new application instance that fronts a Protocol instance
for one of our supported protocols. Pass it the protocol,
and it will work out the type, supply appropriate callables, and
return you the application's input queue
"""
# Make sure the protocol defines a application type
assert protocol.application_type is not None
# Make sure the protocol has not had another application made for it
assert protocol not in self.application_instances
# Make an instance of the application
input_queue = asyncio.Queue()
application_instance = asyncio.ensure_future(self.application(
type=protocol.application_type,
next=input_queue.get,
reply=lambda message: self.handle_reply(protocol, message),
application_instance = self.application(scope=scope)
# Run it, and stash the future for later checking
self.application_instances[protocol] = asyncio.ensure_future(application_instance(
receive=input_queue.get,
send=lambda message: self.handle_reply(protocol, message),
), loop=asyncio.get_event_loop())
self.application_instances[protocol] = application_instance
return input_queue
async def handle_reply(self, protocol, message):
@ -188,7 +186,7 @@ class Server(object):
if not application_instance.done():
application_instance.cancel()
wait_for.append(application_instance)
logging.info("Killed %i pending application instances" % len(wait_for))
logging.info("Killed %i pending application instances", len(wait_for))
# Make Twisted wait until they're all dead
wait_deferred = defer.Deferred.fromFuture(asyncio.gather(*wait_for))
wait_deferred.addErrback(lambda x: None)
@ -199,7 +197,7 @@ class Server(object):
Called periodically to enforce timeout rules on all connections.
Also checks pings at the same time.
"""
for protocol in self.protocols:
for protocol in list(self.protocols):
protocol.check_timeouts()
reactor.callLater(2, self.timeout_checker)

View File

@ -33,8 +33,6 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.socket_opened = time.time()
self.last_data = time.time()
try:
# Make new application instance
self.application_queue = self.server.create_application(self)
# Sanitize and decode headers
self.clean_headers = []
for name, value in request.headers.items():
@ -60,41 +58,31 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.server.proxy_forwarded_port_header,
self.client_addr
)
# Make initial request info dict from request (we only have it here)
# Decode websocket subprotocol options
subprotocols = []
for header, value in self.clean_headers:
if header == b'sec-websocket-protocol':
subprotocols = [x.strip() for x in self.unquote(value).split(",")]
# Make new application instance with scope
self.path = request.path.encode("ascii")
self.connect_message = {
"type": "websocket.connect",
self.application_queue = self.server.create_application(self, {
"type": "websocket",
"path": self.unquote(self.path),
"headers": self.clean_headers,
"query_string": self._raw_query_string, # Passed by HTTP protocol
"client": self.client_addr,
"server": self.server_addr,
"subprotocols": subprotocols,
"order": 0,
}
})
except:
# Exceptions here are not displayed right, just 500.
# Turn them into an ERROR log.
logger.error(traceback.format_exc())
raise
ws_protocol = None
for header, value in self.clean_headers:
if header == b'sec-websocket-protocol':
protocols = [x.strip() for x in self.unquote(value).split(",")]
for protocol in protocols:
if protocol in self.server.websocket_protocols:
ws_protocol = protocol
break
# Work out what subprotocol we will accept, if any
if ws_protocol and ws_protocol in self.server.websocket_protocols:
self.protocol_to_accept = ws_protocol
else:
self.protocol_to_accept = None
# Send over the connect message
self.application_queue.put_nowait(self.connect_message)
self.application_queue.put_nowait({"type": "websocket.connect"})
self.server.log_action("websocket", "connecting", {
"path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
@ -154,7 +142,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
if "type" not in message:
raise ValueError("Message has no type defined")
if message["type"] == "websocket.accept":
self.serverAccept()
self.serverAccept(message.get("subprotocol", None))
elif message["type"] == "websocket.close":
if self.state == self.STATE_CONNECTING:
self.serverReject()
@ -174,11 +162,23 @@ class WebSocketProtocol(WebSocketServerProtocol):
if message.get("text", None):
self.serverSend(message["text"], False)
def serverAccept(self):
def handle_exception(self, exception):
"""
Called by the server when our application tracebacks
"""
if hasattr(self, "handshake_deferred"):
# If the handshake is still ongoing, we need to emit a HTTP error
# code rather than a WebSocket one.
self.handshake_deferred.errback(ConnectionDeny(code=500, reason="Internal server error"))
else:
self.sendCloseFrame(code=1011)
def serverAccept(self, subprotocol=None):
"""
Called when we get a message saying to accept the connection.
"""
self.handshake_deferred.callback(self.protocol_to_accept)
self.handshake_deferred.callback(subprotocol)
del self.handshake_deferred
logger.debug("WebSocket %s accepted by application", self.client_addr)
def serverReject(self):
@ -186,6 +186,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
Called when we get a message saying to reject the connection.
"""
self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied"))
del self.handshake_deferred
self.server.discard_protocol(self)
logger.debug("WebSocket %s rejected by application", self.client_addr)
self.server.log_action("websocket", "rejected", {