Use black in dataclasses

This commit is contained in:
Syrus 2020-03-14 14:25:22 -07:00
parent 015bfa3452
commit c17103005c

View File

@ -9,21 +9,21 @@ import types
import inspect
import keyword
__all__ = ['dataclass',
'field',
'Field',
'FrozenInstanceError',
'InitVar',
'MISSING',
__all__ = [
"dataclass",
"field",
"Field",
"FrozenInstanceError",
"InitVar",
"MISSING",
# Helper functions.
'fields',
'asdict',
'astuple',
'make_dataclass',
'replace',
'is_dataclass',
]
"fields",
"asdict",
"astuple",
"make_dataclass",
"replace",
"is_dataclass",
]
# Conditions for adding methods. The boxes indicate what action the
# dataclass decorator takes. For all of these tables, when I talk
@ -152,20 +152,26 @@ __all__ = ['dataclass',
# Raised when an attempt is made to modify a frozen class.
class FrozenInstanceError(AttributeError): pass
class FrozenInstanceError(AttributeError):
pass
# A sentinel object for default values to signal that a default
# factory will be used. This is given a nice repr() which will appear
# in the function signature of dataclasses' constructors.
class _HAS_DEFAULT_FACTORY_CLASS:
def __repr__(self):
return '<factory>'
return "<factory>"
_HAS_DEFAULT_FACTORY = _HAS_DEFAULT_FACTORY_CLASS()
# A sentinel object to detect if a parameter is supplied or not. Use
# a class to give it a better repr.
class _MISSING_TYPE:
pass
MISSING = _MISSING_TYPE()
# Since most per-field metadata will be unused, create an empty
@ -176,33 +182,38 @@ _EMPTY_METADATA = types.MappingProxyType({})
class _FIELD_BASE:
def __init__(self, name):
self.name = name
def __repr__(self):
return self.name
_FIELD = _FIELD_BASE('_FIELD')
_FIELD_CLASSVAR = _FIELD_BASE('_FIELD_CLASSVAR')
_FIELD_INITVAR = _FIELD_BASE('_FIELD_INITVAR')
_FIELD = _FIELD_BASE("_FIELD")
_FIELD_CLASSVAR = _FIELD_BASE("_FIELD_CLASSVAR")
_FIELD_INITVAR = _FIELD_BASE("_FIELD_INITVAR")
# The name of an attribute on the class where we store the Field
# objects. Also used to check if a class is a Data Class.
_FIELDS = '__dataclass_fields__'
_FIELDS = "__dataclass_fields__"
# The name of an attribute on the class that stores the parameters to
# @dataclass.
_PARAMS = '__dataclass_params__'
_PARAMS = "__dataclass_params__"
# The name of the function, that if it exists, is called at the end of
# __init__.
_POST_INIT_NAME = '__post_init__'
_POST_INIT_NAME = "__post_init__"
# String regex that string annotations for ClassVar or InitVar must match.
# Allows "identifier.identifier[" or "identifier[".
# https://bugs.python.org/issue33453 for details.
_MODULE_IDENTIFIER_RE = re.compile(r'^(?:\s*(\w+)\s*\.)?\s*(\w+)')
_MODULE_IDENTIFIER_RE = re.compile(r"^(?:\s*(\w+)\s*\.)?\s*(\w+)")
class _InitVarMeta(type):
def __getitem__(self, params):
return self
class InitVar(metaclass=_InitVarMeta):
pass
@ -218,20 +229,20 @@ class InitVar(metaclass=_InitVarMeta):
# When cls._FIELDS is filled in with a list of Field objects, the name
# and type fields will have been populated.
class Field:
__slots__ = ('name',
'type',
'default',
'default_factory',
'repr',
'hash',
'init',
'compare',
'metadata',
'_field_type', # Private: not to be used by user code.
__slots__ = (
"name",
"type",
"default",
"default_factory",
"repr",
"hash",
"init",
"compare",
"metadata",
"_field_type", # Private: not to be used by user code.
)
def __init__(self, default, default_factory, init, repr, hash, compare,
metadata):
def __init__(self, default, default_factory, init, repr, hash, compare, metadata):
self.name = None
self.type = None
self.default = default
@ -240,24 +251,28 @@ class Field:
self.repr = repr
self.hash = hash
self.compare = compare
self.metadata = (_EMPTY_METADATA
if metadata is None or len(metadata) == 0 else
types.MappingProxyType(metadata))
self.metadata = (
_EMPTY_METADATA
if metadata is None or len(metadata) == 0
else types.MappingProxyType(metadata)
)
self._field_type = None
def __repr__(self):
return ('Field('
f'name={self.name!r},'
f'type={self.type!r},'
f'default={self.default!r},'
f'default_factory={self.default_factory!r},'
f'init={self.init!r},'
f'repr={self.repr!r},'
f'hash={self.hash!r},'
f'compare={self.compare!r},'
f'metadata={self.metadata!r},'
f'_field_type={self._field_type}'
')')
return (
"Field("
f"name={self.name!r},"
f"type={self.type!r},"
f"default={self.default!r},"
f"default_factory={self.default_factory!r},"
f"init={self.init!r},"
f"repr={self.repr!r},"
f"hash={self.hash!r},"
f"compare={self.compare!r},"
f"metadata={self.metadata!r},"
f"_field_type={self._field_type}"
")"
)
# This is used to support the PEP 487 __set_name__ protocol in the
# case where we're using a field that contains a descriptor as a
@ -268,7 +283,7 @@ class Field:
# with the default value, so the end result is a descriptor that
# had __set_name__ called on it at the right time.
def __set_name__(self, owner, name):
func = getattr(type(self.default), '__set_name__', None)
func = getattr(type(self.default), "__set_name__", None)
if func:
# There is a __set_name__ method on the descriptor, call
# it.
@ -276,12 +291,13 @@ class Field:
class _DataclassParams:
__slots__ = ('init',
'repr',
'eq',
'order',
'unsafe_hash',
'frozen',
__slots__ = (
"init",
"repr",
"eq",
"order",
"unsafe_hash",
"frozen",
)
def __init__(self, init, repr, eq, order, unsafe_hash, frozen):
@ -293,21 +309,31 @@ class _DataclassParams:
self.frozen = frozen
def __repr__(self):
return ('_DataclassParams('
f'init={self.init!r},'
f'repr={self.repr!r},'
f'eq={self.eq!r},'
f'order={self.order!r},'
f'unsafe_hash={self.unsafe_hash!r},'
f'frozen={self.frozen!r}'
')')
return (
"_DataclassParams("
f"init={self.init!r},"
f"repr={self.repr!r},"
f"eq={self.eq!r},"
f"order={self.order!r},"
f"unsafe_hash={self.unsafe_hash!r},"
f"frozen={self.frozen!r}"
")"
)
# This function is used instead of exposing Field creation directly,
# so that a type checker can be told (via overloads) that this is a
# function whose type depends on its parameters.
def field(*, default=MISSING, default_factory=MISSING, init=True, repr=True,
hash=None, compare=True, metadata=None):
def field(
*,
default=MISSING,
default_factory=MISSING,
init=True,
repr=True,
hash=None,
compare=True,
metadata=None,
):
"""Return an object to identify dataclass fields.
default is the default value of the field. default_factory is a
@ -323,9 +349,8 @@ def field(*, default=MISSING, default_factory=MISSING, init=True, repr=True,
"""
if default is not MISSING and default_factory is not MISSING:
raise ValueError('cannot specify both default and default_factory')
return Field(default, default_factory, init, repr, hash, compare,
metadata)
raise ValueError("cannot specify both default and default_factory")
return Field(default, default_factory, init, repr, hash, compare, metadata)
def _tuple_str(obj_name, fields):
@ -335,27 +360,26 @@ def _tuple_str(obj_name, fields):
# Special case for the 0-tuple.
if not fields:
return '()'
return "()"
# Note the trailing comma, needed if this turns out to be a 1-tuple.
return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)'
def _create_fn(name, args, body, *, globals=None, locals=None,
return_type=MISSING):
def _create_fn(name, args, body, *, globals=None, locals=None, return_type=MISSING):
# Note that we mutate locals when exec() is called. Caller
# beware! The only callers are internal to this module, so no
# worries about external callers.
if locals is None:
locals = {}
return_annotation = ''
return_annotation = ""
if return_type is not MISSING:
locals['_return_type'] = return_type
return_annotation = '->_return_type'
args = ','.join(args)
body = '\n'.join(f' {b}' for b in body)
locals["_return_type"] = return_type
return_annotation = "->_return_type"
args = ",".join(args)
body = "\n".join(f" {b}" for b in body)
# Compute the text of the entire function.
txt = f'def {name}({args}){return_annotation}:\n{body}'
txt = f"def {name}({args}){return_annotation}:\n{body}"
exec(txt, globals, locals)
return locals[name]
@ -369,23 +393,25 @@ def _field_assign(frozen, name, value, self_name):
# self_name is what "self" is called in this function: don't
# hard-code "self", since that might be a field name.
if frozen:
return f'object.__setattr__({self_name},{name!r},{value})'
return f'{self_name}.{name}={value}'
return f"object.__setattr__({self_name},{name!r},{value})"
return f"{self_name}.{name}={value}"
def _field_init(f, frozen, globals, self_name):
# Return the text of the line in the body of __init__ that will
# initialize this field.
default_name = f'_dflt_{f.name}'
default_name = f"_dflt_{f.name}"
if f.default_factory is not MISSING:
if f.init:
# This field has a default factory. If a parameter is
# given, use it. If not, call the factory.
globals[default_name] = f.default_factory
value = (f'{default_name}() '
f'if {f.name} is _HAS_DEFAULT_FACTORY '
f'else {f.name}')
value = (
f"{default_name}() "
f"if {f.name} is _HAS_DEFAULT_FACTORY "
f"else {f.name}"
)
else:
# This is a field that's not in the __init__ params, but
# has a default factory function. It needs to be
@ -402,7 +428,7 @@ def _field_init(f, frozen, globals, self_name):
# (which, after all, is why we have a factory function!).
globals[default_name] = f.default_factory
value = f'{default_name}()'
value = f"{default_name}()"
else:
# No default factory.
if f.init:
@ -435,15 +461,15 @@ def _init_param(f):
if f.default is MISSING and f.default_factory is MISSING:
# There's no default, and no default_factory, just output the
# variable name and type.
default = ''
default = ""
elif f.default is not MISSING:
# There's a default, this will be the name that's used to look
# it up.
default = f'=_dflt_{f.name}'
default = f"=_dflt_{f.name}"
elif f.default_factory is not MISSING:
# There's a factory function. Set a marker.
default = '=_HAS_DEFAULT_FACTORY'
return f'{f.name}:_type_{f.name}{default}'
default = "=_HAS_DEFAULT_FACTORY"
return f"{f.name}:_type_{f.name}{default}"
def _init_fn(fields, frozen, has_post_init, self_name):
@ -461,11 +487,11 @@ def _init_fn(fields, frozen, has_post_init, self_name):
if not (f.default is MISSING and f.default_factory is MISSING):
seen_default = True
elif seen_default:
raise TypeError(f'non-default argument {f.name!r} '
'follows default argument')
raise TypeError(
f"non-default argument {f.name!r} " "follows default argument"
)
globals = {'MISSING': MISSING,
'_HAS_DEFAULT_FACTORY': _HAS_DEFAULT_FACTORY}
globals = {"MISSING": MISSING, "_HAS_DEFAULT_FACTORY": _HAS_DEFAULT_FACTORY}
body_lines = []
for f in fields:
@ -477,54 +503,66 @@ def _init_fn(fields, frozen, has_post_init, self_name):
# Does this class have a post-init function?
if has_post_init:
params_str = ','.join(f.name for f in fields
if f._field_type is _FIELD_INITVAR)
body_lines.append(f'{self_name}.{_POST_INIT_NAME}({params_str})')
params_str = ",".join(f.name for f in fields if f._field_type is _FIELD_INITVAR)
body_lines.append(f"{self_name}.{_POST_INIT_NAME}({params_str})")
# If no body lines, use 'pass'.
if not body_lines:
body_lines = ['pass']
body_lines = ["pass"]
locals = {f'_type_{f.name}': f.type for f in fields}
return _create_fn('__init__',
locals = {f"_type_{f.name}": f.type for f in fields}
return _create_fn(
"__init__",
[self_name] + [_init_param(f) for f in fields if f.init],
body_lines,
locals=locals,
globals=globals,
return_type=None)
return_type=None,
)
def _repr_fn(fields):
return _create_fn('__repr__',
('self',),
['return self.__class__.__qualname__ + f"(' +
', '.join([f"{f.name}={{self.{f.name}!r}}"
for f in fields]) +
')"'])
return _create_fn(
"__repr__",
("self",),
[
'return self.__class__.__qualname__ + f"('
+ ", ".join([f"{f.name}={{self.{f.name}!r}}" for f in fields])
+ ')"'
],
)
def _frozen_get_del_attr(cls, fields):
# XXX: globals is modified on the first call to _create_fn, then
# the modified version is used in the second call. Is this okay?
globals = {'cls': cls,
'FrozenInstanceError': FrozenInstanceError}
globals = {"cls": cls, "FrozenInstanceError": FrozenInstanceError}
if fields:
fields_str = '(' + ','.join(repr(f.name) for f in fields) + ',)'
fields_str = "(" + ",".join(repr(f.name) for f in fields) + ",)"
else:
# Special case for the zero-length tuple.
fields_str = '()'
return (_create_fn('__setattr__',
('self', 'name', 'value'),
(f'if type(self) is cls or name in {fields_str}:',
fields_str = "()"
return (
_create_fn(
"__setattr__",
("self", "name", "value"),
(
f"if type(self) is cls or name in {fields_str}:",
' raise FrozenInstanceError(f"cannot assign to field {name!r}")',
f'super(cls, self).__setattr__(name, value)'),
globals=globals),
_create_fn('__delattr__',
('self', 'name'),
(f'if type(self) is cls or name in {fields_str}:',
f"super(cls, self).__setattr__(name, value)",
),
globals=globals,
),
_create_fn(
"__delattr__",
("self", "name"),
(
f"if type(self) is cls or name in {fields_str}:",
' raise FrozenInstanceError(f"cannot delete field {name!r}")',
f'super(cls, self).__delattr__(name)'),
globals=globals),
f"super(cls, self).__delattr__(name)",
),
globals=globals,
),
)
@ -534,18 +572,20 @@ def _cmp_fn(name, op, self_tuple, other_tuple):
# '(self.x,self.y)' and other_tuple is the string
# '(other.x,other.y)'.
return _create_fn(name,
('self', 'other'),
[ 'if other.__class__ is self.__class__:',
f' return {self_tuple}{op}{other_tuple}',
'return NotImplemented'])
return _create_fn(
name,
("self", "other"),
[
"if other.__class__ is self.__class__:",
f" return {self_tuple}{op}{other_tuple}",
"return NotImplemented",
],
)
def _hash_fn(fields):
self_tuple = _tuple_str('self', fields)
return _create_fn('__hash__',
('self',),
[f'return hash({self_tuple})'])
self_tuple = _tuple_str("self", fields)
return _create_fn("__hash__", ("self",), [f"return hash({self_tuple})"])
def _is_classvar(a_type, typing):
@ -658,12 +698,12 @@ def _get_field(cls, a_name, a_type):
# annotation to be a ClassVar. So, only look for ClassVar if
# typing has been imported by any module (not necessarily cls's
# module).
typing = sys.modules.get('typing')
typing = sys.modules.get("typing")
if typing:
if (_is_classvar(a_type, typing)
or (isinstance(f.type, str)
and _is_type(f.type, cls, typing, typing.ClassVar,
_is_classvar))):
if _is_classvar(a_type, typing) or (
isinstance(f.type, str)
and _is_type(f.type, cls, typing, typing.ClassVar, _is_classvar)
):
f._field_type = _FIELD_CLASSVAR
# If the type is InitVar, or if it's a matching string annotation,
@ -672,10 +712,10 @@ def _get_field(cls, a_name, a_type):
# The module we're checking against is the module we're
# currently in (dataclasses.py).
dataclasses = sys.modules[__name__]
if (_is_initvar(a_type, dataclasses)
or (isinstance(f.type, str)
and _is_type(f.type, cls, dataclasses, dataclasses.InitVar,
_is_initvar))):
if _is_initvar(a_type, dataclasses) or (
isinstance(f.type, str)
and _is_type(f.type, cls, dataclasses, dataclasses.InitVar, _is_initvar)
):
f._field_type = _FIELD_INITVAR
# Validations for individual fields. This is delayed until now,
@ -685,8 +725,7 @@ def _get_field(cls, a_name, a_type):
# Special restrictions for ClassVar and InitVar.
if f._field_type in (_FIELD_CLASSVAR, _FIELD_INITVAR):
if f.default_factory is not MISSING:
raise TypeError(f'field {f.name} cannot have a '
'default factory')
raise TypeError(f"field {f.name} cannot have a " "default factory")
# Should I check for other field settings? default_factory
# seems the most serious to check for. Maybe add others. For
# example, how about init=False (or really,
@ -695,8 +734,10 @@ def _get_field(cls, a_name, a_type):
# For real fields, disallow mutable defaults for known types.
if f._field_type is _FIELD and isinstance(f.default, (list, dict, set)):
raise ValueError(f'mutable default {type(f.default)} for field '
f'{f.name} is not allowed: use default_factory')
raise ValueError(
f"mutable default {type(f.default)} for field "
f"{f.name} is not allowed: use default_factory"
)
return f
@ -715,17 +756,20 @@ def _set_new_attribute(cls, name, value):
# take. The common case is to do nothing, so instead of providing a
# function that is a no-op, use None to signify that.
def _hash_set_none(cls, fields):
return None
def _hash_add(cls, fields):
flds = [f for f in fields if (f.compare if f.hash is None else f.hash)]
return _hash_fn(flds)
def _hash_exception(cls, fields):
# Raise an exception.
raise TypeError(f'Cannot overwrite attribute __hash__ '
f'in class {cls.__name__}')
raise TypeError(f"Cannot overwrite attribute __hash__ " f"in class {cls.__name__}")
#
# +-------------------------------------- unsafe_hash?
@ -736,23 +780,24 @@ def _hash_exception(cls, fields):
# | | | | +------- action
# | | | | |
# v v v v v
_hash_action = {(False, False, False, False): None,
(False, False, False, True ): None,
_hash_action = {
(False, False, False, False): None,
(False, False, False, True): None,
(False, False, True, False): None,
(False, False, True, True ): None,
(False, False, True, True): None,
(False, True, False, False): _hash_set_none,
(False, True, False, True ): None,
(False, True, False, True): None,
(False, True, True, False): _hash_add,
(False, True, True, True ): None,
(False, True, True, True): None,
(True, False, False, False): _hash_add,
(True, False, False, True ): _hash_exception,
(True, False, False, True): _hash_exception,
(True, False, True, False): _hash_add,
(True, False, True, True ): _hash_exception,
(True, False, True, True): _hash_exception,
(True, True, False, False): _hash_add,
(True, True, False, True ): _hash_exception,
(True, True, False, True): _hash_exception,
(True, True, True, False): _hash_add,
(True, True, True, True ): _hash_exception,
}
(True, True, True, True): _hash_exception,
}
# See https://bugs.python.org/issue32929#msg312829 for an if-statement
# version of this table.
@ -764,8 +809,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
# is defined by the base class, which is found first.
fields = {}
setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order,
unsafe_hash, frozen))
setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order, unsafe_hash, frozen))
# Find our base classes in reverse MRO order, and exclude
# ourselves. In reversed order so that more derived classes
@ -796,13 +840,12 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
# actual default value. Pseudo-fields ClassVars and InitVars are
# included, despite the fact that they're not real fields. That's
# dealt with later.
cls_annotations = cls.__dict__.get('__annotations__', {})
cls_annotations = cls.__dict__.get("__annotations__", {})
# Now find fields in our class. While doing so, validate some
# things, and set the default values (as class attributes) where
# we can.
cls_fields = [_get_field(cls, name, type)
for name, type in cls_annotations.items()]
cls_fields = [_get_field(cls, name, type) for name, type in cls_annotations.items()]
for f in cls_fields:
fields[f.name] = f
@ -825,19 +868,17 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
# Do we have any Field members that don't also have annotations?
for name, value in cls.__dict__.items():
if isinstance(value, Field) and not name in cls_annotations:
raise TypeError(f'{name!r} is a field but has no type annotation')
raise TypeError(f"{name!r} is a field but has no type annotation")
# Check rules that apply if we are derived from any dataclasses.
if has_dataclass_bases:
# Raise an exception if any of our bases are frozen, but we're not.
if any_frozen_base and not frozen:
raise TypeError('cannot inherit non-frozen dataclass from a '
'frozen one')
raise TypeError("cannot inherit non-frozen dataclass from a " "frozen one")
# Raise an exception if we're frozen, but none of our bases are.
if not any_frozen_base and frozen:
raise TypeError('cannot inherit frozen dataclass from a '
'non-frozen one')
raise TypeError("cannot inherit frozen dataclass from a " "non-frozen one")
# Remember all of the fields on our class (including bases). This
# also marks this class as being a dataclass.
@ -848,32 +889,35 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
# set __hash__ to None. This is a heuristic, as it's possible
# that such a __hash__ == None was not auto-generated, but it
# close enough.
class_hash = cls.__dict__.get('__hash__', MISSING)
has_explicit_hash = not (class_hash is MISSING or
(class_hash is None and '__eq__' in cls.__dict__))
class_hash = cls.__dict__.get("__hash__", MISSING)
has_explicit_hash = not (
class_hash is MISSING or (class_hash is None and "__eq__" in cls.__dict__)
)
# If we're generating ordering methods, we must be generating the
# eq methods.
if order and not eq:
raise ValueError('eq must be true if order is true')
raise ValueError("eq must be true if order is true")
if init:
# Does this class have a post-init function?
has_post_init = hasattr(cls, _POST_INIT_NAME)
# Include InitVars and regular fields (so, not ClassVars).
flds = [f for f in fields.values()
if f._field_type in (_FIELD, _FIELD_INITVAR)]
_set_new_attribute(cls, '__init__',
_init_fn(flds,
flds = [f for f in fields.values() if f._field_type in (_FIELD, _FIELD_INITVAR)]
_set_new_attribute(
cls,
"__init__",
_init_fn(
flds,
frozen,
has_post_init,
# The name to use for the "self"
# param in __init__. Use "self"
# if possible.
'__dataclass_self__' if 'self' in fields
else 'self',
))
"__dataclass_self__" if "self" in fields else "self",
),
)
# Get the fields as a list, and include only real fields. This is
# used in all of the following methods.
@ -881,54 +925,58 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
if repr:
flds = [f for f in field_list if f.repr]
_set_new_attribute(cls, '__repr__', _repr_fn(flds))
_set_new_attribute(cls, "__repr__", _repr_fn(flds))
if eq:
# Create _eq__ method. There's no need for a __ne__ method,
# since python will call __eq__ and negate it.
flds = [f for f in field_list if f.compare]
self_tuple = _tuple_str('self', flds)
other_tuple = _tuple_str('other', flds)
_set_new_attribute(cls, '__eq__',
_cmp_fn('__eq__', '==',
self_tuple, other_tuple))
self_tuple = _tuple_str("self", flds)
other_tuple = _tuple_str("other", flds)
_set_new_attribute(
cls, "__eq__", _cmp_fn("__eq__", "==", self_tuple, other_tuple)
)
if order:
# Create and set the ordering methods.
flds = [f for f in field_list if f.compare]
self_tuple = _tuple_str('self', flds)
other_tuple = _tuple_str('other', flds)
for name, op in [('__lt__', '<'),
('__le__', '<='),
('__gt__', '>'),
('__ge__', '>='),
self_tuple = _tuple_str("self", flds)
other_tuple = _tuple_str("other", flds)
for name, op in [
("__lt__", "<"),
("__le__", "<="),
("__gt__", ">"),
("__ge__", ">="),
]:
if _set_new_attribute(cls, name,
_cmp_fn(name, op, self_tuple, other_tuple)):
raise TypeError(f'Cannot overwrite attribute {name} '
f'in class {cls.__name__}. Consider using '
'functools.total_ordering')
if _set_new_attribute(
cls, name, _cmp_fn(name, op, self_tuple, other_tuple)
):
raise TypeError(
f"Cannot overwrite attribute {name} "
f"in class {cls.__name__}. Consider using "
"functools.total_ordering"
)
if frozen:
for fn in _frozen_get_del_attr(cls, field_list):
if _set_new_attribute(cls, fn.__name__, fn):
raise TypeError(f'Cannot overwrite attribute {fn.__name__} '
f'in class {cls.__name__}')
raise TypeError(
f"Cannot overwrite attribute {fn.__name__} "
f"in class {cls.__name__}"
)
# Decide if/how we're going to create a hash function.
hash_action = _hash_action[bool(unsafe_hash),
bool(eq),
bool(frozen),
has_explicit_hash]
hash_action = _hash_action[
bool(unsafe_hash), bool(eq), bool(frozen), has_explicit_hash
]
if hash_action:
# No need to call _set_new_attribute here, since by the time
# we're here the overwriting is unconditional.
cls.__hash__ = hash_action(cls, field_list)
if not getattr(cls, '__doc__'):
if not getattr(cls, "__doc__"):
# Create a class doc-string.
cls.__doc__ = (cls.__name__ +
str(inspect.signature(cls)).replace(' -> None', ''))
cls.__doc__ = cls.__name__ + str(inspect.signature(cls)).replace(" -> None", "")
return cls
@ -936,8 +984,16 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
# _cls should never be specified by keyword, so start it with an
# underscore. The presence of _cls is used to detect if this
# decorator is being called with parameters or not.
def dataclass(_cls=None, *, init=True, repr=True, eq=True, order=False,
unsafe_hash=False, frozen=False):
def dataclass(
_cls=None,
*,
init=True,
repr=True,
eq=True,
order=False,
unsafe_hash=False,
frozen=False,
):
"""Returns the same class as was passed in, with dunder methods
added based on the fields defined in the class.
@ -973,7 +1029,7 @@ def fields(class_or_instance):
try:
fields = getattr(class_or_instance, _FIELDS)
except AttributeError:
raise TypeError('must be called with a dataclass type or instance')
raise TypeError("must be called with a dataclass type or instance")
# Exclude pseudo-fields. Note that fields is sorted by insertion
# order, so the order of the tuple is as the fields were defined.
@ -1025,8 +1081,10 @@ def _asdict_inner(obj, dict_factory):
elif isinstance(obj, (list, tuple)):
return type(obj)(_asdict_inner(v, dict_factory) for v in obj)
elif isinstance(obj, dict):
return type(obj)((_asdict_inner(k, dict_factory), _asdict_inner(v, dict_factory))
for k, v in obj.items())
return type(obj)(
(_asdict_inner(k, dict_factory), _asdict_inner(v, dict_factory))
for k, v in obj.items()
)
else:
return copy.deepcopy(obj)
@ -1065,15 +1123,27 @@ def _astuple_inner(obj, tuple_factory):
elif isinstance(obj, (list, tuple)):
return type(obj)(_astuple_inner(v, tuple_factory) for v in obj)
elif isinstance(obj, dict):
return type(obj)((_astuple_inner(k, tuple_factory), _astuple_inner(v, tuple_factory))
for k, v in obj.items())
return type(obj)(
(_astuple_inner(k, tuple_factory), _astuple_inner(v, tuple_factory))
for k, v in obj.items()
)
else:
return copy.deepcopy(obj)
def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
repr=True, eq=True, order=False, unsafe_hash=False,
frozen=False):
def make_dataclass(
cls_name,
fields,
*,
bases=(),
namespace=None,
init=True,
repr=True,
eq=True,
order=False,
unsafe_hash=False,
frozen=False,
):
"""Return a new dynamically created dataclass.
The dataclass name will be 'cls_name'. 'fields' is an iterable
@ -1110,31 +1180,38 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
for item in fields:
if isinstance(item, str):
name = item
tp = 'typing.Any'
tp = "typing.Any"
elif len(item) == 2:
name, tp, = item
elif len(item) == 3:
name, tp, spec = item
namespace[name] = spec
else:
raise TypeError(f'Invalid field: {item!r}')
raise TypeError(f"Invalid field: {item!r}")
if not isinstance(name, str) or not name.isidentifier():
raise TypeError(f'Field names must be valid identifers: {name!r}')
raise TypeError(f"Field names must be valid identifers: {name!r}")
if keyword.iskeyword(name):
raise TypeError(f'Field names must not be keywords: {name!r}')
raise TypeError(f"Field names must not be keywords: {name!r}")
if name in seen:
raise TypeError(f'Field name duplicated: {name!r}')
raise TypeError(f"Field name duplicated: {name!r}")
seen.add(name)
anns[name] = tp
namespace['__annotations__'] = anns
namespace["__annotations__"] = anns
# We use `types.new_class()` instead of simply `type()` to allow dynamic creation
# of generic dataclassses.
cls = types.new_class(cls_name, bases, {}, lambda ns: ns.update(namespace))
return dataclass(cls, init=init, repr=repr, eq=eq, order=order,
unsafe_hash=unsafe_hash, frozen=frozen)
return dataclass(
cls,
init=init,
repr=repr,
eq=eq,
order=order,
unsafe_hash=unsafe_hash,
frozen=frozen,
)
def replace(obj, **changes):
@ -1165,9 +1242,11 @@ def replace(obj, **changes):
if not f.init:
# Error if this field is specified in changes.
if f.name in changes:
raise ValueError(f'field {f.name} is declared with '
'init=False, it cannot be specified with '
'replace()')
raise ValueError(
f"field {f.name} is declared with "
"init=False, it cannot be specified with "
"replace()"
)
continue
if f.name not in changes: