From d8201c44fa36c87153f3cb8239cd902ed9eaa0d0 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Mon, 13 Jun 2016 19:14:49 -0700 Subject: [PATCH] Added better enum --- graphene/types/enum.py | 85 ++++++++++++++++--------------- graphene/types/tests/test_enum.py | 2 +- 2 files changed, 44 insertions(+), 43 deletions(-) diff --git a/graphene/types/enum.py b/graphene/types/enum.py index 9eff8c1a..a60d26ff 100644 --- a/graphene/types/enum.py +++ b/graphene/types/enum.py @@ -1,7 +1,10 @@ import six from graphql.type import GraphQLEnumType, GraphQLEnumValue +from collections import OrderedDict -from .definitions import ClassTypeMeta, GrapheneGraphQLType +from .definitions import GrapheneGraphQLType +from .options import Options +from ..utils.is_base_type import is_base_type try: from enum import Enum as PyEnum except ImportError: @@ -11,61 +14,59 @@ from .unmountedtype import UnmountedType class GrapheneEnumType(GrapheneGraphQLType, GraphQLEnumType): - - def __init__(self, *args, **kwargs): - graphene_type = kwargs.pop('graphene_type') - self.graphene_type = graphene_type - self.name = None - self.description = None - self._values = None - self._value_lookup = None - self._name_lookup = None - - def get_values(self): - # list of values GraphQLEnumValue - enum = self.graphene_type._meta.enum - values = [] - for name, value in enum.__members__.items(): - values.append(GraphQLEnumValue(name=name, value=value.value)) - return values + pass -class EnumTypeMeta(ClassTypeMeta): +def values_from_enum(enum): + _values = OrderedDict() + for name, value in enum.__members__.items(): + _values[name] = GraphQLEnumValue(name=name, value=value.value) + return _values - def get_options(cls, meta): - return cls.options_class( - meta, + +class EnumTypeMeta(type): + + def __new__(cls, name, bases, attrs): + super_new = super(EnumTypeMeta, cls).__new__ + + # Also ensure initialization is only performed for subclasses of Model + # (excluding Model class itself). + if not is_base_type(bases, EnumTypeMeta): + return super_new(cls, name, bases, attrs) + + options = Options( + attrs.pop('Meta', None), name=None, description=None, enum=None, - graphql_type=None, - abstract=False + graphql_type=None ) - def construct(cls, bases, attrs): - if not cls._meta.graphql_type and not cls._meta.abstract: - cls._meta.graphql_type = GrapheneEnumType( + if not options.enum: + options.enum = type(cls.__name__, (PyEnum,), attrs) + + cls = super_new(cls, name, bases, dict(attrs, _meta=options, **options.enum.__members__)) + + if not options.graphql_type: + values = values_from_enum(options.enum) + options.graphql_type = GrapheneEnumType( graphene_type=cls, - name=cls._meta.name or cls.__name__, - description=cls._meta.description or cls.__doc__, + values=values, + name=options.name or cls.__name__, + description=options.description or cls.__doc__, ) - if not cls._meta.enum: - cls._meta.enum = type(cls.__name__, (PyEnum,), attrs) - - return super(EnumTypeMeta, cls).construct(bases, dict(attrs, **cls._meta.enum.__members__)) + return cls def __call__(cls, *args, **kwargs): - if cls._meta.abstract: - return cls.create(PyEnum(*args, **kwargs)) + if cls is Enum: + return cls.from_enum(PyEnum(*args, **kwargs)) return super(EnumTypeMeta, cls).__call__(*args, **kwargs) - def create(cls, python_enum): - class Meta: - enum = python_enum - return type(Meta.enum.__name__, (Enum,), {'Meta': Meta}) - class Enum(six.with_metaclass(EnumTypeMeta, UnmountedType)): - class Meta: - abstract = True + @classmethod + def from_enum(cls, python_enum): + class Meta: + enum = python_enum + return type(Meta.enum.__name__, (Enum,), {'Meta': Meta}) diff --git a/graphene/types/tests/test_enum.py b/graphene/types/tests/test_enum.py index 51b3b76c..8f07ccab 100644 --- a/graphene/types/tests/test_enum.py +++ b/graphene/types/tests/test_enum.py @@ -40,7 +40,7 @@ def test_enum_instance_construction(): def test_enum_from_builtin_enum(): PyRGB = PyEnum('RGB', 'RED,GREEN,BLUE') - RGB = Enum.create(PyRGB) + RGB = Enum.from_enum(PyRGB) assert isinstance(RGB._meta.graphql_type, GraphQLEnumType) values = RGB._meta.graphql_type.get_values() assert values == [