considered case with colliding class names (try to distinguish by file name)

This commit is contained in:
sgrekov 2023-02-14 09:01:38 +02:00
parent f5b5a14f3f
commit dac4ba4e12
2 changed files with 39 additions and 3 deletions

View File

@ -5,7 +5,7 @@ from ...types.enum import Enum
class MyEnum(Enum): class MyEnum(Enum):
# is defined outside of test because pickle unable to dump class inside ot pytest function # is defined outside of test because pickle unable to dump class inside ot pytest function
A = 'a' A = "a"
B = 1 B = 1

View File

@ -1,11 +1,40 @@
import inspect
import sys
from enum import Enum as PyEnum from enum import Enum as PyEnum
from typing import Any, Dict
from graphene.utils.subclass_with_meta import SubclassWithMeta_Meta from graphene.utils.subclass_with_meta import SubclassWithMeta_Meta
from .base import BaseOptions, BaseType from .base import BaseOptions, BaseType
from .unmountedtype import UnmountedType from .unmountedtype import UnmountedType
class _ModuleItemHelper:
_enum_metas: Dict[str, Any] = {}
def __getattr__(self, name: str) -> Any:
try:
return globals()[name]
except KeyError:
if name == "__path__":
return None
if len(_ModuleItemHelper._enum_metas[name]) == 1:
return next(iter(_ModuleItemHelper._enum_metas[name].values()))
else:
# if there are more than 1 class with the name - take first by stack
for fr in inspect.stack():
f_name = inspect.getmodulename(fr.filename)
if f_name in _ModuleItemHelper._enum_metas[name]:
return _ModuleItemHelper._enum_metas[name][f_name]
raise
def __setattr__(self, name: str, value: Any) -> None:
cls, path = value
if not _ModuleItemHelper._enum_metas.get(name):
_ModuleItemHelper._enum_metas[name] = {}
_ModuleItemHelper._enum_metas[name][path.split(".")[-1]] = cls
def eq_enum(self, other): def eq_enum(self, other):
if isinstance(other, self.__class__): if isinstance(other, self.__class__):
return self is other return self is other
@ -34,7 +63,11 @@ class EnumMeta(SubclassWithMeta_Meta):
obj = SubclassWithMeta_Meta.__new__( obj = SubclassWithMeta_Meta.__new__(
cls, name_, bases, dict(classdict, __enum__=enum), **options cls, name_, bases, dict(classdict, __enum__=enum), **options
) )
globals()[name_] = obj.__enum__ # globals()[name_] = obj.__enum__
if enum_members.get("__module__"):
setattr(
sys.modules[__name__], name_, (obj.__enum__, enum_members["__module__"])
)
return obj return obj
def get(cls, value): def get(cls, value):
@ -118,3 +151,6 @@ class Enum(UnmountedType, BaseType, metaclass=EnumMeta):
is mounted (as a Field, InputField or Argument) is mounted (as a Field, InputField or Argument)
""" """
return cls return cls
sys.modules[__name__] = _ModuleItemHelper() # type: ignore