diff --git a/graphene_django/utils/testing.py b/graphene_django/utils/testing.py index 5b694b2..8a9b994 100644 --- a/graphene_django/utils/testing.py +++ b/graphene_django/utils/testing.py @@ -24,7 +24,7 @@ class GraphQLTestCase(TestCase): cls._client = Client() - def query(self, query, op_name=None, input_data=None, variables=None): + def query(self, query, op_name=None, input_data=None, variables=None, headers=None): """ Args: query (string) - GraphQL query to run @@ -36,7 +36,9 @@ class GraphQLTestCase(TestCase): are provided, the ``input`` field in the ``variables`` dict will be overwritten with this value. variables (dict) - If provided, the "variables" field in GraphQL will be - set to this value. + set to this value. + headers (dict) - If provided, the headers in POST request to GRAPHQL_URL + will be set to this value. Returns: Response object from client @@ -51,10 +53,17 @@ class GraphQLTestCase(TestCase): body["variables"]["input"] = input_data else: body["variables"] = {"input": input_data} - - resp = self._client.post( - self.GRAPHQL_URL, json.dumps(body), content_type="application/json" - ) + if headers: + resp = self._client.post( + self.GRAPHQL_URL, + json.dumps(body), + content_type="application/json", + **headers + ) + else: + resp = self._client.post( + self.GRAPHQL_URL, json.dumps(body), content_type="application/json" + ) return resp def assertResponseNoErrors(self, resp):