mirror of
https://github.com/explosion/spaCy.git
synced 2025-05-09 10:13:41 +03:00
Use dict comprehension suggested by @svlandeg
This commit is contained in:
parent
51f72e41ec
commit
6e7b958f70
|
@ -202,9 +202,10 @@ class EditTreeLemmatizer(TrainablePipe):
|
||||||
def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT):
|
def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT):
|
||||||
batch_tree_ids = activations["guesses"]
|
batch_tree_ids = activations["guesses"]
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
doc.activations[self.name] = {}
|
stored_activations = {
|
||||||
for activation in self.store_activations:
|
key: activations[key][i] for key in self.store_activations
|
||||||
doc.activations[self.name][activation] = activations[activation][i]
|
}
|
||||||
|
doc.activations[self.name] = stored_activations
|
||||||
doc_tree_ids = batch_tree_ids[i]
|
doc_tree_ids = batch_tree_ids[i]
|
||||||
if hasattr(doc_tree_ids, "get"):
|
if hasattr(doc_tree_ids, "get"):
|
||||||
doc_tree_ids = doc_tree_ids.get()
|
doc_tree_ids = doc_tree_ids.get()
|
||||||
|
|
|
@ -474,7 +474,11 @@ class EntityLinker(TrainablePipe):
|
||||||
# shortcut for efficiency reasons: take the 1 candidate
|
# shortcut for efficiency reasons: take the 1 candidate
|
||||||
final_kb_ids.append(candidates[0].entity_)
|
final_kb_ids.append(candidates[0].entity_)
|
||||||
self._add_activations(
|
self._add_activations(
|
||||||
doc_scores, doc_scores_lens, doc_ents, [1.0], [candidates[0].entity_]
|
doc_scores,
|
||||||
|
doc_scores_lens,
|
||||||
|
doc_ents,
|
||||||
|
[1.0],
|
||||||
|
[candidates[0].entity_],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
random.shuffle(candidates)
|
random.shuffle(candidates)
|
||||||
|
@ -541,12 +545,11 @@ class EntityLinker(TrainablePipe):
|
||||||
i = 0
|
i = 0
|
||||||
overwrite = self.cfg["overwrite"]
|
overwrite = self.cfg["overwrite"]
|
||||||
for j, doc in enumerate(docs):
|
for j, doc in enumerate(docs):
|
||||||
doc.activations[self.name] = {}
|
|
||||||
for activation in self.store_activations:
|
|
||||||
# We only copy activations that are Ragged.
|
# We only copy activations that are Ragged.
|
||||||
doc.activations[self.name][activation] = cast(
|
stored_activations = {
|
||||||
Ragged, activations[activation][j]
|
key: cast(Ragged, activations[key][i]) for key in self.store_activations
|
||||||
)
|
}
|
||||||
|
doc.activations[self.name] = stored_activations
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
kb_id = kb_ids[i]
|
kb_id = kb_ids[i]
|
||||||
i += 1
|
i += 1
|
||||||
|
|
|
@ -249,9 +249,10 @@ class Morphologizer(Tagger):
|
||||||
# to allocate a compatible container out of the iterable.
|
# to allocate a compatible container out of the iterable.
|
||||||
labels = tuple(self.labels)
|
labels = tuple(self.labels)
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
doc.activations[self.name] = {}
|
stored_activations = {
|
||||||
for activation in self.store_activations:
|
key: activations[key][i] for key in self.store_activations
|
||||||
doc.activations[self.name][activation] = activations[activation][i]
|
}
|
||||||
|
doc.activations[self.name] = stored_activations
|
||||||
doc_tag_ids = batch_tag_ids[i]
|
doc_tag_ids = batch_tag_ids[i]
|
||||||
if hasattr(doc_tag_ids, "get"):
|
if hasattr(doc_tag_ids, "get"):
|
||||||
doc_tag_ids = doc_tag_ids.get()
|
doc_tag_ids = doc_tag_ids.get()
|
||||||
|
|
|
@ -135,9 +135,10 @@ class SentenceRecognizer(Tagger):
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
cdef bint overwrite = self.cfg["overwrite"]
|
cdef bint overwrite = self.cfg["overwrite"]
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
doc.activations[self.name] = {}
|
stored_activations = {
|
||||||
for activation in self.store_activations:
|
key: activations[key][i] for key in self.store_activations
|
||||||
doc.activations[self.name][activation] = activations[activation][i]
|
}
|
||||||
|
doc.activations[self.name] = stored_activations
|
||||||
doc_tag_ids = batch_tag_ids[i]
|
doc_tag_ids = batch_tag_ids[i]
|
||||||
if hasattr(doc_tag_ids, "get"):
|
if hasattr(doc_tag_ids, "get"):
|
||||||
doc_tag_ids = doc_tag_ids.get()
|
doc_tag_ids = doc_tag_ids.get()
|
||||||
|
|
|
@ -183,9 +183,10 @@ class Tagger(TrainablePipe):
|
||||||
cdef bint overwrite = self.cfg["overwrite"]
|
cdef bint overwrite = self.cfg["overwrite"]
|
||||||
labels = self.labels
|
labels = self.labels
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
doc.activations[self.name] = {}
|
stored_activations = {
|
||||||
for activation in self.store_activations:
|
key: activations[key][i] for key in self.store_activations
|
||||||
doc.activations[self.name][activation] = activations[activation][i]
|
}
|
||||||
|
doc.activations[self.name] = stored_activations
|
||||||
doc_tag_ids = batch_tag_ids[i]
|
doc_tag_ids = batch_tag_ids[i]
|
||||||
if hasattr(doc_tag_ids, "get"):
|
if hasattr(doc_tag_ids, "get"):
|
||||||
doc_tag_ids = doc_tag_ids.get()
|
doc_tag_ids = doc_tag_ids.get()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user