diff --git a/.travis.yml b/.travis.yml index 544966db..c935ec0d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -24,6 +24,7 @@ install: - | if [ "$TEST_TYPE" = build ]; then pip install --download-cache $HOME/.cache/pip/ pytest pytest-cov coveralls six pytest-django django-filter sqlalchemy_utils + pip install --download-cache $HOME/.cache/pip psycopg2 > /dev/null 2>&1 pip install --download-cache $HOME/.cache/pip/ -e .[django] pip install --download-cache $HOME/.cache/pip/ -e .[sqlalchemy] pip install django==$DJANGO_VERSION diff --git a/docs/pages/docs/basic-types.md b/docs/pages/docs/basic-types.md index d63867e0..f8e251cd 100644 --- a/docs/pages/docs/basic-types.md +++ b/docs/pages/docs/basic-types.md @@ -16,6 +16,10 @@ Also the following Types are available: - `graphene.List` - `graphene.NonNull` +Graphene also provides custom scalars for Dates and JSON: +- `graphene.core.types.custom_scalars.DateTime` +- `graphene.core.types.custom_scalars.JSONString` + ## Shortcuts There are some shortcuts for building schemas more easily. diff --git a/examples/flask_sqlalchemy/app.py b/examples/flask_sqlalchemy/app.py index 0008ffa4..0bf2700f 100644 --- a/examples/flask_sqlalchemy/app.py +++ b/examples/flask_sqlalchemy/app.py @@ -1,8 +1,8 @@ from flask import Flask -from database import db_session, init_db -from schema import schema +from database import db_session, init_db from flask_graphql import GraphQL +from schema import schema app = Flask(__name__) app.debug = True diff --git a/examples/flask_sqlalchemy/database.py b/examples/flask_sqlalchemy/database.py index db9a83b4..b2a51789 100644 --- a/examples/flask_sqlalchemy/database.py +++ b/examples/flask_sqlalchemy/database.py @@ -1,6 +1,6 @@ from sqlalchemy import create_engine -from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import scoped_session, sessionmaker engine = create_engine('sqlite:///database.sqlite3', convert_unicode=True) db_session = scoped_session(sessionmaker(autocommit=False, diff --git a/examples/flask_sqlalchemy/models.py b/examples/flask_sqlalchemy/models.py index 7561bb29..0fffb51d 100644 --- a/examples/flask_sqlalchemy/models.py +++ b/examples/flask_sqlalchemy/models.py @@ -1,6 +1,7 @@ +from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, func +from sqlalchemy.orm import backref, relationship + from database import Base -from sqlalchemy import Column, DateTime, String, Integer, ForeignKey, func -from sqlalchemy.orm import relationship, backref class Department(Base): diff --git a/examples/flask_sqlalchemy/schema.py b/examples/flask_sqlalchemy/schema.py index 4010f5ca..d0de90f6 100644 --- a/examples/flask_sqlalchemy/schema.py +++ b/examples/flask_sqlalchemy/schema.py @@ -1,19 +1,23 @@ import graphene from graphene import relay -from graphene.contrib.sqlalchemy import SQLAlchemyNode, SQLAlchemyConnectionField -from models import Department as DepartmentModel, Employee as EmployeeModel +from graphene.contrib.sqlalchemy import (SQLAlchemyConnectionField, + SQLAlchemyNode) +from models import Department as DepartmentModel +from models import Employee as EmployeeModel schema = graphene.Schema() @schema.register class Department(SQLAlchemyNode): + class Meta: model = DepartmentModel @schema.register class Employee(SQLAlchemyNode): + class Meta: model = EmployeeModel diff --git a/graphene/contrib/django/compat.py b/graphene/contrib/django/compat.py index a5b444c7..4b1f55a6 100644 --- a/graphene/contrib/django/compat.py +++ b/graphene/contrib/django/compat.py @@ -1,15 +1,24 @@ from django.db import models + +class MissingType(object): + pass + try: UUIDField = models.UUIDField except AttributeError: # Improved compatibility for Django 1.6 - class UUIDField(object): - pass + UUIDField = MissingType try: from django.db.models.related import RelatedObject except: # Improved compatibility for Django 1.6 - class RelatedObject(object): - pass + RelatedObject = MissingType + + +try: + # Postgres fields are only available in Django 1.8+ + from django.contrib.postgres.fields import ArrayField, HStoreField, JSONField, RangeField +except ImportError: + ArrayField, HStoreField, JSONField, RangeField = (MissingType, ) * 4 diff --git a/graphene/contrib/django/converter.py b/graphene/contrib/django/converter.py index 754b38a6..589817af 100644 --- a/graphene/contrib/django/converter.py +++ b/graphene/contrib/django/converter.py @@ -1,9 +1,12 @@ from django.db import models -from ...core.types.scalars import ID, Boolean, Float, Int, String from ...core.classtypes.enum import Enum +from ...core.types.custom_scalars import DateTime, JSONString +from ...core.types.definitions import List +from ...core.types.scalars import ID, Boolean, Float, Int, String from ...utils import to_const -from .compat import RelatedObject, UUIDField +from .compat import (ArrayField, HStoreField, JSONField, RangeField, + RelatedObject, UUIDField) from .utils import get_related_model, import_single_dispatch singledispatch = import_single_dispatch() @@ -30,13 +33,13 @@ def convert_django_field(field): (field, field.__class__)) -@convert_django_field.register(models.DateField) @convert_django_field.register(models.CharField) @convert_django_field.register(models.TextField) @convert_django_field.register(models.EmailField) @convert_django_field.register(models.SlugField) @convert_django_field.register(models.URLField) @convert_django_field.register(models.GenericIPAddressField) +@convert_django_field.register(models.FileField) @convert_django_field.register(UUIDField) def convert_field_to_string(field): return String(description=field.help_text) @@ -72,6 +75,11 @@ def convert_field_to_float(field): return Float(description=field.help_text) +@convert_django_field.register(models.DateField) +def convert_date_to_string(field): + return DateTime(description=field.help_text) + + @convert_django_field.register(models.ManyToManyField) @convert_django_field.register(models.ManyToOneRel) @convert_django_field.register(models.ManyToManyRel) @@ -94,3 +102,21 @@ def convert_relatedfield_to_djangomodel(field): def convert_field_to_djangomodel(field): from .fields import DjangoModelField return DjangoModelField(get_related_model(field), description=field.help_text) + + +@convert_django_field.register(ArrayField) +def convert_postgres_array_to_list(field): + base_type = convert_django_field(field.base_field) + return List(base_type, description=field.help_text) + + +@convert_django_field.register(HStoreField) +@convert_django_field.register(JSONField) +def convert_posgres_field_to_string(field): + return JSONString(description=field.help_text) + + +@convert_django_field.register(RangeField) +def convert_posgres_range_to_string(field): + inner_type = convert_django_field(field.base_field) + return List(inner_type, description=field.help_text) diff --git a/graphene/contrib/django/debug/tests/test_query.py b/graphene/contrib/django/debug/tests/test_query.py index 328831b8..c512d6ad 100644 --- a/graphene/contrib/django/debug/tests/test_query.py +++ b/graphene/contrib/django/debug/tests/test_query.py @@ -1,7 +1,7 @@ import pytest import graphene -from graphene.contrib.django import DjangoNode, DjangoConnectionField +from graphene.contrib.django import DjangoConnectionField, DjangoNode from graphene.contrib.django.utils import DJANGO_FILTER_INSTALLED from ...tests.models import Reporter @@ -19,6 +19,7 @@ def test_should_query_field(): r2.save() class ReporterType(DjangoNode): + class Meta: model = Reporter diff --git a/graphene/contrib/django/filter/filterset.py b/graphene/contrib/django/filter/filterset.py index b618893d..70f776be 100644 --- a/graphene/contrib/django/filter/filterset.py +++ b/graphene/contrib/django/filter/filterset.py @@ -2,10 +2,10 @@ import six from django.conf import settings from django.db import models from django.utils.text import capfirst -from graphql_relay.node.node import from_global_id - from django_filters import Filter, MultipleChoiceFilter from django_filters.filterset import FilterSet, FilterSetMetaclass +from graphql_relay.node.node import from_global_id + from graphene.contrib.django.forms import (GlobalIDFormField, GlobalIDMultipleChoiceField) diff --git a/graphene/contrib/django/filter/tests/filters.py b/graphene/contrib/django/filter/tests/filters.py index 94c0dffe..bccd72d5 100644 --- a/graphene/contrib/django/filter/tests/filters.py +++ b/graphene/contrib/django/filter/tests/filters.py @@ -1,4 +1,5 @@ import django_filters + from graphene.contrib.django.tests.models import Article, Pet, Reporter diff --git a/graphene/contrib/django/tests/test_converter.py b/graphene/contrib/django/tests/test_converter.py index ade56390..d3ac6baa 100644 --- a/graphene/contrib/django/tests/test_converter.py +++ b/graphene/contrib/django/tests/test_converter.py @@ -1,12 +1,14 @@ +import pytest from django.db import models from py.test import raises import graphene -from graphene.contrib.django.converter import ( - convert_django_field, convert_django_field_with_choices) -from graphene.contrib.django.fields import (ConnectionOrListField, - DjangoModelField) +from graphene.core.types.custom_scalars import DateTime, JSONString +from ..compat import (ArrayField, HStoreField, JSONField, MissingType, + RangeField) +from ..converter import convert_django_field, convert_django_field_with_choices +from ..fields import ConnectionOrListField, DjangoModelField from .models import Article, Reporter @@ -26,7 +28,7 @@ def test_should_unknown_django_field_raise_exception(): def test_should_date_convert_string(): - assert_conversion(models.DateField, graphene.String) + assert_conversion(models.DateField, DateTime) def test_should_char_convert_string(): @@ -53,6 +55,14 @@ def test_should_ipaddress_convert_string(): assert_conversion(models.GenericIPAddressField, graphene.String) +def test_should_file_convert_string(): + assert_conversion(models.FileField, graphene.String) + + +def test_should_image_convert_string(): + assert_conversion(models.ImageField, graphene.String) + + def test_should_auto_convert_id(): assert_conversion(models.AutoField, graphene.ID, primary_key=True) @@ -136,3 +146,40 @@ def test_should_onetoone_convert_model(): def test_should_foreignkey_convert_model(): field = assert_conversion(models.ForeignKey, DjangoModelField, Article) assert field.type.model == Article + + +@pytest.mark.skipif(ArrayField is MissingType, + reason="ArrayField should exist") +def test_should_postgres_array_convert_list(): + field = assert_conversion(ArrayField, graphene.List, models.CharField(max_length=100)) + assert isinstance(field.type, graphene.List) + assert isinstance(field.type.of_type, graphene.String) + + +@pytest.mark.skipif(ArrayField is MissingType, + reason="ArrayField should exist") +def test_should_postgres_array_multiple_convert_list(): + field = assert_conversion(ArrayField, graphene.List, ArrayField(models.CharField(max_length=100))) + assert isinstance(field.type, graphene.List) + assert isinstance(field.type.of_type, graphene.List) + assert isinstance(field.type.of_type.of_type, graphene.String) + + +@pytest.mark.skipif(HStoreField is MissingType, + reason="HStoreField should exist") +def test_should_postgres_hstore_convert_string(): + assert_conversion(HStoreField, JSONString) + + +@pytest.mark.skipif(JSONField is MissingType, + reason="JSONField should exist") +def test_should_postgres_json_convert_string(): + assert_conversion(JSONField, JSONString) + + +@pytest.mark.skipif(RangeField is MissingType, + reason="RangeField should exist") +def test_should_postgres_range_convert_list(): + from django.contrib.postgres.fields import IntegerRangeField + field = assert_conversion(IntegerRangeField, graphene.List) + assert isinstance(field.type.of_type, graphene.Int) diff --git a/graphene/contrib/django/tests/test_query.py b/graphene/contrib/django/tests/test_query.py index 460c8e22..6d6b7540 100644 --- a/graphene/contrib/django/tests/test_query.py +++ b/graphene/contrib/django/tests/test_query.py @@ -1,10 +1,14 @@ +import datetime + import pytest +from django.db import models from py.test import raises import graphene from graphene import relay -from graphene.contrib.django import DjangoNode, DjangoObjectType +from ..compat import MissingType, RangeField +from ..types import DjangoNode, DjangoObjectType from .models import Article, Reporter pytestmark = pytest.mark.django_db @@ -62,6 +66,57 @@ def test_should_query_well(): assert result.data == expected +@pytest.mark.skipif(RangeField is MissingType, + reason="RangeField should exist") +def test_should_query_postgres_fields(): + from django.contrib.postgres.fields import IntegerRangeField, ArrayField, JSONField, HStoreField + + class Event(models.Model): + ages = IntegerRangeField(help_text='The age ranges') + data = JSONField(help_text='Data') + store = HStoreField() + tags = ArrayField(models.CharField(max_length=50)) + + class EventType(DjangoObjectType): + + class Meta: + model = Event + + class Query(graphene.ObjectType): + event = graphene.Field(EventType) + + def resolve_event(self, *args, **kwargs): + return Event( + ages=(0, 10), + data={'angry_babies': True}, + store={'h': 'store'}, + tags=['child', 'angry', 'babies'] + ) + + schema = graphene.Schema(query=Query) + query = ''' + query myQuery { + event { + ages + tags + data + store + } + } + ''' + expected = { + 'event': { + 'ages': [0, 10], + 'tags': ['child', 'angry', 'babies'], + 'data': '{"angry_babies": true}', + 'store': '{"h": "store"}', + }, + } + result = schema.execute(query) + assert not result.errors + assert result.data == expected + + def test_should_node(): class ReporterNode(DjangoNode): @@ -82,7 +137,7 @@ def test_should_node(): @classmethod def get_node(cls, id, info): - return ArticleNode(Article(id=1, headline='Article node')) + return ArticleNode(Article(id=1, headline='Article node', pub_date=datetime.date(2002, 3, 11))) class Query(graphene.ObjectType): node = relay.NodeField() @@ -115,6 +170,7 @@ def test_should_node(): } ... on ArticleNode { headline + pubDate } } } @@ -135,7 +191,8 @@ def test_should_node(): }, 'myArticle': { 'id': 'QXJ0aWNsZU5vZGU6MQ==', - 'headline': 'Article node' + 'headline': 'Article node', + 'pubDate': '2002-03-11', } } schema = graphene.Schema(query=Query) diff --git a/graphene/contrib/sqlalchemy/converter.py b/graphene/contrib/sqlalchemy/converter.py index 8a2b0cbe..540fcdd0 100644 --- a/graphene/contrib/sqlalchemy/converter.py +++ b/graphene/contrib/sqlalchemy/converter.py @@ -1,17 +1,17 @@ from singledispatch import singledispatch - from sqlalchemy import types from sqlalchemy.orm import interfaces + +from ...core.classtypes.enum import Enum +from ...core.types.scalars import ID, Boolean, Float, Int, String +from .fields import ConnectionOrListField, SQLAlchemyModelField + try: from sqlalchemy_utils.types.choice import ChoiceType except ImportError: class ChoiceType(object): pass -from ...core.classtypes.enum import Enum -from ...core.types.scalars import ID, Boolean, Float, Int, String -from .fields import ConnectionOrListField, SQLAlchemyModelField - def convert_sqlalchemy_relationship(relationship): direction = relationship.direction diff --git a/graphene/contrib/sqlalchemy/fields.py b/graphene/contrib/sqlalchemy/fields.py index b5f4e974..dc3eb66b 100644 --- a/graphene/contrib/sqlalchemy/fields.py +++ b/graphene/contrib/sqlalchemy/fields.py @@ -4,7 +4,7 @@ from ...core.types.base import FieldType from ...core.types.definitions import List from ...relay import ConnectionField from ...relay.utils import is_node -from .utils import get_type_for_model, maybe_query, get_query +from .utils import get_query, get_type_for_model, maybe_query class DefaultQuery(object): diff --git a/graphene/contrib/sqlalchemy/tests/test_converter.py b/graphene/contrib/sqlalchemy/tests/test_converter.py index 53cbcab8..7658ed79 100644 --- a/graphene/contrib/sqlalchemy/tests/test_converter.py +++ b/graphene/contrib/sqlalchemy/tests/test_converter.py @@ -1,13 +1,13 @@ from py.test import raises +from sqlalchemy import Column, Table, types +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy_utils.types.choice import ChoiceType import graphene from graphene.contrib.sqlalchemy.converter import (convert_sqlalchemy_column, convert_sqlalchemy_relationship) from graphene.contrib.sqlalchemy.fields import (ConnectionOrListField, SQLAlchemyModelField) -from sqlalchemy import Table, Column, types -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy_utils.types.choice import ChoiceType from .models import Article, Pet, Reporter diff --git a/graphene/contrib/sqlalchemy/tests/test_query.py b/graphene/contrib/sqlalchemy/tests/test_query.py index 611ab2f6..5f970488 100644 --- a/graphene/contrib/sqlalchemy/tests/test_query.py +++ b/graphene/contrib/sqlalchemy/tests/test_query.py @@ -1,12 +1,13 @@ import pytest - -import graphene -from graphene import relay -from graphene.contrib.sqlalchemy import SQLAlchemyObjectType, SQLAlchemyNode, SQLAlchemyConnectionField from sqlalchemy import create_engine from sqlalchemy.orm import scoped_session, sessionmaker -from .models import Base, Reporter, Article +import graphene +from graphene import relay +from graphene.contrib.sqlalchemy import (SQLAlchemyConnectionField, + SQLAlchemyNode, SQLAlchemyObjectType) + +from .models import Article, Base, Reporter db = create_engine('sqlite:///test_sqlalchemy.sqlite3') diff --git a/graphene/contrib/sqlalchemy/tests/test_utils.py b/graphene/contrib/sqlalchemy/tests/test_utils.py index 2874ffaa..2925f016 100644 --- a/graphene/contrib/sqlalchemy/tests/test_utils.py +++ b/graphene/contrib/sqlalchemy/tests/test_utils.py @@ -1,4 +1,4 @@ -from graphene import Schema, ObjectType, String +from graphene import ObjectType, Schema, String from ..utils import get_session diff --git a/graphene/contrib/sqlalchemy/types.py b/graphene/contrib/sqlalchemy/types.py index 8f70d245..64b09afd 100644 --- a/graphene/contrib/sqlalchemy/types.py +++ b/graphene/contrib/sqlalchemy/types.py @@ -1,7 +1,6 @@ import inspect import six - from sqlalchemy.inspection import inspect as sqlalchemyinspect from sqlalchemy.orm.exc import NoResultFound @@ -10,7 +9,7 @@ from ...relay.types import Connection, Node, NodeMeta from .converter import (convert_sqlalchemy_column, convert_sqlalchemy_relationship) from .options import SQLAlchemyOptions -from .utils import is_mapped, get_query +from .utils import get_query, is_mapped class SQLAlchemyObjectTypeMeta(ObjectTypeMeta): diff --git a/graphene/core/types/custom_scalars.py b/graphene/core/types/custom_scalars.py new file mode 100644 index 00000000..72f0f8b0 --- /dev/null +++ b/graphene/core/types/custom_scalars.py @@ -0,0 +1,41 @@ +import datetime +import json + +from graphql.core.language import ast + +from ...core.classtypes.scalar import Scalar + + +class JSONString(Scalar): + '''JSON String''' + + @staticmethod + def serialize(dt): + return json.dumps(dt) + + @staticmethod + def parse_literal(node): + if isinstance(node, ast.StringValue): + return json.dumps(node.value) + + @staticmethod + def parse_value(value): + return json.dumps(value) + + +class DateTime(Scalar): + '''DateTime in ISO 8601 format''' + + @staticmethod + def serialize(dt): + return dt.isoformat() + + @staticmethod + def parse_literal(node): + if isinstance(node, ast.StringValue): + return datetime.datetime.strptime( + node.value, "%Y-%m-%dT%H:%M:%S.%f") + + @staticmethod + def parse_value(value): + return datetime.datetime.strptime(value, "%Y-%m-%dT%H:%M:%S.%f") diff --git a/graphene/utils/enum.py b/graphene/utils/enum.py index 3a5a5ad9..27a3043c 100644 --- a/graphene/utils/enum.py +++ b/graphene/utils/enum.py @@ -49,6 +49,7 @@ except ImportError: class's __getattr__ method; this is done by raising AttributeError. """ + def __init__(self, fget=None): self.fget = fget @@ -63,14 +64,12 @@ except ImportError: def __delete__(self, instance): raise AttributeError("can't delete attribute") - def _is_descriptor(obj): """Returns True if obj is a descriptor, False otherwise.""" return ( - hasattr(obj, '__get__') or - hasattr(obj, '__set__') or - hasattr(obj, '__delete__')) - + hasattr(obj, '__get__') or + hasattr(obj, '__set__') or + hasattr(obj, '__delete__')) def _is_dunder(name): """Returns True if a __dunder__ name, False otherwise.""" @@ -79,7 +78,6 @@ except ImportError: name[-3:-2] != '_' and len(name) > 4) - def _is_sunder(name): """Returns True if a _sunder_ name, False otherwise.""" return (name[0] == name[-1] == '_' and @@ -87,15 +85,14 @@ except ImportError: name[-2:-1] != '_' and len(name) > 2) - def _make_class_unpicklable(cls): """Make the given class un-picklable.""" + def _break_on_call_reduce(self, protocol=None): raise TypeError('%r cannot be pickled' % self) cls.__reduce_ex__ = _break_on_call_reduce cls.__module__ = '' - class _EnumDict(dict): """Track enum member order and ensure member names are not reused. @@ -103,6 +100,7 @@ except ImportError: enumeration member names. """ + def __init__(self): super(_EnumDict, self).__init__() self._member_names = [] @@ -124,7 +122,7 @@ except ImportError: """ if pyver >= 3.0 and key == '__order__': - return + return if _is_sunder(key): raise ValueError('_names_ are reserved for future Enum use') elif _is_dunder(key): @@ -139,13 +137,11 @@ except ImportError: self._member_names.append(key) super(_EnumDict, self).__setitem__(key, value) - # Dummy value for Enum as EnumMeta explicity checks for it, but of course until # EnumMeta finishes running the first time the Enum class doesn't exist. This # is also why there are checks in EnumMeta like `if Enum is not None` Enum = None - class EnumMeta(type): """Metaclass for Enum""" @classmethod @@ -157,7 +153,7 @@ except ImportError: # cannot be mixed with other types (int, float, etc.) if it has an # inherited __new__ unless a new __new__ is defined (or the resulting # class will fail). - if type(classdict) is dict: + if isinstance(classdict, dict): original_dict = classdict classdict = _EnumDict() for k, v in original_dict.items(): @@ -165,7 +161,7 @@ except ImportError: member_type, first_enum = metacls._get_mixins_(bases) __new__, save_new, use_args = metacls._find_new_(classdict, member_type, - first_enum) + first_enum) # save enum items into separate mapping so they don't get baked into # the new class members = dict((k, classdict[k]) for k in classdict._member_names) @@ -259,7 +255,6 @@ except ImportError: except TypeError: pass - # If a custom type is mixed into the Enum, and it does not know how # to pickle itself, pickle.dumps will succeed but pickle.loads will # fail. Rather than have the error show up later and possibly far @@ -274,17 +269,16 @@ except ImportError: if '__reduce_ex__' not in classdict: if member_type is not object: methods = ('__getnewargs_ex__', '__getnewargs__', - '__reduce_ex__', '__reduce__') + '__reduce_ex__', '__reduce__') if not any(m in member_type.__dict__ for m in methods): _make_class_unpicklable(enum_class) unpicklable = True - # double check that repr and friends are not the mixin's or various # things break (such as pickle) for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'): class_method = getattr(enum_class, name) - obj_method = getattr(member_type, name, None) + getattr(member_type, name, None) enum_method = getattr(first_enum, name, None) if name not in classdict and class_method is not enum_method: if name == '__reduce_ex__' and unpicklable: @@ -310,7 +304,7 @@ except ImportError: '__eq__', '__ne__', '__hash__', - ): + ): setattr(enum_class, method, getattr(int, method)) # replace any other __new__ with our own (as long as Enum is not None, @@ -352,7 +346,7 @@ except ImportError: # (see issue19025). if attr in cls._member_map_: raise AttributeError( - "%s: cannot delete Enum member." % cls.__name__) + "%s: cannot delete Enum member." % cls.__name__) super(EnumMeta, cls).__delattr__(attr) def __dir__(self): @@ -444,7 +438,7 @@ except ImportError: if isinstance(names, basestring): names = names.replace(',', ' ').split() if isinstance(names, (tuple, list)) and isinstance(names[0], basestring): - names = [(e, i+start) for (i, e) in enumerate(names)] + names = [(e, i + start) for (i, e) in enumerate(names)] # Here, names is either an iterable of (name, value) or a mapping. item = None # in case names is empty @@ -485,20 +479,19 @@ except ImportError: if not bases or Enum is None: return object, Enum - # double check that we are not subclassing a class with existing # enumeration members; while we're at it, see if any other data # type has been mixed in so we can use the correct __new__ member_type = first_enum = None for base in bases: - if (base is not Enum and + if (base is not Enum and issubclass(base, Enum) and base._member_names_): raise TypeError("Cannot extend enumerations") # base is now the last base in bases if not issubclass(base, Enum): raise TypeError("new enumerations must be created as " - "`ClassName([mixin_type,] enum_type)`") + "`ClassName([mixin_type,] enum_type)`") # get correct mix-in type (either mix-in type of Enum subclass, or # first base if last base is Enum) @@ -556,7 +549,7 @@ except ImportError: N__new__, O__new__, E__new__, - ]: + ]: if method == '__member_new__': classdict['__new__'] = target return None, False, True @@ -607,7 +600,7 @@ except ImportError: None.__new__, object.__new__, Enum.__new__, - ): + ): __new__ = target break if __new__ is not None: @@ -625,7 +618,6 @@ except ImportError: return __new__, save_new, use_args - ######################################################## # In order to support Python 2 and 3 with a single # codebase we have to create the Enum methods separately @@ -639,10 +631,10 @@ except ImportError: # all enum instances are actually created during class construction # without calling this method; this method is called by the metaclass' # __call__ (i.e. Color(3) ), and by pickle - if type(value) is cls: + if isinstance(value, cls): # For lookups like Color(Color.red) value = value.value - #return value + # return value # by-value search for a matching enum member # see if it's in the reverse mapping (for hashable values) try: @@ -659,7 +651,7 @@ except ImportError: def __repr__(self): return "<%s.%s: %r>" % ( - self.__class__.__name__, self._name_, self._value_) + self.__class__.__name__, self._name_, self._value_) temp_enum_dict['__repr__'] = __repr__ del __repr__ @@ -671,11 +663,11 @@ except ImportError: if pyver >= 3.0: def __dir__(self): added_behavior = [ - m - for cls in self.__class__.mro() - for m in cls.__dict__ - if m[0] != '_' and m not in self._member_map_ - ] + m + for cls in self.__class__.mro() + for m in cls.__dict__ + if m[0] != '_' and m not in self._member_map_ + ] return (['__class__', '__doc__', '__module__', ] + added_behavior) temp_enum_dict['__dir__'] = __dir__ del __dir__ @@ -697,14 +689,13 @@ except ImportError: temp_enum_dict['__format__'] = __format__ del __format__ - #################################### # Python's less than 2.6 use __cmp__ if pyver < 2.6: def __cmp__(self, other): - if type(other) is self.__class__: + if isinstance(other, self.__class__): if self is other: return 0 return -1 @@ -735,16 +726,15 @@ except ImportError: temp_enum_dict['__gt__'] = __gt__ del __gt__ - def __eq__(self, other): - if type(other) is self.__class__: + if isinstance(other, self.__class__): return self is other return NotImplemented temp_enum_dict['__eq__'] = __eq__ del __eq__ def __ne__(self, other): - if type(other) is self.__class__: + if isinstance(other, self.__class__): return self is not other return NotImplemented temp_enum_dict['__ne__'] = __ne__ @@ -832,9 +822,9 @@ except ImportError: duplicates.append((name, member.name)) if duplicates: duplicate_names = ', '.join( - ["%s -> %s" % (alias, name) for (alias, name) in duplicates] - ) + ["%s -> %s" % (alias, name) for (alias, name) in duplicates] + ) raise ValueError('duplicate names found in %r: %s' % - (enumeration, duplicate_names) - ) + (enumeration, duplicate_names) + ) return enumeration diff --git a/setup.py b/setup.py index 88af1f62..a4923e8d 100644 --- a/setup.py +++ b/setup.py @@ -65,6 +65,8 @@ setup( 'sqlalchemy', 'sqlalchemy_utils', 'mock', + # Required for Django postgres fields testing + 'psycopg2', ], extras_require={ 'django': [