From 708b3c0179e63ee36631f868278669e1985b6aa4 Mon Sep 17 00:00:00 2001 From: Maik Hoepfel Date: Wed, 8 Mar 2017 11:28:11 +0800 Subject: [PATCH] Check order of header values I'm in the process of updating the ASGI spec to require that the order of header values is kept. To match that work, I'm adding matching assertions to the tests. The code unfortunately is not as elegant as I'd like, but then it's only a result of the underlying HTTP spec. --- daphne/tests/testcases.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/daphne/tests/testcases.py b/daphne/tests/testcases.py index 2ddd744..6e2e7a8 100644 --- a/daphne/tests/testcases.py +++ b/daphne/tests/testcases.py @@ -2,6 +2,9 @@ Contains a test case class to allow verifying ASGI messages """ from __future__ import unicode_literals + +from collections import defaultdict + import six import socket from six.moves.urllib import parse @@ -67,13 +70,22 @@ class ASGITestCase(unittest.TestCase): # Assert that query_string is a byte string and still url encoded self.assertIsInstance(query_string, six.binary_type) self.assertEqual(query_string, parse.urlencode(request_params or []).encode('ascii')) - # Current implementation doesn't keep ordering - headers = channel_message['headers'] - expected = { - (name.lower().strip().encode('ascii'), value.strip().encode('ascii')) - for name, value in (request_headers or []) - } - self.assertEqual(set(headers), expected) + + # Ordering of header names is not important, but the order of values for a header + # name is. To assert whether that order is kept, we transform both the request + # headers and the channel message headers into a dictionary + # {name: [value1, value2, ...]} and check if they're equal. + transformed_message_headers = defaultdict(list) + for name, value in channel_message['headers']: + transformed_message_headers[name].append(value) + + transformed_request_headers = defaultdict(list) + for name, value in (request_headers or []): + expected_name = name.lower().strip().encode('ascii') + expected_value = value.strip().encode('ascii') + transformed_request_headers[expected_name].append(expected_value) + + self.assertEqual(transformed_message_headers, transformed_request_headers) # == Assertions about optional channel_message fields ==