Merge pull request #1420

fix: add default param _variables to parse_literal #1419
This commit is contained in:
Christoph Zwerschke 2022-05-06 21:57:57 +02:00 committed by GitHub
commit 4e8a1e6057
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 64 additions and 11 deletions

View File

@ -0,0 +1,53 @@
import pytest
from ...types.base64 import Base64
from ...types.datetime import Date, DateTime
from ...types.decimal import Decimal
from ...types.generic import GenericScalar
from ...types.json import JSONString
from ...types.objecttype import ObjectType
from ...types.scalars import ID, BigInt, Boolean, Float, Int, String
from ...types.schema import Schema
from ...types.uuid import UUID
@pytest.mark.parametrize(
"input_type,input_value",
[
(Date, '"2022-02-02"'),
(GenericScalar, '"foo"'),
(Int, "1"),
(BigInt, "12345678901234567890"),
(Float, "1.1"),
(String, '"foo"'),
(Boolean, "true"),
(ID, "1"),
(DateTime, '"2022-02-02T11:11:11"'),
(UUID, '"cbebbc62-758e-4f75-a890-bc73b5017d81"'),
(Decimal, "1.1"),
(JSONString, '{key:"foo",value:"bar"}'),
(Base64, '"Q2hlbG8gd29ycmxkCg=="'),
],
)
def test_parse_literal_with_variables(input_type, input_value):
# input_b needs to be evaluated as literal while the variable dict for
# input_a is passed along.
class Query(ObjectType):
generic = GenericScalar(input_a=GenericScalar(), input_b=input_type())
def resolve_generic(self, info, input_a=None, input_b=None):
return input
schema = Schema(query=Query)
query = f"""
query Test($a: GenericScalar){{
generic(inputA: $a, inputB: {input_value})
}}
"""
result = schema.execute(
query,
variables={"a": "bar"},
)
assert not result.errors

View File

@ -22,7 +22,7 @@ class Base64(Scalar):
return b64encode(value).decode("utf-8")
@classmethod
def parse_literal(cls, node):
def parse_literal(cls, node, _variables=None):
if not isinstance(node, StringValueNode):
raise GraphQLError(
f"Base64 cannot represent non-string value: {print_ast(node)}"

View File

@ -22,7 +22,7 @@ class Decimal(Scalar):
return str(dec)
@classmethod
def parse_literal(cls, node):
def parse_literal(cls, node, _variables=None):
if isinstance(node, (StringValueNode, IntValueNode)):
return cls.parse_value(node.value)

View File

@ -29,7 +29,7 @@ class GenericScalar(Scalar):
parse_value = identity
@staticmethod
def parse_literal(ast):
def parse_literal(ast, _variables=None):
if isinstance(ast, (StringValueNode, BooleanValueNode)):
return ast.value
elif isinstance(ast, IntValueNode):

View File

@ -20,7 +20,7 @@ class JSONString(Scalar):
return json.dumps(dt)
@staticmethod
def parse_literal(node):
def parse_literal(node, _variables=None):
if isinstance(node, StringValueNode):
return json.loads(node.value)

View File

@ -75,7 +75,7 @@ class Int(Scalar):
parse_value = coerce_int
@staticmethod
def parse_literal(ast):
def parse_literal(ast, _variables=None):
if isinstance(ast, IntValueNode):
num = int(ast.value)
if MIN_INT <= num <= MAX_INT:
@ -104,7 +104,7 @@ class BigInt(Scalar):
parse_value = coerce_int
@staticmethod
def parse_literal(ast):
def parse_literal(ast, _variables=None):
if isinstance(ast, IntValueNode):
return int(ast.value)
@ -128,7 +128,7 @@ class Float(Scalar):
parse_value = coerce_float
@staticmethod
def parse_literal(ast):
def parse_literal(ast, _variables=None):
if isinstance(ast, (FloatValueNode, IntValueNode)):
return float(ast.value)
@ -150,7 +150,7 @@ class String(Scalar):
parse_value = coerce_string
@staticmethod
def parse_literal(ast):
def parse_literal(ast, _variables=None):
if isinstance(ast, StringValueNode):
return ast.value
@ -164,7 +164,7 @@ class Boolean(Scalar):
parse_value = bool
@staticmethod
def parse_literal(ast):
def parse_literal(ast, _variables=None):
if isinstance(ast, BooleanValueNode):
return ast.value
@ -182,6 +182,6 @@ class ID(Scalar):
parse_value = str
@staticmethod
def parse_literal(ast):
def parse_literal(ast, _variables=None):
if isinstance(ast, (StringValueNode, IntValueNode)):
return ast.value

View File

@ -21,7 +21,7 @@ class UUID(Scalar):
return str(uuid)
@staticmethod
def parse_literal(node):
def parse_literal(node, _variables=None):
if isinstance(node, StringValueNode):
return _UUID(node.value)