diff --git a/spacy/pipeline/edit_tree_lemmatizer.py b/spacy/pipeline/edit_tree_lemmatizer.py index e83fe63ba..332badd8c 100644 --- a/spacy/pipeline/edit_tree_lemmatizer.py +++ b/spacy/pipeline/edit_tree_lemmatizer.py @@ -5,8 +5,8 @@ from itertools import islice import numpy as np import srsly -from thinc.api import Config, Model, SequenceCategoricalCrossentropy -from thinc.types import Floats2d, Ints1d, Ints2d +from thinc.api import Config, Model, SequenceCategoricalCrossentropy, NumpyOps +from thinc.types import Floats2d, Ints2d from ._edit_tree_internals.edit_trees import EditTrees from ._edit_tree_internals.schemas import validate_edit_tree @@ -20,6 +20,10 @@ from ..vocab import Vocab from .. import util +# The cutoff value of *top_k* above which an alternative method is used to process guesses. +TOP_K_GUARDRAIL = 20 + + default_model_config = """ [model] @architectures = "spacy.Tagger.v2" @@ -115,6 +119,7 @@ class EditTreeLemmatizer(TrainablePipe): self.cfg: Dict[str, Any] = {"labels": []} self.scorer = scorer + self.numpy_ops = NumpyOps() def get_loss( self, examples: Iterable[Example], scores: List[Floats2d] @@ -144,6 +149,18 @@ class EditTreeLemmatizer(TrainablePipe): return float(loss), d_scores def predict(self, docs: Iterable[Doc]) -> List[Ints2d]: + if self.top_k == 1: + scores2guesses = self._scores2guesses_top_k_equals_1 + elif self.top_k <= TOP_K_GUARDRAIL: + scores2guesses = self._scores2guesses_top_k_greater_1 + else: + scores2guesses = self._scores2guesses_top_k_guardrail + # The behaviour of *_scores2guesses_top_k_greater_1()* is efficient for values + # of *top_k>1* that are likely to be useful when the edit tree lemmatizer is used + # for its principal purpose of lemmatizing tokens. However, the code could also + # be used for other purposes, and with very large values of *top_k* the method + # becomes inefficient. In such cases, *_scores2guesses_top_k_guardrail()* is used + # instead. n_docs = len(list(docs)) if not any(len(doc) for doc in docs): # Handle cases where there are no tokens in any docs. @@ -153,20 +170,52 @@ class EditTreeLemmatizer(TrainablePipe): return guesses scores = self.model.predict(docs) assert len(scores) == n_docs - guesses = self._scores2guesses(docs, scores) + guesses = scores2guesses(docs, scores) assert len(guesses) == n_docs return guesses - def _scores2guesses(self, docs, scores): + def _scores2guesses_top_k_equals_1(self, docs, scores): guesses = [] for doc, doc_scores in zip(docs, scores): - if self.top_k == 1: - doc_guesses = doc_scores.argmax(axis=1).reshape(-1, 1) - else: - doc_guesses = np.argsort(doc_scores)[..., : -self.top_k - 1 : -1] + doc_guesses = doc_scores.argmax(axis=1) + doc_guesses = self.numpy_ops.asarray(doc_guesses) - if not isinstance(doc_guesses, np.ndarray): - doc_guesses = doc_guesses.get() + doc_compat_guesses = [] + for i, token in enumerate(doc): + tree_id = self.cfg["labels"][doc_guesses[i]] + if self.trees.apply(tree_id, token.text) is not None: + doc_compat_guesses.append(tree_id) + else: + doc_compat_guesses.append(-1) + guesses.append(np.array(doc_compat_guesses)) + + return guesses + + def _scores2guesses_top_k_greater_1(self, docs, scores): + guesses = [] + top_k = min(self.top_k, len(self.labels)) + for doc, doc_scores in zip(docs, scores): + doc_scores = self.numpy_ops.asarray(doc_scores) + doc_compat_guesses = [] + for i, token in enumerate(doc): + for _ in range(top_k): + candidate = int(doc_scores[i].argmax()) + candidate_tree_id = self.cfg["labels"][candidate] + if self.trees.apply(candidate_tree_id, token.text) is not None: + doc_compat_guesses.append(candidate_tree_id) + break + doc_scores[i, candidate] = np.finfo(np.float32).min + else: + doc_compat_guesses.append(-1) + guesses.append(np.array(doc_compat_guesses)) + + return guesses + + def _scores2guesses_top_k_guardrail(self, docs, scores): + guesses = [] + for doc, doc_scores in zip(docs, scores): + doc_guesses = np.argsort(doc_scores)[..., : -self.top_k - 1 : -1] + doc_guesses = self.numpy_ops.asarray(doc_guesses) doc_compat_guesses = [] for token, candidates in zip(doc, doc_guesses): diff --git a/spacy/tests/pipeline/test_edit_tree_lemmatizer.py b/spacy/tests/pipeline/test_edit_tree_lemmatizer.py index c4f9b09f3..128d75680 100644 --- a/spacy/tests/pipeline/test_edit_tree_lemmatizer.py +++ b/spacy/tests/pipeline/test_edit_tree_lemmatizer.py @@ -101,14 +101,15 @@ def test_initialize_from_labels(): } -def test_no_data(): +@pytest.mark.parametrize("top_k", (1, 5, 30)) +def test_no_data(top_k): # Test that the lemmatizer provides a nice error when there's no tagging data / labels TEXTCAT_DATA = [ ("I'm so happy.", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}), ("I'm so angry", {"cats": {"POSITIVE": 0.0, "NEGATIVE": 1.0}}), ] nlp = English() - nlp.add_pipe("trainable_lemmatizer") + nlp.add_pipe("trainable_lemmatizer", config={"top_k": top_k}) nlp.add_pipe("textcat") train_examples = [] @@ -119,10 +120,11 @@ def test_no_data(): nlp.initialize(get_examples=lambda: train_examples) -def test_incomplete_data(): +@pytest.mark.parametrize("top_k", (1, 5, 30)) +def test_incomplete_data(top_k): # Test that the lemmatizer works with incomplete information nlp = English() - lemmatizer = nlp.add_pipe("trainable_lemmatizer") + lemmatizer = nlp.add_pipe("trainable_lemmatizer", config={"top_k": top_k}) lemmatizer.min_tree_freq = 1 train_examples = [] for t in PARTIAL_DATA: @@ -154,9 +156,10 @@ def test_incomplete_data(): assert xp.count_nonzero(dX[1][1]) == 0 -def test_overfitting_IO(): +@pytest.mark.parametrize("top_k", (1, 5, 30)) +def test_overfitting_IO(top_k): nlp = English() - lemmatizer = nlp.add_pipe("trainable_lemmatizer") + lemmatizer = nlp.add_pipe("trainable_lemmatizer", config={"top_k": top_k}) lemmatizer.min_tree_freq = 1 train_examples = [] for t in TRAIN_DATA: @@ -189,7 +192,7 @@ def test_overfitting_IO(): # Check model after a {to,from}_bytes roundtrip nlp_bytes = nlp.to_bytes() nlp3 = English() - nlp3.add_pipe("trainable_lemmatizer") + nlp3.add_pipe("trainable_lemmatizer", config={"top_k": top_k}) nlp3.from_bytes(nlp_bytes) doc3 = nlp3(test_text) assert doc3[0].lemma_ == "she"