diff --git a/spacy/pipeline/edit_tree_lemmatizer.py b/spacy/pipeline/edit_tree_lemmatizer.py index 12f9b73a3..a56c9975e 100644 --- a/spacy/pipeline/edit_tree_lemmatizer.py +++ b/spacy/pipeline/edit_tree_lemmatizer.py @@ -328,9 +328,9 @@ class EditTreeLemmatizer(TrainablePipe): tree = dict(tree) if "orig" in tree: - tree["orig"] = self.vocab.strings[tree["orig"]] + tree["orig"] = self.vocab.strings.add(tree["orig"]) if "orig" in tree: - tree["subst"] = self.vocab.strings[tree["subst"]] + tree["subst"] = self.vocab.strings.add(tree["subst"]) trees.append(tree) diff --git a/spacy/tests/pipeline/test_edit_tree_lemmatizer.py b/spacy/tests/pipeline/test_edit_tree_lemmatizer.py index cf541e301..b12ca5dd4 100644 --- a/spacy/tests/pipeline/test_edit_tree_lemmatizer.py +++ b/spacy/tests/pipeline/test_edit_tree_lemmatizer.py @@ -60,10 +60,45 @@ def test_initialize_from_labels(): nlp2 = Language() lemmatizer2 = nlp2.add_pipe("trainable_lemmatizer") lemmatizer2.initialize( - get_examples=lambda: train_examples, + # We want to check that the strings in replacement nodes are + # added to the string store. Avoid that they get added through + # the examples. + get_examples=lambda: train_examples[:1], labels=lemmatizer.label_data, ) assert lemmatizer2.tree2label == {1: 0, 3: 1, 4: 2, 6: 3} + assert lemmatizer2.label_data == { + "trees": [ + {"orig": "S", "subst": "s"}, + { + "prefix_len": 1, + "suffix_len": 0, + "prefix_tree": 0, + "suffix_tree": 4294967295, + }, + {"orig": "s", "subst": ""}, + { + "prefix_len": 0, + "suffix_len": 1, + "prefix_tree": 4294967295, + "suffix_tree": 2, + }, + { + "prefix_len": 0, + "suffix_len": 0, + "prefix_tree": 4294967295, + "suffix_tree": 4294967295, + }, + {"orig": "E", "subst": "e"}, + { + "prefix_len": 1, + "suffix_len": 0, + "prefix_tree": 5, + "suffix_tree": 4294967295, + }, + ], + "labels": (1, 3, 4, 6), + } def test_no_data():