Merge pull request #140 from graphql-python/features/django-fields

Improved support for Django fields
This commit is contained in:
Syrus Akbary 2016-04-02 21:05:03 -07:00
commit 332d7b0227
22 changed files with 269 additions and 85 deletions

View File

@ -24,6 +24,7 @@ install:
- | - |
if [ "$TEST_TYPE" = build ]; then if [ "$TEST_TYPE" = build ]; then
pip install --download-cache $HOME/.cache/pip/ pytest pytest-cov coveralls six pytest-django django-filter sqlalchemy_utils pip install --download-cache $HOME/.cache/pip/ pytest pytest-cov coveralls six pytest-django django-filter sqlalchemy_utils
pip install --download-cache $HOME/.cache/pip psycopg2 > /dev/null 2>&1
pip install --download-cache $HOME/.cache/pip/ -e .[django] pip install --download-cache $HOME/.cache/pip/ -e .[django]
pip install --download-cache $HOME/.cache/pip/ -e .[sqlalchemy] pip install --download-cache $HOME/.cache/pip/ -e .[sqlalchemy]
pip install django==$DJANGO_VERSION pip install django==$DJANGO_VERSION

View File

@ -16,6 +16,10 @@ Also the following Types are available:
- `graphene.List` - `graphene.List`
- `graphene.NonNull` - `graphene.NonNull`
Graphene also provides custom scalars for Dates and JSON:
- `graphene.core.types.custom_scalars.DateTime`
- `graphene.core.types.custom_scalars.JSONString`
## Shortcuts ## Shortcuts
There are some shortcuts for building schemas more easily. There are some shortcuts for building schemas more easily.

View File

@ -1,8 +1,8 @@
from flask import Flask from flask import Flask
from database import db_session, init_db
from schema import schema from database import db_session, init_db
from flask_graphql import GraphQL from flask_graphql import GraphQL
from schema import schema
app = Flask(__name__) app = Flask(__name__)
app.debug = True app.debug = True

View File

