Fix issues for Mypy 0.950 and Pydantic 1.9.0 (#10786)

* Make changes to typing

* Correction

* Format with black

* Corrections based on review

* Bumped Thinc dependency version

* Bumped blis requirement

* Correction for older Python versions

* Update spacy/ml/models/textcat.py

Co-authored-by: Daniël de Kok <me@github.danieldk.eu>

* Corrections based on review feedback

* Readd deleted docstring line

Co-authored-by: Daniël de Kok <me@github.danieldk.eu>
This commit is contained in:
Richard Hudson 2022-05-25 09:33:54 +02:00 committed by GitHub
parent 6be09bbd07
commit 32954c3bcb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 63 additions and 61 deletions

View File

@ -5,8 +5,8 @@ requires = [
"cymem>=2.0.2,<2.1.0",
"preshed>=3.0.2,<3.1.0",
"murmurhash>=0.28.0,<1.1.0",
"thinc>=8.0.14,<8.1.0",
"blis>=0.4.0,<0.8.0",
"thinc>=8.1.0.dev0,<8.2.0",
"blis>=0.9.0,<0.10.0",
"pathy",
"numpy>=1.15.0",
]

View File

@ -3,8 +3,8 @@ spacy-legacy>=3.0.9,<3.1.0
spacy-loggers>=1.0.0,<2.0.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
thinc>=8.0.14,<8.1.0
blis>=0.4.0,<0.8.0
thinc>=8.1.0.dev0,<8.2.0
blis>=0.9.0,<0.10.0
ml_datasets>=0.2.0,<0.3.0
murmurhash>=0.28.0,<1.1.0
wasabi>=0.9.1,<1.1.0
@ -16,7 +16,7 @@ pathy>=0.3.5
numpy>=1.15.0
requests>=2.13.0,<3.0.0
tqdm>=4.38.0,<5.0.0
pydantic>=1.7.4,!=1.8,!=1.8.1,<1.9.0
pydantic>=1.7.4,!=1.8,!=1.8.1,<1.10.0
jinja2
langcodes>=3.2.0,<4.0.0
# Official Python utilities
@ -31,7 +31,7 @@ pytest-timeout>=1.3.0,<2.0.0
mock>=2.0.0,<3.0.0
flake8>=3.8.0,<3.10.0
hypothesis>=3.27.0,<7.0.0
mypy==0.910
mypy>=0.910,<=0.960
types-dataclasses>=0.1.3; python_version < "3.7"
types-mock>=0.1.1
types-requests

View File

@ -38,7 +38,7 @@ setup_requires =
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
murmurhash>=0.28.0,<1.1.0
thinc>=8.0.14,<8.1.0
thinc>=8.1.0.dev0,<8.2.0
install_requires =
# Our libraries
spacy-legacy>=3.0.9,<3.1.0
@ -46,8 +46,8 @@ install_requires =
murmurhash>=0.28.0,<1.1.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
thinc>=8.0.14,<8.1.0
blis>=0.4.0,<0.8.0
thinc>=8.1.0.dev0,<8.2.0
blis>=0.9.0,<0.10.0
wasabi>=0.9.1,<1.1.0
srsly>=2.4.3,<3.0.0
catalogue>=2.0.6,<2.1.0
@ -57,7 +57,7 @@ install_requires =
tqdm>=4.38.0,<5.0.0
numpy>=1.15.0
requests>=2.13.0,<3.0.0
pydantic>=1.7.4,!=1.8,!=1.8.1,<1.9.0
pydantic>=1.7.4,!=1.8,!=1.8.1,<1.10.0
jinja2
# Official Python utilities
setuptools

View File

@ -1,4 +1,5 @@
import warnings
from .compat import Literal
class ErrorsWithCodes(type):
@ -26,7 +27,10 @@ def setup_default_warnings():
filter_warning("once", error_msg="[W114]")
def filter_warning(action: str, error_msg: str):
def filter_warning(
action: Literal["default", "error", "ignore", "always", "module", "once"],
error_msg: str,
):
"""Customize how spaCy should handle a certain warning.
error_msg (str): e.g. "W006", or a full error message

View File

@ -85,7 +85,7 @@ class Table(OrderedDict):
value: The value to set.
"""
key = get_string_id(key)
OrderedDict.__setitem__(self, key, value)
OrderedDict.__setitem__(self, key, value) # type:ignore[assignment]
self.bloom.add(key)
def set(self, key: Union[str, int], value: Any) -> None:
@ -104,7 +104,7 @@ class Table(OrderedDict):
RETURNS: The value.
"""
key = get_string_id(key)
return OrderedDict.__getitem__(self, key)
return OrderedDict.__getitem__(self, key) # type:ignore[index]
def get(self, key: Union[str, int], default: Optional[Any] = None) -> Any:
"""Get the value for a given key. String keys will be hashed.
@ -114,7 +114,7 @@ class Table(OrderedDict):
RETURNS: The value.
"""
key = get_string_id(key)
return OrderedDict.get(self, key, default)
return OrderedDict.get(self, key, default) # type:ignore[arg-type]
def __contains__(self, key: Union[str, int]) -> bool: # type: ignore[override]
"""Check whether a key is in the table. String keys will be hashed.

View File

@ -23,7 +23,7 @@ def build_nel_encoder(
((tok2vec >> list2ragged()) & build_span_maker())
>> extract_spans()
>> reduce_mean()
>> residual(Maxout(nO=token_width, nI=token_width, nP=2, dropout=0.0)) # type: ignore[arg-type]
>> residual(Maxout(nO=token_width, nI=token_width, nP=2, dropout=0.0))
>> output_layer
)
model.set_ref("output_layer", output_layer)

View File

@ -72,7 +72,7 @@ def build_tb_parser_model(
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
tok2vec = chain(
tok2vec,
cast(Model[List["Floats2d"], Floats2d], list2array()),
list2array(),
Linear(hidden_width, t2v_width),
)
tok2vec.set_dim("nO", hidden_width)

View File

@ -1,5 +1,5 @@
from typing import Optional, List, cast
from functools import partial
from typing import Optional, List
from thinc.types import Floats2d
from thinc.api import Model, reduce_mean, Linear, list2ragged, Logistic
@ -59,7 +59,8 @@ def build_simple_cnn_text_classifier(
resizable_layer=resizable_layer,
)
model.set_ref("tok2vec", tok2vec)
model.set_dim("nO", nO) # type: ignore # TODO: remove type ignore once Thinc has been updated
if nO is not None:
model.set_dim("nO", cast(int, nO))
model.attrs["multi_label"] = not exclusive_classes
return model
@ -85,7 +86,7 @@ def build_bow_text_classifier(
if not no_output_layer:
fill_defaults["b"] = NEG_VALUE
output_layer = softmax_activation() if exclusive_classes else Logistic()
resizable_layer = resizable( # type: ignore[var-annotated]
resizable_layer: Model[Floats2d, Floats2d] = resizable(
sparse_linear,
resize_layer=partial(resize_linear_weighted, fill_defaults=fill_defaults),
)
@ -93,7 +94,8 @@ def build_bow_text_classifier(
model = with_cpu(model, model.ops)
if output_layer:
model = model >> with_cpu(output_layer, output_layer.ops)
model.set_dim("nO", nO) # type: ignore[arg-type]
if nO is not None:
model.set_dim("nO", cast(int, nO))
model.set_ref("output_layer", sparse_linear)
model.attrs["multi_label"] = not exclusive_classes
model.attrs["resize_output"] = partial(
@ -129,8 +131,8 @@ def build_text_classifier_v2(
output_layer = Linear(nO=nO, nI=nO_double) >> Logistic()
model = (linear_model | cnn_model) >> output_layer
model.set_ref("tok2vec", tok2vec)
if model.has_dim("nO") is not False:
model.set_dim("nO", nO) # type: ignore[arg-type]
if model.has_dim("nO") is not False and nO is not None:
model.set_dim("nO", cast(int, nO))
model.set_ref("output_layer", linear_model.get_ref("output_layer"))
model.set_ref("attention_layer", attention_layer)
model.set_ref("maxout_layer", maxout_layer)
@ -164,7 +166,7 @@ def build_text_classifier_lowdata(
>> list2ragged()
>> ParametricAttention(width)
>> reduce_sum()
>> residual(Relu(width, width)) ** 2 # type: ignore[arg-type]
>> residual(Relu(width, width)) ** 2
>> Linear(nO, width)
)
if dropout:

View File

@ -1,5 +1,5 @@
from typing import Optional, List, Union, cast
from thinc.types import Floats2d, Ints2d, Ragged
from thinc.types import Floats2d, Ints2d, Ragged, Ints1d
from thinc.api import chain, clone, concatenate, with_array, with_padded
from thinc.api import Model, noop, list2ragged, ragged2list, HashEmbed
from thinc.api import expand_window, residual, Maxout, Mish, PyTorchLSTM
@ -159,7 +159,7 @@ def MultiHashEmbed(
embeddings = [make_hash_embed(i) for i in range(len(attrs))]
concat_size = width * (len(embeddings) + include_static_vectors)
max_out: Model[Ragged, Ragged] = with_array(
Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True) # type: ignore
Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True)
)
if include_static_vectors:
feature_extractor: Model[List[Doc], Ragged] = chain(
@ -173,7 +173,7 @@ def MultiHashEmbed(
StaticVectors(width, dropout=0.0),
),
max_out,
cast(Model[Ragged, List[Floats2d]], ragged2list()),
ragged2list(),
)
else:
model = chain(
@ -181,7 +181,7 @@ def MultiHashEmbed(
cast(Model[List[Ints2d], Ragged], list2ragged()),
with_array(concatenate(*embeddings)),
max_out,
cast(Model[Ragged, List[Floats2d]], ragged2list()),
ragged2list(),
)
return model
@ -232,12 +232,12 @@ def CharacterEmbed(
feature_extractor: Model[List[Doc], Ragged] = chain(
FeatureExtractor([feature]),
cast(Model[List[Ints2d], Ragged], list2ragged()),
with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)), # type: ignore
with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)), # type: ignore[misc]
)
max_out: Model[Ragged, Ragged]
if include_static_vectors:
max_out = with_array(
Maxout(width, nM * nC + (2 * width), nP=3, normalize=True, dropout=0.0) # type: ignore
Maxout(width, nM * nC + (2 * width), nP=3, normalize=True, dropout=0.0)
)
model = chain(
concatenate(
@ -246,11 +246,11 @@ def CharacterEmbed(
StaticVectors(width, dropout=0.0),
),
max_out,
cast(Model[Ragged, List[Floats2d]], ragged2list()),
ragged2list(),
)
else:
max_out = with_array(
Maxout(width, nM * nC + width, nP=3, normalize=True, dropout=0.0) # type: ignore
Maxout(width, nM * nC + width, nP=3, normalize=True, dropout=0.0)
)
model = chain(
concatenate(
@ -258,7 +258,7 @@ def CharacterEmbed(
feature_extractor,
),
max_out,
cast(Model[Ragged, List[Floats2d]], ragged2list()),
ragged2list(),
)
return model
@ -289,10 +289,10 @@ def MaxoutWindowEncoder(
normalize=True,
),
)
model = clone(residual(cnn), depth) # type: ignore[arg-type]
model = clone(residual(cnn), depth)
model.set_dim("nO", width)
receptive_field = window_size * depth
return with_array(model, pad=receptive_field) # type: ignore[arg-type]
return with_array(model, pad=receptive_field)
@registry.architectures("spacy.MishWindowEncoder.v2")
@ -313,9 +313,9 @@ def MishWindowEncoder(
expand_window(window_size=window_size),
Mish(nO=width, nI=width * ((window_size * 2) + 1), dropout=0.0, normalize=True),
)
model = clone(residual(cnn), depth) # type: ignore[arg-type]
model = clone(residual(cnn), depth)
model.set_dim("nO", width)
return with_array(model) # type: ignore[arg-type]
return with_array(model)
@registry.architectures("spacy.TorchBiLSTMEncoder.v1")

View File

@ -40,17 +40,15 @@ def forward(
if not token_count:
return _handle_empty(model.ops, model.get_dim("nO"))
key_attr: int = model.attrs["key_attr"]
keys: Ints1d = model.ops.flatten(
cast(Sequence, [doc.to_array(key_attr) for doc in docs])
)
keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs])
vocab: Vocab = docs[0].vocab
W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
if vocab.vectors.mode == Mode.default:
V = cast(Floats2d, model.ops.asarray(vocab.vectors.data))
V = model.ops.asarray(vocab.vectors.data)
rows = vocab.vectors.find(keys=keys)
V = model.ops.as_contig(V[rows])
elif vocab.vectors.mode == Mode.floret:
V = cast(Floats2d, vocab.vectors.get_batch(keys))
V = vocab.vectors.get_batch(keys)
V = model.ops.as_contig(V)
else:
raise RuntimeError(Errors.E896)
@ -62,9 +60,7 @@ def forward(
# Convert negative indices to 0-vectors
# TODO: more options for UNK tokens
vectors_data[rows < 0] = 0
output = Ragged(
vectors_data, model.ops.asarray([len(doc) for doc in docs], dtype="i") # type: ignore
)
output = Ragged(vectors_data, model.ops.asarray1i([len(doc) for doc in docs]))
mask = None
if is_train:
mask = _get_drop_mask(model.ops, W.shape[0], model.attrs.get("dropout_rate"))
@ -77,7 +73,9 @@ def forward(
model.inc_grad(
"W",
model.ops.gemm(
cast(Floats2d, d_output.data), model.ops.as_contig(V), trans1=True
cast(Floats2d, d_output.data),
cast(Floats2d, model.ops.as_contig(V)),
trans1=True,
),
)
return []

View File

@ -138,7 +138,7 @@ class EditTreeLemmatizer(TrainablePipe):
truths.append(eg_truths)
d_scores, loss = loss_func(scores, truths) # type: ignore
d_scores, loss = loss_func(scores, truths)
if self.model.ops.xp.isnan(loss):
raise ValueError(Errors.E910.format(name=self.name))

View File

@ -159,10 +159,8 @@ class EntityRuler(Pipe):
self._require_patterns()
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="\\[W036")
matches = cast(
List[Tuple[int, int, int]],
list(self.matcher(doc)) + list(self.phrase_matcher(doc)),
)
matches = list(self.matcher(doc)) + list(self.phrase_matcher(doc))
final_matches = set(
[(m_id, start, end) for m_id, start, end in matches if start != end]
)

View File

@ -213,15 +213,14 @@ class EntityLinker_v1(TrainablePipe):
if kb_id:
entity_encoding = self.kb.get_vector(kb_id)
entity_encodings.append(entity_encoding)
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
entity_encodings = self.model.ops.asarray2f(entity_encodings)
if sentence_encodings.shape != entity_encodings.shape:
err = Errors.E147.format(
method="get_loss", msg="gold entities do not match up"
)
raise RuntimeError(err)
# TODO: fix typing issue here
gradients = self.distance.get_grad(sentence_encodings, entity_encodings) # type: ignore
loss = self.distance.get_loss(sentence_encodings, entity_encodings) # type: ignore
gradients = self.distance.get_grad(sentence_encodings, entity_encodings)
loss = self.distance.get_loss(sentence_encodings, entity_encodings)
loss = loss / len(entity_encodings)
return float(loss), gradients

View File

@ -75,7 +75,7 @@ def build_ngram_suggester(sizes: List[int]) -> Suggester:
if spans:
assert spans[-1].ndim == 2, spans[-1].shape
lengths.append(length)
lengths_array = cast(Ints1d, ops.asarray(lengths, dtype="i"))
lengths_array = ops.asarray1i(lengths)
if len(spans) > 0:
output = Ragged(ops.xp.vstack(spans), lengths_array)
else:

View File

@ -104,7 +104,7 @@ def get_arg_model(
sig_args[param.name] = (annotation, default)
is_strict = strict and not has_variable
sig_args["__config__"] = ArgSchemaConfig if is_strict else ArgSchemaConfigExtra # type: ignore[assignment]
return create_model(name, **sig_args) # type: ignore[arg-type, return-value]
return create_model(name, **sig_args) # type: ignore[call-overload, arg-type, return-value]
def validate_init_settings(

View File

@ -1,4 +1,4 @@
from typing import List, Mapping, NoReturn, Union, Dict, Any, Set
from typing import List, Mapping, NoReturn, Union, Dict, Any, Set, cast
from typing import Optional, Iterable, Callable, Tuple, Type
from typing import Iterator, Type, Pattern, Generator, TYPE_CHECKING
from types import ModuleType
@ -294,7 +294,7 @@ def find_matching_language(lang: str) -> Optional[str]:
# Find out which language modules we have
possible_languages = []
for modinfo in pkgutil.iter_modules(spacy.lang.__path__): # type: ignore
for modinfo in pkgutil.iter_modules(spacy.lang.__path__): # type: ignore[attr-defined]
code = modinfo.name
if code == "xx":
# Temporarily make 'xx' into a valid language code
@ -391,7 +391,8 @@ def get_module_path(module: ModuleType) -> Path:
"""
if not hasattr(module, "__module__"):
raise ValueError(Errors.E169.format(module=repr(module)))
return Path(sys.modules[module.__module__].__file__).parent
file_path = Path(cast(os.PathLike, sys.modules[module.__module__].__file__))
return file_path.parent
def load_model(
@ -878,7 +879,7 @@ def get_package_path(name: str) -> Path:
# Here we're importing the module just to find it. This is worryingly
# indirect, but it's otherwise very difficult to find the package.
pkg = importlib.import_module(name)
return Path(pkg.__file__).parent
return Path(cast(Union[str, os.PathLike], pkg.__file__)).parent
def replace_model_node(model: Model, target: Model, replacement: Model) -> None:
@ -1675,7 +1676,7 @@ def packages_distributions() -> Dict[str, List[str]]:
it's not available in the builtin importlib.metadata.
"""
pkg_to_dist = defaultdict(list)
for dist in importlib_metadata.distributions(): # type: ignore[attr-defined]
for dist in importlib_metadata.distributions():
for pkg in (dist.read_text("top_level.txt") or "").split():
pkg_to_dist[pkg].append(dist.metadata["Name"])
return dict(pkg_to_dist)