From eae1ff0df4603843d412feb9f0a27dff5f76c471 Mon Sep 17 00:00:00 2001 From: Marcin Muszynski Date: Mon, 14 Feb 2022 15:12:56 +0000 Subject: [PATCH] Set default attributes on WebRequest (#406) --- daphne/http_protocol.py | 5 ++-- tests/test_http_protocol.py | 49 +++++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 3 deletions(-) create mode 100644 tests/test_http_protocol.py diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py index 7df7bae..a289e93 100755 --- a/daphne/http_protocol.py +++ b/daphne/http_protocol.py @@ -50,6 +50,8 @@ class WebRequest(http.Request): ) # Shorten it a bit, bytes wise def __init__(self, *args, **kwargs): + self.client_addr = None + self.server_addr = None try: http.Request.__init__(self, *args, **kwargs) # Easy server link @@ -77,9 +79,6 @@ class WebRequest(http.Request): # requires unicode string. self.client_addr = [str(self.client.host), self.client.port] self.server_addr = [str(self.host.host), self.host.port] - else: - self.client_addr = None - self.server_addr = None self.client_scheme = "https" if self.isSecure() else "http" diff --git a/tests/test_http_protocol.py b/tests/test_http_protocol.py new file mode 100644 index 0000000..024479d --- /dev/null +++ b/tests/test_http_protocol.py @@ -0,0 +1,49 @@ +import unittest + +from daphne.http_protocol import WebRequest + + +class MockServer: + """ + Mock server object for testing. + """ + + def protocol_connected(self, *args, **kwargs): + pass + + +class MockFactory: + """ + Mock factory object for testing. + """ + + def __init__(self): + self.server = MockServer() + + +class MockChannel: + """ + Mock channel object for testing. + """ + + def __init__(self): + self.factory = MockFactory() + self.transport = None + + def getPeer(self, *args, **kwargs): + return "peer" + + def getHost(self, *args, **kwargs): + return "host" + + +class TestHTTPProtocol(unittest.TestCase): + """ + Tests the HTTP protocol classes. + """ + + def test_web_request_initialisation(self): + channel = MockChannel() + request = WebRequest(channel) + self.assertIsNone(request.client_addr) + self.assertIsNone(request.server_addr)