diff --git a/graphene/tests/issues/test_881.py b/graphene/tests/issues/test_881.py index a08f79b6..f97b5917 100644 --- a/graphene/tests/issues/test_881.py +++ b/graphene/tests/issues/test_881.py @@ -3,14 +3,14 @@ import pickle from ...types.enum import Enum -class MyEnum(Enum): +class PickleEnum(Enum): # is defined outside of test because pickle unable to dump class inside ot pytest function A = "a" B = 1 def test_enums_pickling(): - a = MyEnum.A + a = PickleEnum.A pickled = pickle.dumps(a) restored = pickle.loads(pickled) assert type(a) is type(restored) @@ -18,7 +18,7 @@ def test_enums_pickling(): assert a.value == restored.value assert a.name == restored.name - b = MyEnum.B + b = PickleEnum.B pickled = pickle.dumps(b) restored = pickle.loads(pickled) assert type(a) is type(restored) diff --git a/graphene/types/enum.py b/graphene/types/enum.py index b16a63b8..79bafe00 100644 --- a/graphene/types/enum.py +++ b/graphene/types/enum.py @@ -1,40 +1,11 @@ -import inspect -import sys from enum import Enum as PyEnum -from typing import Any, Dict + from graphene.utils.subclass_with_meta import SubclassWithMeta_Meta from .base import BaseOptions, BaseType from .unmountedtype import UnmountedType -class _ModuleItemHelper: - _enum_metas: Dict[str, Any] = {} - - def __getattr__(self, name: str) -> Any: - try: - return globals()[name] - except KeyError: - if name == "__path__": - return None - if len(_ModuleItemHelper._enum_metas[name]) == 1: - return next(iter(_ModuleItemHelper._enum_metas[name].values())) - else: - # if there are more than 1 class with the name - take first by stack - for fr in inspect.stack(): - f_name = inspect.getmodulename(fr.filename) - if f_name in _ModuleItemHelper._enum_metas[name]: - return _ModuleItemHelper._enum_metas[name][f_name] - raise - - def __setattr__(self, name: str, value: Any) -> None: - cls, path = value - if not _ModuleItemHelper._enum_metas.get(name): - _ModuleItemHelper._enum_metas[name] = {} - - _ModuleItemHelper._enum_metas[name][path.split(".")[-1]] = cls - - def eq_enum(self, other): if isinstance(other, self.__class__): return self is other @@ -63,11 +34,7 @@ class EnumMeta(SubclassWithMeta_Meta): obj = SubclassWithMeta_Meta.__new__( cls, name_, bases, dict(classdict, __enum__=enum), **options ) - # globals()[name_] = obj.__enum__ - if enum_members.get("__module__"): - setattr( - sys.modules[__name__], name_, (obj.__enum__, enum_members["__module__"]) - ) + globals()[name_] = obj.__enum__ return obj def get(cls, value): @@ -151,6 +118,3 @@ class Enum(UnmountedType, BaseType, metaclass=EnumMeta): is mounted (as a Field, InputField or Argument) """ return cls - - -sys.modules[__name__] = _ModuleItemHelper() # type: ignore