diff --git a/graphene/__init__.py b/graphene/__init__.py index e75f0b08..c922aea2 100644 --- a/graphene/__init__.py +++ b/graphene/__init__.py @@ -18,6 +18,7 @@ from .types import ( Date, DateTime, Time, + Decimal, JSONString, UUID, List, @@ -65,6 +66,7 @@ __all__ = [ "Date", "DateTime", "Time", + "Decimal", "JSONString", "UUID", "List", diff --git a/graphene/types/__init__.py b/graphene/types/__init__.py index 570664db..96591605 100644 --- a/graphene/types/__init__.py +++ b/graphene/types/__init__.py @@ -6,6 +6,7 @@ from .interface import Interface from .mutation import Mutation from .scalars import Scalar, String, ID, Int, Float, Boolean from .datetime import Date, DateTime, Time +from .decimal import Decimal from .json import JSONString from .uuid import UUID from .schema import Schema @@ -40,6 +41,7 @@ __all__ = [ "Date", "DateTime", "Time", + "Decimal", "JSONString", "UUID", "Boolean", diff --git a/graphene/types/decimal.py b/graphene/types/decimal.py new file mode 100644 index 00000000..2f99134d --- /dev/null +++ b/graphene/types/decimal.py @@ -0,0 +1,34 @@ +from __future__ import absolute_import + +from decimal import Decimal as _Decimal + +from graphql.language import ast + +from .scalars import Scalar + + +class Decimal(Scalar): + """ + The `Decimal` scalar type represents a python Decimal. + """ + + @staticmethod + def serialize(dec): + if isinstance(dec, str): + dec = _Decimal(dec) + assert isinstance(dec, _Decimal), 'Received not compatible Decimal "{}"'.format( + repr(dec) + ) + return str(dec) + + @classmethod + def parse_literal(cls, node): + if isinstance(node, ast.StringValue): + return cls.parse_value(node.value) + + @staticmethod + def parse_value(value): + try: + return _Decimal(value) + except ValueError: + return None diff --git a/graphene/types/tests/test_decimal.py b/graphene/types/tests/test_decimal.py new file mode 100644 index 00000000..abc4a6c4 --- /dev/null +++ b/graphene/types/tests/test_decimal.py @@ -0,0 +1,43 @@ +import decimal + +from ..decimal import Decimal +from ..objecttype import ObjectType +from ..schema import Schema + + +class Query(ObjectType): + decimal = Decimal(input=Decimal()) + + def resolve_decimal(self, info, input): + return input + + +schema = Schema(query=Query) + + +def test_decimal_string_query(): + decimal_value = decimal.Decimal("1969.1974") + result = schema.execute("""{ decimal(input: "%s") }""" % decimal_value) + assert not result.errors + assert result.data == {"decimal": str(decimal_value)} + assert decimal.Decimal(result.data["decimal"]) == decimal_value + + +def test_decimal_string_query_variable(): + decimal_value = decimal.Decimal("1969.1974") + + result = schema.execute( + """query Test($decimal: Decimal){ decimal(input: $decimal) }""", + variable_values={"decimal": decimal_value}, + ) + assert not result.errors + assert result.data == {"decimal": str(decimal_value)} + assert decimal.Decimal(result.data["decimal"]) == decimal_value + + +def test_bad_decimal_query(): + not_a_decimal = "Nobody expects the Spanish Inquisition!" + + result = schema.execute("""{ decimal(input: "%s") }""" % not_a_decimal) + assert len(result.errors) == 1 + assert result.data is None