Use dict comprehension suggested by @svlandeg

This commit is contained in:
Daniël de Kok 2022-08-04 17:13:52 +02:00
parent 51f72e41ec
commit 6e7b958f70
5 changed files with 26 additions and 19 deletions

View File

@ -202,9 +202,10 @@ class EditTreeLemmatizer(TrainablePipe):
def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT):
batch_tree_ids = activations["guesses"]
for i, doc in enumerate(docs):
doc.activations[self.name] = {}
for activation in self.store_activations:
doc.activations[self.name][activation] = activations[activation][i]
stored_activations = {
key: activations[key][i] for key in self.store_activations
}
doc.activations[self.name] = stored_activations
doc_tree_ids = batch_tree_ids[i]
if hasattr(doc_tree_ids, "get"):
doc_tree_ids = doc_tree_ids.get()

View File

@ -474,7 +474,11 @@ class EntityLinker(TrainablePipe):
# shortcut for efficiency reasons: take the 1 candidate
final_kb_ids.append(candidates[0].entity_)
self._add_activations(
doc_scores, doc_scores_lens, doc_ents, [1.0], [candidates[0].entity_]
doc_scores,
doc_scores_lens,
doc_ents,
[1.0],
[candidates[0].entity_],
)
else:
random.shuffle(candidates)
@ -541,12 +545,11 @@ class EntityLinker(TrainablePipe):
i = 0
overwrite = self.cfg["overwrite"]
for j, doc in enumerate(docs):
doc.activations[self.name] = {}
for activation in self.store_activations:
# We only copy activations that are Ragged.
doc.activations[self.name][activation] = cast(
Ragged, activations[activation][j]
)
# We only copy activations that are Ragged.
stored_activations = {
key: cast(Ragged, activations[key][i]) for key in self.store_activations
}
doc.activations[self.name] = stored_activations
for ent in doc.ents:
kb_id = kb_ids[i]
i += 1

View File

@ -249,9 +249,10 @@ class Morphologizer(Tagger):
# to allocate a compatible container out of the iterable.
labels = tuple(self.labels)
for i, doc in enumerate(docs):
doc.activations[self.name] = {}
for activation in self.store_activations:
doc.activations[self.name][activation] = activations[activation][i]
stored_activations = {
key: activations[key][i] for key in self.store_activations
}
doc.activations[self.name] = stored_activations
doc_tag_ids = batch_tag_ids[i]
if hasattr(doc_tag_ids, "get"):
doc_tag_ids = doc_tag_ids.get()

View File

@ -135,9 +135,10 @@ class SentenceRecognizer(Tagger):
cdef Doc doc
cdef bint overwrite = self.cfg["overwrite"]
for i, doc in enumerate(docs):
doc.activations[self.name] = {}
for activation in self.store_activations:
doc.activations[self.name][activation] = activations[activation][i]
stored_activations = {
key: activations[key][i] for key in self.store_activations
}
doc.activations[self.name] = stored_activations
doc_tag_ids = batch_tag_ids[i]
if hasattr(doc_tag_ids, "get"):
doc_tag_ids = doc_tag_ids.get()

View File

@ -183,9 +183,10 @@ class Tagger(TrainablePipe):
cdef bint overwrite = self.cfg["overwrite"]
labels = self.labels
for i, doc in enumerate(docs):
doc.activations[self.name] = {}
for activation in self.store_activations:
doc.activations[self.name][activation] = activations[activation][i]
stored_activations = {
key: activations[key][i] for key in self.store_activations
}
doc.activations[self.name] = stored_activations
doc_tag_ids = batch_tag_ids[i]
if hasattr(doc_tag_ids, "get"):
doc_tag_ids = doc_tag_ids.get()