Fix mypy issues

This commit is contained in:
richardpaulhudson 2022-12-09 15:55:31 +01:00
parent edc483c379
commit 7cd9a99854
2 changed files with 44 additions and 23 deletions

View File

@ -163,7 +163,7 @@ class EditTreeLemmatizer(TrainablePipe):
for i, doc_d_tree_scores in enumerate(d_tree_scores): for i, doc_d_tree_scores in enumerate(d_tree_scores):
eg_lowercasing_flags = lowercasing_flags[i] eg_lowercasing_flags = lowercasing_flags[i]
eg_d_lowercasing_flags, eg_lowercasing_loss = lowercasing_loss_func( eg_d_lowercasing_flags, eg_lowercasing_loss = lowercasing_loss_func(
eg_lowercasing_flags, self.model.ops.asarray(lowercasing_truths[i]) eg_lowercasing_flags, self.model.ops.asarray2f(lowercasing_truths[i])
) )
doc_d_scores = self.model.ops.xp.hstack( doc_d_scores = self.model.ops.xp.hstack(
[ [

View File

@ -1,8 +1,11 @@
from typing import cast, Dict
import pickle import pickle
import pytest import pytest
from hypothesis import given from hypothesis import given
import hypothesis.strategies as st import hypothesis.strategies as st
from spacy.pipeline.edit_tree_lemmatizer import EditTreeLemmatizer
from thinc.api import fix_random_seed from thinc.api import fix_random_seed
from thinc.types import Floats2d
from spacy import util from spacy import util
from spacy.lang.en import English from spacy.lang.en import English
from spacy.language import Language from spacy.language import Language
@ -52,21 +55,24 @@ def test_initialize_examples(lowercasing: bool):
# you shouldn't really call this more than once, but for testing it should be fine # you shouldn't really call this more than once, but for testing it should be fine
nlp.initialize(get_examples=lambda: train_examples) nlp.initialize(get_examples=lambda: train_examples)
with pytest.raises(TypeError): with pytest.raises(TypeError):
nlp.initialize(get_examples=lambda: None) nlp.initialize(get_examples=lambda: None) # type:ignore[arg-type, return-value]
with pytest.raises(TypeError): with pytest.raises(TypeError):
nlp.initialize(get_examples=lambda: train_examples[0]) nlp.initialize(get_examples=lambda: train_examples[0])
with pytest.raises(TypeError): with pytest.raises(TypeError):
nlp.initialize(get_examples=lambda: []) nlp.initialize(get_examples=lambda: [])
with pytest.raises(TypeError): with pytest.raises(TypeError):
nlp.initialize(get_examples=train_examples) nlp.initialize(get_examples=train_examples) # type:ignore[arg-type]
@pytest.mark.parametrize("lowercasing", [True, False]) @pytest.mark.parametrize("lowercasing", [True, False])
def test_initialize_from_labels(lowercasing: bool): def test_initialize_from_labels(lowercasing: bool):
nlp = Language() nlp = Language()
lemmatizer = nlp.add_pipe( lemmatizer = cast(
"trainable_lemmatizer", EditTreeLemmatizer,
config={"model": {"lowercasing": lowercasing}}, nlp.add_pipe(
"trainable_lemmatizer",
config={"model": {"lowercasing": lowercasing}},
),
) )
lemmatizer.min_tree_freq = 1 lemmatizer.min_tree_freq = 1
train_examples = [] train_examples = []
@ -75,9 +81,12 @@ def test_initialize_from_labels(lowercasing: bool):
nlp.initialize(get_examples=lambda: train_examples) nlp.initialize(get_examples=lambda: train_examples)
nlp2 = Language() nlp2 = Language()
lemmatizer2 = nlp2.add_pipe( lemmatizer2 = cast(
"trainable_lemmatizer", EditTreeLemmatizer,
config={"model": {"lowercasing": lowercasing}}, nlp2.add_pipe(
"trainable_lemmatizer",
config={"model": {"lowercasing": lowercasing}},
),
) )
lemmatizer2.initialize( lemmatizer2.initialize(
# We want to check that the strings in replacement nodes are # We want to check that the strings in replacement nodes are
@ -170,9 +179,12 @@ def test_no_data(lowercasing: bool):
def test_incomplete_data(lowercasing: bool): def test_incomplete_data(lowercasing: bool):
# Test that the lemmatizer works with incomplete information # Test that the lemmatizer works with incomplete information
nlp = English() nlp = English()
lemmatizer = nlp.add_pipe( lemmatizer = cast(
"trainable_lemmatizer", EditTreeLemmatizer,
config={"model": {"lowercasing": lowercasing}}, nlp.add_pipe(
"trainable_lemmatizer",
config={"model": {"lowercasing": lowercasing}},
),
) )
lemmatizer.min_tree_freq = 1 lemmatizer.min_tree_freq = 1
fix_random_seed(0) fix_random_seed(0)
@ -181,7 +193,7 @@ def test_incomplete_data(lowercasing: bool):
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1])) train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
optimizer = nlp.initialize(get_examples=lambda: train_examples) optimizer = nlp.initialize(get_examples=lambda: train_examples)
for i in range(50): for i in range(50):
losses = {} losses: Dict[Floats2d, Floats2d] = {}
nlp.update(train_examples, sgd=optimizer, losses=losses) nlp.update(train_examples, sgd=optimizer, losses=losses)
# test the trained model # test the trained model
@ -194,9 +206,12 @@ def test_incomplete_data(lowercasing: bool):
@pytest.mark.parametrize("lowercasing", [True, False]) @pytest.mark.parametrize("lowercasing", [True, False])
def test_overfitting_IO(lowercasing: bool): def test_overfitting_IO(lowercasing: bool):
nlp = English() nlp = English()
lemmatizer = nlp.add_pipe( lemmatizer = cast(
"trainable_lemmatizer", EditTreeLemmatizer,
config={"model": {"lowercasing": lowercasing}}, nlp.add_pipe(
"trainable_lemmatizer",
config={"model": {"lowercasing": lowercasing}},
),
) )
lemmatizer.min_tree_freq = 1 lemmatizer.min_tree_freq = 1
fix_random_seed(0) fix_random_seed(0)
@ -207,7 +222,7 @@ def test_overfitting_IO(lowercasing: bool):
optimizer = nlp.initialize(get_examples=lambda: train_examples) optimizer = nlp.initialize(get_examples=lambda: train_examples)
for i in range(50): for i in range(50):
losses = {} losses: Dict[Floats2d, Floats2d] = {}
nlp.update(train_examples, sgd=optimizer, losses=losses) nlp.update(train_examples, sgd=optimizer, losses=losses)
test_text = "She likes blue eggs" test_text = "She likes blue eggs"
@ -261,8 +276,11 @@ def test_lemmatizer_requires_labels():
@pytest.mark.parametrize("lowercasing", [True, False]) @pytest.mark.parametrize("lowercasing", [True, False])
def test_lemmatizer_label_data(lowercasing: bool): def test_lemmatizer_label_data(lowercasing: bool):
nlp = English() nlp = English()
lemmatizer = nlp.add_pipe( lemmatizer = cast(
"trainable_lemmatizer", config={"model": {"lowercasing": lowercasing}} EditTreeLemmatizer,
nlp.add_pipe(
"trainable_lemmatizer", config={"model": {"lowercasing": lowercasing}}
),
) )
lemmatizer.min_tree_freq = 1 lemmatizer.min_tree_freq = 1
train_examples = [] train_examples = []
@ -273,7 +291,7 @@ def test_lemmatizer_label_data(lowercasing: bool):
assert len(lemmatizer.trees) == 3 if lowercasing else 7 assert len(lemmatizer.trees) == 3 if lowercasing else 7
nlp2 = English() nlp2 = English()
lemmatizer2 = nlp2.add_pipe("trainable_lemmatizer") lemmatizer2 = cast(EditTreeLemmatizer, nlp2.add_pipe("trainable_lemmatizer"))
lemmatizer2.initialize( lemmatizer2.initialize(
get_examples=lambda: train_examples, labels=lemmatizer.label_data get_examples=lambda: train_examples, labels=lemmatizer.label_data
) )
@ -380,8 +398,11 @@ def test_empty_strings():
@pytest.mark.parametrize("lowercasing", [True, False]) @pytest.mark.parametrize("lowercasing", [True, False])
def test_lowercasing(lowercasing: bool): def test_lowercasing(lowercasing: bool):
nlp = English() nlp = English()
lemmatizer = nlp.add_pipe( lemmatizer = cast(
"trainable_lemmatizer", config={"model": {"lowercasing": lowercasing}} EditTreeLemmatizer,
nlp.add_pipe(
"trainable_lemmatizer", config={"model": {"lowercasing": lowercasing}}
),
) )
lemmatizer.min_tree_freq = 1 lemmatizer.min_tree_freq = 1
fix_random_seed(0) fix_random_seed(0)
@ -390,7 +411,7 @@ def test_lowercasing(lowercasing: bool):
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1])) train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
optimizer = nlp.initialize(get_examples=lambda: train_examples) optimizer = nlp.initialize(get_examples=lambda: train_examples)
for _ in range(50): for _ in range(50):
losses = {} losses: Dict[Floats2d, Floats2d] = {}
nlp.update(train_examples, sgd=optimizer, losses=losses) nlp.update(train_examples, sgd=optimizer, losses=losses)
# test the trained model # test the trained model