mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-25 00:34:20 +03:00
output tensors as part of predict
This commit is contained in:
parent
21176517a7
commit
41fb5204ba
|
@ -1212,15 +1212,15 @@ class EntityLinker(Pipe):
|
|||
return loss, d_scores
|
||||
|
||||
def __call__(self, doc):
|
||||
kb_ids = self.predict([doc])
|
||||
self.set_annotations([doc], kb_ids)
|
||||
kb_ids, tensors = self.predict([doc])
|
||||
self.set_annotations([doc], kb_ids, tensors=tensors)
|
||||
return doc
|
||||
|
||||
def pipe(self, stream, batch_size=128, n_threads=-1):
|
||||
for docs in util.minibatch(stream, size=batch_size):
|
||||
docs = list(docs)
|
||||
kb_ids = self.predict(docs)
|
||||
self.set_annotations(docs, kb_ids)
|
||||
kb_ids, tensors = self.predict(docs)
|
||||
self.set_annotations(docs, kb_ids, tensors=tensors)
|
||||
yield from docs
|
||||
|
||||
def predict(self, docs):
|
||||
|
@ -1230,6 +1230,7 @@ class EntityLinker(Pipe):
|
|||
|
||||
entity_count = 0
|
||||
final_kb_ids = []
|
||||
final_tensors = []
|
||||
|
||||
if not docs:
|
||||
return final_kb_ids
|
||||
|
@ -1244,6 +1245,7 @@ class EntityLinker(Pipe):
|
|||
|
||||
for i, doc in enumerate(docs):
|
||||
if len(doc) > 0:
|
||||
# currently, the context is the same for each entity in a sentence (should be refined)
|
||||
context_encoding = context_encodings[i]
|
||||
for ent in doc.ents:
|
||||
entity_count += 1
|
||||
|
@ -1254,6 +1256,7 @@ class EntityLinker(Pipe):
|
|||
candidates = self.kb.get_candidates(ent.text)
|
||||
if not candidates:
|
||||
final_kb_ids.append(self.NIL) # no prediction possible for this entity
|
||||
final_tensors.append(context_encoding)
|
||||
else:
|
||||
random.shuffle(candidates)
|
||||
|
||||
|
@ -1274,12 +1277,16 @@ class EntityLinker(Pipe):
|
|||
best_index = scores.argmax()
|
||||
best_candidate = candidates[best_index]
|
||||
final_kb_ids.append(best_candidate.entity_)
|
||||
final_tensors.append(context_encoding)
|
||||
|
||||
assert len(final_kb_ids) == entity_count
|
||||
assert len(final_tensors) == len(final_kb_ids) == entity_count
|
||||
|
||||
return final_kb_ids
|
||||
return final_kb_ids, final_tensors
|
||||
|
||||
def set_annotations(self, docs, kb_ids, tensors=None):
|
||||
count_ents = len([ent for doc in docs for ent in doc.ents])
|
||||
assert count_ents == len(kb_ids)
|
||||
|
||||
i=0
|
||||
for doc in docs:
|
||||
for ent in doc.ents:
|
||||
|
|
Loading…
Reference in New Issue
Block a user