From 941270e707da835b394708f6cb003886d9141002 Mon Sep 17 00:00:00 2001 From: richardpaulhudson Date: Thu, 8 Dec 2022 17:18:42 +0100 Subject: [PATCH] Add tests --- spacy/pipeline/edit_tree_lemmatizer.py | 2 + .../pipeline/test_edit_tree_lemmatizer.py | 190 +++++++++++++----- 2 files changed, 142 insertions(+), 50 deletions(-) diff --git a/spacy/pipeline/edit_tree_lemmatizer.py b/spacy/pipeline/edit_tree_lemmatizer.py index 6feaeb465..5dbfa3152 100644 --- a/spacy/pipeline/edit_tree_lemmatizer.py +++ b/spacy/pipeline/edit_tree_lemmatizer.py @@ -424,6 +424,8 @@ class EditTreeLemmatizer(TrainablePipe): tree is unknown and "add_label" is set, the edit tree will be added to the labels. """ + if self.lowercasing and _should_lowercased(form, lemma): + form = form.lower() tree_id = self.trees.add(form, lemma) if tree_id not in self.tree2label: if not add_label: diff --git a/spacy/tests/pipeline/test_edit_tree_lemmatizer.py b/spacy/tests/pipeline/test_edit_tree_lemmatizer.py index b12ca5dd4..f256679ba 100644 --- a/spacy/tests/pipeline/test_edit_tree_lemmatizer.py +++ b/spacy/tests/pipeline/test_edit_tree_lemmatizer.py @@ -2,6 +2,7 @@ import pickle import pytest from hypothesis import given import hypothesis.strategies as st +from thinc.api import fix_random_seed from spacy import util from spacy.lang.en import English from spacy.language import Language @@ -29,10 +30,22 @@ PARTIAL_DATA = [ ), ] +LOWERCASING_DATA = [ + ("A B C D", {"lemmas": ["a", "b", "c", "d"]}), + ("E F G H", {"lemmas": ["e", "f", "g", "h"]}), + ("I J K L", {"lemmas": ["i", "j", "k", "l"]}), + ("M N O P", {"lemmas": ["m", "n", "o", "p"]}), + ("Q R S T", {"lemmas": ["q", "r", "s", "t"]}), +] -def test_initialize_examples(): + +@pytest.mark.parametrize("lowercasing", [True, False]) +def test_initialize_examples(lowercasing: bool): nlp = Language() - lemmatizer = nlp.add_pipe("trainable_lemmatizer") + nlp.add_pipe( + "trainable_lemmatizer", + config={"model": {"lowercasing": lowercasing}}, + ) train_examples = [] for t in TRAIN_DATA: train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1])) @@ -48,9 +61,13 @@ def test_initialize_examples(): nlp.initialize(get_examples=train_examples) -def test_initialize_from_labels(): +@pytest.mark.parametrize("lowercasing", [True, False]) +def test_initialize_from_labels(lowercasing: bool): nlp = Language() - lemmatizer = nlp.add_pipe("trainable_lemmatizer") + lemmatizer = nlp.add_pipe( + "trainable_lemmatizer", + config={"model": {"lowercasing": lowercasing}}, + ) lemmatizer.min_tree_freq = 1 train_examples = [] for t in TRAIN_DATA: @@ -58,7 +75,10 @@ def test_initialize_from_labels(): nlp.initialize(get_examples=lambda: train_examples) nlp2 = Language() - lemmatizer2 = nlp2.add_pipe("trainable_lemmatizer") + lemmatizer2 = nlp2.add_pipe( + "trainable_lemmatizer", + config={"model": {"lowercasing": lowercasing}}, + ) lemmatizer2.initialize( # We want to check that the strings in replacement nodes are # added to the string store. Avoid that they get added through @@ -66,49 +86,76 @@ def test_initialize_from_labels(): get_examples=lambda: train_examples[:1], labels=lemmatizer.label_data, ) - assert lemmatizer2.tree2label == {1: 0, 3: 1, 4: 2, 6: 3} - assert lemmatizer2.label_data == { - "trees": [ - {"orig": "S", "subst": "s"}, - { - "prefix_len": 1, - "suffix_len": 0, - "prefix_tree": 0, - "suffix_tree": 4294967295, - }, - {"orig": "s", "subst": ""}, - { - "prefix_len": 0, - "suffix_len": 1, - "prefix_tree": 4294967295, - "suffix_tree": 2, - }, - { - "prefix_len": 0, - "suffix_len": 0, - "prefix_tree": 4294967295, - "suffix_tree": 4294967295, - }, - {"orig": "E", "subst": "e"}, - { - "prefix_len": 1, - "suffix_len": 0, - "prefix_tree": 5, - "suffix_tree": 4294967295, - }, - ], - "labels": (1, 3, 4, 6), - } + + if lowercasing: + assert lemmatizer2.tree2label == {2: 1, 0: 0} + assert lemmatizer2.label_data == { + "trees": [ + { + "prefix_len": 0, + "suffix_len": 0, + "prefix_tree": 4294967295, + "suffix_tree": 4294967295, + }, + {"orig": "s", "subst": ""}, + { + "prefix_len": 0, + "suffix_len": 1, + "prefix_tree": 4294967295, + "suffix_tree": 1, + }, + ], + "labels": (0, 2), + } + + else: + assert lemmatizer2.tree2label == {1: 0, 3: 1, 4: 2, 6: 3} + assert lemmatizer2.label_data == { + "trees": [ + {"orig": "S", "subst": "s"}, + { + "prefix_len": 1, + "suffix_len": 0, + "prefix_tree": 0, + "suffix_tree": 4294967295, + }, + {"orig": "s", "subst": ""}, + { + "prefix_len": 0, + "suffix_len": 1, + "prefix_tree": 4294967295, + "suffix_tree": 2, + }, + { + "prefix_len": 0, + "suffix_len": 0, + "prefix_tree": 4294967295, + "suffix_tree": 4294967295, + }, + {"orig": "E", "subst": "e"}, + { + "prefix_len": 1, + "suffix_len": 0, + "prefix_tree": 5, + "suffix_tree": 4294967295, + }, + ], + "labels": (1, 3, 4, 6), + } -def test_no_data(): +@pytest.mark.parametrize("lowercasing", [True, False]) +def test_no_data(lowercasing: bool): # Test that the lemmatizer provides a nice error when there's no tagging data / labels TEXTCAT_DATA = [ ("I'm so happy.", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}), ("I'm so angry", {"cats": {"POSITIVE": 0.0, "NEGATIVE": 1.0}}), ] nlp = English() - nlp.add_pipe("trainable_lemmatizer") + nlp.add_pipe( + "trainable_lemmatizer", + config={"model": {"lowercasing": lowercasing}}, + ) nlp.add_pipe("textcat") train_examples = [] @@ -119,11 +166,16 @@ def test_no_data(): nlp.initialize(get_examples=lambda: train_examples) -def test_incomplete_data(): +@pytest.mark.parametrize("lowercasing", [True, False]) +def test_incomplete_data(lowercasing: bool): # Test that the lemmatizer works with incomplete information nlp = English() - lemmatizer = nlp.add_pipe("trainable_lemmatizer") + lemmatizer = nlp.add_pipe( + "trainable_lemmatizer", + config={"model": {"lowercasing": lowercasing}}, + ) lemmatizer.min_tree_freq = 1 + fix_random_seed(0) train_examples = [] for t in PARTIAL_DATA: train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1])) @@ -131,7 +183,6 @@ def test_incomplete_data(): for i in range(50): losses = {} nlp.update(train_examples, sgd=optimizer, losses=losses) - assert losses["trainable_lemmatizer"] < 0.00001 # test the trained model test_text = "She likes blue eggs" @@ -140,10 +191,15 @@ def test_incomplete_data(): assert doc[2].lemma_ == "blue" -def test_overfitting_IO(): +@pytest.mark.parametrize("lowercasing", [True, False]) +def test_overfitting_IO(lowercasing: bool): nlp = English() - lemmatizer = nlp.add_pipe("trainable_lemmatizer") + lemmatizer = nlp.add_pipe( + "trainable_lemmatizer", + config={"model": {"lowercasing": lowercasing}}, + ) lemmatizer.min_tree_freq = 1 + fix_random_seed(0) train_examples = [] for t in TRAIN_DATA: train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1])) @@ -153,7 +209,6 @@ def test_overfitting_IO(): for i in range(50): losses = {} nlp.update(train_examples, sgd=optimizer, losses=losses) - assert losses["trainable_lemmatizer"] < 0.00001 test_text = "She likes blue eggs" doc = nlp(test_text) @@ -175,7 +230,10 @@ def test_overfitting_IO(): # Check model after a {to,from}_bytes roundtrip nlp_bytes = nlp.to_bytes() nlp3 = English() - nlp3.add_pipe("trainable_lemmatizer") + nlp3.add_pipe( + "trainable_lemmatizer", + config={"model": {"lowercasing": lowercasing}}, + ) nlp3.from_bytes(nlp_bytes) doc3 = nlp3(test_text) assert doc3[0].lemma_ == "she" @@ -200,15 +258,19 @@ def test_lemmatizer_requires_labels(): nlp.initialize() -def test_lemmatizer_label_data(): +@pytest.mark.parametrize("lowercasing", [True, False]) +def test_lemmatizer_label_data(lowercasing: bool): nlp = English() - lemmatizer = nlp.add_pipe("trainable_lemmatizer") + lemmatizer = nlp.add_pipe( + "trainable_lemmatizer", config={"model": {"lowercasing": lowercasing}} + ) lemmatizer.min_tree_freq = 1 train_examples = [] for t in TRAIN_DATA: train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1])) nlp.initialize(get_examples=lambda: train_examples) + assert len(lemmatizer.trees) == 3 if lowercasing else 7 nlp2 = English() lemmatizer2 = nlp2.add_pipe("trainable_lemmatizer") @@ -286,7 +348,7 @@ def test_roundtrip(form, lemma): assert trees.apply(tree, form) == lemma -@given(st.text(alphabet="ab"), st.text(alphabet="ab")) +@given(st.text(alphabet="aB"), st.text(alphabet="aB")) def test_roundtrip_small_alphabet(form, lemma): # Test with small alphabets to have more overlap. strings = StringStore() @@ -313,3 +375,31 @@ def test_empty_strings(): no_change = trees.add("xyz", "xyz") empty = trees.add("", "") assert no_change == empty + + +@pytest.mark.parametrize("lowercasing", [True, False]) +def test_lowercasing(lowercasing: bool): + nlp = English() + lemmatizer = nlp.add_pipe( + "trainable_lemmatizer", config={"model": {"lowercasing": lowercasing}} + ) + lemmatizer.min_tree_freq = 1 + fix_random_seed(0) + train_examples = [] + for t in LOWERCASING_DATA: + train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1])) + optimizer = nlp.initialize(get_examples=lambda: train_examples) + for _ in range(50): + losses = {} + nlp.update(train_examples, sgd=optimizer, losses=losses) + + # test the trained model + test_text = "U V W X" + doc = nlp(test_text) + assert doc[0].lemma_ == "u" if lowercasing else "U" + assert doc[1].lemma_ == "v" if lowercasing else "V" + assert doc[2].lemma_ == "w" if lowercasing else "W" + assert doc[3].lemma_ == "x" if lowercasing else "X" + assert len(lemmatizer.trees) == 1 if lowercasing else 20 + if lowercasing: + assert lemmatizer.trees.tree_to_str(0) == "(m 0 0 () ())"