Improve enum compatibility (#1153)

* Improve enum compatibility by supporting return enum as well as values and names

* Handle invalid enum values

* Rough implementation of compat middleware

* Move enum middleware into compat module

* Fix tests

* Tweak enum examples

* Add some tests for the middleware

* Clean up tests

* Add missing imports

* Remove enum compat middleware

* Use custom dedent function and pin graphql-core to >3.1.2
This commit is contained in:
Jonathan Kim 2020-07-13 23:40:57 +01:00 committed by GitHub
parent d042d5e95a
commit 81fff0f1b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 290 additions and 8 deletions

View File

@ -61,7 +61,8 @@ you can add description etc. to your enum without changing the original:
graphene.Enum.from_enum(
AlreadyExistingPyEnum,
description=lambda v: return 'foo' if v == AlreadyExistingPyEnum.Foo else 'bar')
description=lambda v: return 'foo' if v == AlreadyExistingPyEnum.Foo else 'bar'
)
Notes
@ -76,6 +77,7 @@ In the Python ``Enum`` implementation you can access a member by initing the Enu
.. code:: python
from enum import Enum
class Color(Enum):
RED = 1
GREEN = 2
@ -89,6 +91,7 @@ However, in Graphene ``Enum`` you need to call get to have the same effect:
.. code:: python
from graphene import Enum
class Color(Enum):
RED = 1
GREEN = 2

View File

@ -1,7 +1,7 @@
import re
from graphql_relay import to_global_id
from graphql.pyutils import dedent
from graphene.tests.utils import dedent
from ...types import ObjectType, Schema, String
from ..node import Node, is_node

View File

@ -1,5 +1,6 @@
from graphql import graphql_sync
from graphql.pyutils import dedent
from graphene.tests.utils import dedent
from ...types import Interface, ObjectType, Schema
from ...types.scalars import Int, String

9
graphene/tests/utils.py Normal file
View File

@ -0,0 +1,9 @@
from textwrap import dedent as _dedent
def dedent(text: str) -> str:
"""Fix indentation of given text by removing leading spaces and tabs.
Also removes leading newlines and trailing spaces and tabs, but keeps trailing
newlines.
"""
return _dedent(text.lstrip("\n").rstrip(" \t"))

View File