@ -1,6 +1,6 @@
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import scoped_session, sessionmaker
engine = create_engine('sqlite:///database.sqlite3', convert_unicode=True) engine = create_engine('sqlite:///database.sqlite3', convert_unicode=True)
db_session = scoped_session(sessionmaker(autocommit=False, db_session = scoped_session(sessionmaker(autocommit=False,

View File

@ -1,6 +1,7 @@
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, func
from sqlalchemy.orm import backref, relationship
from database import Base from database import Base
from sqlalchemy import Column, DateTime, String, Integer, ForeignKey, func
from sqlalchemy.orm import relationship, backref
class Department(Base): class Department(Base):

View File

@ -1,19 +1,23 @@
import graphene import graphene
from graphene import relay from graphene import relay
from graphene.contrib.sqlalchemy import SQLAlchemyNode, SQLAlchemyConnectionField from graphene.contrib.sqlalchemy import (SQLAlchemyConnectionField,
from models import Department as DepartmentModel, Employee as EmployeeModel SQLAlchemyNode)
from models import Department as DepartmentModel
from models import Employee as EmployeeModel
schema = graphene.Schema() schema = graphene.Schema()
@schema.register @schema.register
class Department(SQLAlchemyNode): class Department(SQLAlchemyNode):
class Meta: class Meta:
model = DepartmentModel model = DepartmentModel
@schema.register @schema.register
class Employee(SQLAlchemyNode): class Employee(SQLAlchemyNode):
class Meta: class Meta:
model = EmployeeModel model = EmployeeModel

View File

@ -1,15 +1,24 @@
from django.db import models from django.db import models
class MissingType(object):
pass
try: try:
UUIDField = models.UUIDField UUIDField = models.UUIDField
except AttributeError: except AttributeError:
# Improved compatibility for Django 1.6 # Improved compatibility for Django 1.6
class UUIDField(object): UUIDField = MissingType
pass
try: try:
from django.db.models.related import RelatedObject from django.db.models.related import RelatedObject
except: except:
# Improved compatibility for Django 1.6 # Improved compatibility for Django 1.6
class RelatedObject(object): RelatedObject = MissingType
pass
try:
# Postgres fields are only available in Django 1.8+
from django.contrib.postgres.fields import ArrayField, HStoreField, JSONField, RangeField
except ImportError:
ArrayField, HStoreField, JSONField, RangeField = (MissingType, ) * 4

View File

@ -1,9 +1,12 @@
from django.db import models from django.db import models
from ...core.types.scalars import ID, Boolean, Float, Int, String
from ...core.classtypes.enum import Enum from ...core.classtypes.enum import Enum
from ...core.types.custom_scalars import DateTime, JSONString
from ...core.types.definitions import List
from ...core.types.scalars import ID, Boolean, Float, Int, String
from ...utils import to_const from ...utils import to_const
from .compat import RelatedObject, UUIDField from .compat import (ArrayField, HStoreField, JSONField, RangeField,
RelatedObject, UUIDField)
from .utils import get_related_model, import_single_dispatch from .utils import get_related_model, import_single_dispatch
singledispatch = import_single_dispatch() singledispatch = import_single_dispatch()
@ -30,13 +33,13 @@ def convert_django_field(field):
(field, field.__class__)) (field, field.__class__))
@convert_django_field.register(models.DateField)
@convert_django_field.register(models.CharField) @convert_django_field.register(models.CharField)
@convert_django_field.register(models.TextField) @convert_django_field.register(models.TextField)
@convert_django_field.register(models.EmailField) @convert_django_field.register(models.EmailField)
@convert_django_field.register(models.SlugField) @convert_django_field.register(models.SlugField)
@convert_django_field.register(models.URLField) @convert_django_field.register(models.URLField)
@convert_django_field.register(models.GenericIPAddressField) @convert_django_field.register(models.GenericIPAddressField)
@convert_django_field.register(models.FileField)
@convert_django_field.register(UUIDField) @convert_django_field.register(UUIDField)
def convert_field_to_string(field): def convert_field_to_string(field):
return String(description=field.help_text) return String(description=field.help_text)
@ -72,6 +75,11 @@ def convert_field_to_float(field):
return Float(description=field.help_text) return Float(description=field.help_text)
@convert_django_field.register(models.DateField)
def convert_date_to_string(field):
return DateTime(description=field.help_text)
@convert_django_field.register(models.ManyToManyField) @convert_django_field.register(models.ManyToManyField)
@convert_django_field.register(models.ManyToOneRel) @convert_django_field.register(models.ManyToOneRel)
@convert_django_field.register(models.ManyToManyRel) @convert_django_field.register(models.ManyToManyRel)
@ -94,3 +102,21 @@ def convert_relatedfield_to_djangomodel(field):
def convert_field_to_djangomodel(field): def convert_field_to_djangomodel(field):
from .fields import DjangoModelField from .fields import DjangoModelField
return DjangoModelField(get_related_model(field), description=field.help_text) return DjangoModelField(get_related_model(field), description=field.help_text)
@convert_django_field.register(ArrayField)
def convert_postgres_array_to_list(field):
base_type = convert_django_field(field.base_field)
return List(base_type, description=field.help_text)
@convert_django_field.register(HStoreField)
@convert_django_field.register(JSONField)
def convert_posgres_field_to_string(field):
return JSONString(description=field.help_text)
@convert_django_field.register(RangeField)
def convert_posgres_range_to_string(field):
inner_type = convert_django_field(field.base_field)
return List(inner_type, description=field.help_text)

View File

@ -1,7 +1,7 @@
import pytest import pytest
import graphene import graphene
from graphene.contrib.django import DjangoNode, DjangoConnectionField from graphene.contrib.django import DjangoConnectionField, DjangoNode
from graphene.contrib.django.utils import DJANGO_FILTER_INSTALLED from graphene.contrib.django.utils import DJANGO_FILTER_INSTALLED
from ...tests.models import Reporter from ...tests.models import Reporter
@ -19,6 +19,7 @@ def test_should_query_field():
r2.save() r2.save()
class ReporterType(DjangoNode): class ReporterType(DjangoNode):
class Meta: class Meta:
model = Reporter model = Reporter

View File

@ -2,10 +2,10 @@ import six
from django.conf import settings from django.conf import settings
from django.db import models from django.db import models
from django.utils.text import capfirst from django.utils.text import capfirst
from graphql_relay.node.node import from_global_id
from django_filters import Filter, MultipleChoiceFilter from django_filters import Filter, MultipleChoiceFilter
from django_filters.filterset import FilterSet, FilterSetMetaclass from django_filters.filterset import FilterSet, FilterSetMetaclass
from graphql_relay.node.node import from_global_id
from graphene.contrib.django.forms import (GlobalIDFormField, from graphene.contrib.django.forms import (GlobalIDFormField,
GlobalIDMultipleChoiceField) GlobalIDMultipleChoiceField)

View File

@ -1,4 +1,5 @@
import django_filters import django_filters
from graphene.contrib.django.tests.models import Article, Pet, Reporter from graphene.contrib.django.tests.models import Article, Pet, Reporter

View File

@ -1,12 +1,14 @@
import pytest
from django.db import models from django.db import models
from py.test import raises from py.test import raises
import graphene import graphene
from graphene.contrib.django.converter import ( from graphene.core.types.custom_scalars import DateTime, JSONString
convert_django_field, convert_django_field_with_choices)
from graphene.contrib.django.fields import (ConnectionOrListField,
DjangoModelField)
from ..compat import (ArrayField, HStoreField, JSONField, MissingType,
RangeField)
from ..converter import convert_django_field, convert_django_field_with_choices
from ..fields import ConnectionOrListField, DjangoModelField
from .models import Article, Reporter from .models import Article, Reporter
@ -26,7 +28,7 @@ def test_should_unknown_django_field_raise_exception():
def test_should_date_convert_string(): def test_should_date_convert_string():
assert_conversion(models.DateField, graphene.String) assert_conversion(models.DateField, DateTime)
def test_should_char_convert_string(): def test_should_char_convert_string():
@ -53,6 +55,14 @@ def test_should_ipaddress_convert_string():
assert_conversion(models.GenericIPAddressField, graphene.String) assert_conversion(models.GenericIPAddressField, graphene.String)
def test_should_file_convert_string():
assert_conversion(models.FileField, graphene.String)
def test_should_image_convert_string():
assert_conversion(models.ImageField, graphene.String)
def test_should_auto_convert_id(): def test_should_auto_convert_id():
assert_conversion(models.AutoField, graphene.ID, primary_key=True) assert_conversion(models.AutoField, graphene.ID, primary_key=True)
@ -136,3 +146,40 @@ def test_should_onetoone_convert_model():
def test_should_foreignkey_convert_model(): def test_should_foreignkey_convert_model():
field = assert_conversion(models.ForeignKey, DjangoModelField, Article) field = assert_conversion(models.ForeignKey, DjangoModelField, Article)
assert field.type.model == Article assert field.type.model == Article
@pytest.mark.skipif(ArrayField is MissingType,
reason="ArrayField should exist")
def test_should_postgres_array_convert_list():
field = assert_conversion(ArrayField, graphene.List, models.CharField(max_length=100))
assert isinstance(field.type, graphene.List)
assert isinstance(field.type.of_type, graphene.String)
@pytest.mark.skipif(ArrayField is MissingType,
reason="ArrayField should exist")
def test_should_postgres_array_multiple_convert_list():
field = assert_conversion(ArrayField, graphene.List, ArrayField(models.CharField(max_length=100)))
assert isinstance(field.type, graphene.List)
assert isinstance(field.type.of_type, graphene.List)
assert isinstance(field.type.of_type.of_type, graphene.String)
@pytest.mark.skipif(HStoreField is MissingType,
reason="HStoreField should exist")
def test_should_postgres_hstore_convert_string():
assert_conversion(HStoreField, JSONString)
@pytest.mark.skipif(JSONField is MissingType,
reason="JSONField should exist")
def test_should_postgres_json_convert_string():
assert_conversion(JSONField, JSONString)
@pytest.mark.skipif(RangeField is MissingType,
reason="RangeField should exist")
def test_should_postgres_range_convert_list():
from django.contrib.postgres.fields import IntegerRangeField
field = assert_conversion(IntegerRangeField, graphene.List)
assert isinstance(field.type.of_type, graphene.Int)

View File

@ -1,10 +1,14 @@
import datetime
import pytest import pytest
from django.db import models
from py.test import raises from py.test import raises
import graphene import graphene
from graphene import relay from graphene import relay
from graphene.contrib.django import DjangoNode, DjangoObjectType
from ..compat import MissingType, RangeField
from ..types import DjangoNode, DjangoObjectType
from .models import Article, Reporter from .models import Article, Reporter
pytestmark = pytest.mark.django_db pytestmark = pytest.mark.django_db
@ -62,6 +66,57 @@ def test_should_query_well():
assert result.data == expected assert result.data == expected
@pytest.mark.skipif(RangeField is MissingType,
reason="RangeField should exist")
def test_should_query_postgres_fields():
from django.contrib.postgres.fields import IntegerRangeField, ArrayField, JSONField, HStoreField
class Event(models.Model):
ages = IntegerRangeField(help_text='The age ranges')
data = JSONField(help_text='Data')
store = HStoreField()
tags = ArrayField(models.CharField(max_length=50))
class EventType(DjangoObjectType):
class Meta:
model = Event
class Query(graphene.ObjectType):
event = graphene.Field(EventType)
def resolve_event(self, *args, **kwargs):
return Event(
ages=(0, 10),
data={'angry_babies': True},
store={'h': 'store'},
tags=['child', 'angry', 'babies']
)
schema = graphene.Schema(query=Query)
query = '''
query myQuery {
event {
ages
tags
data
store
}
}
'''
expected = {
'event': {
'ages': [0, 10],
'tags': ['child', 'angry', 'babies'],
'data': '{"angry_babies": true}',
'store': '{"h": "store"}',
},
}
result = schema.execute(query)
assert not result.errors
assert result.data == expected
def test_should_node(): def test_should_node():
class ReporterNode(DjangoNode): class ReporterNode(DjangoNode):
@ -82,7 +137,7 @@ def test_should_node():
@classmethod @classmethod
def get_node(cls, id, info): def get_node(cls, id, info):
return ArticleNode(Article(id=1, headline='Article node')) return ArticleNode(Article(id=1, headline='Article node', pub_date=datetime.date(2002, 3, 11)))
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
node = relay.NodeField() node = relay.NodeField()
@ -115,6 +170,7 @@ def test_should_node():
} }
... on ArticleNode { ... on ArticleNode {
headline headline
pubDate
} }
} }
} }
@ -135,7 +191,8 @@ def test_should_node():
}, },
'myArticle': { 'myArticle': {
'id': 'QXJ0aWNsZU5vZGU6MQ==', 'id': 'QXJ0aWNsZU5vZGU6MQ==',
'headline': 'Article node' 'headline': 'Article node',
'pubDate': '2002-03-11',
} }
} }
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)

