fix: add default param _variables to parse_literal #1419

This is to match the `graphql-core` API. If it's not respected
the `parse_literal` method will produce an error event though
dealing with a valid value.
This commit is contained in:
René Birrer 2022-05-03 13:51:14 +02:00
parent 03277a5512
commit 181d9f76da
7 changed files with 64 additions and 11 deletions

View File

@ -0,0 +1,53 @@
import pytest
from ...types.base64 import Base64
from ...types.datetime import Date, DateTime
from ...types.decimal import Decimal
from ...types.generic import GenericScalar
from ...types.json import JSONString
from ...types.objecttype import ObjectType
from ...types.scalars import ID, BigInt, Boolean, Float, Int, String
from ...types.schema import Schema
from ...types.uuid import UUID
@pytest.mark.parametrize(
"input_type,input_value",
[
(Date, '"2022-02-02"'),
(GenericScalar, '"foo"'),
(Int, "1"),
(BigInt, "12345678901234567890"),
(Float, "1.1"),
(String, '"foo"'),
(Boolean, "true"),
(ID, "1"),
(DateTime, '"2022-02-02T11:11:11"'),
(UUID, '"cbebbc62-758e-4f75-a890-bc73b5017d81"'),
(Decimal, "1.1"),
(JSONString, '{key:"foo",value:"bar"}'),
(Base64, '"Q2hlbG8gd29ycmxkCg=="'),
],
)
def test_parse_literal_with_variables(input_type, input_value):
# input_b needs to be evaluated as literal while the variable dict for
# input_a is passed along.
class Query(ObjectType):
generic = GenericScalar(input_a=GenericScalar(), input_b=input_type())
def resolve_generic(self, info, input_a=None, input_b=None):
return input
schema = Schema(query=Query)
query = f"""
query Test($a: GenericScalar){{
generic(inputA: $a, inputB: {input_value})
}}
"""
result = schema.execute(
query,
variables={"a": "bar"},
)
assert not result.errors

View File

@ -22,7 +22,7 @@ class Base64(Scalar):
return b64encode(value).decode("utf-8")
@classmethod
def parse_literal(cls, node):
def parse_literal(cls, node, _variables=None):
if not isinstance(node, StringValueNode):
raise GraphQLError(
f"Base64 cannot represent non-string value: {print_ast(node)}"

View File

@ -22,7 +22,7 @@ class Decimal(Scalar):
return str(dec)
@classmethod
def parse_literal(cls, node):
def parse_literal(cls, node, _variables=None):
if isinstance(node, (StringValueNode, IntValueNode)):
return cls.parse_value(node.value)

View File

@ -29,7 +29,7 @@ class GenericScalar(Scalar):
parse_value = identity
@staticmethod
def parse_literal(ast):
def parse_literal(ast, _variables=None):
if isinstance(ast, (StringValueNode, BooleanValueNode)):
return ast.value
elif isinstance(ast, IntValueNode):

View File

@ -20,7 +20,7 @@ class JSONString(Scalar):
return json.dumps(dt)
@staticmethod
def parse_literal(node):
def parse_literal(node, _variables=None):
if isinstance(node, StringValueNode):
return json.loads(node.value)

View File

@ -75,7 +75,7 @@ class Int(Scalar):
parse_value = coerce_int
@staticmethod
def parse_literal(ast):
def parse_literal(ast, _variables=None):
if isinstance(ast, IntValueNode):
num = int(ast.value)
if MIN_INT <= num <= MAX_INT:
@ -104,7 +104,7 @@ class BigInt(Scalar):
parse_value = coerce_int
@staticmethod
def parse_literal(ast):
def parse_literal(ast, _variables=None):
if isinstance(ast, IntValueNode):
return int(ast.value)
@ -128,7 +128,7 @@ class Float(Scalar):
parse_value = coerce_float
@staticmethod
def parse_literal(ast):
def parse_literal(ast, _variables=None):
if isinstance(ast, (FloatValueNode, IntValueNode)):
return float(ast.value)
@ -150,7 +150,7 @@ class String(Scalar):
parse_value = coerce_string
@staticmethod
def parse_literal(ast):
def parse_literal(ast, _variables=None):
if isinstance(ast, StringValueNode):
return ast.value
@ -164,7 +164,7 @@ class Boolean(Scalar):
parse_value = bool
@staticmethod
def parse_literal(ast):
def parse_literal(ast, _variables=None):
if isinstance(ast, BooleanValueNode):
return ast.value
@ -182,6 +182,6 @@ class ID(Scalar):
parse_value = str
@staticmethod
def parse_literal(ast):
def parse_literal(ast, _variables=None):
if isinstance(ast, (StringValueNode, IntValueNode)):
return ast.value

View File

@ -21,7 +21,7 @@ class UUID(Scalar):
return str(uuid)
@staticmethod
def parse_literal(node):
def parse_literal(node, _variables=None):
if isinstance(node, StringValueNode):
return _UUID(node.value)