Draft fix.

This commit is contained in:
Raphael Mitsch 2023-03-10 14:06:10 +01:00
parent 9696a0df01
commit d851b3656d
2 changed files with 16 additions and 5 deletions

View File

@ -474,18 +474,29 @@ class EntityLinker(TrainablePipe):
# Looping through each entity in batch (TODO: rewrite) # Looping through each entity in batch (TODO: rewrite)
for j, ent in enumerate(ent_batch): for j, ent in enumerate(ent_batch):
sent_index = sentences.index(ent.sent) sents = list(ent.sents)
assert sent_index >= 0 # Note: the last sentence associated with an sentence-crossing entity isn't complete. E. g. if you
# have "Mahler's Symphony No. 8 was beautiful", the entity being "No. 8", ent.sents would be:
# 1. "Mahler's Symphony No."
# 2. "8"
# whereas doc.sents would be:
# 1. "Mahler's Symphony No."
# 2. "8 was beautiful"
# This makes it tricky to receive the last sentence by indexing doc.sents - hence we use an offset
# to determine sent_indices[1].
sent_indices = (sentences.index(sents[0]), sentences.index(sents[0]) + len(sents) - 1)
assert all([si >= 0 for si in sent_indices])
if self.incl_context: if self.incl_context:
# get n_neighbour sentences, clipped to the length of the document # get n_neighbour sentences, clipped to the length of the document
start_sentence = max(0, sent_index - self.n_sents) start_sentence = max(0, sent_indices[0] - self.n_sents)
end_sentence = min( end_sentence = min(
len(sentences) - 1, sent_index + self.n_sents len(sentences) - 1, sent_indices[1] + self.n_sents
) )
start_token = sentences[start_sentence].start start_token = sentences[start_sentence].start
end_token = sentences[end_sentence].end end_token = sentences[end_sentence].end
sent_doc = doc[start_token:end_token].as_doc() sent_doc = doc[start_token:end_token].as_doc()
# currently, the context is the same for each entity in a sentence (should be refined) # currently, the context is the same for each entity in a sentence (should be refined)
sentence_encoding = self.model.predict([sent_doc])[0] sentence_encoding = self.model.predict([sent_doc])[0]
sentence_encoding_t = sentence_encoding.T sentence_encoding_t = sentence_encoding.T

View File

@ -1221,7 +1221,6 @@ def test_span_maker_forward_with_empty():
span_maker([doc1, doc2], False) span_maker([doc1, doc2], False)
@pytest.mark.skip(reason="Not fixed yet, expected to fail")
def test_sentence_crossing_ents(): def test_sentence_crossing_ents():
"""Tests if NEL crashes if entities cross sentence boundaries and the first associated sentence doesn't have an """Tests if NEL crashes if entities cross sentence boundaries and the first associated sentence doesn't have an
entity. entity.
@ -1239,6 +1238,7 @@ def test_sentence_crossing_ents():
example = Example.from_dict( example = Example.from_dict(
doc, {"entities": entities, "links": links, "sent_starts": sent_starts} doc, {"entities": entities, "links": links, "sent_starts": sent_starts}
) )
assert len(list(example.reference.ents[0].sents)) == 2
train_examples = [example] train_examples = [example]
def create_kb(vocab): def create_kb(vocab):