mirror of
https://github.com/django/daphne.git
synced 2026-01-08 09:10:53 +03:00
Merge 5dd67afa12 into 8b5696768f
This commit is contained in:
commit
616b17bb7e
|
|
@ -5,11 +5,10 @@ from urllib.parse import unquote
|
|||
|
||||
from twisted.internet.defer import inlineCallbacks, maybeDeferred
|
||||
from twisted.internet.interfaces import IProtocolNegotiationFactory
|
||||
from twisted.protocols.policies import ProtocolWrapper
|
||||
from twisted.web import http
|
||||
from zope.interface import implementer
|
||||
|
||||
from .utils import HEADER_NAME_RE, parse_x_forwarded_for
|
||||
from .utils import parse_x_forwarded_for
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -24,132 +23,77 @@ class WebRequest(http.Request):
|
|||
"""
|
||||
|
||||
error_template = (
|
||||
"""
|
||||
<html>
|
||||
<head>
|
||||
<title>%(title)s</title>
|
||||
<style>
|
||||
body { font-family: sans-serif; margin: 0; padding: 0; }
|
||||
h1 { padding: 0.6em 0 0.2em 20px; color: #896868; margin: 0; }
|
||||
p { padding: 0 0 0.3em 20px; margin: 0; }
|
||||
footer { padding: 1em 0 0.3em 20px; color: #999; font-size: 80%%; font-style: italic; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>%(title)s</h1>
|
||||
<p>%(body)s</p>
|
||||
<footer>Daphne</footer>
|
||||
</body>
|
||||
</html>
|
||||
""".replace(
|
||||
"\n", ""
|
||||
)
|
||||
.replace(" ", " ")
|
||||
.replace(" ", " ")
|
||||
.replace(" ", " ")
|
||||
) # Shorten it a bit, bytes wise
|
||||
b"<!DOCTYPE html>"
|
||||
b"<html>"
|
||||
b"<head><title>%(status)d %(status_text)s</title></head>"
|
||||
b"<body><h1>%(status)d %(status_text)s</h1>%(text)s</body>"
|
||||
b"</html>"
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
http.Request.__init__(self, *args, **kwargs)
|
||||
# Easy server link
|
||||
self.server = self.channel.factory.server
|
||||
self.application_queue = None
|
||||
self._response_started = False
|
||||
self.client_addr = None
|
||||
self.server_addr = None
|
||||
try:
|
||||
http.Request.__init__(self, *args, **kwargs)
|
||||
# Easy server link
|
||||
self.server = self.channel.factory.server
|
||||
self.application_queue = None
|
||||
self._response_started = False
|
||||
self.server.protocol_connected(self)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
### Twisted progress callbacks
|
||||
self.client_scheme = None
|
||||
# Build the client address
|
||||
if self.transport:
|
||||
peer = self.transport.getPeer()
|
||||
host = self.transport.getHost()
|
||||
# Always set scheme if we have a transport
|
||||
self.client_scheme = (
|
||||
"https" if hasattr(peer, "is_ssl") and peer.is_ssl else "http"
|
||||
)
|
||||
if hasattr(peer, "host") and hasattr(peer, "port"):
|
||||
self.client_addr = [str(peer.host), peer.port]
|
||||
self.server_addr = [str(host.host), host.port]
|
||||
# Get upgrade header
|
||||
upgrade_header = None
|
||||
if self.requestHeaders.hasHeader(b"Upgrade"):
|
||||
upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0]
|
||||
self.is_websocket = upgrade_header and upgrade_header.lower() == b"websocket"
|
||||
# Hook up request parsing
|
||||
self.socket_opened = time.time()
|
||||
self.server.protocol_connected(self)
|
||||
|
||||
@inlineCallbacks
|
||||
def process(self):
|
||||
"""
|
||||
Called when all headers have been received and we can start processing content.
|
||||
"""
|
||||
# Get upgrade header
|
||||
upgrade_header = None
|
||||
if self.requestHeaders.hasHeader(b"Upgrade"):
|
||||
upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0]
|
||||
# Get client address if forwarded
|
||||
if self.server.proxy_forwarded_address_header:
|
||||
self.client_addr, self.client_scheme = parse_x_forwarded_for(
|
||||
{name: value for name, value in self.requestHeaders.getAllRawHeaders()},
|
||||
self.server.proxy_forwarded_address_header,
|
||||
self.server.proxy_forwarded_port_header,
|
||||
self.server.proxy_forwarded_proto_header,
|
||||
self.client_addr,
|
||||
self.client_scheme,
|
||||
)
|
||||
# Check for maximum request body size
|
||||
if self.server.request_max_size:
|
||||
self.channel.maxData = self.server.request_max_size
|
||||
# Get query string
|
||||
self.query_string = self.uri.split(b"?", 1)[1] if b"?" in self.uri else b""
|
||||
try:
|
||||
self.request_start = time.time()
|
||||
|
||||
# Validate header names.
|
||||
for name, _ in self.requestHeaders.getAllRawHeaders():
|
||||
if not HEADER_NAME_RE.fullmatch(name):
|
||||
self.basic_error(400, b"Bad Request", "Invalid header name")
|
||||
return
|
||||
|
||||
# Get upgrade header
|
||||
upgrade_header = None
|
||||
if self.requestHeaders.hasHeader(b"Upgrade"):
|
||||
upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0]
|
||||
# Get client address if possible
|
||||
if hasattr(self.client, "host") and hasattr(self.client, "port"):
|
||||
# client.host and host.host are byte strings in Python 2, but spec
|
||||
# requires unicode string.
|
||||
self.client_addr = [str(self.client.host), self.client.port]
|
||||
self.server_addr = [str(self.host.host), self.host.port]
|
||||
|
||||
self.client_scheme = "https" if self.isSecure() else "http"
|
||||
|
||||
# See if we need to get the address from a proxy header instead
|
||||
if self.server.proxy_forwarded_address_header:
|
||||
self.client_addr, self.client_scheme = parse_x_forwarded_for(
|
||||
self.requestHeaders,
|
||||
self.server.proxy_forwarded_address_header,
|
||||
self.server.proxy_forwarded_port_header,
|
||||
self.server.proxy_forwarded_proto_header,
|
||||
self.client_addr,
|
||||
self.client_scheme,
|
||||
)
|
||||
# Check for unicodeish path (or it'll crash when trying to parse)
|
||||
try:
|
||||
self.path.decode("ascii")
|
||||
except UnicodeDecodeError:
|
||||
self.path = b"/"
|
||||
self.basic_error(400, b"Bad Request", "Invalid characters in path")
|
||||
return
|
||||
# Calculate query string
|
||||
self.query_string = b""
|
||||
if b"?" in self.uri:
|
||||
self.query_string = self.uri.split(b"?", 1)[1]
|
||||
try:
|
||||
self.query_string.decode("ascii")
|
||||
except UnicodeDecodeError:
|
||||
self.basic_error(400, b"Bad Request", "Invalid query string")
|
||||
return
|
||||
# Is it WebSocket? IS IT?!
|
||||
# Process WebSocket requests via HTTP upgrade
|
||||
if upgrade_header and upgrade_header.lower() == b"websocket":
|
||||
# Make WebSocket protocol to hand off to
|
||||
protocol = self.server.ws_factory.buildProtocol(
|
||||
self.transport.getPeer()
|
||||
)
|
||||
if not protocol:
|
||||
# If protocol creation fails, we signal "internal server error"
|
||||
self.setResponseCode(500)
|
||||
logger.warn("Could not make WebSocket protocol")
|
||||
self.finish()
|
||||
# Give it the raw query string
|
||||
protocol._raw_query_string = self.query_string
|
||||
# Port across transport
|
||||
transport, self.transport = self.transport, None
|
||||
if isinstance(transport, ProtocolWrapper):
|
||||
# i.e. TLS is a wrapping protocol
|
||||
transport.wrappedProtocol = protocol
|
||||
else:
|
||||
transport.protocol = protocol
|
||||
protocol.makeConnection(transport)
|
||||
# Re-inject request
|
||||
data = self.method + b" " + self.uri + b" HTTP/1.1\x0d\x0a"
|
||||
for h in self.requestHeaders.getAllRawHeaders():
|
||||
data += h[0] + b": " + b",".join(h[1]) + b"\x0d\x0a"
|
||||
data += b"\x0d\x0a"
|
||||
data += self.content.read()
|
||||
protocol.dataReceived(data)
|
||||
# Remove our HTTP reply channel association
|
||||
# Pass request to WebSocketResource for handling
|
||||
self.server.ws_resource.render_GET(self)
|
||||
# The WebSocketResource will handle the rest of the connection
|
||||
logger.debug("Upgraded connection %s to WebSocket", self.client_addr)
|
||||
self.server.protocol_disconnected(self)
|
||||
# Resume the producer so we keep getting data, if it's available as a method
|
||||
self.channel._networkProducer.resumeProducing()
|
||||
|
||||
# Boring old HTTP.
|
||||
# Don't continue with HTTP processing
|
||||
return
|
||||
# Handle normal HTTP requests
|
||||
else:
|
||||
# Sanitize and decode headers, potentially extracting root path
|
||||
self.clean_headers = []
|
||||
|
|
@ -339,9 +283,9 @@ class WebRequest(http.Request):
|
|||
"""
|
||||
Returns the time since the start of the request.
|
||||
"""
|
||||
if not hasattr(self, "request_start"):
|
||||
if not hasattr(self, "socket_opened"):
|
||||
return 0
|
||||
return time.time() - self.request_start
|
||||
return time.time() - self.socket_opened
|
||||
|
||||
def basic_error(self, status, status_text, body):
|
||||
"""
|
||||
|
|
@ -357,13 +301,12 @@ class WebRequest(http.Request):
|
|||
self.handle_reply(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": (
|
||||
self.error_template
|
||||
% {
|
||||
"title": str(status) + " " + status_text.decode("ascii"),
|
||||
"body": body,
|
||||
}
|
||||
).encode("utf8"),
|
||||
"body": self.error_template
|
||||
% {
|
||||
"status": status,
|
||||
"status_text": status_text,
|
||||
"text": body,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ from twisted.internet import defer, reactor
|
|||
from twisted.internet.endpoints import serverFromString
|
||||
from twisted.logger import STDLibLogObserver, globalLogBeginner
|
||||
from twisted.web import http
|
||||
from twisted.web.websocket import WebSocketResource
|
||||
|
||||
from .http_protocol import HTTPFactory
|
||||
from .ws_protocol import WebSocketFactory
|
||||
|
|
@ -77,6 +78,7 @@ class Server:
|
|||
self.ping_interval = ping_interval
|
||||
self.ping_timeout = ping_timeout
|
||||
self.request_buffer_size = request_buffer_size
|
||||
self.request_max_size = None # No limit by default
|
||||
self.proxy_forwarded_address_header = proxy_forwarded_address_header
|
||||
self.proxy_forwarded_port_header = proxy_forwarded_port_header
|
||||
self.proxy_forwarded_proto_header = proxy_forwarded_proto_header
|
||||
|
|
@ -99,12 +101,14 @@ class Server:
|
|||
self.connections = {}
|
||||
# Make the factory
|
||||
self.http_factory = HTTPFactory(self)
|
||||
self.ws_factory = WebSocketFactory(self, server=self.server_name)
|
||||
self.ws_factory.setProtocolOptions(
|
||||
autoPingTimeout=self.ping_timeout,
|
||||
allowNullOrigin=True,
|
||||
openHandshakeTimeout=self.websocket_handshake_timeout,
|
||||
)
|
||||
|
||||
# Create WebSocket factory
|
||||
self.ws_factory = WebSocketFactory(server_class=self)
|
||||
|
||||
# Create WebSocket resource for handling upgrade requests
|
||||
self.ws_resource = WebSocketResource(self.ws_factory)
|
||||
|
||||
# Configure logging
|
||||
if self.verbosity <= 1:
|
||||
# Redirect the Twisted log to nowhere
|
||||
globalLogBeginner.beginLoggingTo(
|
||||
|
|
|
|||
|
|
@ -3,19 +3,15 @@ import time
|
|||
import traceback
|
||||
from urllib.parse import unquote
|
||||
|
||||
from autobahn.twisted.websocket import (
|
||||
ConnectionDeny,
|
||||
WebSocketServerFactory,
|
||||
WebSocketServerProtocol,
|
||||
)
|
||||
from twisted.internet import defer
|
||||
from twisted.web.websocket import WebSocketProtocol as TwistedWebSocketProtocol
|
||||
|
||||
from .utils import parse_x_forwarded_for
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebSocketProtocol(WebSocketServerProtocol):
|
||||
class WebSocketProtocol(TwistedWebSocketProtocol):
|
||||
"""
|
||||
Protocol which supports WebSockets and forwards incoming messages to
|
||||
the websocket channels.
|
||||
|
|
@ -26,29 +22,47 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
# If we should send no more messages (e.g. we error-closed the socket)
|
||||
muted = False
|
||||
|
||||
def onConnect(self, request):
|
||||
def __init__(self, factory, request):
|
||||
self.factory = factory
|
||||
self.transport = None
|
||||
self.server = self.factory.server_class
|
||||
self.server.protocol_connected(self)
|
||||
self.request = request
|
||||
self.protocol_to_accept = None
|
||||
self.root_path = self.server.root_path
|
||||
self.socket_opened = time.time()
|
||||
self.last_ping = time.time()
|
||||
self.client_addr = None
|
||||
self.server_addr = None
|
||||
self.clean_headers = []
|
||||
self.handshake_deferred = None
|
||||
self.path = None
|
||||
self.root_path = None
|
||||
self.application_queue = None
|
||||
self.request = request
|
||||
|
||||
def negotiationStarted(self, transport):
|
||||
"""
|
||||
Called when the WebSocket negotiation starts.
|
||||
"""
|
||||
self.transport = transport
|
||||
self.server.protocol_connected(self)
|
||||
self.protocol_to_accept = None
|
||||
self.root_path = self.server.root_path
|
||||
try:
|
||||
# Sanitize and decode headers, potentially extracting root path
|
||||
self.clean_headers = []
|
||||
for name, value in request.headers.items():
|
||||
name = name.encode("ascii")
|
||||
for name, value in self.request.requestHeaders.getAllRawHeaders():
|
||||
name = name.lower()
|
||||
# Prevent CVE-2015-0219
|
||||
if b"_" in name:
|
||||
continue
|
||||
if name.lower() == b"daphne-root-path":
|
||||
self.root_path = unquote(value)
|
||||
if name == b"daphne-root-path":
|
||||
self.root_path = unquote(value[0].decode("ascii"))
|
||||
else:
|
||||
self.clean_headers.append((name.lower(), value.encode("latin1")))
|
||||
self.clean_headers.append((name, value[0]))
|
||||
|
||||
# Get client address if possible
|
||||
peer = self.transport.getPeer()
|
||||
host = self.transport.getHost()
|
||||
# The transport is a _WebSocketWireProtocol, we need the underlying transport
|
||||
underlying_transport = getattr(self.transport, "transport", self.transport)
|
||||
peer = underlying_transport.getPeer()
|
||||
host = underlying_transport.getHost()
|
||||
if hasattr(peer, "host") and hasattr(peer, "port"):
|
||||
self.client_addr = [str(peer.host), peer.port]
|
||||
self.server_addr = [str(host.host), host.port]
|
||||
|
|
@ -64,6 +78,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
self.server.proxy_forwarded_proto_header,
|
||||
self.client_addr,
|
||||
)
|
||||
|
||||
# Decode websocket subprotocol options
|
||||
subprotocols = []
|
||||
for header, value in self.clean_headers:
|
||||
|
|
@ -71,8 +86,16 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
subprotocols = [
|
||||
x.strip() for x in unquote(value.decode("ascii")).split(",")
|
||||
]
|
||||
|
||||
# Extract query string
|
||||
query_string = b""
|
||||
if b"?" in self.request.uri:
|
||||
query_string = self.request.uri.split(b"?", 1)[1]
|
||||
|
||||
# Get the path
|
||||
self.path = self.request.path
|
||||
|
||||
# Make new application instance with scope
|
||||
self.path = request.path.encode("ascii")
|
||||
self.application_deferred = defer.maybeDeferred(
|
||||
self.server.create_application,
|
||||
self,
|
||||
|
|
@ -82,7 +105,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
"raw_path": self.path,
|
||||
"root_path": self.root_path,
|
||||
"headers": self.clean_headers,
|
||||
"query_string": self._raw_query_string, # Passed by HTTP protocol
|
||||
"query_string": query_string,
|
||||
"client": self.client_addr,
|
||||
"server": self.server_addr,
|
||||
"subprotocols": subprotocols,
|
||||
|
|
@ -97,10 +120,6 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
# Make a deferred and return it - we'll either call it or err it later on
|
||||
self.handshake_deferred = defer.Deferred()
|
||||
return self.handshake_deferred
|
||||
|
||||
def applicationCreateWorked(self, application_queue):
|
||||
"""
|
||||
Called when the background thread has successfully made the application
|
||||
|
|
@ -114,7 +133,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
"websocket",
|
||||
"connecting",
|
||||
{
|
||||
"path": self.request.path,
|
||||
"path": self.request.path.decode("ascii"),
|
||||
"client": (
|
||||
"%s:%s" % tuple(self.client_addr) if self.client_addr else None
|
||||
),
|
||||
|
|
@ -128,144 +147,107 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
logger.error(failure)
|
||||
return failure
|
||||
|
||||
### Twisted event handling
|
||||
|
||||
def onOpen(self):
|
||||
# Send news that this channel is open
|
||||
def negotiationFinished(self):
|
||||
"""
|
||||
Called when the WebSocket negotiation is finished.
|
||||
"""
|
||||
logger.debug("WebSocket %s open and established", self.client_addr)
|
||||
self.server.log_action(
|
||||
"websocket",
|
||||
"connected",
|
||||
{
|
||||
"path": self.request.path,
|
||||
"path": self.request.path.decode("ascii"),
|
||||
"client": (
|
||||
"%s:%s" % tuple(self.client_addr) if self.client_addr else None
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def onMessage(self, payload, isBinary):
|
||||
def textMessageReceived(self, message):
|
||||
"""
|
||||
Called when a text message is received.
|
||||
"""
|
||||
# If we're muted, do nothing.
|
||||
if self.muted:
|
||||
logger.debug("Muting incoming frame on %s", self.client_addr)
|
||||
return
|
||||
logger.debug("WebSocket incoming frame on %s", self.client_addr)
|
||||
self.last_ping = time.time()
|
||||
if isBinary:
|
||||
self.application_queue.put_nowait(
|
||||
{"type": "websocket.receive", "bytes": payload}
|
||||
)
|
||||
else:
|
||||
self.application_queue.put_nowait(
|
||||
{"type": "websocket.receive", "text": payload.decode("utf8")}
|
||||
)
|
||||
self.application_queue.put_nowait(
|
||||
{"type": "websocket.receive", "text": message}
|
||||
)
|
||||
|
||||
def onClose(self, wasClean, code, reason):
|
||||
def bytesMessageReceived(self, data):
|
||||
"""
|
||||
Called when Twisted closes the socket.
|
||||
Called when a binary message is received.
|
||||
"""
|
||||
# If we're muted, do nothing.
|
||||
if self.muted:
|
||||
logger.debug("Muting incoming frame on %s", self.client_addr)
|
||||
return
|
||||
logger.debug("WebSocket incoming frame on %s", self.client_addr)
|
||||
self.last_ping = time.time()
|
||||
self.application_queue.put_nowait({"type": "websocket.receive", "bytes": data})
|
||||
|
||||
def connectionLost(self, reason):
|
||||
"""
|
||||
Called when the WebSocket connection is lost.
|
||||
"""
|
||||
self.server.protocol_disconnected(self)
|
||||
logger.debug("WebSocket closed for %s", self.client_addr)
|
||||
if not self.muted and hasattr(self, "application_queue"):
|
||||
self.application_queue.put_nowait(
|
||||
{"type": "websocket.disconnect", "code": code}
|
||||
{"type": "websocket.disconnect", "code": 1000} # Default close code
|
||||
)
|
||||
self.server.log_action(
|
||||
"websocket",
|
||||
"disconnected",
|
||||
{
|
||||
"path": self.request.path,
|
||||
"path": self.request.path.decode("ascii"),
|
||||
"client": (
|
||||
"%s:%s" % tuple(self.client_addr) if self.client_addr else None
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def pongReceived(self, payload):
|
||||
"""
|
||||
Called when a pong frame is received in response to a ping.
|
||||
"""
|
||||
self.last_ping = time.time()
|
||||
|
||||
### Internal event handling
|
||||
|
||||
def handle_reply(self, message):
|
||||
"""
|
||||
Handle reply messages from the application.
|
||||
"""
|
||||
if "type" not in message:
|
||||
raise ValueError("Message has no type defined")
|
||||
|
||||
if message["type"] == "websocket.accept":
|
||||
self.serverAccept(message.get("subprotocol", None))
|
||||
# Accept is handled by WebSocketResource in Twisted 25
|
||||
# Our protocol is already established at this point
|
||||
pass
|
||||
elif message["type"] == "websocket.close":
|
||||
if self.state == self.STATE_CONNECTING:
|
||||
self.serverReject()
|
||||
else:
|
||||
self.serverClose(code=message.get("code", None))
|
||||
self.transport.loseConnection(code=message.get("code", 1000))
|
||||
elif message["type"] == "websocket.send":
|
||||
if self.state == self.STATE_CONNECTING:
|
||||
raise ValueError("Socket has not been accepted, so cannot send over it")
|
||||
if message.get("bytes", None) and message.get("text", None):
|
||||
raise ValueError(
|
||||
"Got invalid WebSocket reply message on %s - contains both bytes and text keys"
|
||||
% (message,)
|
||||
% (self.client_addr,)
|
||||
)
|
||||
if message.get("bytes", None):
|
||||
self.serverSend(message["bytes"], True)
|
||||
self.transport.sendBytesMessage(message["bytes"])
|
||||
if message.get("text", None):
|
||||
self.serverSend(message["text"], False)
|
||||
self.transport.sendTextMessage(message["text"])
|
||||
|
||||
def handle_exception(self, exception):
|
||||
"""
|
||||
Called by the server when our application tracebacks
|
||||
"""
|
||||
if hasattr(self, "handshake_deferred"):
|
||||
# If the handshake is still ongoing, we need to emit a HTTP error
|
||||
# code rather than a WebSocket one.
|
||||
self.handshake_deferred.errback(
|
||||
ConnectionDeny(code=500, reason="Internal server error")
|
||||
)
|
||||
else:
|
||||
self.sendCloseFrame(code=1011)
|
||||
|
||||
def serverAccept(self, subprotocol=None):
|
||||
"""
|
||||
Called when we get a message saying to accept the connection.
|
||||
"""
|
||||
self.handshake_deferred.callback(subprotocol)
|
||||
del self.handshake_deferred
|
||||
logger.debug("WebSocket %s accepted by application", self.client_addr)
|
||||
|
||||
def serverReject(self):
|
||||
"""
|
||||
Called when we get a message saying to reject the connection.
|
||||
"""
|
||||
self.handshake_deferred.errback(
|
||||
ConnectionDeny(code=403, reason="Access denied")
|
||||
)
|
||||
del self.handshake_deferred
|
||||
self.server.protocol_disconnected(self)
|
||||
logger.debug("WebSocket %s rejected by application", self.client_addr)
|
||||
self.server.log_action(
|
||||
"websocket",
|
||||
"rejected",
|
||||
{
|
||||
"path": self.request.path,
|
||||
"client": (
|
||||
"%s:%s" % tuple(self.client_addr) if self.client_addr else None
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def serverSend(self, content, binary=False):
|
||||
"""
|
||||
Server-side channel message to send a message.
|
||||
"""
|
||||
if self.state == self.STATE_CONNECTING:
|
||||
self.serverAccept()
|
||||
logger.debug("Sent WebSocket packet to client for %s", self.client_addr)
|
||||
if binary:
|
||||
self.sendMessage(content, binary)
|
||||
else:
|
||||
self.sendMessage(content.encode("utf8"), binary)
|
||||
|
||||
def serverClose(self, code=None):
|
||||
"""
|
||||
Server-side channel message to close the socket
|
||||
"""
|
||||
code = 1000 if code is None else code
|
||||
self.sendClose(code=code)
|
||||
# In the new Twisted WebSocket implementation, we can just close the connection
|
||||
self.transport.loseConnection(code=1011) # Internal server error
|
||||
|
||||
### Utils
|
||||
|
||||
|
|
@ -284,15 +266,12 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
self.duration() > self.server.websocket_timeout
|
||||
and self.server.websocket_timeout >= 0
|
||||
):
|
||||
self.serverClose()
|
||||
self.transport.loseConnection(code=1000)
|
||||
|
||||
# Ping check
|
||||
# If we're still connecting, deny the connection
|
||||
if self.state == self.STATE_CONNECTING:
|
||||
if self.duration() > self.server.websocket_connect_timeout:
|
||||
self.serverReject()
|
||||
elif self.state == self.STATE_OPEN:
|
||||
if hasattr(self, "transport") and self.transport:
|
||||
if (time.time() - self.last_ping) > self.server.ping_interval:
|
||||
self._sendAutoPing()
|
||||
self.transport.ping()
|
||||
self.last_ping = time.time()
|
||||
|
||||
def __hash__(self):
|
||||
|
|
@ -305,26 +284,20 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
return f"<WebSocketProtocol client={self.client_addr!r} path={self.path!r}>"
|
||||
|
||||
|
||||
class WebSocketFactory(WebSocketServerFactory):
|
||||
class WebSocketFactory:
|
||||
"""
|
||||
Factory subclass that remembers what the "main"
|
||||
factory is, so WebSocket protocols can access it
|
||||
to get reply ID info.
|
||||
Factory for WebSocket protocols.
|
||||
"""
|
||||
|
||||
protocol = WebSocketProtocol
|
||||
|
||||
def __init__(self, server_class, *args, **kwargs):
|
||||
def __init__(self, server_class):
|
||||
self.server_class = server_class
|
||||
WebSocketServerFactory.__init__(self, *args, **kwargs)
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
def buildProtocol(self, request):
|
||||
"""
|
||||
Builds protocol instances. We use this to inject the factory object into the protocol.
|
||||
Builds a new WebSocket protocol.
|
||||
"""
|
||||
try:
|
||||
protocol = super().buildProtocol(addr)
|
||||
protocol.factory = self
|
||||
protocol = WebSocketProtocol(self, request)
|
||||
return protocol
|
||||
except Exception:
|
||||
logger.error("Cannot build protocol: %s" % traceback.format_exc())
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ classifiers = [
|
|||
"Topic :: Internet :: WWW/HTTP",
|
||||
]
|
||||
|
||||
dependencies = ["asgiref>=3.5.2,<4", "autobahn>=22.4.2", "twisted[tls]>=22.4"]
|
||||
dependencies = ["asgiref>=3.5.2,<4", "twisted[tls,websocket]>=22.4"]
|
||||
|
||||
[project.optional-dependencies]
|
||||
tests = [
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user