From 6e7b958f7060397965176c69649e5414f1f24988 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Thu, 4 Aug 2022 17:13:52 +0200 Subject: [PATCH] Use dict comprehension suggested by @svlandeg --- spacy/pipeline/edit_tree_lemmatizer.py | 7 ++++--- spacy/pipeline/entity_linker.py | 17 ++++++++++------- spacy/pipeline/morphologizer.pyx | 7 ++++--- spacy/pipeline/senter.pyx | 7 ++++--- spacy/pipeline/tagger.pyx | 7 ++++--- 5 files changed, 26 insertions(+), 19 deletions(-) diff --git a/spacy/pipeline/edit_tree_lemmatizer.py b/spacy/pipeline/edit_tree_lemmatizer.py index 3af39b1d1..f457e1d7b 100644 --- a/spacy/pipeline/edit_tree_lemmatizer.py +++ b/spacy/pipeline/edit_tree_lemmatizer.py @@ -202,9 +202,10 @@ class EditTreeLemmatizer(TrainablePipe): def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT): batch_tree_ids = activations["guesses"] for i, doc in enumerate(docs): - doc.activations[self.name] = {} - for activation in self.store_activations: - doc.activations[self.name][activation] = activations[activation][i] + stored_activations = { + key: activations[key][i] for key in self.store_activations + } + doc.activations[self.name] = stored_activations doc_tree_ids = batch_tree_ids[i] if hasattr(doc_tree_ids, "get"): doc_tree_ids = doc_tree_ids.get() diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 266c1f07f..02529183b 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -474,7 +474,11 @@ class EntityLinker(TrainablePipe): # shortcut for efficiency reasons: take the 1 candidate final_kb_ids.append(candidates[0].entity_) self._add_activations( - doc_scores, doc_scores_lens, doc_ents, [1.0], [candidates[0].entity_] + doc_scores, + doc_scores_lens, + doc_ents, + [1.0], + [candidates[0].entity_], ) else: random.shuffle(candidates) @@ -541,12 +545,11 @@ class EntityLinker(TrainablePipe): i = 0 overwrite = self.cfg["overwrite"] for j, doc in enumerate(docs): - doc.activations[self.name] = {} - for activation in self.store_activations: - # We only copy activations that are Ragged. - doc.activations[self.name][activation] = cast( - Ragged, activations[activation][j] - ) + # We only copy activations that are Ragged. + stored_activations = { + key: cast(Ragged, activations[key][i]) for key in self.store_activations + } + doc.activations[self.name] = stored_activations for ent in doc.ents: kb_id = kb_ids[i] i += 1 diff --git a/spacy/pipeline/morphologizer.pyx b/spacy/pipeline/morphologizer.pyx index 0c7eacd12..6ecba4ccf 100644 --- a/spacy/pipeline/morphologizer.pyx +++ b/spacy/pipeline/morphologizer.pyx @@ -249,9 +249,10 @@ class Morphologizer(Tagger): # to allocate a compatible container out of the iterable. labels = tuple(self.labels) for i, doc in enumerate(docs): - doc.activations[self.name] = {} - for activation in self.store_activations: - doc.activations[self.name][activation] = activations[activation][i] + stored_activations = { + key: activations[key][i] for key in self.store_activations + } + doc.activations[self.name] = stored_activations doc_tag_ids = batch_tag_ids[i] if hasattr(doc_tag_ids, "get"): doc_tag_ids = doc_tag_ids.get() diff --git a/spacy/pipeline/senter.pyx b/spacy/pipeline/senter.pyx index 1cfd6c4b1..a789934d7 100644 --- a/spacy/pipeline/senter.pyx +++ b/spacy/pipeline/senter.pyx @@ -135,9 +135,10 @@ class SentenceRecognizer(Tagger): cdef Doc doc cdef bint overwrite = self.cfg["overwrite"] for i, doc in enumerate(docs): - doc.activations[self.name] = {} - for activation in self.store_activations: - doc.activations[self.name][activation] = activations[activation][i] + stored_activations = { + key: activations[key][i] for key in self.store_activations + } + doc.activations[self.name] = stored_activations doc_tag_ids = batch_tag_ids[i] if hasattr(doc_tag_ids, "get"): doc_tag_ids = doc_tag_ids.get() diff --git a/spacy/pipeline/tagger.pyx b/spacy/pipeline/tagger.pyx index 498b3de08..3a8bcb67c 100644 --- a/spacy/pipeline/tagger.pyx +++ b/spacy/pipeline/tagger.pyx @@ -183,9 +183,10 @@ class Tagger(TrainablePipe): cdef bint overwrite = self.cfg["overwrite"] labels = self.labels for i, doc in enumerate(docs): - doc.activations[self.name] = {} - for activation in self.store_activations: - doc.activations[self.name][activation] = activations[activation][i] + stored_activations = { + key: activations[key][i] for key in self.store_activations + } + doc.activations[self.name] = stored_activations doc_tag_ids = batch_tag_ids[i] if hasattr(doc_tag_ids, "get"): doc_tag_ids = doc_tag_ids.get()