reverted simple solution back (without attempt to support duplicate Enum class names)

This commit is contained in:
sgrekov 2023-02-14 21:43:38 +02:00
parent dac4ba4e12
commit fa20db0054
2 changed files with 5 additions and 41 deletions

View File

@ -3,14 +3,14 @@ import pickle
from ...types.enum import Enum
class MyEnum(Enum):
class PickleEnum(Enum):
# is defined outside of test because pickle unable to dump class inside ot pytest function
A = "a"
B = 1
def test_enums_pickling():
a = MyEnum.A
a = PickleEnum.A
pickled = pickle.dumps(a)
restored = pickle.loads(pickled)
assert type(a) is type(restored)
@ -18,7 +18,7 @@ def test_enums_pickling():
assert a.value == restored.value
assert a.name == restored.name
b = MyEnum.B
b = PickleEnum.B
pickled = pickle.dumps(b)
restored = pickle.loads(pickled)
assert type(a) is type(restored)

View File

@ -1,40 +1,11 @@
import inspect
import sys
from enum import Enum as PyEnum
from typing import Any, Dict
from graphene.utils.subclass_with_meta import SubclassWithMeta_Meta
from .base import BaseOptions, BaseType
from .unmountedtype import UnmountedType
class _ModuleItemHelper:
_enum_metas: Dict[str, Any] = {}
def __getattr__(self, name: str) -> Any:
try:
return globals()[name]
except KeyError:
if name == "__path__":
return None
if len(_ModuleItemHelper._enum_metas[name]) == 1:
return next(iter(_ModuleItemHelper._enum_metas[name].values()))
else:
# if there are more than 1 class with the name - take first by stack
for fr in inspect.stack():
f_name = inspect.getmodulename(fr.filename)
if f_name in _ModuleItemHelper._enum_metas[name]:
return _ModuleItemHelper._enum_metas[name][f_name]
raise
def __setattr__(self, name: str, value: Any) -> None:
cls, path = value
if not _ModuleItemHelper._enum_metas.get(name):
_ModuleItemHelper._enum_metas[name] = {}
_ModuleItemHelper._enum_metas[name][path.split(".")[-1]] = cls
def eq_enum(self, other):
if isinstance(other, self.__class__):
return self is other
@ -63,11 +34,7 @@ class EnumMeta(SubclassWithMeta_Meta):
obj = SubclassWithMeta_Meta.__new__(
cls, name_, bases, dict(classdict, __enum__=enum), **options
)
# globals()[name_] = obj.__enum__
if enum_members.get("__module__"):
setattr(
sys.modules[__name__], name_, (obj.__enum__, enum_members["__module__"])
)
globals()[name_] = obj.__enum__
return obj
def get(cls, value):
@ -151,6 +118,3 @@ class Enum(UnmountedType, BaseType, metaclass=EnumMeta):
is mounted (as a Field, InputField or Argument)
"""
return cls
sys.modules[__name__] = _ModuleItemHelper() # type: ignore