@ -1,3 +1,5 @@
from enum import Enum as PyEnum
from graphql import (
GraphQLEnumType,
GraphQLInputObjectType,
@ -5,6 +7,7 @@ from graphql import (
GraphQLObjectType,
GraphQLScalarType,
GraphQLUnionType,
Undefined,
)
@ -36,7 +39,19 @@ class GrapheneScalarType(GrapheneGraphQLType, GraphQLScalarType):
class GrapheneEnumType(GrapheneGraphQLType, GraphQLEnumType):
pass
def serialize(self, value):
if not isinstance(value, PyEnum):
enum = self.graphene_type._meta.enum
try:
# Try and get enum by value
value = enum(value)
except ValueError:
# Try and get enum by name
try:
value = enum[value]
except KeyError:
return Undefined
return super(GrapheneEnumType, self).serialize(value)
class GrapheneInputObjectType(GrapheneGraphQLType, GraphQLInputObjectType):

View File

@ -172,7 +172,7 @@ class TypeMap(dict):
deprecation_reason = graphene_type._meta.deprecation_reason(value)
values[name] = GraphQLEnumValue(
value=value.value,
value=value,
description=description,
deprecation_reason=deprecation_reason,
)

View File

@ -1,7 +1,12 @@
from textwrap import dedent
from ..argument import Argument
from ..enum import Enum, PyEnum
from ..field import Field
from ..inputfield import InputField
from ..inputobjecttype import InputObjectType
from ..mutation import Mutation
from ..scalars import String
from ..schema import ObjectType, Schema
@ -224,3 +229,245 @@ def test_enum_skip_meta_from_members():
"GREEN": RGB1.GREEN,
"BLUE": RGB1.BLUE,
}
def test_enum_types():
from enum import Enum as PyEnum
class Color(PyEnum):
"""Primary colors"""
RED = 1
YELLOW = 2
BLUE = 3
GColor = Enum.from_enum(Color)
class Query(ObjectType):
color = GColor(required=True)
def resolve_color(_, info):
return Color.RED
schema = Schema(query=Query)
assert str(schema) == dedent(
'''\
type Query {
color: Color!
}
"""Primary colors"""
enum Color {
RED
YELLOW
BLUE
}
'''
)
def test_enum_resolver():
from enum import Enum as PyEnum
class Color(PyEnum):
RED = 1
GREEN = 2
BLUE = 3
GColor = Enum.from_enum(Color)
class Query(ObjectType):
color = GColor(required=True)
def resolve_color(_, info):
return Color.RED
schema = Schema(query=Query)
results = schema.execute("query { color }")
assert not results.errors
assert results.data["color"] == Color.RED.name
def test_enum_resolver_compat():
from enum import Enum as PyEnum
class Color(PyEnum):
RED = 1
GREEN = 2
BLUE = 3
GColor = Enum.from_enum(Color)
class Query(ObjectType):
color = GColor(required=True)
color_by_name = GColor(required=True)
def resolve_color(_, info):
return Color.RED.value
def resolve_color_by_name(_, info):
return Color.RED.name
schema = Schema(query=Query)
results = schema.execute(
"""query {
color
colorByName
}"""
)
assert not results.errors
assert results.data["color"] == Color.RED.name
assert results.data["colorByName"] == Color.RED.name
def test_enum_resolver_invalid():
from enum import Enum as PyEnum
class Color(PyEnum):
RED = 1
GREEN = 2
BLUE = 3
GColor = Enum.from_enum(Color)
class Query(ObjectType):
color = GColor(required=True)
def resolve_color(_, info):
return "BLACK"
schema = Schema(query=Query)
results = schema.execute("query { color }")
assert results.errors
assert (
results.errors[0].message
== "Expected a value of type 'Color' but received: 'BLACK'"
)
def test_field_enum_argument():
class Color(Enum):
RED = 1
GREEN = 2
BLUE = 3
class Brick(ObjectType):
color = Color(required=True)
color_filter = None
class Query(ObjectType):
bricks_by_color = Field(Brick, color=Color(required=True))
def resolve_bricks_by_color(_, info, color):
nonlocal color_filter
color_filter = color
return Brick(color=color)
schema = Schema(query=Query)
results = schema.execute(
"""
query {
bricksByColor(color: RED) {
color
}
}
"""
)
assert not results.errors
assert results.data == {"bricksByColor": {"color": "RED"}}
assert color_filter == Color.RED
def test_mutation_enum_input():
class RGB(Enum):
"""Available colors"""
RED = 1
GREEN = 2
BLUE = 3
color_input = None
class CreatePaint(Mutation):
class Arguments:
color = RGB(required=True)
color = RGB(required=True)
def mutate(_, info, color):
nonlocal color_input
color_input = color
return CreatePaint(color=color)
class MyMutation(ObjectType):
create_paint = CreatePaint.Field()
class Query(ObjectType):
a = String()
schema = Schema(query=Query, mutation=MyMutation)
result = schema.execute(
""" mutation MyMutation {
createPaint(color: RED) {
color
}
}
"""
)
assert not result.errors
assert result.data == {"createPaint": {"color": "RED"}}
assert color_input == RGB.RED
def test_mutation_enum_input_type():
class RGB(Enum):
"""Available colors"""
RED = 1
GREEN = 2
BLUE = 3
class ColorInput(InputObjectType):
color = RGB(required=True)
color_input_value = None
class CreatePaint(Mutation):
class Arguments:
color_input = ColorInput(required=True)
color = RGB(required=True)
def mutate(_, info, color_input):
nonlocal color_input_value
color_input_value = color_input.color
return CreatePaint(color=color_input.color)
class MyMutation(ObjectType):
create_paint = CreatePaint.Field()
class Query(ObjectType):
a = String()
schema = Schema(query=Query, mutation=MyMutation)
result = schema.execute(
""" mutation MyMutation {
createPaint(colorInput: { color: RED }) {
color
}
}
""",
)
assert not result.errors
assert result.data == {"createPaint": {"color": "RED"}}
assert color_input_value == RGB.RED

View File

@ -1,7 +1,7 @@
from graphql.type import GraphQLObjectType, GraphQLSchema
from pytest import raises
from graphql.type import GraphQLObjectType, GraphQLSchema
from graphql.pyutils import dedent
from graphene.tests.utils import dedent
from ..field import Field
from ..objecttype import ObjectType

View File

@ -41,3 +41,10 @@ def get_type(_type):
if inspect.isfunction(_type) or isinstance(_type, partial):
return _type()
return _type
def get_underlying_type(_type):
"""Get the underlying type even if it is wrapped in structures like NonNull"""
while hasattr(_type, "of_type"):
_type = _type.of_type
return _type

View File

@ -82,7 +82,7 @@ setup(
keywords="api graphql protocol rest relay graphene",
packages=find_packages(exclude=["examples*"]),
install_requires=[
"graphql-core>=3.1.1,<4",
"graphql-core>=3.1.2,<4",
"graphql-relay>=3.0,<4",
"aniso8601>=8,<9",
],