View File

@ -1,17 +1,17 @@
from singledispatch import singledispatch from singledispatch import singledispatch
from sqlalchemy import types from sqlalchemy import types
from sqlalchemy.orm import interfaces from sqlalchemy.orm import interfaces
from ...core.classtypes.enum import Enum
from ...core.types.scalars import ID, Boolean, Float, Int, String
from .fields import ConnectionOrListField, SQLAlchemyModelField
try: try:
from sqlalchemy_utils.types.choice import ChoiceType from sqlalchemy_utils.types.choice import ChoiceType
except ImportError: except ImportError:
class ChoiceType(object): class ChoiceType(object):
pass pass
from ...core.classtypes.enum import Enum
from ...core.types.scalars import ID, Boolean, Float, Int, String
from .fields import ConnectionOrListField, SQLAlchemyModelField
def convert_sqlalchemy_relationship(relationship): def convert_sqlalchemy_relationship(relationship):
direction = relationship.direction direction = relationship.direction

View File

@ -4,7 +4,7 @@ from ...core.types.base import FieldType
from ...core.types.definitions import List from ...core.types.definitions import List
from ...relay import ConnectionField from ...relay import ConnectionField
from ...relay.utils import is_node from ...relay.utils import is_node
from .utils import get_type_for_model, maybe_query, get_query from .utils import get_query, get_type_for_model, maybe_query
class DefaultQuery(object): class DefaultQuery(object):

