Use different _scores2guesses depending on tree_k

This commit is contained in:
richardpaulhudson 2023-01-10 09:42:31 +01:00
parent 4daf5e9b81
commit 8d487bf35c
2 changed files with 60 additions and 5 deletions

View File

@ -20,6 +20,10 @@ from ..vocab import Vocab
from .. import util from .. import util
# The cutoff value of *top_k* above which an alternative method is used to process guesses.
TOP_K_GUARDRAIL = 20
default_model_config = """ default_model_config = """
[model] [model]
@architectures = "spacy.Tagger.v2" @architectures = "spacy.Tagger.v2"
@ -145,6 +149,18 @@ class EditTreeLemmatizer(TrainablePipe):
return float(loss), d_scores return float(loss), d_scores
def predict(self, docs: Iterable[Doc]) -> List[Ints2d]: def predict(self, docs: Iterable[Doc]) -> List[Ints2d]:
if self.top_k == 1:
scores2guesses = self._scores2guesses_top_k_equals_1
elif self.top_k <= TOP_K_GUARDRAIL:
scores2guesses = self._scores2guesses_top_k_greater_1
else:
scores2guesses = self._scores2guesses_top_k_guardrail
# The behaviour of *_scores2guesses_top_k_greater_1()* is efficient for values
# of *top_k>1* that are likely to be useful when the edit tree lemmatizer is used
# for its principal purpose of lemmatizing tokens. However, the code could also
# be used for other purposes, and with very large values of *top_k* the method
# becomes inefficient. In such cases, *_scores2guesses_top_k_guardrail()* is used
# instead.
n_docs = len(list(docs)) n_docs = len(list(docs))
if not any(len(doc) for doc in docs): if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs. # Handle cases where there are no tokens in any docs.
@ -154,11 +170,28 @@ class EditTreeLemmatizer(TrainablePipe):
return guesses return guesses
scores = self.model.predict(docs) scores = self.model.predict(docs)
assert len(scores) == n_docs assert len(scores) == n_docs
guesses = self._scores2guesses(docs, scores) guesses = scores2guesses(docs, scores)
assert len(guesses) == n_docs assert len(guesses) == n_docs
return guesses return guesses
def _scores2guesses(self, docs, scores): def _scores2guesses_top_k_equals_1(self, docs, scores):
guesses = []
for doc, doc_scores in zip(docs, scores):
doc_guesses = doc_scores.argmax(axis=1)
doc_guesses = self.numpy_ops.asarray(doc_guesses)
doc_compat_guesses = []
for i, token in enumerate(doc):
tree_id = self.cfg["labels"][doc_guesses[i]]
if self.trees.apply(tree_id, token.text) is not None:
doc_compat_guesses.append(tree_id)
else:
doc_compat_guesses.append(-1)
guesses.append(np.array(doc_compat_guesses))
return guesses
def _scores2guesses_top_k_greater_1(self, docs, scores):
guesses = [] guesses = []
predictions_to_consider = min(self.top_k, len(self.labels)) predictions_to_consider = min(self.top_k, len(self.labels))
for doc, doc_scores in zip(docs, scores): for doc, doc_scores in zip(docs, scores):
@ -177,6 +210,27 @@ class EditTreeLemmatizer(TrainablePipe):
guesses.append(np.array(doc_compat_guesses)) guesses.append(np.array(doc_compat_guesses))
return guesses return guesses
def _scores2guesses_top_k_guardrail(self, docs, scores):
guesses = []
for doc, doc_scores in zip(docs, scores):
doc_guesses = np.argsort(doc_scores)[..., : -self.top_k - 1 : -1]
doc_guesses = self.numpy_ops.asarray(doc_guesses)
doc_compat_guesses = []
for token, candidates in zip(doc, doc_guesses):
tree_id = -1
for candidate in candidates:
candidate_tree_id = self.cfg["labels"][candidate]
if self.trees.apply(candidate_tree_id, token.text) is not None:
tree_id = candidate_tree_id
break
doc_compat_guesses.append(tree_id)
guesses.append(np.array(doc_compat_guesses))
return guesses
def set_annotations(self, docs: Iterable[Doc], batch_tree_ids): def set_annotations(self, docs: Iterable[Doc], batch_tree_ids):
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
doc_tree_ids = batch_tree_ids[i] doc_tree_ids = batch_tree_ids[i]

View File

@ -140,9 +140,10 @@ def test_incomplete_data():
assert doc[2].lemma_ == "blue" assert doc[2].lemma_ == "blue"
def test_overfitting_IO(): @pytest.mark.parametrize("top_k", (1, 5, 30))
def test_overfitting_IO(top_k):
nlp = English() nlp = English()
lemmatizer = nlp.add_pipe("trainable_lemmatizer") lemmatizer = nlp.add_pipe("trainable_lemmatizer", config={"top_k": top_k})
lemmatizer.min_tree_freq = 1 lemmatizer.min_tree_freq = 1
train_examples = [] train_examples = []
for t in TRAIN_DATA: for t in TRAIN_DATA:
@ -175,7 +176,7 @@ def test_overfitting_IO():
# Check model after a {to,from}_bytes roundtrip # Check model after a {to,from}_bytes roundtrip
nlp_bytes = nlp.to_bytes() nlp_bytes = nlp.to_bytes()
nlp3 = English() nlp3 = English()
nlp3.add_pipe("trainable_lemmatizer") nlp3.add_pipe("trainable_lemmatizer", config={"top_k": top_k})
nlp3.from_bytes(nlp_bytes) nlp3.from_bytes(nlp_bytes)
doc3 = nlp3(test_text) doc3 = nlp3(test_text)
assert doc3[0].lemma_ == "she" assert doc3[0].lemma_ == "she"