Fix mypy issues

This commit is contained in:
richardpaulhudson 2022-12-09 15:55:31 +01:00
parent edc483c379
commit 7cd9a99854
2 changed files with 44 additions and 23 deletions

View File

@ -163,7 +163,7 @@ class EditTreeLemmatizer(TrainablePipe):
for i, doc_d_tree_scores in enumerate(d_tree_scores):
eg_lowercasing_flags = lowercasing_flags[i]
eg_d_lowercasing_flags, eg_lowercasing_loss = lowercasing_loss_func(
eg_lowercasing_flags, self.model.ops.asarray(lowercasing_truths[i])
eg_lowercasing_flags, self.model.ops.asarray2f(lowercasing_truths[i])
)
doc_d_scores = self.model.ops.xp.hstack(
[

View File

@ -1,8 +1,11 @@
from typing import cast, Dict
import pickle
import pytest
from hypothesis import given
import hypothesis.strategies as st
from spacy.pipeline.edit_tree_lemmatizer import EditTreeLemmatizer
from thinc.api import fix_random_seed
from thinc.types import Floats2d
from spacy import util
from spacy.lang.en import English
from spacy.language import Language
@ -52,21 +55,24 @@ def test_initialize_examples(lowercasing: bool):
# you shouldn't really call this more than once, but for testing it should be fine
nlp.initialize(get_examples=lambda: train_examples)
with pytest.raises(TypeError):
nlp.initialize(get_examples=lambda: None)
nlp.initialize(get_examples=lambda: None) # type:ignore[arg-type, return-value]
with pytest.raises(TypeError):
nlp.initialize(get_examples=lambda: train_examples[0])
with pytest.raises(TypeError):
nlp.initialize(get_examples=lambda: [])
with pytest.raises(TypeError):
nlp.initialize(get_examples=train_examples)
nlp.initialize(get_examples=train_examples) # type:ignore[arg-type]
@pytest.mark.parametrize("lowercasing", [True, False])
def test_initialize_from_labels(lowercasing: bool):
nlp = Language()
lemmatizer = nlp.add_pipe(
"trainable_lemmatizer",
config={"model": {"lowercasing": lowercasing}},
lemmatizer = cast(
EditTreeLemmatizer,
nlp.add_pipe(
"trainable_lemmatizer",
config={"model": {"lowercasing": lowercasing}},
),
)
lemmatizer.min_tree_freq = 1
train_examples = []
@ -75,9 +81,12 @@ def test_initialize_from_labels(lowercasing: bool):
nlp.initialize(get_examples=lambda: train_examples)
nlp2 = Language()
lemmatizer2 = nlp2.add_pipe(
"trainable_lemmatizer",
config={"model": {"lowercasing": lowercasing}},
lemmatizer2 = cast(
EditTreeLemmatizer,
nlp2.add_pipe(
"trainable_lemmatizer",
config={"model": {"lowercasing": lowercasing}},
),
)
lemmatizer2.initialize(
# We want to check that the strings in replacement nodes are
@ -170,9 +179,12 @@ def test_no_data(lowercasing: bool):
def test_incomplete_data(lowercasing: bool):
# Test that the lemmatizer works with incomplete information
nlp = English()
lemmatizer = nlp.add_pipe(
"trainable_lemmatizer",
config={"model": {"lowercasing": lowercasing}},
lemmatizer = cast(
EditTreeLemmatizer,
nlp.add_pipe(
"trainable_lemmatizer",
config={"model": {"lowercasing": lowercasing}},
),
)
lemmatizer.min_tree_freq = 1
fix_random_seed(0)
@ -181,7 +193,7 @@ def test_incomplete_data(lowercasing: bool):
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
optimizer = nlp.initialize(get_examples=lambda: train_examples)
for i in range(50):
losses = {}
losses: Dict[Floats2d, Floats2d] = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
# test the trained model
@ -194,9 +206,12 @@ def test_incomplete_data(lowercasing: bool):
@pytest.mark.parametrize("lowercasing", [True, False])
def test_overfitting_IO(lowercasing: bool):
nlp = English()
lemmatizer = nlp.add_pipe(
"trainable_lemmatizer",
config={"model": {"lowercasing": lowercasing}},
lemmatizer = cast(
EditTreeLemmatizer,
nlp.add_pipe(
"trainable_lemmatizer",
config={"model": {"lowercasing": lowercasing}},
),
)
lemmatizer.min_tree_freq = 1
fix_random_seed(0)
@ -207,7 +222,7 @@ def test_overfitting_IO(lowercasing: bool):
optimizer = nlp.initialize(get_examples=lambda: train_examples)
for i in range(50):
losses = {}
losses: Dict[Floats2d, Floats2d] = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
test_text = "She likes blue eggs"
@ -261,8 +276,11 @@ def test_lemmatizer_requires_labels():
@pytest.mark.parametrize("lowercasing", [True, False])
def test_lemmatizer_label_data(lowercasing: bool):
nlp = English()
lemmatizer = nlp.add_pipe(
"trainable_lemmatizer", config={"model": {"lowercasing": lowercasing}}
lemmatizer = cast(
EditTreeLemmatizer,
nlp.add_pipe(
"trainable_lemmatizer", config={"model": {"lowercasing": lowercasing}}
),
)
lemmatizer.min_tree_freq = 1
train_examples = []
@ -273,7 +291,7 @@ def test_lemmatizer_label_data(lowercasing: bool):
assert len(lemmatizer.trees) == 3 if lowercasing else 7
nlp2 = English()
lemmatizer2 = nlp2.add_pipe("trainable_lemmatizer")
lemmatizer2 = cast(EditTreeLemmatizer, nlp2.add_pipe("trainable_lemmatizer"))
lemmatizer2.initialize(
get_examples=lambda: train_examples, labels=lemmatizer.label_data
)
@ -380,8 +398,11 @@ def test_empty_strings():
@pytest.mark.parametrize("lowercasing", [True, False])
def test_lowercasing(lowercasing: bool):
nlp = English()
lemmatizer = nlp.add_pipe(
"trainable_lemmatizer", config={"model": {"lowercasing": lowercasing}}
lemmatizer = cast(
EditTreeLemmatizer,
nlp.add_pipe(
"trainable_lemmatizer", config={"model": {"lowercasing": lowercasing}}
),
)
lemmatizer.min_tree_freq = 1
fix_random_seed(0)
@ -390,7 +411,7 @@ def test_lowercasing(lowercasing: bool):
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
optimizer = nlp.initialize(get_examples=lambda: train_examples)
for _ in range(50):
losses = {}
losses: Dict[Floats2d, Floats2d] = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
# test the trained model