View File

@ -1,13 +1,13 @@
from py.test import raises from py.test import raises
from sqlalchemy import Column, Table, types
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils.types.choice import ChoiceType
import graphene import graphene
from graphene.contrib.sqlalchemy.converter import (convert_sqlalchemy_column, from graphene.contrib.sqlalchemy.converter import (convert_sqlalchemy_column,
convert_sqlalchemy_relationship) convert_sqlalchemy_relationship)
from graphene.contrib.sqlalchemy.fields import (ConnectionOrListField, from graphene.contrib.sqlalchemy.fields import (ConnectionOrListField,
SQLAlchemyModelField) SQLAlchemyModelField)
from sqlalchemy import Table, Column, types
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils.types.choice import ChoiceType
from .models import Article, Pet, Reporter from .models import Article, Pet, Reporter

View File

@ -1,12 +1,13 @@
import pytest import pytest
import graphene
from graphene import relay
from graphene.contrib.sqlalchemy import SQLAlchemyObjectType, SQLAlchemyNode, SQLAlchemyConnectionField
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.orm import scoped_session, sessionmaker
from .models import Base, Reporter, Article import graphene
from graphene import relay
from graphene.contrib.sqlalchemy import (SQLAlchemyConnectionField,
SQLAlchemyNode, SQLAlchemyObjectType)
from .models import Article, Base, Reporter
db = create_engine('sqlite:///test_sqlalchemy.sqlite3') db = create_engine('sqlite:///test_sqlalchemy.sqlite3')

