From 8d487bf35c9a1c766fbc6a4b2a66754f5ee794ba Mon Sep 17 00:00:00 2001 From: richardpaulhudson Date: Tue, 10 Jan 2023 09:42:31 +0100 Subject: [PATCH] Use different _scores2guesses depending on tree_k --- spacy/pipeline/edit_tree_lemmatizer.py | 58 ++++++++++++++++++- .../pipeline/test_edit_tree_lemmatizer.py | 7 ++- 2 files changed, 60 insertions(+), 5 deletions(-) diff --git a/spacy/pipeline/edit_tree_lemmatizer.py b/spacy/pipeline/edit_tree_lemmatizer.py index 1483096d1..c88c9466c 100644 --- a/spacy/pipeline/edit_tree_lemmatizer.py +++ b/spacy/pipeline/edit_tree_lemmatizer.py @@ -20,6 +20,10 @@ from ..vocab import Vocab from .. import util +# The cutoff value of *top_k* above which an alternative method is used to process guesses. +TOP_K_GUARDRAIL = 20 + + default_model_config = """ [model] @architectures = "spacy.Tagger.v2" @@ -145,6 +149,18 @@ class EditTreeLemmatizer(TrainablePipe): return float(loss), d_scores def predict(self, docs: Iterable[Doc]) -> List[Ints2d]: + if self.top_k == 1: + scores2guesses = self._scores2guesses_top_k_equals_1 + elif self.top_k <= TOP_K_GUARDRAIL: + scores2guesses = self._scores2guesses_top_k_greater_1 + else: + scores2guesses = self._scores2guesses_top_k_guardrail + # The behaviour of *_scores2guesses_top_k_greater_1()* is efficient for values + # of *top_k>1* that are likely to be useful when the edit tree lemmatizer is used + # for its principal purpose of lemmatizing tokens. However, the code could also + # be used for other purposes, and with very large values of *top_k* the method + # becomes inefficient. In such cases, *_scores2guesses_top_k_guardrail()* is used + # instead. n_docs = len(list(docs)) if not any(len(doc) for doc in docs): # Handle cases where there are no tokens in any docs. @@ -154,11 +170,28 @@ class EditTreeLemmatizer(TrainablePipe): return guesses scores = self.model.predict(docs) assert len(scores) == n_docs - guesses = self._scores2guesses(docs, scores) + guesses = scores2guesses(docs, scores) assert len(guesses) == n_docs return guesses - def _scores2guesses(self, docs, scores): + def _scores2guesses_top_k_equals_1(self, docs, scores): + guesses = [] + for doc, doc_scores in zip(docs, scores): + doc_guesses = doc_scores.argmax(axis=1) + doc_guesses = self.numpy_ops.asarray(doc_guesses) + + doc_compat_guesses = [] + for i, token in enumerate(doc): + tree_id = self.cfg["labels"][doc_guesses[i]] + if self.trees.apply(tree_id, token.text) is not None: + doc_compat_guesses.append(tree_id) + else: + doc_compat_guesses.append(-1) + guesses.append(np.array(doc_compat_guesses)) + + return guesses + + def _scores2guesses_top_k_greater_1(self, docs, scores): guesses = [] predictions_to_consider = min(self.top_k, len(self.labels)) for doc, doc_scores in zip(docs, scores): @@ -177,6 +210,27 @@ class EditTreeLemmatizer(TrainablePipe): guesses.append(np.array(doc_compat_guesses)) return guesses + def _scores2guesses_top_k_guardrail(self, docs, scores): + guesses = [] + for doc, doc_scores in zip(docs, scores): + doc_guesses = np.argsort(doc_scores)[..., : -self.top_k - 1 : -1] + doc_guesses = self.numpy_ops.asarray(doc_guesses) + + doc_compat_guesses = [] + for token, candidates in zip(doc, doc_guesses): + tree_id = -1 + for candidate in candidates: + candidate_tree_id = self.cfg["labels"][candidate] + + if self.trees.apply(candidate_tree_id, token.text) is not None: + tree_id = candidate_tree_id + break + doc_compat_guesses.append(tree_id) + + guesses.append(np.array(doc_compat_guesses)) + + return guesses + def set_annotations(self, docs: Iterable[Doc], batch_tree_ids): for i, doc in enumerate(docs): doc_tree_ids = batch_tree_ids[i] diff --git a/spacy/tests/pipeline/test_edit_tree_lemmatizer.py b/spacy/tests/pipeline/test_edit_tree_lemmatizer.py index b12ca5dd4..ae9444462 100644 --- a/spacy/tests/pipeline/test_edit_tree_lemmatizer.py +++ b/spacy/tests/pipeline/test_edit_tree_lemmatizer.py @@ -140,9 +140,10 @@ def test_incomplete_data(): assert doc[2].lemma_ == "blue" -def test_overfitting_IO(): +@pytest.mark.parametrize("top_k", (1, 5, 30)) +def test_overfitting_IO(top_k): nlp = English() - lemmatizer = nlp.add_pipe("trainable_lemmatizer") + lemmatizer = nlp.add_pipe("trainable_lemmatizer", config={"top_k": top_k}) lemmatizer.min_tree_freq = 1 train_examples = [] for t in TRAIN_DATA: @@ -175,7 +176,7 @@ def test_overfitting_IO(): # Check model after a {to,from}_bytes roundtrip nlp_bytes = nlp.to_bytes() nlp3 = English() - nlp3.add_pipe("trainable_lemmatizer") + nlp3.add_pipe("trainable_lemmatizer", config={"top_k": top_k}) nlp3.from_bytes(nlp_bytes) doc3 = nlp3(test_text) assert doc3[0].lemma_ == "she"