From 230264daa0e6f17b2f733ffef48f35fed825758e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Fri, 5 Aug 2022 10:33:20 +0200 Subject: [PATCH] Revert "Use dict comprehension suggested by @svlandeg" This reverts commit 6e7b958f7060397965176c69649e5414f1f24988. --- spacy/pipeline/edit_tree_lemmatizer.py | 7 +++---- spacy/pipeline/entity_linker.py | 17 +++++++---------- spacy/pipeline/morphologizer.pyx | 7 +++---- spacy/pipeline/senter.pyx | 7 +++---- spacy/pipeline/tagger.pyx | 7 +++---- 5 files changed, 19 insertions(+), 26 deletions(-) diff --git a/spacy/pipeline/edit_tree_lemmatizer.py b/spacy/pipeline/edit_tree_lemmatizer.py index f457e1d7b..3af39b1d1 100644 --- a/spacy/pipeline/edit_tree_lemmatizer.py +++ b/spacy/pipeline/edit_tree_lemmatizer.py @@ -202,10 +202,9 @@ class EditTreeLemmatizer(TrainablePipe): def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT): batch_tree_ids = activations["guesses"] for i, doc in enumerate(docs): - stored_activations = { - key: activations[key][i] for key in self.store_activations - } - doc.activations[self.name] = stored_activations + doc.activations[self.name] = {} + for activation in self.store_activations: + doc.activations[self.name][activation] = activations[activation][i] doc_tree_ids = batch_tree_ids[i] if hasattr(doc_tree_ids, "get"): doc_tree_ids = doc_tree_ids.get() diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 02529183b..266c1f07f 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -474,11 +474,7 @@ class EntityLinker(TrainablePipe): # shortcut for efficiency reasons: take the 1 candidate final_kb_ids.append(candidates[0].entity_) self._add_activations( - doc_scores, - doc_scores_lens, - doc_ents, - [1.0], - [candidates[0].entity_], + doc_scores, doc_scores_lens, doc_ents, [1.0], [candidates[0].entity_] ) else: random.shuffle(candidates) @@ -545,11 +541,12 @@ class EntityLinker(TrainablePipe): i = 0 overwrite = self.cfg["overwrite"] for j, doc in enumerate(docs): - # We only copy activations that are Ragged. - stored_activations = { - key: cast(Ragged, activations[key][i]) for key in self.store_activations - } - doc.activations[self.name] = stored_activations + doc.activations[self.name] = {} + for activation in self.store_activations: + # We only copy activations that are Ragged. + doc.activations[self.name][activation] = cast( + Ragged, activations[activation][j] + ) for ent in doc.ents: kb_id = kb_ids[i] i += 1 diff --git a/spacy/pipeline/morphologizer.pyx b/spacy/pipeline/morphologizer.pyx index 6ecba4ccf..0c7eacd12 100644 --- a/spacy/pipeline/morphologizer.pyx +++ b/spacy/pipeline/morphologizer.pyx @@ -249,10 +249,9 @@ class Morphologizer(Tagger): # to allocate a compatible container out of the iterable. labels = tuple(self.labels) for i, doc in enumerate(docs): - stored_activations = { - key: activations[key][i] for key in self.store_activations - } - doc.activations[self.name] = stored_activations + doc.activations[self.name] = {} + for activation in self.store_activations: + doc.activations[self.name][activation] = activations[activation][i] doc_tag_ids = batch_tag_ids[i] if hasattr(doc_tag_ids, "get"): doc_tag_ids = doc_tag_ids.get() diff --git a/spacy/pipeline/senter.pyx b/spacy/pipeline/senter.pyx index a789934d7..1cfd6c4b1 100644 --- a/spacy/pipeline/senter.pyx +++ b/spacy/pipeline/senter.pyx @@ -135,10 +135,9 @@ class SentenceRecognizer(Tagger): cdef Doc doc cdef bint overwrite = self.cfg["overwrite"] for i, doc in enumerate(docs): - stored_activations = { - key: activations[key][i] for key in self.store_activations - } - doc.activations[self.name] = stored_activations + doc.activations[self.name] = {} + for activation in self.store_activations: + doc.activations[self.name][activation] = activations[activation][i] doc_tag_ids = batch_tag_ids[i] if hasattr(doc_tag_ids, "get"): doc_tag_ids = doc_tag_ids.get() diff --git a/spacy/pipeline/tagger.pyx b/spacy/pipeline/tagger.pyx index 3a8bcb67c..498b3de08 100644 --- a/spacy/pipeline/tagger.pyx +++ b/spacy/pipeline/tagger.pyx @@ -183,10 +183,9 @@ class Tagger(TrainablePipe): cdef bint overwrite = self.cfg["overwrite"] labels = self.labels for i, doc in enumerate(docs): - stored_activations = { - key: activations[key][i] for key in self.store_activations - } - doc.activations[self.name] = stored_activations + doc.activations[self.name] = {} + for activation in self.store_activations: + doc.activations[self.name][activation] = activations[activation][i] doc_tag_ids = batch_tag_ids[i] if hasattr(doc_tag_ids, "get"): doc_tag_ids = doc_tag_ids.get()