From 806c957d2a7ad7b7486bff2ac2a685f86b12b281 Mon Sep 17 00:00:00 2001 From: Jonathan Kim Date: Fri, 13 Mar 2020 20:03:40 +0000 Subject: [PATCH] Improve enum compatibility by supporting return enum as well as values and names --- graphene/types/definitions.py | 13 ++- graphene/types/schema.py | 2 +- graphene/types/tests/test_enum.py | 143 ++++++++++++++++++++++++++++++ 3 files changed, 156 insertions(+), 2 deletions(-) diff --git a/graphene/types/definitions.py b/graphene/types/definitions.py index 00916920..4dddaa90 100644 --- a/graphene/types/definitions.py +++ b/graphene/types/definitions.py @@ -1,3 +1,5 @@ +from enum import Enum as PyEnum + from graphql import ( GraphQLEnumType, GraphQLInputObjectType, @@ -36,7 +38,16 @@ class GrapheneScalarType(GrapheneGraphQLType, GraphQLScalarType): class GrapheneEnumType(GrapheneGraphQLType, GraphQLEnumType): - pass + def serialize(self, value): + if not isinstance(value, PyEnum): + enum = self.graphene_type._meta.enum + try: + # Try and get enum by value + value = enum(value) + except ValueError: + # Try ang get enum by name + value = enum[value] + return super(GrapheneEnumType, self).serialize(value) class GrapheneInputObjectType(GrapheneGraphQLType, GraphQLInputObjectType): diff --git a/graphene/types/schema.py b/graphene/types/schema.py index 29ead4a7..ce0c7439 100644 --- a/graphene/types/schema.py +++ b/graphene/types/schema.py @@ -172,7 +172,7 @@ class TypeMap(dict): deprecation_reason = graphene_type._meta.deprecation_reason(value) values[name] = GraphQLEnumValue( - value=value.value, + value=value, description=description, deprecation_reason=deprecation_reason, ) diff --git a/graphene/types/tests/test_enum.py b/graphene/types/tests/test_enum.py index 1b618120..7f9fcc4a 100644 --- a/graphene/types/tests/test_enum.py +++ b/graphene/types/tests/test_enum.py @@ -1,8 +1,11 @@ +from textwrap import dedent + from ..argument import Argument from ..enum import Enum, PyEnum from ..field import Field from ..inputfield import InputField from ..schema import ObjectType, Schema +from ..mutation import Mutation def test_enum_construction(): @@ -224,3 +227,143 @@ def test_enum_skip_meta_from_members(): "GREEN": RGB1.GREEN, "BLUE": RGB1.BLUE, } + + +def test_enum_types(): + from enum import Enum as PyEnum + + class Color(PyEnum): + RED = 1 + GREEN = 2 + BLUE = 3 + + GColor = Enum.from_enum(Color) + + class Query(ObjectType): + color = GColor(required=True) + + def resolve_color(_, info): + return Color.RED.value + + schema = Schema(query=Query) + + assert str(schema) == dedent( + '''\ + """An enumeration.""" + enum Color { + RED + GREEN + BLUE + } + + type Query { + color: Color! + } + ''' + ) + + +def test_enum_resolver(): + from enum import Enum as PyEnum + + class Color(PyEnum): + RED = 1 + GREEN = 2 + BLUE = 3 + + GColor = Enum.from_enum(Color) + + class Query(ObjectType): + color = GColor(required=True) + + def resolve_color(_, info): + return Color.RED + + schema = Schema(query=Query) + + results = schema.execute("query { color }") + assert not results.errors + + assert results.data["color"] == Color.RED.name + + +def test_enum_resolver_compat(): + from enum import Enum as PyEnum + + class Color(PyEnum): + RED = 1 + GREEN = 2 + BLUE = 3 + + GColor = Enum.from_enum(Color) + + class Query(ObjectType): + color = GColor(required=True) + color_by_name = GColor(required=True) + + def resolve_color(_, info): + return Color.RED.value + + def resolve_color_by_name(_, info): + return Color.RED.name + + schema = Schema(query=Query) + + results = schema.execute( + """query { + color + colorByName + }""" + ) + assert not results.errors + + assert results.data["color"] == Color.RED.name + assert results.data["colorByName"] == Color.RED.name + + +def test_enum_mutation(): + from enum import Enum as PyEnum + + class Color(PyEnum): + RED = 1 + GREEN = 2 + BLUE = 3 + + GColor = Enum.from_enum(Color) + + my_fav_color = None + + class Query(ObjectType): + fav_color = GColor(required=True) + + def resolve_fav_color(_, info): + return my_fav_color + + class SetFavColor(Mutation): + class Arguments: + fav_color = Argument(GColor, required=True) + + Output = Query + + def mutate(self, info, fav_color): + nonlocal my_fav_color + my_fav_color = fav_color + return Query() + + class MyMutations(ObjectType): + set_fav_color = SetFavColor.Field() + + schema = Schema(query=Query, mutation=MyMutations) + + results = schema.execute( + """mutation { + setFavColor(favColor: RED) { + favColor + } + }""" + ) + assert not results.errors + + assert my_fav_color == Color.RED + + assert results.data["setFavColor"]["favColor"] == Color.RED.name