diff --git a/graphene/tests/issues/test_881.py b/graphene/tests/issues/test_881.py index 960c5839..a08f79b6 100644 --- a/graphene/tests/issues/test_881.py +++ b/graphene/tests/issues/test_881.py @@ -5,7 +5,7 @@ from ...types.enum import Enum class MyEnum(Enum): # is defined outside of test because pickle unable to dump class inside ot pytest function - A = 'a' + A = "a" B = 1 diff --git a/graphene/types/enum.py b/graphene/types/enum.py index 79bafe00..b16a63b8 100644 --- a/graphene/types/enum.py +++ b/graphene/types/enum.py @@ -1,11 +1,40 @@ +import inspect +import sys from enum import Enum as PyEnum - +from typing import Any, Dict from graphene.utils.subclass_with_meta import SubclassWithMeta_Meta from .base import BaseOptions, BaseType from .unmountedtype import UnmountedType +class _ModuleItemHelper: + _enum_metas: Dict[str, Any] = {} + + def __getattr__(self, name: str) -> Any: + try: + return globals()[name] + except KeyError: + if name == "__path__": + return None + if len(_ModuleItemHelper._enum_metas[name]) == 1: + return next(iter(_ModuleItemHelper._enum_metas[name].values())) + else: + # if there are more than 1 class with the name - take first by stack + for fr in inspect.stack(): + f_name = inspect.getmodulename(fr.filename) + if f_name in _ModuleItemHelper._enum_metas[name]: + return _ModuleItemHelper._enum_metas[name][f_name] + raise + + def __setattr__(self, name: str, value: Any) -> None: + cls, path = value + if not _ModuleItemHelper._enum_metas.get(name): + _ModuleItemHelper._enum_metas[name] = {} + + _ModuleItemHelper._enum_metas[name][path.split(".")[-1]] = cls + + def eq_enum(self, other): if isinstance(other, self.__class__): return self is other @@ -34,7 +63,11 @@ class EnumMeta(SubclassWithMeta_Meta): obj = SubclassWithMeta_Meta.__new__( cls, name_, bases, dict(classdict, __enum__=enum), **options ) - globals()[name_] = obj.__enum__ + # globals()[name_] = obj.__enum__ + if enum_members.get("__module__"): + setattr( + sys.modules[__name__], name_, (obj.__enum__, enum_members["__module__"]) + ) return obj def get(cls, value): @@ -118,3 +151,6 @@ class Enum(UnmountedType, BaseType, metaclass=EnumMeta): is mounted (as a Field, InputField or Argument) """ return cls + + +sys.modules[__name__] = _ModuleItemHelper() # type: ignore