mirror of
https://github.com/graphql-python/graphene.git
synced 2024-11-25 02:53:54 +03:00
Improve enum compatibility (#1153)
* Improve enum compatibility by supporting return enum as well as values and names * Handle invalid enum values * Rough implementation of compat middleware * Move enum middleware into compat module * Fix tests * Tweak enum examples * Add some tests for the middleware * Clean up tests * Add missing imports * Remove enum compat middleware * Use custom dedent function and pin graphql-core to >3.1.2
This commit is contained in:
parent
d042d5e95a
commit
81fff0f1b5
|
@ -61,7 +61,8 @@ you can add description etc. to your enum without changing the original:
|
|||
|
||||
graphene.Enum.from_enum(
|
||||
AlreadyExistingPyEnum,
|
||||
description=lambda v: return 'foo' if v == AlreadyExistingPyEnum.Foo else 'bar')
|
||||
description=lambda v: return 'foo' if v == AlreadyExistingPyEnum.Foo else 'bar'
|
||||
)
|
||||
|
||||
|
||||
Notes
|
||||
|
@ -76,6 +77,7 @@ In the Python ``Enum`` implementation you can access a member by initing the Enu
|
|||
.. code:: python
|
||||
|
||||
from enum import Enum
|
||||
|
||||
class Color(Enum):
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
|
@ -89,6 +91,7 @@ However, in Graphene ``Enum`` you need to call get to have the same effect:
|
|||
.. code:: python
|
||||
|
||||
from graphene import Enum
|
||||
|
||||
class Color(Enum):
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import re
|
||||
from graphql_relay import to_global_id
|
||||
|
||||
from graphql.pyutils import dedent
|
||||
from graphene.tests.utils import dedent
|
||||
|
||||
from ...types import ObjectType, Schema, String
|
||||
from ..node import Node, is_node
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from graphql import graphql_sync
|
||||
from graphql.pyutils import dedent
|
||||
|
||||
from graphene.tests.utils import dedent
|
||||
|
||||
from ...types import Interface, ObjectType, Schema
|
||||
from ...types.scalars import Int, String
|
||||
|
|
9
graphene/tests/utils.py
Normal file
9
graphene/tests/utils.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
from textwrap import dedent as _dedent
|
||||
|
||||
|
||||
def dedent(text: str) -> str:
|
||||
"""Fix indentation of given text by removing leading spaces and tabs.
|
||||
Also removes leading newlines and trailing spaces and tabs, but keeps trailing
|
||||
newlines.
|
||||
"""
|
||||
return _dedent(text.lstrip("\n").rstrip(" \t"))
|
|
@ -1,3 +1,5 @@
|
|||
from enum import Enum as PyEnum
|
||||
|
||||
from graphql import (
|
||||
GraphQLEnumType,
|
||||
GraphQLInputObjectType,
|
||||
|
@ -5,6 +7,7 @@ from graphql import (
|
|||
GraphQLObjectType,
|
||||
GraphQLScalarType,
|
||||
GraphQLUnionType,
|
||||
Undefined,
|
||||
)
|
||||
|
||||
|
||||
|
@ -36,7 +39,19 @@ class GrapheneScalarType(GrapheneGraphQLType, GraphQLScalarType):
|
|||
|
||||
|
||||
class GrapheneEnumType(GrapheneGraphQLType, GraphQLEnumType):
|
||||
pass
|
||||
def serialize(self, value):
|
||||
if not isinstance(value, PyEnum):
|
||||
enum = self.graphene_type._meta.enum
|
||||
try:
|
||||
# Try and get enum by value
|
||||
value = enum(value)
|
||||
except ValueError:
|
||||
# Try and get enum by name
|
||||
try:
|
||||
value = enum[value]
|
||||
except KeyError:
|
||||
return Undefined
|
||||
return super(GrapheneEnumType, self).serialize(value)
|
||||
|
||||
|
||||
class GrapheneInputObjectType(GrapheneGraphQLType, GraphQLInputObjectType):
|
||||
|
|
|
@ -172,7 +172,7 @@ class TypeMap(dict):
|
|||
deprecation_reason = graphene_type._meta.deprecation_reason(value)
|
||||
|
||||
values[name] = GraphQLEnumValue(
|
||||
value=value.value,
|
||||
value=value,
|
||||
description=description,
|
||||
deprecation_reason=deprecation_reason,
|
||||
)
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
from textwrap import dedent
|
||||
|
||||
from ..argument import Argument
|
||||
from ..enum import Enum, PyEnum
|
||||
from ..field import Field
|
||||
from ..inputfield import InputField
|
||||
from ..inputobjecttype import InputObjectType
|
||||
from ..mutation import Mutation
|
||||
from ..scalars import String
|
||||
from ..schema import ObjectType, Schema
|
||||
|
||||
|
||||
|
@ -224,3 +229,245 @@ def test_enum_skip_meta_from_members():
|
|||
"GREEN": RGB1.GREEN,
|
||||
"BLUE": RGB1.BLUE,
|
||||
}
|
||||
|
||||
|
||||
def test_enum_types():
|
||||
from enum import Enum as PyEnum
|
||||
|
||||
class Color(PyEnum):
|
||||
"""Primary colors"""
|
||||
|
||||
RED = 1
|
||||
YELLOW = 2
|
||||
BLUE = 3
|
||||
|
||||
GColor = Enum.from_enum(Color)
|
||||
|
||||
class Query(ObjectType):
|
||||
color = GColor(required=True)
|
||||
|
||||
def resolve_color(_, info):
|
||||
return Color.RED
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
assert str(schema) == dedent(
|
||||
'''\
|
||||
type Query {
|
||||
color: Color!
|
||||
}
|
||||
|
||||
"""Primary colors"""
|
||||
enum Color {
|
||||
RED
|
||||
YELLOW
|
||||
BLUE
|
||||
}
|
||||
'''
|
||||
)
|
||||
|
||||
|
||||
def test_enum_resolver():
|
||||
from enum import Enum as PyEnum
|
||||
|
||||
class Color(PyEnum):
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
BLUE = 3
|
||||
|
||||
GColor = Enum.from_enum(Color)
|
||||
|
||||
class Query(ObjectType):
|
||||
color = GColor(required=True)
|
||||
|
||||
def resolve_color(_, info):
|
||||
return Color.RED
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
results = schema.execute("query { color }")
|
||||
assert not results.errors
|
||||
|
||||
assert results.data["color"] == Color.RED.name
|
||||
|
||||
|
||||
def test_enum_resolver_compat():
|
||||
from enum import Enum as PyEnum
|
||||
|
||||
class Color(PyEnum):
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
BLUE = 3
|
||||
|
||||
GColor = Enum.from_enum(Color)
|
||||
|
||||
class Query(ObjectType):
|
||||
color = GColor(required=True)
|
||||
color_by_name = GColor(required=True)
|
||||
|
||||
def resolve_color(_, info):
|
||||
return Color.RED.value
|
||||
|
||||
def resolve_color_by_name(_, info):
|
||||
return Color.RED.name
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
results = schema.execute(
|
||||
"""query {
|
||||
color
|
||||
colorByName
|
||||
}"""
|
||||
)
|
||||
assert not results.errors
|
||||
|
||||
assert results.data["color"] == Color.RED.name
|
||||
assert results.data["colorByName"] == Color.RED.name
|
||||
|
||||
|
||||
def test_enum_resolver_invalid():
|
||||
from enum import Enum as PyEnum
|
||||
|
||||
class Color(PyEnum):
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
BLUE = 3
|
||||
|
||||
GColor = Enum.from_enum(Color)
|
||||
|
||||
class Query(ObjectType):
|
||||
color = GColor(required=True)
|
||||
|
||||
def resolve_color(_, info):
|
||||
return "BLACK"
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
results = schema.execute("query { color }")
|
||||
assert results.errors
|
||||
assert (
|
||||
results.errors[0].message
|
||||
== "Expected a value of type 'Color' but received: 'BLACK'"
|
||||
)
|
||||
|
||||
|
||||
def test_field_enum_argument():
|
||||
class Color(Enum):
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
BLUE = 3
|
||||
|
||||
class Brick(ObjectType):
|
||||
color = Color(required=True)
|
||||
|
||||
color_filter = None
|
||||
|
||||
class Query(ObjectType):
|
||||
bricks_by_color = Field(Brick, color=Color(required=True))
|
||||
|
||||
def resolve_bricks_by_color(_, info, color):
|
||||
nonlocal color_filter
|
||||
color_filter = color
|
||||
return Brick(color=color)
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
results = schema.execute(
|
||||
"""
|
||||
query {
|
||||
bricksByColor(color: RED) {
|
||||
color
|
||||
}
|
||||
}
|
||||
"""
|
||||
)
|
||||
assert not results.errors
|
||||
assert results.data == {"bricksByColor": {"color": "RED"}}
|
||||
assert color_filter == Color.RED
|
||||
|
||||
|
||||
def test_mutation_enum_input():
|
||||
class RGB(Enum):
|
||||
"""Available colors"""
|
||||
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
BLUE = 3
|
||||
|
||||
color_input = None
|
||||
|
||||
class CreatePaint(Mutation):
|
||||
class Arguments:
|
||||
color = RGB(required=True)
|
||||
|
||||
color = RGB(required=True)
|
||||
|
||||
def mutate(_, info, color):
|
||||
nonlocal color_input
|
||||
color_input = color
|
||||
return CreatePaint(color=color)
|
||||
|
||||
class MyMutation(ObjectType):
|
||||
create_paint = CreatePaint.Field()
|
||||
|
||||
class Query(ObjectType):
|
||||
a = String()
|
||||
|
||||
schema = Schema(query=Query, mutation=MyMutation)
|
||||
result = schema.execute(
|
||||
""" mutation MyMutation {
|
||||
createPaint(color: RED) {
|
||||
color
|
||||
}
|
||||
}
|
||||
"""
|
||||
)
|
||||
assert not result.errors
|
||||
assert result.data == {"createPaint": {"color": "RED"}}
|
||||
|
||||
assert color_input == RGB.RED
|
||||
|
||||
|
||||
def test_mutation_enum_input_type():
|
||||
class RGB(Enum):
|
||||
"""Available colors"""
|
||||
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
BLUE = 3
|
||||
|
||||
class ColorInput(InputObjectType):
|
||||
color = RGB(required=True)
|
||||
|
||||
color_input_value = None
|
||||
|
||||
class CreatePaint(Mutation):
|
||||
class Arguments:
|
||||
color_input = ColorInput(required=True)
|
||||
|
||||
color = RGB(required=True)
|
||||
|
||||
def mutate(_, info, color_input):
|
||||
nonlocal color_input_value
|
||||
color_input_value = color_input.color
|
||||
return CreatePaint(color=color_input.color)
|
||||
|
||||
class MyMutation(ObjectType):
|
||||
create_paint = CreatePaint.Field()
|
||||
|
||||
class Query(ObjectType):
|
||||
a = String()
|
||||
|
||||
schema = Schema(query=Query, mutation=MyMutation)
|
||||
result = schema.execute(
|
||||
""" mutation MyMutation {
|
||||
createPaint(colorInput: { color: RED }) {
|
||||
color
|
||||
}
|
||||
}
|
||||
""",
|
||||
)
|
||||
assert not result.errors
|
||||
assert result.data == {"createPaint": {"color": "RED"}}
|
||||
|
||||
assert color_input_value == RGB.RED
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from graphql.type import GraphQLObjectType, GraphQLSchema
|
||||
from pytest import raises
|
||||
|
||||
from graphql.type import GraphQLObjectType, GraphQLSchema
|
||||
from graphql.pyutils import dedent
|
||||
from graphene.tests.utils import dedent
|
||||
|
||||
from ..field import Field
|
||||
from ..objecttype import ObjectType
|
||||
|
|
|
@ -41,3 +41,10 @@ def get_type(_type):
|
|||
if inspect.isfunction(_type) or isinstance(_type, partial):
|
||||
return _type()
|
||||
return _type
|
||||
|
||||
|
||||
def get_underlying_type(_type):
|
||||
"""Get the underlying type even if it is wrapped in structures like NonNull"""
|
||||
while hasattr(_type, "of_type"):
|
||||
_type = _type.of_type
|
||||
return _type
|
||||
|
|
Loading…
Reference in New Issue
Block a user