diff --git a/graphene/tests/issues/test_881.py b/graphene/tests/issues/test_881.py new file mode 100644 index 00000000..f97b5917 --- /dev/null +++ b/graphene/tests/issues/test_881.py @@ -0,0 +1,27 @@ +import pickle + +from ...types.enum import Enum + + +class PickleEnum(Enum): + # is defined outside of test because pickle unable to dump class inside ot pytest function + A = "a" + B = 1 + + +def test_enums_pickling(): + a = PickleEnum.A + pickled = pickle.dumps(a) + restored = pickle.loads(pickled) + assert type(a) is type(restored) + assert a == restored + assert a.value == restored.value + assert a.name == restored.name + + b = PickleEnum.B + pickled = pickle.dumps(b) + restored = pickle.loads(pickled) + assert type(a) is type(restored) + assert b == restored + assert b.value == restored.value + assert b.name == restored.name diff --git a/graphene/types/enum.py b/graphene/types/enum.py index 58e65c69..d3469a15 100644 --- a/graphene/types/enum.py +++ b/graphene/types/enum.py @@ -31,9 +31,11 @@ class EnumMeta(SubclassWithMeta_Meta): # with the enum values. enum_members.pop("Meta", None) enum = PyEnum(cls.__name__, enum_members) - return SubclassWithMeta_Meta.__new__( + obj = SubclassWithMeta_Meta.__new__( cls, name_, bases, dict(classdict, __enum__=enum), **options ) + globals()[name_] = obj.__enum__ + return obj def get(cls, value): return cls._meta.enum(value) @@ -63,7 +65,7 @@ class EnumMeta(SubclassWithMeta_Meta): cls, enum, name=None, description=None, deprecation_reason=None ): # noqa: N805 name = name or enum.__name__ - description = description or enum.__doc__ + description = description or enum.__doc__ or "An enumeration." meta_dict = { "enum": enum, "description": description, diff --git a/graphene/types/inputobjecttype.py b/graphene/types/inputobjecttype.py index c0454aba..257f48be 100644 --- a/graphene/types/inputobjecttype.py +++ b/graphene/types/inputobjecttype.py @@ -15,6 +15,31 @@ class InputObjectTypeOptions(BaseOptions): container = None # type: InputObjectTypeContainer +# Currently in Graphene, we get a `None` whenever we access an (optional) field that was not set in an InputObjectType +# using the InputObjectType. dot access syntax. This is ambiguous, because in this current (Graphene +# historical) arrangement, we cannot distinguish between a field not being set and a field being set to None. +# At the same time, we shouldn't break existing code that expects a `None` when accessing a field that was not set. +_INPUT_OBJECT_TYPE_DEFAULT_VALUE = None + +# To mitigate this, we provide the function `set_input_object_type_default_value` to allow users to change the default +# value returned in non-specified fields in InputObjectType to another meaningful sentinel value (e.g. Undefined) +# if they want to. This way, we can keep code that expects a `None` working while we figure out a better solution (or +# a well-documented breaking change) for this issue. + + +def set_input_object_type_default_value(default_value): + """ + Change the sentinel value returned by non-specified fields in an InputObjectType + Useful to differentiate between a field not being set and a field being set to None by using a sentinel value + (e.g. Undefined is a good sentinel value for this purpose) + + This function should be called at the beginning of the app or in some other place where it is guaranteed to + be called before any InputObjectType is defined. + """ + global _INPUT_OBJECT_TYPE_DEFAULT_VALUE + _INPUT_OBJECT_TYPE_DEFAULT_VALUE = default_value + + class InputObjectTypeContainer(dict, BaseType): # type: ignore class Meta: abstract = True @@ -22,7 +47,7 @@ class InputObjectTypeContainer(dict, BaseType): # type: ignore def __init__(self, *args, **kwargs): dict.__init__(self, *args, **kwargs) for key in self._meta.fields: - setattr(self, key, self.get(key, None)) + setattr(self, key, self.get(key, _INPUT_OBJECT_TYPE_DEFAULT_VALUE)) def __init_subclass__(cls, *args, **kwargs): pass diff --git a/graphene/types/tests/conftest.py b/graphene/types/tests/conftest.py new file mode 100644 index 00000000..43f7d726 --- /dev/null +++ b/graphene/types/tests/conftest.py @@ -0,0 +1,12 @@ +import pytest +from graphql import Undefined + +from graphene.types.inputobjecttype import set_input_object_type_default_value + + +@pytest.fixture() +def set_default_input_object_type_to_undefined(): + """This fixture is used to change the default value of optional inputs in InputObjectTypes for specific tests""" + set_input_object_type_default_value(Undefined) + yield + set_input_object_type_default_value(None) diff --git a/graphene/types/tests/test_enum.py b/graphene/types/tests/test_enum.py index 9b3082df..e6fce66c 100644 --- a/graphene/types/tests/test_enum.py +++ b/graphene/types/tests/test_enum.py @@ -65,6 +65,21 @@ def test_enum_from_builtin_enum(): assert RGB.BLUE +def test_enum_custom_description_in_constructor(): + description = "An enumeration, but with a custom description" + RGB = Enum( + "RGB", + "RED,GREEN,BLUE", + description=description, + ) + assert RGB._meta.description == description + + +def test_enum_from_python3_enum_uses_default_builtin_doc(): + RGB = Enum("RGB", "RED,GREEN,BLUE") + assert RGB._meta.description == "An enumeration." + + def test_enum_from_builtin_enum_accepts_lambda_description(): def custom_description(value): if not value: diff --git a/graphene/types/tests/test_inputobjecttype.py b/graphene/types/tests/test_inputobjecttype.py index 0fb7e394..0d7bcf80 100644 --- a/graphene/types/tests/test_inputobjecttype.py +++ b/graphene/types/tests/test_inputobjecttype.py @@ -1,3 +1,5 @@ +from graphql import Undefined + from ..argument import Argument from ..field import Field from ..inputfield import InputField @@ -6,6 +8,7 @@ from ..objecttype import ObjectType from ..scalars import Boolean, String from ..schema import Schema from ..unmountedtype import UnmountedType +from ... import NonNull class MyType: @@ -136,3 +139,31 @@ def test_inputobjecttype_of_input(): assert not result.errors assert result.data == {"isChild": True} + + +def test_inputobjecttype_default_input_as_undefined( + set_default_input_object_type_to_undefined, +): + class TestUndefinedInput(InputObjectType): + required_field = String(required=True) + optional_field = String() + + class Query(ObjectType): + undefined_optionals_work = Field(NonNull(Boolean), input=TestUndefinedInput()) + + def resolve_undefined_optionals_work(self, info, input: TestUndefinedInput): + # Confirm that optional_field comes as Undefined + return ( + input.required_field == "required" and input.optional_field is Undefined + ) + + schema = Schema(query=Query) + result = schema.execute( + """query basequery { + undefinedOptionalsWork(input: {requiredField: "required"}) + } + """ + ) + + assert not result.errors + assert result.data == {"undefinedOptionalsWork": True} diff --git a/graphene/types/tests/test_type_map.py b/graphene/types/tests/test_type_map.py index 55b1706e..55665b6b 100644 --- a/graphene/types/tests/test_type_map.py +++ b/graphene/types/tests/test_type_map.py @@ -20,8 +20,8 @@ from ..inputobjecttype import InputObjectType from ..interface import Interface from ..objecttype import ObjectType from ..scalars import Int, String -from ..structures import List, NonNull from ..schema import Schema +from ..structures import List, NonNull def create_type_map(types, auto_camelcase=True): @@ -227,6 +227,18 @@ def test_inputobject(): assert foo_field.description == "Field description" +def test_inputobject_undefined(set_default_input_object_type_to_undefined): + class OtherObjectType(InputObjectType): + optional_field = String() + + type_map = create_type_map([OtherObjectType]) + assert "OtherObjectType" in type_map + graphql_type = type_map["OtherObjectType"] + + container = graphql_type.out_type({}) + assert container.optional_field is Undefined + + def test_objecttype_camelcase(): class MyObjectType(ObjectType): """Description""" diff --git a/graphene/validation/depth_limit.py b/graphene/validation/depth_limit.py index b4599e66..e0f28663 100644 --- a/graphene/validation/depth_limit.py +++ b/graphene/validation/depth_limit.py @@ -30,7 +30,7 @@ try: except ImportError: # backwards compatibility for v3.6 from typing import Pattern -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union, Tuple from graphql import GraphQLError from graphql.validation import ValidationContext, ValidationRule @@ -82,7 +82,7 @@ def depth_limit_validator( def get_fragments( - definitions: List[DefinitionNode], + definitions: Tuple[DefinitionNode, ...], ) -> Dict[str, FragmentDefinitionNode]: fragments = {} for definition in definitions: @@ -94,7 +94,7 @@ def get_fragments( # This will actually get both queries and mutations. # We can basically treat those the same def get_queries_and_mutations( - definitions: List[DefinitionNode], + definitions: Tuple[DefinitionNode, ...], ) -> Dict[str, OperationDefinitionNode]: operations = {}