Revert "Use dict comprehension suggested by @svlandeg"

This reverts commit 6e7b958f70.
This commit is contained in:
Daniël de Kok 2022-08-05 10:33:20 +02:00
parent 6e7b958f70
commit 230264daa0
5 changed files with 19 additions and 26 deletions

View File

@ -202,10 +202,9 @@ class EditTreeLemmatizer(TrainablePipe):
def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT): def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT):
batch_tree_ids = activations["guesses"] batch_tree_ids = activations["guesses"]
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
stored_activations = { doc.activations[self.name] = {}
key: activations[key][i] for key in self.store_activations for activation in self.store_activations:
} doc.activations[self.name][activation] = activations[activation][i]
doc.activations[self.name] = stored_activations
doc_tree_ids = batch_tree_ids[i] doc_tree_ids = batch_tree_ids[i]
if hasattr(doc_tree_ids, "get"): if hasattr(doc_tree_ids, "get"):
doc_tree_ids = doc_tree_ids.get() doc_tree_ids = doc_tree_ids.get()

View File

@ -474,11 +474,7 @@ class EntityLinker(TrainablePipe):
# shortcut for efficiency reasons: take the 1 candidate # shortcut for efficiency reasons: take the 1 candidate
final_kb_ids.append(candidates[0].entity_) final_kb_ids.append(candidates[0].entity_)
self._add_activations( self._add_activations(
doc_scores, doc_scores, doc_scores_lens, doc_ents, [1.0], [candidates[0].entity_]
doc_scores_lens,
doc_ents,
[1.0],
[candidates[0].entity_],
) )
else: else:
random.shuffle(candidates) random.shuffle(candidates)
@ -545,11 +541,12 @@ class EntityLinker(TrainablePipe):
i = 0 i = 0
overwrite = self.cfg["overwrite"] overwrite = self.cfg["overwrite"]
for j, doc in enumerate(docs): for j, doc in enumerate(docs):
doc.activations[self.name] = {}
for activation in self.store_activations:
# We only copy activations that are Ragged. # We only copy activations that are Ragged.
stored_activations = { doc.activations[self.name][activation] = cast(
key: cast(Ragged, activations[key][i]) for key in self.store_activations Ragged, activations[activation][j]
} )
doc.activations[self.name] = stored_activations
for ent in doc.ents: for ent in doc.ents:
kb_id = kb_ids[i] kb_id = kb_ids[i]
i += 1 i += 1

View File

@ -249,10 +249,9 @@ class Morphologizer(Tagger):
# to allocate a compatible container out of the iterable. # to allocate a compatible container out of the iterable.
labels = tuple(self.labels) labels = tuple(self.labels)
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
stored_activations = { doc.activations[self.name] = {}
key: activations[key][i] for key in self.store_activations for activation in self.store_activations:
} doc.activations[self.name][activation] = activations[activation][i]
doc.activations[self.name] = stored_activations
doc_tag_ids = batch_tag_ids[i] doc_tag_ids = batch_tag_ids[i]
if hasattr(doc_tag_ids, "get"): if hasattr(doc_tag_ids, "get"):
doc_tag_ids = doc_tag_ids.get() doc_tag_ids = doc_tag_ids.get()

View File

@ -135,10 +135,9 @@ class SentenceRecognizer(Tagger):
cdef Doc doc cdef Doc doc
cdef bint overwrite = self.cfg["overwrite"] cdef bint overwrite = self.cfg["overwrite"]
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
stored_activations = { doc.activations[self.name] = {}
key: activations[key][i] for key in self.store_activations for activation in self.store_activations:
} doc.activations[self.name][activation] = activations[activation][i]
doc.activations[self.name] = stored_activations
doc_tag_ids = batch_tag_ids[i] doc_tag_ids = batch_tag_ids[i]
if hasattr(doc_tag_ids, "get"): if hasattr(doc_tag_ids, "get"):
doc_tag_ids = doc_tag_ids.get() doc_tag_ids = doc_tag_ids.get()

View File

@ -183,10 +183,9 @@ class Tagger(TrainablePipe):
cdef bint overwrite = self.cfg["overwrite"] cdef bint overwrite = self.cfg["overwrite"]
labels = self.labels labels = self.labels
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
stored_activations = { doc.activations[self.name] = {}
key: activations[key][i] for key in self.store_activations for activation in self.store_activations:
} doc.activations[self.name][activation] = activations[activation][i]
doc.activations[self.name] = stored_activations
doc_tag_ids = batch_tag_ids[i] doc_tag_ids = batch_tag_ids[i]
if hasattr(doc_tag_ids, "get"): if hasattr(doc_tag_ids, "get"):
doc_tag_ids = doc_tag_ids.get() doc_tag_ids = doc_tag_ids.get()