View File

@ -1,4 +1,4 @@
from graphene import Schema, ObjectType, String from graphene import ObjectType, Schema, String
from ..utils import get_session from ..utils import get_session

View File

@ -1,7 +1,6 @@
import inspect import inspect
import six import six
from sqlalchemy.inspection import inspect as sqlalchemyinspect from sqlalchemy.inspection import inspect as sqlalchemyinspect
from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.exc import NoResultFound
@ -10,7 +9,7 @@ from ...relay.types import Connection, Node, NodeMeta
from .converter import (convert_sqlalchemy_column, from .converter import (convert_sqlalchemy_column,
convert_sqlalchemy_relationship) convert_sqlalchemy_relationship)
from .options import SQLAlchemyOptions from .options import SQLAlchemyOptions
from .utils import is_mapped, get_query from .utils import get_query, is_mapped
class SQLAlchemyObjectTypeMeta(ObjectTypeMeta): class SQLAlchemyObjectTypeMeta(ObjectTypeMeta):

View File

@ -0,0 +1,41 @@
import datetime
import json
from graphql.core.language import ast
from ...core.classtypes.scalar import Scalar
class JSONString(Scalar):
'''JSON String'''
@staticmethod
def serialize(dt):
return json.dumps(dt)
@staticmethod
def parse_literal(node):
if isinstance(node, ast.StringValue):
return json.dumps(node.value)
@staticmethod
def parse_value(value):
return json.dumps(value)
class DateTime(Scalar):
'''DateTime in ISO 8601 format'''
@staticmethod
def serialize(dt):
return dt.isoformat()
@staticmethod
def parse_literal(node):
if isinstance(node, ast.StringValue):
return datetime.datetime.strptime(
node.value, "%Y-%m-%dT%H:%M:%S.%f")
@staticmethod
def parse_value(value):
return datetime.datetime.strptime(value, "%Y-%m-%dT%H:%M:%S.%f")

View File

