mirror of
https://github.com/graphql-python/graphene.git
synced 2025-02-02 20:54:16 +03:00
Merge pull request #140 from graphql-python/features/django-fields
Improved support for Django fields
This commit is contained in:
commit
332d7b0227
|
@ -24,6 +24,7 @@ install:
|
|||
- |
|
||||
if [ "$TEST_TYPE" = build ]; then
|
||||
pip install --download-cache $HOME/.cache/pip/ pytest pytest-cov coveralls six pytest-django django-filter sqlalchemy_utils
|
||||
pip install --download-cache $HOME/.cache/pip psycopg2 > /dev/null 2>&1
|
||||
pip install --download-cache $HOME/.cache/pip/ -e .[django]
|
||||
pip install --download-cache $HOME/.cache/pip/ -e .[sqlalchemy]
|
||||
pip install django==$DJANGO_VERSION
|
||||
|
|
|
@ -16,6 +16,10 @@ Also the following Types are available:
|
|||
- `graphene.List`
|
||||
- `graphene.NonNull`
|
||||
|
||||
Graphene also provides custom scalars for Dates and JSON:
|
||||
- `graphene.core.types.custom_scalars.DateTime`
|
||||
- `graphene.core.types.custom_scalars.JSONString`
|
||||
|
||||
## Shortcuts
|
||||
|
||||
There are some shortcuts for building schemas more easily.
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from flask import Flask
|
||||
from database import db_session, init_db
|
||||
|
||||
from schema import schema
|
||||
from database import db_session, init_db
|
||||
from flask_graphql import GraphQL
|
||||
from schema import schema
|
||||
|
||||
app = Flask(__name__)
|
||||
app.debug = True
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import scoped_session, sessionmaker
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import scoped_session, sessionmaker
|
||||
|
||||
engine = create_engine('sqlite:///database.sqlite3', convert_unicode=True)
|
||||
db_session = scoped_session(sessionmaker(autocommit=False,
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, func
|
||||
from sqlalchemy.orm import backref, relationship
|
||||
|
||||
from database import Base
|
||||
from sqlalchemy import Column, DateTime, String, Integer, ForeignKey, func
|
||||
from sqlalchemy.orm import relationship, backref
|
||||
|
||||
|
||||
class Department(Base):
|
||||
|
|
|
@ -1,19 +1,23 @@
|
|||
import graphene
|
||||
from graphene import relay
|
||||
from graphene.contrib.sqlalchemy import SQLAlchemyNode, SQLAlchemyConnectionField
|
||||
from models import Department as DepartmentModel, Employee as EmployeeModel
|
||||
from graphene.contrib.sqlalchemy import (SQLAlchemyConnectionField,
|
||||
SQLAlchemyNode)
|
||||
from models import Department as DepartmentModel
|
||||
from models import Employee as EmployeeModel
|
||||
|
||||
schema = graphene.Schema()
|
||||
|
||||
|
||||
@schema.register
|
||||
class Department(SQLAlchemyNode):
|
||||
|
||||
class Meta:
|
||||
model = DepartmentModel
|
||||
|
||||
|
||||
@schema.register
|
||||
class Employee(SQLAlchemyNode):
|
||||
|
||||
class Meta:
|
||||
model = EmployeeModel
|
||||
|
||||
|
|
|
@ -1,15 +1,24 @@
|
|||
from django.db import models
|
||||
|
||||
|
||||
class MissingType(object):
|
||||
pass
|
||||
|
||||
try:
|
||||
UUIDField = models.UUIDField
|
||||
except AttributeError:
|
||||
# Improved compatibility for Django 1.6
|
||||
class UUIDField(object):
|
||||
pass
|
||||
UUIDField = MissingType
|
||||
|
||||
try:
|
||||
from django.db.models.related import RelatedObject
|
||||
except:
|
||||
# Improved compatibility for Django 1.6
|
||||
class RelatedObject(object):
|
||||
pass
|
||||
RelatedObject = MissingType
|
||||
|
||||
|
||||
try:
|
||||
# Postgres fields are only available in Django 1.8+
|
||||
from django.contrib.postgres.fields import ArrayField, HStoreField, JSONField, RangeField
|
||||
except ImportError:
|
||||
ArrayField, HStoreField, JSONField, RangeField = (MissingType, ) * 4
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
from django.db import models
|
||||
|
||||
from ...core.types.scalars import ID, Boolean, Float, Int, String
|
||||
from ...core.classtypes.enum import Enum
|
||||
from ...core.types.custom_scalars import DateTime, JSONString
|
||||
from ...core.types.definitions import List
|
||||
from ...core.types.scalars import ID, Boolean, Float, Int, String
|
||||
from ...utils import to_const
|
||||
from .compat import RelatedObject, UUIDField
|
||||
from .compat import (ArrayField, HStoreField, JSONField, RangeField,
|
||||
RelatedObject, UUIDField)
|
||||
from .utils import get_related_model, import_single_dispatch
|
||||
|
||||
singledispatch = import_single_dispatch()
|
||||
|
@ -30,13 +33,13 @@ def convert_django_field(field):
|
|||
(field, field.__class__))
|
||||
|
||||
|
||||
@convert_django_field.register(models.DateField)
|
||||
@convert_django_field.register(models.CharField)
|
||||
@convert_django_field.register(models.TextField)
|
||||
@convert_django_field.register(models.EmailField)
|
||||
@convert_django_field.register(models.SlugField)
|
||||
@convert_django_field.register(models.URLField)
|
||||
@convert_django_field.register(models.GenericIPAddressField)
|
||||
@convert_django_field.register(models.FileField)
|
||||
@convert_django_field.register(UUIDField)
|
||||
def convert_field_to_string(field):
|
||||
return String(description=field.help_text)
|
||||
|
@ -72,6 +75,11 @@ def convert_field_to_float(field):
|
|||
return Float(description=field.help_text)
|
||||
|
||||
|
||||
@convert_django_field.register(models.DateField)
|
||||
def convert_date_to_string(field):
|
||||
return DateTime(description=field.help_text)
|
||||
|
||||
|
||||
@convert_django_field.register(models.ManyToManyField)
|
||||
@convert_django_field.register(models.ManyToOneRel)
|
||||
@convert_django_field.register(models.ManyToManyRel)
|
||||
|
@ -94,3 +102,21 @@ def convert_relatedfield_to_djangomodel(field):
|
|||
def convert_field_to_djangomodel(field):
|
||||
from .fields import DjangoModelField
|
||||
return DjangoModelField(get_related_model(field), description=field.help_text)
|
||||
|
||||
|
||||
@convert_django_field.register(ArrayField)
|
||||
def convert_postgres_array_to_list(field):
|
||||
base_type = convert_django_field(field.base_field)
|
||||
return List(base_type, description=field.help_text)
|
||||
|
||||
|
||||
@convert_django_field.register(HStoreField)
|
||||
@convert_django_field.register(JSONField)
|
||||
def convert_posgres_field_to_string(field):
|
||||
return JSONString(description=field.help_text)
|
||||
|
||||
|
||||
@convert_django_field.register(RangeField)
|
||||
def convert_posgres_range_to_string(field):
|
||||
inner_type = convert_django_field(field.base_field)
|
||||
return List(inner_type, description=field.help_text)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
|
||||
import graphene
|
||||
from graphene.contrib.django import DjangoNode, DjangoConnectionField
|
||||
from graphene.contrib.django import DjangoConnectionField, DjangoNode
|
||||
from graphene.contrib.django.utils import DJANGO_FILTER_INSTALLED
|
||||
|
||||
from ...tests.models import Reporter
|
||||
|
@ -19,6 +19,7 @@ def test_should_query_field():
|
|||
r2.save()
|
||||
|
||||
class ReporterType(DjangoNode):
|
||||
|
||||
class Meta:
|
||||
model = Reporter
|
||||
|
||||
|
|
|
@ -2,10 +2,10 @@ import six
|
|||
from django.conf import settings
|
||||
from django.db import models
|
||||
from django.utils.text import capfirst
|
||||
from graphql_relay.node.node import from_global_id
|
||||
|
||||
from django_filters import Filter, MultipleChoiceFilter
|
||||
from django_filters.filterset import FilterSet, FilterSetMetaclass
|
||||
from graphql_relay.node.node import from_global_id
|
||||
|
||||
from graphene.contrib.django.forms import (GlobalIDFormField,
|
||||
GlobalIDMultipleChoiceField)
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import django_filters
|
||||
|
||||
from graphene.contrib.django.tests.models import Article, Pet, Reporter
|
||||
|
||||
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
import pytest
|
||||
from django.db import models
|
||||
from py.test import raises
|
||||
|
||||
import graphene
|
||||
from graphene.contrib.django.converter import (
|
||||
convert_django_field, convert_django_field_with_choices)
|
||||
from graphene.contrib.django.fields import (ConnectionOrListField,
|
||||
DjangoModelField)
|
||||
from graphene.core.types.custom_scalars import DateTime, JSONString
|
||||
|
||||
from ..compat import (ArrayField, HStoreField, JSONField, MissingType,
|
||||
RangeField)
|
||||
from ..converter import convert_django_field, convert_django_field_with_choices
|
||||
from ..fields import ConnectionOrListField, DjangoModelField
|
||||
from .models import Article, Reporter
|
||||
|
||||
|
||||
|
@ -26,7 +28,7 @@ def test_should_unknown_django_field_raise_exception():
|
|||
|
||||
|
||||
def test_should_date_convert_string():
|
||||
assert_conversion(models.DateField, graphene.String)
|
||||
assert_conversion(models.DateField, DateTime)
|
||||
|
||||
|
||||
def test_should_char_convert_string():
|
||||
|
@ -53,6 +55,14 @@ def test_should_ipaddress_convert_string():
|
|||
assert_conversion(models.GenericIPAddressField, graphene.String)
|
||||
|
||||
|
||||
def test_should_file_convert_string():
|
||||
assert_conversion(models.FileField, graphene.String)
|
||||
|
||||
|
||||
def test_should_image_convert_string():
|
||||
assert_conversion(models.ImageField, graphene.String)
|
||||
|
||||
|
||||
def test_should_auto_convert_id():
|
||||
assert_conversion(models.AutoField, graphene.ID, primary_key=True)
|
||||
|
||||
|
@ -136,3 +146,40 @@ def test_should_onetoone_convert_model():
|
|||
def test_should_foreignkey_convert_model():
|
||||
field = assert_conversion(models.ForeignKey, DjangoModelField, Article)
|
||||
assert field.type.model == Article
|
||||
|
||||
|
||||
@pytest.mark.skipif(ArrayField is MissingType,
|
||||
reason="ArrayField should exist")
|
||||
def test_should_postgres_array_convert_list():
|
||||
field = assert_conversion(ArrayField, graphene.List, models.CharField(max_length=100))
|
||||
assert isinstance(field.type, graphene.List)
|
||||
assert isinstance(field.type.of_type, graphene.String)
|
||||
|
||||
|
||||
@pytest.mark.skipif(ArrayField is MissingType,
|
||||
reason="ArrayField should exist")
|
||||
def test_should_postgres_array_multiple_convert_list():
|
||||
field = assert_conversion(ArrayField, graphene.List, ArrayField(models.CharField(max_length=100)))
|
||||
assert isinstance(field.type, graphene.List)
|
||||
assert isinstance(field.type.of_type, graphene.List)
|
||||
assert isinstance(field.type.of_type.of_type, graphene.String)
|
||||
|
||||
|
||||
@pytest.mark.skipif(HStoreField is MissingType,
|
||||
reason="HStoreField should exist")
|
||||
def test_should_postgres_hstore_convert_string():
|
||||
assert_conversion(HStoreField, JSONString)
|
||||
|
||||
|
||||
@pytest.mark.skipif(JSONField is MissingType,
|
||||
reason="JSONField should exist")
|
||||
def test_should_postgres_json_convert_string():
|
||||
assert_conversion(JSONField, JSONString)
|
||||
|
||||
|
||||
@pytest.mark.skipif(RangeField is MissingType,
|
||||
reason="RangeField should exist")
|
||||
def test_should_postgres_range_convert_list():
|
||||
from django.contrib.postgres.fields import IntegerRangeField
|
||||
field = assert_conversion(IntegerRangeField, graphene.List)
|
||||
assert isinstance(field.type.of_type, graphene.Int)
|
||||
|
|
|
@ -1,10 +1,14 @@
|
|||
import datetime
|
||||
|
||||
import pytest
|
||||
from django.db import models
|
||||
from py.test import raises
|
||||
|
||||
import graphene
|
||||
from graphene import relay
|
||||
from graphene.contrib.django import DjangoNode, DjangoObjectType
|
||||
|
||||
from ..compat import MissingType, RangeField
|
||||
from ..types import DjangoNode, DjangoObjectType
|
||||
from .models import Article, Reporter
|
||||
|
||||
pytestmark = pytest.mark.django_db
|
||||
|
@ -62,6 +66,57 @@ def test_should_query_well():
|
|||
assert result.data == expected
|
||||
|
||||
|
||||
@pytest.mark.skipif(RangeField is MissingType,
|
||||
reason="RangeField should exist")
|
||||
def test_should_query_postgres_fields():
|
||||
from django.contrib.postgres.fields import IntegerRangeField, ArrayField, JSONField, HStoreField
|
||||
|
||||
class Event(models.Model):
|
||||
ages = IntegerRangeField(help_text='The age ranges')
|
||||
data = JSONField(help_text='Data')
|
||||
store = HStoreField()
|
||||
tags = ArrayField(models.CharField(max_length=50))
|
||||
|
||||
class EventType(DjangoObjectType):
|
||||
|
||||
class Meta:
|
||||
model = Event
|
||||
|
||||
class Query(graphene.ObjectType):
|
||||
event = graphene.Field(EventType)
|
||||
|
||||
def resolve_event(self, *args, **kwargs):
|
||||
return Event(
|
||||
ages=(0, 10),
|
||||
data={'angry_babies': True},
|
||||
store={'h': 'store'},
|
||||
tags=['child', 'angry', 'babies']
|
||||
)
|
||||
|
||||
schema = graphene.Schema(query=Query)
|
||||
query = '''
|
||||
query myQuery {
|
||||
event {
|
||||
ages
|
||||
tags
|
||||
data
|
||||
store
|
||||
}
|
||||
}
|
||||
'''
|
||||
expected = {
|
||||
'event': {
|
||||
'ages': [0, 10],
|
||||
'tags': ['child', 'angry', 'babies'],
|
||||
'data': '{"angry_babies": true}',
|
||||
'store': '{"h": "store"}',
|
||||
},
|
||||
}
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data == expected
|
||||
|
||||
|
||||
def test_should_node():
|
||||
class ReporterNode(DjangoNode):
|
||||
|
||||
|
@ -82,7 +137,7 @@ def test_should_node():
|
|||
|
||||
@classmethod
|
||||
def get_node(cls, id, info):
|
||||
return ArticleNode(Article(id=1, headline='Article node'))
|
||||
return ArticleNode(Article(id=1, headline='Article node', pub_date=datetime.date(2002, 3, 11)))
|
||||
|
||||
class Query(graphene.ObjectType):
|
||||
node = relay.NodeField()
|
||||
|
@ -115,6 +170,7 @@ def test_should_node():
|
|||
}
|
||||
... on ArticleNode {
|
||||
headline
|
||||
pubDate
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -135,7 +191,8 @@ def test_should_node():
|
|||
},
|
||||
'myArticle': {
|
||||
'id': 'QXJ0aWNsZU5vZGU6MQ==',
|
||||
'headline': 'Article node'
|
||||
'headline': 'Article node',
|
||||
'pubDate': '2002-03-11',
|
||||
}
|
||||
}
|
||||
schema = graphene.Schema(query=Query)
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
from singledispatch import singledispatch
|
||||
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy.orm import interfaces
|
||||
|
||||
from ...core.classtypes.enum import Enum
|
||||
from ...core.types.scalars import ID, Boolean, Float, Int, String
|
||||
from .fields import ConnectionOrListField, SQLAlchemyModelField
|
||||
|
||||
try:
|
||||
from sqlalchemy_utils.types.choice import ChoiceType
|
||||
except ImportError:
|
||||
class ChoiceType(object):
|
||||
pass
|
||||
|
||||
from ...core.classtypes.enum import Enum
|
||||
from ...core.types.scalars import ID, Boolean, Float, Int, String
|
||||
from .fields import ConnectionOrListField, SQLAlchemyModelField
|
||||
|
||||
|
||||
def convert_sqlalchemy_relationship(relationship):
|
||||
direction = relationship.direction
|
||||
|
|
|
@ -4,7 +4,7 @@ from ...core.types.base import FieldType
|
|||
from ...core.types.definitions import List
|
||||
from ...relay import ConnectionField
|
||||
from ...relay.utils import is_node
|
||||
from .utils import get_type_for_model, maybe_query, get_query
|
||||
from .utils import get_query, get_type_for_model, maybe_query
|
||||
|
||||
|
||||
class DefaultQuery(object):
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
from py.test import raises
|
||||
from sqlalchemy import Column, Table, types
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy_utils.types.choice import ChoiceType
|
||||
|
||||
import graphene
|
||||
from graphene.contrib.sqlalchemy.converter import (convert_sqlalchemy_column,
|
||||
convert_sqlalchemy_relationship)
|
||||
from graphene.contrib.sqlalchemy.fields import (ConnectionOrListField,
|
||||
SQLAlchemyModelField)
|
||||
from sqlalchemy import Table, Column, types
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy_utils.types.choice import ChoiceType
|
||||
|
||||
from .models import Article, Pet, Reporter
|
||||
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
import pytest
|
||||
|
||||
import graphene
|
||||
from graphene import relay
|
||||
from graphene.contrib.sqlalchemy import SQLAlchemyObjectType, SQLAlchemyNode, SQLAlchemyConnectionField
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import scoped_session, sessionmaker
|
||||
|
||||
from .models import Base, Reporter, Article
|
||||
import graphene
|
||||
from graphene import relay
|
||||
from graphene.contrib.sqlalchemy import (SQLAlchemyConnectionField,
|
||||
SQLAlchemyNode, SQLAlchemyObjectType)
|
||||
|
||||
from .models import Article, Base, Reporter
|
||||
|
||||
db = create_engine('sqlite:///test_sqlalchemy.sqlite3')
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from graphene import Schema, ObjectType, String
|
||||
from graphene import ObjectType, Schema, String
|
||||
|
||||
from ..utils import get_session
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import inspect
|
||||
|
||||
import six
|
||||
|
||||
from sqlalchemy.inspection import inspect as sqlalchemyinspect
|
||||
from sqlalchemy.orm.exc import NoResultFound
|
||||
|
||||
|
@ -10,7 +9,7 @@ from ...relay.types import Connection, Node, NodeMeta
|
|||
from .converter import (convert_sqlalchemy_column,
|
||||
convert_sqlalchemy_relationship)
|
||||
from .options import SQLAlchemyOptions
|
||||
from .utils import is_mapped, get_query
|
||||
from .utils import get_query, is_mapped
|
||||
|
||||
|
||||
class SQLAlchemyObjectTypeMeta(ObjectTypeMeta):
|
||||
|
|
41
graphene/core/types/custom_scalars.py
Normal file
41
graphene/core/types/custom_scalars.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
import datetime
|
||||
import json
|
||||
|
||||
from graphql.core.language import ast
|
||||
|
||||
from ...core.classtypes.scalar import Scalar
|
||||
|
||||
|
||||
class JSONString(Scalar):
|
||||
'''JSON String'''
|
||||
|
||||
@staticmethod
|
||||
def serialize(dt):
|
||||
return json.dumps(dt)
|
||||
|
||||
@staticmethod
|
||||
def parse_literal(node):
|
||||
if isinstance(node, ast.StringValue):
|
||||
return json.dumps(node.value)
|
||||
|
||||
@staticmethod
|
||||
def parse_value(value):
|
||||
return json.dumps(value)
|
||||
|
||||
|
||||
class DateTime(Scalar):
|
||||
'''DateTime in ISO 8601 format'''
|
||||
|
||||
@staticmethod
|
||||
def serialize(dt):
|
||||
return dt.isoformat()
|
||||
|
||||
@staticmethod
|
||||
def parse_literal(node):
|
||||
if isinstance(node, ast.StringValue):
|
||||
return datetime.datetime.strptime(
|
||||
node.value, "%Y-%m-%dT%H:%M:%S.%f")
|
||||
|
||||
@staticmethod
|
||||
def parse_value(value):
|
||||
return datetime.datetime.strptime(value, "%Y-%m-%dT%H:%M:%S.%f")
|
|
@ -49,6 +49,7 @@ except ImportError:
|
|||
class's __getattr__ method; this is done by raising AttributeError.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, fget=None):
|
||||
self.fget = fget
|
||||
|
||||
|
@ -63,7 +64,6 @@ except ImportError:
|
|||
def __delete__(self, instance):
|
||||
raise AttributeError("can't delete attribute")
|
||||
|
||||
|
||||
def _is_descriptor(obj):
|
||||
"""Returns True if obj is a descriptor, False otherwise."""
|
||||
return (
|
||||
|
@ -71,7 +71,6 @@ except ImportError:
|
|||
hasattr(obj, '__set__') or
|
||||
hasattr(obj, '__delete__'))
|
||||
|
||||
|
||||
def _is_dunder(name):
|
||||
"""Returns True if a __dunder__ name, False otherwise."""
|
||||
return (name[:2] == name[-2:] == '__' and
|
||||
|
@ -79,7 +78,6 @@ except ImportError:
|
|||
name[-3:-2] != '_' and
|
||||
len(name) > 4)
|
||||
|
||||
|
||||
def _is_sunder(name):
|
||||
"""Returns True if a _sunder_ name, False otherwise."""
|
||||
return (name[0] == name[-1] == '_' and
|
||||
|
@ -87,15 +85,14 @@ except ImportError:
|
|||
name[-2:-1] != '_' and
|
||||
len(name) > 2)
|
||||
|
||||
|
||||
def _make_class_unpicklable(cls):
|
||||
"""Make the given class un-picklable."""
|
||||
|
||||
def _break_on_call_reduce(self, protocol=None):
|
||||
raise TypeError('%r cannot be pickled' % self)
|
||||
cls.__reduce_ex__ = _break_on_call_reduce
|
||||
cls.__module__ = '<unknown>'
|
||||
|
||||
|
||||
class _EnumDict(dict):
|
||||
"""Track enum member order and ensure member names are not reused.
|
||||
|
||||
|
@ -103,6 +100,7 @@ except ImportError:
|
|||
enumeration member names.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(_EnumDict, self).__init__()
|
||||
self._member_names = []
|
||||
|
@ -139,13 +137,11 @@ except ImportError:
|
|||
self._member_names.append(key)
|
||||
super(_EnumDict, self).__setitem__(key, value)
|
||||
|
||||
|
||||
# Dummy value for Enum as EnumMeta explicity checks for it, but of course until
|
||||
# EnumMeta finishes running the first time the Enum class doesn't exist. This
|
||||
# is also why there are checks in EnumMeta like `if Enum is not None`
|
||||
Enum = None
|
||||
|
||||
|
||||
class EnumMeta(type):
|
||||
"""Metaclass for Enum"""
|
||||
@classmethod
|
||||
|
@ -157,7 +153,7 @@ except ImportError:
|
|||
# cannot be mixed with other types (int, float, etc.) if it has an
|
||||
# inherited __new__ unless a new __new__ is defined (or the resulting
|
||||
# class will fail).
|
||||
if type(classdict) is dict:
|
||||
if isinstance(classdict, dict):
|
||||
original_dict = classdict
|
||||
classdict = _EnumDict()
|
||||
for k, v in original_dict.items():
|
||||
|
@ -259,7 +255,6 @@ except ImportError:
|
|||
except TypeError:
|
||||
pass
|
||||
|
||||
|
||||
# If a custom type is mixed into the Enum, and it does not know how
|
||||
# to pickle itself, pickle.dumps will succeed but pickle.loads will
|
||||
# fail. Rather than have the error show up later and possibly far
|
||||
|
@ -279,12 +274,11 @@ except ImportError:
|
|||
_make_class_unpicklable(enum_class)
|
||||
unpicklable = True
|
||||
|
||||
|
||||
# double check that repr and friends are not the mixin's or various
|
||||
# things break (such as pickle)
|
||||
for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'):
|
||||
class_method = getattr(enum_class, name)
|
||||
obj_method = getattr(member_type, name, None)
|
||||
getattr(member_type, name, None)
|
||||
enum_method = getattr(first_enum, name, None)
|
||||
if name not in classdict and class_method is not enum_method:
|
||||
if name == '__reduce_ex__' and unpicklable:
|
||||
|
@ -444,7 +438,7 @@ except ImportError:
|
|||
if isinstance(names, basestring):
|
||||
names = names.replace(',', ' ').split()
|
||||
if isinstance(names, (tuple, list)) and isinstance(names[0], basestring):
|
||||
names = [(e, i+start) for (i, e) in enumerate(names)]
|
||||
names = [(e, i + start) for (i, e) in enumerate(names)]
|
||||
|
||||
# Here, names is either an iterable of (name, value) or a mapping.
|
||||
item = None # in case names is empty
|
||||
|
@ -485,7 +479,6 @@ except ImportError:
|
|||
if not bases or Enum is None:
|
||||
return object, Enum
|
||||
|
||||
|
||||
# double check that we are not subclassing a class with existing
|
||||
# enumeration members; while we're at it, see if any other data
|
||||
# type has been mixed in so we can use the correct __new__
|
||||
|
@ -625,7 +618,6 @@ except ImportError:
|
|||
|
||||
return __new__, save_new, use_args
|
||||
|
||||
|
||||
########################################################
|
||||
# In order to support Python 2 and 3 with a single
|
||||
# codebase we have to create the Enum methods separately
|
||||
|
@ -639,10 +631,10 @@ except ImportError:
|
|||
# all enum instances are actually created during class construction
|
||||
# without calling this method; this method is called by the metaclass'
|
||||
# __call__ (i.e. Color(3) ), and by pickle
|
||||
if type(value) is cls:
|
||||
if isinstance(value, cls):
|
||||
# For lookups like Color(Color.red)
|
||||
value = value.value
|
||||
#return value
|
||||
# return value
|
||||
# by-value search for a matching enum member
|
||||
# see if it's in the reverse mapping (for hashable values)
|
||||
try:
|
||||
|
@ -697,14 +689,13 @@ except ImportError:
|
|||
temp_enum_dict['__format__'] = __format__
|
||||
del __format__
|
||||
|
||||
|
||||
####################################
|
||||
# Python's less than 2.6 use __cmp__
|
||||
|
||||
if pyver < 2.6:
|
||||
|
||||
def __cmp__(self, other):
|
||||
if type(other) is self.__class__:
|
||||
if isinstance(other, self.__class__):
|
||||
if self is other:
|
||||
return 0
|
||||
return -1
|
||||
|
@ -735,16 +726,15 @@ except ImportError:
|
|||
temp_enum_dict['__gt__'] = __gt__
|
||||
del __gt__
|
||||
|
||||
|
||||
def __eq__(self, other):
|
||||
if type(other) is self.__class__:
|
||||
if isinstance(other, self.__class__):
|
||||
return self is other
|
||||
return NotImplemented
|
||||
temp_enum_dict['__eq__'] = __eq__
|
||||
del __eq__
|
||||
|
||||
def __ne__(self, other):
|
||||
if type(other) is self.__class__:
|
||||
if isinstance(other, self.__class__):
|
||||
return self is not other
|
||||
return NotImplemented
|
||||
temp_enum_dict['__ne__'] = __ne__
|
||||
|
|
Loading…
Reference in New Issue
Block a user