From d851b3656d0a6fc902c0790223bc233529c17d80 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 10 Mar 2023 14:06:10 +0100 Subject: [PATCH] Draft fix. --- spacy/pipeline/entity_linker.py | 19 +++++++++++++++---- spacy/tests/pipeline/test_entity_linker.py | 2 +- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index f2dae0529..fdfcafc92 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -474,18 +474,29 @@ class EntityLinker(TrainablePipe): # Looping through each entity in batch (TODO: rewrite) for j, ent in enumerate(ent_batch): - sent_index = sentences.index(ent.sent) - assert sent_index >= 0 + sents = list(ent.sents) + # Note: the last sentence associated with an sentence-crossing entity isn't complete. E. g. if you + # have "Mahler's Symphony No. 8 was beautiful", the entity being "No. 8", ent.sents would be: + # 1. "Mahler's Symphony No." + # 2. "8" + # whereas doc.sents would be: + # 1. "Mahler's Symphony No." + # 2. "8 was beautiful" + # This makes it tricky to receive the last sentence by indexing doc.sents - hence we use an offset + # to determine sent_indices[1]. + sent_indices = (sentences.index(sents[0]), sentences.index(sents[0]) + len(sents) - 1) + assert all([si >= 0 for si in sent_indices]) if self.incl_context: # get n_neighbour sentences, clipped to the length of the document - start_sentence = max(0, sent_index - self.n_sents) + start_sentence = max(0, sent_indices[0] - self.n_sents) end_sentence = min( - len(sentences) - 1, sent_index + self.n_sents + len(sentences) - 1, sent_indices[1] + self.n_sents ) start_token = sentences[start_sentence].start end_token = sentences[end_sentence].end sent_doc = doc[start_token:end_token].as_doc() + # currently, the context is the same for each entity in a sentence (should be refined) sentence_encoding = self.model.predict([sent_doc])[0] sentence_encoding_t = sentence_encoding.T diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index aac3cf790..a25243904 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -1221,7 +1221,6 @@ def test_span_maker_forward_with_empty(): span_maker([doc1, doc2], False) -@pytest.mark.skip(reason="Not fixed yet, expected to fail") def test_sentence_crossing_ents(): """Tests if NEL crashes if entities cross sentence boundaries and the first associated sentence doesn't have an entity. @@ -1239,6 +1238,7 @@ def test_sentence_crossing_ents(): example = Example.from_dict( doc, {"entities": entities, "links": links, "sent_starts": sent_starts} ) + assert len(list(example.reference.ents[0].sents)) == 2 train_examples = [example] def create_kb(vocab):