@ -49,6 +49,7 @@ except ImportError:
class's __getattr__ method; this is done by raising AttributeError. class's __getattr__ method; this is done by raising AttributeError.
""" """
def __init__(self, fget=None): def __init__(self, fget=None):
self.fget = fget self.fget = fget
@ -63,14 +64,12 @@ except ImportError:
def __delete__(self, instance): def __delete__(self, instance):
raise AttributeError("can't delete attribute") raise AttributeError("can't delete attribute")
def _is_descriptor(obj): def _is_descriptor(obj):
"""Returns True if obj is a descriptor, False otherwise.""" """Returns True if obj is a descriptor, False otherwise."""
return ( return (
hasattr(obj, '__get__') or hasattr(obj, '__get__') or
hasattr(obj, '__set__') or hasattr(obj, '__set__') or
hasattr(obj, '__delete__')) hasattr(obj, '__delete__'))
def _is_dunder(name): def _is_dunder(name):
"""Returns True if a __dunder__ name, False otherwise.""" """Returns True if a __dunder__ name, False otherwise."""
@ -79,7 +78,6 @@ except ImportError:
name[-3:-2] != '_' and name[-3:-2] != '_' and
len(name) > 4) len(name) > 4)
def _is_sunder(name): def _is_sunder(name):
"""Returns True if a _sunder_ name, False otherwise.""" """Returns True if a _sunder_ name, False otherwise."""
return (name[0] == name[-1] == '_' and return (name[0] == name[-1] == '_' and
@ -87,15 +85,14 @@ except ImportError:
name[-2:-1] != '_' and name[-2:-1] != '_' and
len(name) > 2) len(name) > 2)
def _make_class_unpicklable(cls): def _make_class_unpicklable(cls):
"""Make the given class un-picklable.""" """Make the given class un-picklable."""
def _break_on_call_reduce(self, protocol=None): def _break_on_call_reduce(self, protocol=None):
raise TypeError('%r cannot be pickled' % self) raise TypeError('%r cannot be pickled' % self)
cls.__reduce_ex__ = _break_on_call_reduce cls.__reduce_ex__ = _break_on_call_reduce
cls.__module__ = '<unknown>' cls.__module__ = '<unknown>'
class _EnumDict(dict): class _EnumDict(dict):
"""Track enum member order and ensure member names are not reused. """Track enum member order and ensure member names are not reused.
@ -103,6 +100,7 @@ except ImportError:
enumeration member names. enumeration member names.
""" """
def __init__(self): def __init__(self):
super(_EnumDict, self).__init__() super(_EnumDict, self).__init__()
self._member_names = [] self._member_names = []
@ -124,7 +122,7 @@ except ImportError:
""" """
if pyver >= 3.0 and key == '__order__': if pyver >= 3.0 and key == '__order__':
return return
if _is_sunder(key): if _is_sunder(key):
raise ValueError('_names_ are reserved for future Enum use') raise ValueError('_names_ are reserved for future Enum use')
elif _is_dunder(key): elif _is_dunder(key):
@ -139,13 +137,11 @@ except ImportError:
self._member_names.append(key) self._member_names.append(key)
super(_EnumDict, self).__setitem__(key, value) super(_EnumDict, self).__setitem__(key, value)
# Dummy value for Enum as EnumMeta explicity checks for it, but of course until # Dummy value for Enum as EnumMeta explicity checks for it, but of course until
# EnumMeta finishes running the first time the Enum class doesn't exist. This # EnumMeta finishes running the first time the Enum class doesn't exist. This
# is also why there are checks in EnumMeta like `if Enum is not None` # is also why there are checks in EnumMeta like `if Enum is not None`
Enum = None Enum = None
class EnumMeta(type): class EnumMeta(type):
"""Metaclass for Enum""" """Metaclass for Enum"""
@classmethod @classmethod
@ -157,7 +153,7 @@ except ImportError:
# cannot be mixed with other types (int, float, etc.) if it has an # cannot be mixed with other types (int, float, etc.) if it has an
# inherited __new__ unless a new __new__ is defined (or the resulting # inherited __new__ unless a new __new__ is defined (or the resulting
# class will fail). # class will fail).
if type(classdict) is dict: if isinstance(classdict, dict):
original_dict = classdict original_dict = classdict
classdict = _EnumDict() classdict = _EnumDict()
for k, v in original_dict.items(): for k, v in original_dict.items():
@ -165,7 +161,7 @@ except ImportError:
member_type, first_enum = metacls._get_mixins_(bases) member_type, first_enum = metacls._get_mixins_(bases)
__new__, save_new, use_args = metacls._find_new_(classdict, member_type, __new__, save_new, use_args = metacls._find_new_(classdict, member_type,
first_enum) first_enum)
# save enum items into separate mapping so they don't get baked into # save enum items into separate mapping so they don't get baked into
# the new class # the new class
members = dict((k, classdict[k]) for k in classdict._member_names) members = dict((k, classdict[k]) for k in classdict._member_names)
@ -259,7 +255,6 @@ except ImportError:
except TypeError: except TypeError:
pass pass
# If a custom type is mixed into the Enum, and it does not know how # If a custom type is mixed into the Enum, and it does not know how
# to pickle itself, pickle.dumps will succeed but pickle.loads will # to pickle itself, pickle.dumps will succeed but pickle.loads will
# fail. Rather than have the error show up later and possibly far # fail. Rather than have the error show up later and possibly far
@ -274,17 +269,16 @@ except ImportError:
if '__reduce_ex__' not in classdict: if '__reduce_ex__' not in classdict:
if member_type is not object: if member_type is not object:
methods = ('__getnewargs_ex__', '__getnewargs__', methods = ('__getnewargs_ex__', '__getnewargs__',
'__reduce_ex__', '__reduce__') '__reduce_ex__', '__reduce__')
if not any(m in member_type.__dict__ for m in methods): if not any(m in member_type.__dict__ for m in methods):
_make_class_unpicklable(enum_class) _make_class_unpicklable(enum_class)
unpicklable = True unpicklable = True
# double check that repr and friends are not the mixin's or various # double check that repr and friends are not the mixin's or various
# things break (such as pickle) # things break (such as pickle)
for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'): for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'):
class_method = getattr(enum_class, name) class_method = getattr(enum_class, name)
obj_method = getattr(member_type, name, None) getattr(member_type, name, None)
enum_method = getattr(first_enum, name, None) enum_method = getattr(first_enum, name, None)
if name not in classdict and class_method is not enum_method: if name not in classdict and class_method is not enum_method:
if name == '__reduce_ex__' and unpicklable: if name == '__reduce_ex__' and unpicklable:
@ -310,7 +304,7 @@ except ImportError:
'__eq__', '__eq__',
'__ne__', '__ne__',
'__hash__', '__hash__',
): ):
setattr(enum_class, method, getattr(int, method)) setattr(enum_class, method, getattr(int, method))
# replace any other __new__ with our own (as long as Enum is not None, # replace any other __new__ with our own (as long as Enum is not None,
@ -352,7 +346,7 @@ except ImportError:
# (see issue19025). # (see issue19025).
if attr in cls._member_map_: if attr in cls._member_map_:
raise AttributeError( raise AttributeError(
"%s: cannot delete Enum member." % cls.__name__) "%s: cannot delete Enum member." % cls.__name__)
super(EnumMeta, cls).__delattr__(attr) super(EnumMeta, cls).__delattr__(attr)
def __dir__(self): def __dir__(self):
@ -444,7 +438,7 @@ except ImportError:
if isinstance(names, basestring): if isinstance(names, basestring):
names = names.replace(',', ' ').split() names = names.replace(',', ' ').split()
if isinstance(names, (tuple, list)) and isinstance(names[0], basestring): if isinstance(names, (tuple, list)) and isinstance(names[0], basestring):
names = [(e, i+start) for (i, e) in enumerate(names)] names = [(e, i + start) for (i, e) in enumerate(names)]
# Here, names is either an iterable of (name, value) or a mapping. # Here, names is either an iterable of (name, value) or a mapping.
item = None # in case names is empty item = None # in case names is empty
@ -485,20 +479,19 @@ except ImportError:
if not bases or Enum is None: if not bases or Enum is None:
return object, Enum return object, Enum
# double check that we are not subclassing a class with existing # double check that we are not subclassing a class with existing
# enumeration members; while we're at it, see if any other data # enumeration members; while we're at it, see if any other data
# type has been mixed in so we can use the correct __new__ # type has been mixed in so we can use the correct __new__
member_type = first_enum = None member_type = first_enum = None
for base in bases: for base in bases:
if (base is not Enum and if (base is not Enum and
issubclass(base, Enum) and issubclass(base, Enum) and
base._member_names_): base._member_names_):
raise TypeError("Cannot extend enumerations") raise TypeError("Cannot extend enumerations")
# base is now the last base in bases # base is now the last base in bases
if not issubclass(base, Enum): if not issubclass(base, Enum):
raise TypeError("new enumerations must be created as " raise TypeError("new enumerations must be created as "
"`ClassName([mixin_type,] enum_type)`") "`ClassName([mixin_type,] enum_type)`")
# get correct mix-in type (either mix-in type of Enum subclass, or # get correct mix-in type (either mix-in type of Enum subclass, or
# first base if last base is Enum) # first base if last base is Enum)
@ -556,7 +549,7 @@ except ImportError:
N__new__, N__new__,
O__new__, O__new__,
E__new__, E__new__,
]: ]:
if method == '__member_new__': if method == '__member_new__':
classdict['__new__'] = target classdict['__new__'] = target
return None, False, True return None, False, True
@ -607,7 +600,7 @@ except ImportError:
None.__new__, None.__new__,
object.__new__, object.__new__,
Enum.__new__, Enum.__new__,
): ):
__new__ = target __new__ = target
break break
if __new__ is not None: if __new__ is not None:
@ -625,7 +618,6 @@ except ImportError:
return __new__, save_new, use_args return __new__, save_new, use_args
######################################################## ########################################################
# In order to support Python 2 and 3 with a single # In order to support Python 2 and 3 with a single
# codebase we have to create the Enum methods separately # codebase we have to create the Enum methods separately
@ -639,10 +631,10 @@ except ImportError:
# all enum instances are actually created during class construction # all enum instances are actually created during class construction
# without calling this method; this method is called by the metaclass' # without calling this method; this method is called by the metaclass'
# __call__ (i.e. Color(3) ), and by pickle # __call__ (i.e. Color(3) ), and by pickle
if type(value) is cls: if isinstance(value, cls):
# For lookups like Color(Color.red) # For lookups like Color(Color.red)
value = value.value value = value.value
#return value # return value
# by-value search for a matching enum member # by-value search for a matching enum member
# see if it's in the reverse mapping (for hashable values) # see if it's in the reverse mapping (for hashable values)
try: try:
@ -659,7 +651,7 @@ except ImportError:
def __repr__(self): def __repr__(self):
return "<%s.%s: %r>" % ( return "<%s.%s: %r>" % (
self.__class__.__name__, self._name_, self._value_) self.__class__.__name__, self._name_, self._value_)
temp_enum_dict['__repr__'] = __repr__ temp_enum_dict['__repr__'] = __repr__
del __repr__ del __repr__
@ -671,11 +663,11 @@ except ImportError:
if pyver >= 3.0: if pyver >= 3.0:
def __dir__(self): def __dir__(self):
added_behavior = [ added_behavior = [
m m
for cls in self.__class__.mro() for cls in self.__class__.mro()
for m in cls.__dict__ for m in cls.__dict__
if m[0] != '_' and m not in self._member_map_ if m[0] != '_' and m not in self._member_map_
] ]
return (['__class__', '__doc__', '__module__', ] + added_behavior) return (['__class__', '__doc__', '__module__', ] + added_behavior)
temp_enum_dict['__dir__'] = __dir__ temp_enum_dict['__dir__'] = __dir__
del __dir__ del __dir__
@ -697,14 +689,13 @@ except ImportError:
temp_enum_dict['__format__'] = __format__ temp_enum_dict['__format__'] = __format__
del __format__ del __format__
#################################### ####################################
# Python's less than 2.6 use __cmp__ # Python's less than 2.6 use __cmp__
if pyver < 2.6: if pyver < 2.6:
def __cmp__(self, other): def __cmp__(self, other):
if type(other) is self.__class__: if isinstance(other, self.__class__):
if self is other: if self is other:
return 0 return 0
return -1 return -1
@ -735,16 +726,15 @@ except ImportError:
temp_enum_dict['__gt__'] = __gt__ temp_enum_dict['__gt__'] = __gt__
del __gt__ del __gt__
def __eq__(self, other): def __eq__(self, other):
if type(other) is self.__class__: if isinstance(other, self.__class__):
return self is other return self is other
return NotImplemented return NotImplemented
temp_enum_dict['__eq__'] = __eq__ temp_enum_dict['__eq__'] = __eq__
del __eq__ del __eq__
def __ne__(self, other): def __ne__(self, other):
if type(other) is self.__class__: if isinstance(other, self.__class__):
return self is not other return self is not other
return NotImplemented return NotImplemented
temp_enum_dict['__ne__'] = __ne__ temp_enum_dict['__ne__'] = __ne__
@ -832,9 +822,9 @@ except ImportError:
duplicates.append((name, member.name)) duplicates.append((name, member.name))
if duplicates: if duplicates:
duplicate_names = ', '.join( duplicate_names = ', '.join(
["%s -> %s" % (alias, name) for (alias, name) in duplicates] ["%s -> %s" % (alias, name) for (alias, name) in duplicates]
) )
raise ValueError('duplicate names found in %r: %s' % raise ValueError('duplicate names found in %r: %s' %
(enumeration, duplicate_names) (enumeration, duplicate_names)
) )
return enumeration return enumeration

View File

@ -65,6 +65,8 @@ setup(
'sqlalchemy', 'sqlalchemy',
'sqlalchemy_utils', 'sqlalchemy_utils',
'mock', 'mock',
# Required for Django postgres fields testing
'psycopg2',
], ],
extras_require={ extras_require={
'django': [ 'django': [