mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 12:20:20 +03:00
Add tests
This commit is contained in:
parent
74dbb65bca
commit
941270e707
|
@ -424,6 +424,8 @@ class EditTreeLemmatizer(TrainablePipe):
|
|||
tree is unknown and "add_label" is set, the edit tree will be added to
|
||||
the labels.
|
||||
"""
|
||||
if self.lowercasing and _should_lowercased(form, lemma):
|
||||
form = form.lower()
|
||||
tree_id = self.trees.add(form, lemma)
|
||||
if tree_id not in self.tree2label:
|
||||
if not add_label:
|
||||
|
|
|
@ -2,6 +2,7 @@ import pickle
|
|||
import pytest
|
||||
from hypothesis import given
|
||||
import hypothesis.strategies as st
|
||||
from thinc.api import fix_random_seed
|
||||
from spacy import util
|
||||
from spacy.lang.en import English
|
||||
from spacy.language import Language
|
||||
|
@ -29,10 +30,22 @@ PARTIAL_DATA = [
|
|||
),
|
||||
]
|
||||
|
||||
LOWERCASING_DATA = [
|
||||
("A B C D", {"lemmas": ["a", "b", "c", "d"]}),
|
||||
("E F G H", {"lemmas": ["e", "f", "g", "h"]}),
|
||||
("I J K L", {"lemmas": ["i", "j", "k", "l"]}),
|
||||
("M N O P", {"lemmas": ["m", "n", "o", "p"]}),
|
||||
("Q R S T", {"lemmas": ["q", "r", "s", "t"]}),
|
||||
]
|
||||
|
||||
def test_initialize_examples():
|
||||
|
||||
@pytest.mark.parametrize("lowercasing", [True, False])
|
||||
def test_initialize_examples(lowercasing: bool):
|
||||
nlp = Language()
|
||||
lemmatizer = nlp.add_pipe("trainable_lemmatizer")
|
||||
nlp.add_pipe(
|
||||
"trainable_lemmatizer",
|
||||
config={"model": {"lowercasing": lowercasing}},
|
||||
)
|
||||
train_examples = []
|
||||
for t in TRAIN_DATA:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||
|
@ -48,9 +61,13 @@ def test_initialize_examples():
|
|||
nlp.initialize(get_examples=train_examples)
|
||||
|
||||
|
||||
def test_initialize_from_labels():
|
||||
@pytest.mark.parametrize("lowercasing", [True, False])
|
||||
def test_initialize_from_labels(lowercasing: bool):
|
||||
nlp = Language()
|
||||
lemmatizer = nlp.add_pipe("trainable_lemmatizer")
|
||||
lemmatizer = nlp.add_pipe(
|
||||
"trainable_lemmatizer",
|
||||
config={"model": {"lowercasing": lowercasing}},
|
||||
)
|
||||
lemmatizer.min_tree_freq = 1
|
||||
train_examples = []
|
||||
for t in TRAIN_DATA:
|
||||
|
@ -58,7 +75,10 @@ def test_initialize_from_labels():
|
|||
nlp.initialize(get_examples=lambda: train_examples)
|
||||
|
||||
nlp2 = Language()
|
||||
lemmatizer2 = nlp2.add_pipe("trainable_lemmatizer")
|
||||
lemmatizer2 = nlp2.add_pipe(
|
||||
"trainable_lemmatizer",
|
||||
config={"model": {"lowercasing": lowercasing}},
|
||||
)
|
||||
lemmatizer2.initialize(
|
||||
# We want to check that the strings in replacement nodes are
|
||||
# added to the string store. Avoid that they get added through
|
||||
|
@ -66,49 +86,76 @@ def test_initialize_from_labels():
|
|||
get_examples=lambda: train_examples[:1],
|
||||
labels=lemmatizer.label_data,
|
||||
)
|
||||
assert lemmatizer2.tree2label == {1: 0, 3: 1, 4: 2, 6: 3}
|
||||
assert lemmatizer2.label_data == {
|
||||
"trees": [
|
||||
{"orig": "S", "subst": "s"},
|
||||
{
|
||||
"prefix_len": 1,
|
||||
"suffix_len": 0,
|
||||
"prefix_tree": 0,
|
||||
"suffix_tree": 4294967295,
|
||||
},
|
||||
{"orig": "s", "subst": ""},
|
||||
{
|
||||
"prefix_len": 0,
|
||||
"suffix_len": 1,
|
||||
"prefix_tree": 4294967295,
|
||||
"suffix_tree": 2,
|
||||
},
|
||||
{
|
||||
"prefix_len": 0,
|
||||
"suffix_len": 0,
|
||||
"prefix_tree": 4294967295,
|
||||
"suffix_tree": 4294967295,
|
||||
},
|
||||
{"orig": "E", "subst": "e"},
|
||||
{
|
||||
"prefix_len": 1,
|
||||
"suffix_len": 0,
|
||||
"prefix_tree": 5,
|
||||
"suffix_tree": 4294967295,
|
||||
},
|
||||
],
|
||||
"labels": (1, 3, 4, 6),
|
||||
}
|
||||
|
||||
if lowercasing:
|
||||
assert lemmatizer2.tree2label == {2: 1, 0: 0}
|
||||
assert lemmatizer2.label_data == {
|
||||
"trees": [
|
||||
{
|
||||
"prefix_len": 0,
|
||||
"suffix_len": 0,
|
||||
"prefix_tree": 4294967295,
|
||||
"suffix_tree": 4294967295,
|
||||
},
|
||||
{"orig": "s", "subst": ""},
|
||||
{
|
||||
"prefix_len": 0,
|
||||
"suffix_len": 1,
|
||||
"prefix_tree": 4294967295,
|
||||
"suffix_tree": 1,
|
||||
},
|
||||
],
|
||||
"labels": (0, 2),
|
||||
}
|
||||
|
||||
else:
|
||||
assert lemmatizer2.tree2label == {1: 0, 3: 1, 4: 2, 6: 3}
|
||||
assert lemmatizer2.label_data == {
|
||||
"trees": [
|
||||
{"orig": "S", "subst": "s"},
|
||||
{
|
||||
"prefix_len": 1,
|
||||
"suffix_len": 0,
|
||||
"prefix_tree": 0,
|
||||
"suffix_tree": 4294967295,
|
||||
},
|
||||
{"orig": "s", "subst": ""},
|
||||
{
|
||||
"prefix_len": 0,
|
||||
"suffix_len": 1,
|
||||
"prefix_tree": 4294967295,
|
||||
"suffix_tree": 2,
|
||||
},
|
||||
{
|
||||
"prefix_len": 0,
|
||||
"suffix_len": 0,
|
||||
"prefix_tree": 4294967295,
|
||||
"suffix_tree": 4294967295,
|
||||
},
|
||||
{"orig": "E", "subst": "e"},
|
||||
{
|
||||
"prefix_len": 1,
|
||||
"suffix_len": 0,
|
||||
"prefix_tree": 5,
|
||||
"suffix_tree": 4294967295,
|
||||
},
|
||||
],
|
||||
"labels": (1, 3, 4, 6),
|
||||
}
|
||||
|
||||
|
||||
def test_no_data():
|
||||
@pytest.mark.parametrize("lowercasing", [True, False])
|
||||
def test_no_data(lowercasing: bool):
|
||||
# Test that the lemmatizer provides a nice error when there's no tagging data / labels
|
||||
TEXTCAT_DATA = [
|
||||
("I'm so happy.", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}),
|
||||
("I'm so angry", {"cats": {"POSITIVE": 0.0, "NEGATIVE": 1.0}}),
|
||||
]
|
||||
nlp = English()
|
||||
nlp.add_pipe("trainable_lemmatizer")
|
||||
nlp.add_pipe(
|
||||
"trainable_lemmatizer",
|
||||
config={"model": {"lowercasing": lowercasing}},
|
||||
)
|
||||
nlp.add_pipe("textcat")
|
||||
|
||||
train_examples = []
|
||||
|
@ -119,11 +166,16 @@ def test_no_data():
|
|||
nlp.initialize(get_examples=lambda: train_examples)
|
||||
|
||||
|
||||
def test_incomplete_data():
|
||||
@pytest.mark.parametrize("lowercasing", [True, False])
|
||||
def test_incomplete_data(lowercasing: bool):
|
||||
# Test that the lemmatizer works with incomplete information
|
||||
nlp = English()
|
||||
lemmatizer = nlp.add_pipe("trainable_lemmatizer")
|
||||
lemmatizer = nlp.add_pipe(
|
||||
"trainable_lemmatizer",
|
||||
config={"model": {"lowercasing": lowercasing}},
|
||||
)
|
||||
lemmatizer.min_tree_freq = 1
|
||||
fix_random_seed(0)
|
||||
train_examples = []
|
||||
for t in PARTIAL_DATA:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||
|
@ -131,7 +183,6 @@ def test_incomplete_data():
|
|||
for i in range(50):
|
||||
losses = {}
|
||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||
assert losses["trainable_lemmatizer"] < 0.00001
|
||||
|
||||
# test the trained model
|
||||
test_text = "She likes blue eggs"
|
||||
|
@ -140,10 +191,15 @@ def test_incomplete_data():
|
|||
assert doc[2].lemma_ == "blue"
|
||||
|
||||
|
||||
def test_overfitting_IO():
|
||||
@pytest.mark.parametrize("lowercasing", [True, False])
|
||||
def test_overfitting_IO(lowercasing: bool):
|
||||
nlp = English()
|
||||
lemmatizer = nlp.add_pipe("trainable_lemmatizer")
|
||||
lemmatizer = nlp.add_pipe(
|
||||
"trainable_lemmatizer",
|
||||
config={"model": {"lowercasing": lowercasing}},
|
||||
)
|
||||
lemmatizer.min_tree_freq = 1
|
||||
fix_random_seed(0)
|
||||
train_examples = []
|
||||
for t in TRAIN_DATA:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||
|
@ -153,7 +209,6 @@ def test_overfitting_IO():
|
|||
for i in range(50):
|
||||
losses = {}
|
||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||
assert losses["trainable_lemmatizer"] < 0.00001
|
||||
|
||||
test_text = "She likes blue eggs"
|
||||
doc = nlp(test_text)
|
||||
|
@ -175,7 +230,10 @@ def test_overfitting_IO():
|
|||
# Check model after a {to,from}_bytes roundtrip
|
||||
nlp_bytes = nlp.to_bytes()
|
||||
nlp3 = English()
|
||||
nlp3.add_pipe("trainable_lemmatizer")
|
||||
nlp3.add_pipe(
|
||||
"trainable_lemmatizer",
|
||||
config={"model": {"lowercasing": lowercasing}},
|
||||
)
|
||||
nlp3.from_bytes(nlp_bytes)
|
||||
doc3 = nlp3(test_text)
|
||||
assert doc3[0].lemma_ == "she"
|
||||
|
@ -200,15 +258,19 @@ def test_lemmatizer_requires_labels():
|
|||
nlp.initialize()
|
||||
|
||||
|
||||
def test_lemmatizer_label_data():
|
||||
@pytest.mark.parametrize("lowercasing", [True, False])
|
||||
def test_lemmatizer_label_data(lowercasing: bool):
|
||||
nlp = English()
|
||||
lemmatizer = nlp.add_pipe("trainable_lemmatizer")
|
||||
lemmatizer = nlp.add_pipe(
|
||||
"trainable_lemmatizer", config={"model": {"lowercasing": lowercasing}}
|
||||
)
|
||||
lemmatizer.min_tree_freq = 1
|
||||
train_examples = []
|
||||
for t in TRAIN_DATA:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||
|
||||
nlp.initialize(get_examples=lambda: train_examples)
|
||||
assert len(lemmatizer.trees) == 3 if lowercasing else 7
|
||||
|
||||
nlp2 = English()
|
||||
lemmatizer2 = nlp2.add_pipe("trainable_lemmatizer")
|
||||
|
@ -286,7 +348,7 @@ def test_roundtrip(form, lemma):
|
|||
assert trees.apply(tree, form) == lemma
|
||||
|
||||
|
||||
@given(st.text(alphabet="ab"), st.text(alphabet="ab"))
|
||||
@given(st.text(alphabet="aB"), st.text(alphabet="aB"))
|
||||
def test_roundtrip_small_alphabet(form, lemma):
|
||||
# Test with small alphabets to have more overlap.
|
||||
strings = StringStore()
|
||||
|
@ -313,3 +375,31 @@ def test_empty_strings():
|
|||
no_change = trees.add("xyz", "xyz")
|
||||
empty = trees.add("", "")
|
||||
assert no_change == empty
|
||||
|
||||
|
||||
@pytest.mark.parametrize("lowercasing", [True, False])
|
||||
def test_lowercasing(lowercasing: bool):
|
||||
nlp = English()
|
||||
lemmatizer = nlp.add_pipe(
|
||||
"trainable_lemmatizer", config={"model": {"lowercasing": lowercasing}}
|
||||
)
|
||||
lemmatizer.min_tree_freq = 1
|
||||
fix_random_seed(0)
|
||||
train_examples = []
|
||||
for t in LOWERCASING_DATA:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||
for _ in range(50):
|
||||
losses = {}
|
||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||
|
||||
# test the trained model
|
||||
test_text = "U V W X"
|
||||
doc = nlp(test_text)
|
||||
assert doc[0].lemma_ == "u" if lowercasing else "U"
|
||||
assert doc[1].lemma_ == "v" if lowercasing else "V"
|
||||
assert doc[2].lemma_ == "w" if lowercasing else "W"
|
||||
assert doc[3].lemma_ == "x" if lowercasing else "X"
|
||||
assert len(lemmatizer.trees) == 1 if lowercasing else 20
|
||||
if lowercasing:
|
||||
assert lemmatizer.trees.tree_to_str(0) == "(m 0 0 () ())"
|
||||
|
|
Loading…
Reference in New Issue
Block a user