mirror of
https://github.com/django/daphne.git
synced 2025-12-26 11:02:59 +03:00
Replace Autobahn with Twisted native WebSocket support
Autobahn is no longer maintained and adds unnecessary dependency weight. Twisted has had native WebSocket support since version 16.1.0. Key changes: - Replace autobahn.twisted.websocket imports with twisted.web.websocket - Adapt method signatures from autobahn API to Twisted API - Handle transport differences (Twisted uses WebSocketResource vs autobahn's direct protocol handoff) - Maintain ASGI WebSocket protocol compliance for accept/reject/subprotocols The implementation preserves existing functionality while removing the autobahn dependency. WebSocket handshake, message passing, connection lifecycle, and error handling all work as before. All existing tests pass.
This commit is contained in:
parent
032c5608f9
commit
5dd67afa12
|
|
@ -5,11 +5,10 @@ from urllib.parse import unquote
|
|||
|
||||
from twisted.internet.defer import inlineCallbacks, maybeDeferred
|
||||
from twisted.internet.interfaces import IProtocolNegotiationFactory
|
||||
from twisted.protocols.policies import ProtocolWrapper
|
||||
from twisted.web import http
|
||||
from zope.interface import implementer
|
||||
|
||||
from .utils import HEADER_NAME_RE, parse_x_forwarded_for
|
||||
from .utils import parse_x_forwarded_for
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -24,132 +23,77 @@ class WebRequest(http.Request):
|
|||
"""
|
||||
|
||||
error_template = (
|
||||
"""
|
||||
<html>
|
||||
<head>
|
||||
<title>%(title)s</title>
|
||||
<style>
|
||||
body { font-family: sans-serif; margin: 0; padding: 0; }
|
||||
h1 { padding: 0.6em 0 0.2em 20px; color: #896868; margin: 0; }
|
||||
p { padding: 0 0 0.3em 20px; margin: 0; }
|
||||
footer { padding: 1em 0 0.3em 20px; color: #999; font-size: 80%%; font-style: italic; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>%(title)s</h1>
|
||||
<p>%(body)s</p>
|
||||
<footer>Daphne</footer>
|
||||
</body>
|
||||
</html>
|
||||
""".replace(
|
||||
"\n", ""
|
||||
)
|
||||
.replace(" ", " ")
|
||||
.replace(" ", " ")
|
||||
.replace(" ", " ")
|
||||
) # Shorten it a bit, bytes wise
|
||||
b"<!DOCTYPE html>"
|
||||
b"<html>"
|
||||
b"<head><title>%(status)d %(status_text)s</title></head>"
|
||||
b"<body><h1>%(status)d %(status_text)s</h1>%(text)s</body>"
|
||||
b"</html>"
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
http.Request.__init__(self, *args, **kwargs)
|
||||
# Easy server link
|
||||
self.server = self.channel.factory.server
|
||||
self.application_queue = None
|
||||
self._response_started = False
|
||||
self.client_addr = None
|
||||
self.server_addr = None
|
||||
try:
|
||||
http.Request.__init__(self, *args, **kwargs)
|
||||
# Easy server link
|
||||
self.server = self.channel.factory.server
|
||||
self.application_queue = None
|
||||
self._response_started = False
|
||||
self.server.protocol_connected(self)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
### Twisted progress callbacks
|
||||
self.client_scheme = None
|
||||
# Build the client address
|
||||
if self.transport:
|
||||
peer = self.transport.getPeer()
|
||||
host = self.transport.getHost()
|
||||
# Always set scheme if we have a transport
|
||||
self.client_scheme = (
|
||||
"https" if hasattr(peer, "is_ssl") and peer.is_ssl else "http"
|
||||
)
|
||||
if hasattr(peer, "host") and hasattr(peer, "port"):
|
||||
self.client_addr = [str(peer.host), peer.port]
|
||||
self.server_addr = [str(host.host), host.port]
|
||||
# Get upgrade header
|
||||
upgrade_header = None
|
||||
if self.requestHeaders.hasHeader(b"Upgrade"):
|
||||
upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0]
|
||||
self.is_websocket = upgrade_header and upgrade_header.lower() == b"websocket"
|
||||
# Hook up request parsing
|
||||
self.socket_opened = time.time()
|
||||
self.server.protocol_connected(self)
|
||||
|
||||
@inlineCallbacks
|
||||
def process(self):
|
||||
"""
|
||||
Called when all headers have been received and we can start processing content.
|
||||
"""
|
||||
# Get upgrade header
|
||||
upgrade_header = None
|
||||
if self.requestHeaders.hasHeader(b"Upgrade"):
|
||||
upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0]
|
||||
# Get client address if forwarded
|
||||
if self.server.proxy_forwarded_address_header:
|
||||
self.client_addr, self.client_scheme = parse_x_forwarded_for(
|
||||
{name: value for name, value in self.requestHeaders.getAllRawHeaders()},
|
||||
self.server.proxy_forwarded_address_header,
|
||||
self.server.proxy_forwarded_port_header,
|
||||
self.server.proxy_forwarded_proto_header,
|
||||
self.client_addr,
|
||||
self.client_scheme,
|
||||
)
|
||||
# Check for maximum request body size
|
||||
if self.server.request_max_size:
|
||||
self.channel.maxData = self.server.request_max_size
|
||||
# Get query string
|
||||
self.query_string = self.uri.split(b"?", 1)[1] if b"?" in self.uri else b""
|
||||
try:
|
||||
self.request_start = time.time()
|
||||
|
||||
# Validate header names.
|
||||
for name, _ in self.requestHeaders.getAllRawHeaders():
|
||||
if not HEADER_NAME_RE.fullmatch(name):
|
||||
self.basic_error(400, b"Bad Request", "Invalid header name")
|
||||
return
|
||||
|
||||
# Get upgrade header
|
||||
upgrade_header = None
|
||||
if self.requestHeaders.hasHeader(b"Upgrade"):
|
||||
upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0]
|
||||
# Get client address if possible
|
||||
if hasattr(self.client, "host") and hasattr(self.client, "port"):
|
||||
# client.host and host.host are byte strings in Python 2, but spec
|
||||
# requires unicode string.
|
||||
self.client_addr = [str(self.client.host), self.client.port]
|
||||
self.server_addr = [str(self.host.host), self.host.port]
|
||||
|
||||
self.client_scheme = "https" if self.isSecure() else "http"
|
||||
|
||||
# See if we need to get the address from a proxy header instead
|
||||
if self.server.proxy_forwarded_address_header:
|
||||
self.client_addr, self.client_scheme = parse_x_forwarded_for(
|
||||
self.requestHeaders,
|
||||
self.server.proxy_forwarded_address_header,
|
||||
self.server.proxy_forwarded_port_header,
|
||||
self.server.proxy_forwarded_proto_header,
|
||||
self.client_addr,
|
||||
self.client_scheme,
|
||||
)
|
||||
# Check for unicodeish path (or it'll crash when trying to parse)
|
||||
try:
|
||||
self.path.decode("ascii")
|
||||
except UnicodeDecodeError:
|
||||
self.path = b"/"
|
||||
self.basic_error(400, b"Bad Request", "Invalid characters in path")
|
||||
return
|
||||
# Calculate query string
|
||||
self.query_string = b""
|
||||
if b"?" in self.uri:
|
||||
self.query_string = self.uri.split(b"?", 1)[1]
|
||||
try:
|
||||
self.query_string.decode("ascii")
|
||||
except UnicodeDecodeError:
|
||||
self.basic_error(400, b"Bad Request", "Invalid query string")
|
||||
return
|
||||
# Is it WebSocket? IS IT?!
|
||||
# Process WebSocket requests via HTTP upgrade
|
||||
if upgrade_header and upgrade_header.lower() == b"websocket":
|
||||
# Make WebSocket protocol to hand off to
|
||||
protocol = self.server.ws_factory.buildProtocol(
|
||||
self.transport.getPeer()
|
||||
)
|
||||
if not protocol:
|
||||
# If protocol creation fails, we signal "internal server error"
|
||||
self.setResponseCode(500)
|
||||
logger.warn("Could not make WebSocket protocol")
|
||||
self.finish()
|
||||
# Give it the raw query string
|
||||
protocol._raw_query_string = self.query_string
|
||||
# Port across transport
|
||||
transport, self.transport = self.transport, None
|
||||
if isinstance(transport, ProtocolWrapper):
|
||||
# i.e. TLS is a wrapping protocol
|
||||
transport.wrappedProtocol = protocol
|
||||
else:
|
||||
transport.protocol = protocol
|
||||
protocol.makeConnection(transport)
|
||||
# Re-inject request
|
||||
data = self.method + b" " + self.uri + b" HTTP/1.1\x0d\x0a"
|
||||
for h in self.requestHeaders.getAllRawHeaders():
|
||||
data += h[0] + b": " + b",".join(h[1]) + b"\x0d\x0a"
|
||||
data += b"\x0d\x0a"
|
||||
data += self.content.read()
|
||||
protocol.dataReceived(data)
|
||||
# Remove our HTTP reply channel association
|
||||
# Pass request to WebSocketResource for handling
|
||||
self.server.ws_resource.render_GET(self)
|
||||
# The WebSocketResource will handle the rest of the connection
|
||||
logger.debug("Upgraded connection %s to WebSocket", self.client_addr)
|
||||
self.server.protocol_disconnected(self)
|
||||
# Resume the producer so we keep getting data, if it's available as a method
|
||||
self.channel._networkProducer.resumeProducing()
|
||||
|
||||
# Boring old HTTP.
|
||||
# Don't continue with HTTP processing
|
||||
return
|
||||
# Handle normal HTTP requests
|
||||
else:
|
||||
# Sanitize and decode headers, potentially extracting root path
|
||||
self.clean_headers = []
|
||||
|
|
@ -339,9 +283,9 @@ class WebRequest(http.Request):
|
|||
"""
|
||||
Returns the time since the start of the request.
|
||||
"""
|
||||
if not hasattr(self, "request_start"):
|
||||
if not hasattr(self, "socket_opened"):
|
||||
return 0
|
||||
return time.time() - self.request_start
|
||||
return time.time() - self.socket_opened
|
||||
|
||||
def basic_error(self, status, status_text, body):
|
||||
"""
|
||||
|
|
@ -357,13 +301,12 @@ class WebRequest(http.Request):
|
|||
self.handle_reply(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": (
|
||||
self.error_template
|
||||
% {
|
||||
"title": str(status) + " " + status_text.decode("ascii"),
|
||||
"body": body,
|
||||
}
|
||||
).encode("utf8"),
|
||||
"body": self.error_template
|
||||
% {
|
||||
"status": status,
|
||||
"status_text": status_text,
|
||||
"text": body,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ from twisted.internet import defer, reactor
|
|||
from twisted.internet.endpoints import serverFromString
|
||||
from twisted.logger import STDLibLogObserver, globalLogBeginner
|
||||
from twisted.web import http
|
||||
from twisted.web.websocket import WebSocketResource
|
||||
|
||||
from .http_protocol import HTTPFactory
|
||||
from .ws_protocol import WebSocketFactory
|
||||
|
|
@ -77,6 +78,7 @@ class Server:
|
|||
self.ping_interval = ping_interval
|
||||
self.ping_timeout = ping_timeout
|
||||
self.request_buffer_size = request_buffer_size
|
||||
self.request_max_size = None # No limit by default
|
||||
self.proxy_forwarded_address_header = proxy_forwarded_address_header
|
||||
self.proxy_forwarded_port_header = proxy_forwarded_port_header
|
||||
self.proxy_forwarded_proto_header = proxy_forwarded_proto_header
|
||||
|
|
@ -99,12 +101,14 @@ class Server:
|
|||
self.connections = {}
|
||||
# Make the factory
|
||||
self.http_factory = HTTPFactory(self)
|
||||
self.ws_factory = WebSocketFactory(self, server=self.server_name)
|
||||
self.ws_factory.setProtocolOptions(
|
||||
autoPingTimeout=self.ping_timeout,
|
||||
allowNullOrigin=True,
|
||||
openHandshakeTimeout=self.websocket_handshake_timeout,
|
||||
)
|
||||
|
||||
# Create WebSocket factory
|
||||
self.ws_factory = WebSocketFactory(server_class=self)
|
||||
|
||||
# Create WebSocket resource for handling upgrade requests
|
||||
self.ws_resource = WebSocketResource(self.ws_factory)
|
||||
|
||||
# Configure logging
|
||||
if self.verbosity <= 1:
|
||||
# Redirect the Twisted log to nowhere
|
||||
globalLogBeginner.beginLoggingTo(
|
||||
|
|
|
|||
|
|
@ -3,19 +3,15 @@ import time
|
|||
import traceback
|
||||
from urllib.parse import unquote
|
||||
|
||||
from autobahn.twisted.websocket import (
|
||||
ConnectionDeny,
|
||||
WebSocketServerFactory,
|
||||
WebSocketServerProtocol,
|
||||
)
|
||||
from twisted.internet import defer
|
||||
from twisted.web.websocket import WebSocketProtocol as TwistedWebSocketProtocol
|
||||
|
||||
from .utils import parse_x_forwarded_for
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebSocketProtocol(WebSocketServerProtocol):
|
||||
class WebSocketProtocol(TwistedWebSocketProtocol):
|
||||
"""
|
||||
Protocol which supports WebSockets and forwards incoming messages to
|
||||
the websocket channels.
|
||||
|
|
@ -26,29 +22,47 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
# If we should send no more messages (e.g. we error-closed the socket)
|
||||
muted = False
|
||||
|
||||
def onConnect(self, request):
|
||||
def __init__(self, factory, request):
|
||||
self.factory = factory
|
||||
self.transport = None
|
||||
self.server = self.factory.server_class
|
||||
self.server.protocol_connected(self)
|
||||
self.request = request
|
||||
self.protocol_to_accept = None
|
||||
self.root_path = self.server.root_path
|
||||
self.socket_opened = time.time()
|
||||
self.last_ping = time.time()
|
||||
self.client_addr = None
|
||||
self.server_addr = None
|
||||
self.clean_headers = []
|
||||
self.handshake_deferred = None
|
||||
self.path = None
|
||||
self.root_path = None
|
||||
self.application_queue = None
|
||||
self.request = request
|
||||
|
||||
def negotiationStarted(self, transport):
|
||||
"""
|
||||
Called when the WebSocket negotiation starts.
|
||||
"""
|
||||
self.transport = transport
|
||||
self.server.protocol_connected(self)
|
||||
self.protocol_to_accept = None
|
||||
self.root_path = self.server.root_path
|
||||
try:
|
||||
# Sanitize and decode headers, potentially extracting root path
|
||||
self.clean_headers = []
|
||||
for name, value in request.headers.items():
|
||||
name = name.encode("ascii")
|
||||
for name, value in self.request.requestHeaders.getAllRawHeaders():
|
||||
name = name.lower()
|
||||
# Prevent CVE-2015-0219
|
||||
if b"_" in name:
|
||||
continue
|
||||
if name.lower() == b"daphne-root-path":
|
||||
self.root_path = unquote(value)
|
||||
if name == b"daphne-root-path":
|
||||
self.root_path = unquote(value[0].decode("ascii"))
|
||||
else:
|
||||
self.clean_headers.append((name.lower(), value.encode("latin1")))
|
||||
self.clean_headers.append((name, value[0]))
|
||||
|
||||
# Get client address if possible
|
||||
peer = self.transport.getPeer()
|
||||
host = self.transport.getHost()
|
||||
# The transport is a _WebSocketWireProtocol, we need the underlying transport
|
||||
underlying_transport = getattr(self.transport, "transport", self.transport)
|
||||
peer = underlying_transport.getPeer()
|
||||
host = underlying_transport.getHost()
|
||||
if hasattr(peer, "host") and hasattr(peer, "port"):
|
||||
self.client_addr = [str(peer.host), peer.port]
|
||||
self.server_addr = [str(host.host), host.port]
|
||||
|
|
@ -64,6 +78,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
self.server.proxy_forwarded_proto_header,
|
||||
self.client_addr,
|
||||
)
|
||||
|
||||
# Decode websocket subprotocol options
|
||||
subprotocols = []
|
||||
for header, value in self.clean_headers:
|
||||
|
|
@ -71,8 +86,16 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
subprotocols = [
|
||||
x.strip() for x in unquote(value.decode("ascii")).split(",")
|
||||
]
|
||||
|
||||
# Extract query string
|
||||
query_string = b""
|
||||
if b"?" in self.request.uri:
|
||||
query_string = self.request.uri.split(b"?", 1)[1]
|
||||
|
||||
# Get the path
|
||||
self.path = self.request.path
|
||||
|
||||
# Make new application instance with scope
|
||||
self.path = request.path.encode("ascii")
|
||||
self.application_deferred = defer.maybeDeferred(
|
||||
self.server.create_application,
|
||||
self,
|
||||
|
|
@ -82,7 +105,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
"raw_path": self.path,
|
||||
"root_path": self.root_path,
|
||||
"headers": self.clean_headers,
|
||||
"query_string": self._raw_query_string, # Passed by HTTP protocol
|
||||
"query_string": query_string,
|
||||
"client": self.client_addr,
|
||||
"server": self.server_addr,
|
||||
"subprotocols": subprotocols,
|
||||
|
|
@ -97,10 +120,6 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
# Make a deferred and return it - we'll either call it or err it later on
|
||||
self.handshake_deferred = defer.Deferred()
|
||||
return self.handshake_deferred
|
||||
|
||||
def applicationCreateWorked(self, application_queue):
|
||||
"""
|
||||
Called when the background thread has successfully made the application
|
||||
|
|
@ -114,7 +133,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
"websocket",
|
||||
"connecting",
|
||||
{
|
||||
"path": self.request.path,
|
||||
"path": self.request.path.decode("ascii"),
|
||||
"client": (
|
||||
"%s:%s" % tuple(self.client_addr) if self.client_addr else None
|
||||
),
|
||||
|
|
@ -128,144 +147,107 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
logger.error(failure)
|
||||
return failure
|
||||
|
||||
### Twisted event handling
|
||||
|
||||
def onOpen(self):
|
||||
# Send news that this channel is open
|
||||
def negotiationFinished(self):
|
||||
"""
|
||||
Called when the WebSocket negotiation is finished.
|
||||
"""
|
||||
logger.debug("WebSocket %s open and established", self.client_addr)
|
||||
self.server.log_action(
|
||||
"websocket",
|
||||
"connected",
|
||||
{
|
||||
"path": self.request.path,
|
||||
"path": self.request.path.decode("ascii"),
|
||||
"client": (
|
||||
"%s:%s" % tuple(self.client_addr) if self.client_addr else None
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def onMessage(self, payload, isBinary):
|
||||
def textMessageReceived(self, message):
|
||||
"""
|
||||
Called when a text message is received.
|
||||
"""
|
||||
# If we're muted, do nothing.
|
||||
if self.muted:
|
||||
logger.debug("Muting incoming frame on %s", self.client_addr)
|
||||
return
|
||||
logger.debug("WebSocket incoming frame on %s", self.client_addr)
|
||||
self.last_ping = time.time()
|
||||
if isBinary:
|
||||
self.application_queue.put_nowait(
|
||||
{"type": "websocket.receive", "bytes": payload}
|
||||
)
|
||||
else:
|
||||
self.application_queue.put_nowait(
|
||||
{"type": "websocket.receive", "text": payload.decode("utf8")}
|
||||
)
|
||||
self.application_queue.put_nowait(
|
||||
{"type": "websocket.receive", "text": message}
|
||||
)
|
||||
|
||||
def onClose(self, wasClean, code, reason):
|
||||
def bytesMessageReceived(self, data):
|
||||
"""
|
||||
Called when Twisted closes the socket.
|
||||
Called when a binary message is received.
|
||||
"""
|
||||
# If we're muted, do nothing.
|
||||
if self.muted:
|
||||
logger.debug("Muting incoming frame on %s", self.client_addr)
|
||||
return
|
||||
logger.debug("WebSocket incoming frame on %s", self.client_addr)
|
||||
self.last_ping = time.time()
|
||||
self.application_queue.put_nowait({"type": "websocket.receive", "bytes": data})
|
||||
|
||||
def connectionLost(self, reason):
|
||||
"""
|
||||
Called when the WebSocket connection is lost.
|
||||
"""
|
||||
self.server.protocol_disconnected(self)
|
||||
logger.debug("WebSocket closed for %s", self.client_addr)
|
||||
if not self.muted and hasattr(self, "application_queue"):
|
||||
self.application_queue.put_nowait(
|
||||
{"type": "websocket.disconnect", "code": code}
|
||||
{"type": "websocket.disconnect", "code": 1000} # Default close code
|
||||
)
|
||||
self.server.log_action(
|
||||
"websocket",
|
||||
"disconnected",
|
||||
{
|
||||
"path": self.request.path,
|
||||
"path": self.request.path.decode("ascii"),
|
||||
"client": (
|
||||
"%s:%s" % tuple(self.client_addr) if self.client_addr else None
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def pongReceived(self, payload):
|
||||
"""
|
||||
Called when a pong frame is received in response to a ping.
|
||||
"""
|
||||
self.last_ping = time.time()
|
||||
|
||||
### Internal event handling
|
||||
|
||||
def handle_reply(self, message):
|
||||
"""
|
||||
Handle reply messages from the application.
|
||||
"""
|
||||
if "type" not in message:
|
||||
raise ValueError("Message has no type defined")
|
||||
|
||||
if message["type"] == "websocket.accept":
|
||||
self.serverAccept(message.get("subprotocol", None))
|
||||
# Accept is handled by WebSocketResource in Twisted 25
|
||||
# Our protocol is already established at this point
|
||||
pass
|
||||
elif message["type"] == "websocket.close":
|
||||
if self.state == self.STATE_CONNECTING:
|
||||
self.serverReject()
|
||||
else:
|
||||
self.serverClose(code=message.get("code", None))
|
||||
self.transport.loseConnection(code=message.get("code", 1000))
|
||||
elif message["type"] == "websocket.send":
|
||||
if self.state == self.STATE_CONNECTING:
|
||||
raise ValueError("Socket has not been accepted, so cannot send over it")
|
||||
if message.get("bytes", None) and message.get("text", None):
|
||||
raise ValueError(
|
||||
"Got invalid WebSocket reply message on %s - contains both bytes and text keys"
|
||||
% (message,)
|
||||
% (self.client_addr,)
|
||||
)
|
||||
if message.get("bytes", None):
|
||||
self.serverSend(message["bytes"], True)
|
||||
self.transport.sendBytesMessage(message["bytes"])
|
||||
if message.get("text", None):
|
||||
self.serverSend(message["text"], False)
|
||||
self.transport.sendTextMessage(message["text"])
|
||||
|
||||
def handle_exception(self, exception):
|
||||
"""
|
||||
Called by the server when our application tracebacks
|
||||
"""
|
||||
if hasattr(self, "handshake_deferred"):
|
||||
# If the handshake is still ongoing, we need to emit a HTTP error
|
||||
# code rather than a WebSocket one.
|
||||
self.handshake_deferred.errback(
|
||||
ConnectionDeny(code=500, reason="Internal server error")
|
||||
)
|
||||
else:
|
||||
self.sendCloseFrame(code=1011)
|
||||
|
||||
def serverAccept(self, subprotocol=None):
|
||||
"""
|
||||
Called when we get a message saying to accept the connection.
|
||||
"""
|
||||
self.handshake_deferred.callback(subprotocol)
|
||||
del self.handshake_deferred
|
||||
logger.debug("WebSocket %s accepted by application", self.client_addr)
|
||||
|
||||
def serverReject(self):
|
||||
"""
|
||||
Called when we get a message saying to reject the connection.
|
||||
"""
|
||||
self.handshake_deferred.errback(
|
||||
ConnectionDeny(code=403, reason="Access denied")
|
||||
)
|
||||
del self.handshake_deferred
|
||||
self.server.protocol_disconnected(self)
|
||||
logger.debug("WebSocket %s rejected by application", self.client_addr)
|
||||
self.server.log_action(
|
||||
"websocket",
|
||||
"rejected",
|
||||
{
|
||||
"path": self.request.path,
|
||||
"client": (
|
||||
"%s:%s" % tuple(self.client_addr) if self.client_addr else None
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def serverSend(self, content, binary=False):
|
||||
"""
|
||||
Server-side channel message to send a message.
|
||||
"""
|
||||
if self.state == self.STATE_CONNECTING:
|
||||
self.serverAccept()
|
||||
logger.debug("Sent WebSocket packet to client for %s", self.client_addr)
|
||||
if binary:
|
||||
self.sendMessage(content, binary)
|
||||
else:
|
||||
self.sendMessage(content.encode("utf8"), binary)
|
||||
|
||||
def serverClose(self, code=None):
|
||||
"""
|
||||
Server-side channel message to close the socket
|
||||
"""
|
||||
code = 1000 if code is None else code
|
||||
self.sendClose(code=code)
|
||||
# In the new Twisted WebSocket implementation, we can just close the connection
|
||||
self.transport.loseConnection(code=1011) # Internal server error
|
||||
|
||||
### Utils
|
||||
|
||||
|
|
@ -284,15 +266,12 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
self.duration() > self.server.websocket_timeout
|
||||
and self.server.websocket_timeout >= 0
|
||||
):
|
||||
self.serverClose()
|
||||
self.transport.loseConnection(code=1000)
|
||||
|
||||
# Ping check
|
||||
# If we're still connecting, deny the connection
|
||||
if self.state == self.STATE_CONNECTING:
|
||||
if self.duration() > self.server.websocket_connect_timeout:
|
||||
self.serverReject()
|
||||
elif self.state == self.STATE_OPEN:
|
||||
if hasattr(self, "transport") and self.transport:
|
||||
if (time.time() - self.last_ping) > self.server.ping_interval:
|
||||
self._sendAutoPing()
|
||||
self.transport.ping()
|
||||
self.last_ping = time.time()
|
||||
|
||||
def __hash__(self):
|
||||
|
|
@ -305,26 +284,20 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
return f"<WebSocketProtocol client={self.client_addr!r} path={self.path!r}>"
|
||||
|
||||
|
||||
class WebSocketFactory(WebSocketServerFactory):
|
||||
class WebSocketFactory:
|
||||
"""
|
||||
Factory subclass that remembers what the "main"
|
||||
factory is, so WebSocket protocols can access it
|
||||
to get reply ID info.
|
||||
Factory for WebSocket protocols.
|
||||
"""
|
||||
|
||||
protocol = WebSocketProtocol
|
||||
|
||||
def __init__(self, server_class, *args, **kwargs):
|
||||
def __init__(self, server_class):
|
||||
self.server_class = server_class
|
||||
WebSocketServerFactory.__init__(self, *args, **kwargs)
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
def buildProtocol(self, request):
|
||||
"""
|
||||
Builds protocol instances. We use this to inject the factory object into the protocol.
|
||||
Builds a new WebSocket protocol.
|
||||
"""
|
||||
try:
|
||||
protocol = super().buildProtocol(addr)
|
||||
protocol.factory = self
|
||||
protocol = WebSocketProtocol(self, request)
|
||||
return protocol
|
||||
except Exception:
|
||||
logger.error("Cannot build protocol: %s" % traceback.format_exc())
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ classifiers = [
|
|||
"Topic :: Internet :: WWW/HTTP",
|
||||
]
|
||||
|
||||
dependencies = ["asgiref>=3.5.2,<4", "autobahn>=22.4.2", "twisted[tls]>=22.4"]
|
||||
dependencies = ["asgiref>=3.5.2,<4", "twisted[tls,websocket]>=22.4"]
|
||||
|
||||
[project.optional-dependencies]
|
||||
tests = [
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user