Draft fix.

This commit is contained in:
Raphael Mitsch 2023-03-10 14:06:10 +01:00
parent 9696a0df01
commit d851b3656d
2 changed files with 16 additions and 5 deletions

View File

@ -474,18 +474,29 @@ class EntityLinker(TrainablePipe):
# Looping through each entity in batch (TODO: rewrite)
for j, ent in enumerate(ent_batch):
sent_index = sentences.index(ent.sent)
assert sent_index >= 0
sents = list(ent.sents)
# Note: the last sentence associated with an sentence-crossing entity isn't complete. E. g. if you
# have "Mahler's Symphony No. 8 was beautiful", the entity being "No. 8", ent.sents would be:
# 1. "Mahler's Symphony No."
# 2. "8"
# whereas doc.sents would be:
# 1. "Mahler's Symphony No."
# 2. "8 was beautiful"
# This makes it tricky to receive the last sentence by indexing doc.sents - hence we use an offset
# to determine sent_indices[1].
sent_indices = (sentences.index(sents[0]), sentences.index(sents[0]) + len(sents) - 1)
assert all([si >= 0 for si in sent_indices])
if self.incl_context:
# get n_neighbour sentences, clipped to the length of the document
start_sentence = max(0, sent_index - self.n_sents)
start_sentence = max(0, sent_indices[0] - self.n_sents)
end_sentence = min(
len(sentences) - 1, sent_index + self.n_sents
len(sentences) - 1, sent_indices[1] + self.n_sents
)
start_token = sentences[start_sentence].start
end_token = sentences[end_sentence].end
sent_doc = doc[start_token:end_token].as_doc()
# currently, the context is the same for each entity in a sentence (should be refined)
sentence_encoding = self.model.predict([sent_doc])[0]
sentence_encoding_t = sentence_encoding.T

View File

@ -1221,7 +1221,6 @@ def test_span_maker_forward_with_empty():
span_maker([doc1, doc2], False)
@pytest.mark.skip(reason="Not fixed yet, expected to fail")
def test_sentence_crossing_ents():
"""Tests if NEL crashes if entities cross sentence boundaries and the first associated sentence doesn't have an
entity.
@ -1239,6 +1238,7 @@ def test_sentence_crossing_ents():
example = Example.from_dict(
doc, {"entities": entities, "links": links, "sent_starts": sent_starts}
)
assert len(list(example.reference.ents[0].sents)) == 2
train_examples = [example]
def create_kb(vocab):