mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	Fix EL failure with sentence-crossing entities (#12398)
* Add test reproducing EL failure in sentence-crossing entities. * Format. * Draft fix. * Format. * Fix case for len(ent.sents) == 1. * Format. * Format. * Format. * Fix mypy error. * Merge EL sentence crossing tests. * Remove unneeded sentencizer component. * Fix or ignore mypy issues in test. * Simplify ent.sents handling. * Format. Update assert in ent.sents handling. * Small rewrite --------- Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									2ce9a220db
								
							
						
					
					
						commit
						96b61d0671
					
				|  | @ -474,18 +474,24 @@ class EntityLinker(TrainablePipe): | ||||||
| 
 | 
 | ||||||
|                 # Looping through each entity in batch (TODO: rewrite) |                 # Looping through each entity in batch (TODO: rewrite) | ||||||
|                 for j, ent in enumerate(ent_batch): |                 for j, ent in enumerate(ent_batch): | ||||||
|                     sent_index = sentences.index(ent.sent) |                     assert hasattr(ent, "sents") | ||||||
|                     assert sent_index >= 0 |                     sents = list(ent.sents) | ||||||
|  |                     sent_indices = ( | ||||||
|  |                         sentences.index(sents[0]), | ||||||
|  |                         sentences.index(sents[-1]), | ||||||
|  |                     ) | ||||||
|  |                     assert sent_indices[1] >= sent_indices[0] >= 0 | ||||||
| 
 | 
 | ||||||
|                     if self.incl_context: |                     if self.incl_context: | ||||||
|                         # get n_neighbour sentences, clipped to the length of the document |                         # get n_neighbour sentences, clipped to the length of the document | ||||||
|                         start_sentence = max(0, sent_index - self.n_sents) |                         start_sentence = max(0, sent_indices[0] - self.n_sents) | ||||||
|                         end_sentence = min( |                         end_sentence = min( | ||||||
|                             len(sentences) - 1, sent_index + self.n_sents |                             len(sentences) - 1, sent_indices[1] + self.n_sents | ||||||
|                         ) |                         ) | ||||||
|                         start_token = sentences[start_sentence].start |                         start_token = sentences[start_sentence].start | ||||||
|                         end_token = sentences[end_sentence].end |                         end_token = sentences[end_sentence].end | ||||||
|                         sent_doc = doc[start_token:end_token].as_doc() |                         sent_doc = doc[start_token:end_token].as_doc() | ||||||
|  | 
 | ||||||
|                         # currently, the context is the same for each entity in a sentence (should be refined) |                         # currently, the context is the same for each entity in a sentence (should be refined) | ||||||
|                         sentence_encoding = self.model.predict([sent_doc])[0] |                         sentence_encoding = self.model.predict([sent_doc])[0] | ||||||
|                         sentence_encoding_t = sentence_encoding.T |                         sentence_encoding_t = sentence_encoding.T | ||||||
|  |  | ||||||
|  | @ -1,9 +1,9 @@ | ||||||
| from typing import Callable, Iterable, Dict, Any | from typing import Callable, Iterable, Dict, Any, Tuple | ||||||
| 
 | 
 | ||||||
| import pytest | import pytest | ||||||
| from numpy.testing import assert_equal | from numpy.testing import assert_equal | ||||||
| 
 | 
 | ||||||
| from spacy import registry, util | from spacy import registry, util, Language | ||||||
| from spacy.attrs import ENT_KB_ID | from spacy.attrs import ENT_KB_ID | ||||||
| from spacy.compat import pickle | from spacy.compat import pickle | ||||||
| from spacy.kb import Candidate, InMemoryLookupKB, get_candidates, KnowledgeBase | from spacy.kb import Candidate, InMemoryLookupKB, get_candidates, KnowledgeBase | ||||||
|  | @ -108,18 +108,23 @@ def test_issue7065(): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.issue(7065) | @pytest.mark.issue(7065) | ||||||
| def test_issue7065_b(): | @pytest.mark.parametrize("entity_in_first_sentence", [True, False]) | ||||||
|  | def test_sentence_crossing_ents(entity_in_first_sentence: bool): | ||||||
|  |     """Tests if NEL crashes if entities cross sentence boundaries and the first associated sentence doesn't have an | ||||||
|  |     entity. | ||||||
|  |     entity_in_prior_sentence (bool): Whether to include an entity in the first sentence associated with the | ||||||
|  |     sentence-crossing entity. | ||||||
|  |     """ | ||||||
|     # Test that the NEL doesn't crash when an entity crosses a sentence boundary |     # Test that the NEL doesn't crash when an entity crosses a sentence boundary | ||||||
|     nlp = English() |     nlp = English() | ||||||
|     vector_length = 3 |     vector_length = 3 | ||||||
|     nlp.add_pipe("sentencizer") |  | ||||||
|     text = "Mahler 's Symphony No. 8 was beautiful." |     text = "Mahler 's Symphony No. 8 was beautiful." | ||||||
|     entities = [(0, 6, "PERSON"), (10, 24, "WORK")] |     entities = [(10, 24, "WORK")] | ||||||
|     links = { |     links = {(10, 24): {"Q7304": 0.0, "Q270853": 1.0}} | ||||||
|         (0, 6): {"Q7304": 1.0, "Q270853": 0.0}, |     if entity_in_first_sentence: | ||||||
|         (10, 24): {"Q7304": 0.0, "Q270853": 1.0}, |         entities.append((0, 6, "PERSON")) | ||||||
|     } |         links[(0, 6)] = {"Q7304": 1.0, "Q270853": 0.0} | ||||||
|     sent_starts = [1, -1, 0, 0, 0, 0, 0, 0, 0] |     sent_starts = [1, -1, 0, 0, 0, 1, 0, 0, 0] | ||||||
|     doc = nlp(text) |     doc = nlp(text) | ||||||
|     example = Example.from_dict( |     example = Example.from_dict( | ||||||
|         doc, {"entities": entities, "links": links, "sent_starts": sent_starts} |         doc, {"entities": entities, "links": links, "sent_starts": sent_starts} | ||||||
|  | @ -145,31 +150,14 @@ def test_issue7065_b(): | ||||||
| 
 | 
 | ||||||
|     # Create the Entity Linker component and add it to the pipeline |     # Create the Entity Linker component and add it to the pipeline | ||||||
|     entity_linker = nlp.add_pipe("entity_linker", last=True) |     entity_linker = nlp.add_pipe("entity_linker", last=True) | ||||||
|     entity_linker.set_kb(create_kb) |     entity_linker.set_kb(create_kb)  # type: ignore | ||||||
|     # train the NEL pipe |     # train the NEL pipe | ||||||
|     optimizer = nlp.initialize(get_examples=lambda: train_examples) |     optimizer = nlp.initialize(get_examples=lambda: train_examples) | ||||||
|     for i in range(2): |     for i in range(2): | ||||||
|         losses = {} |         nlp.update(train_examples, sgd=optimizer) | ||||||
|         nlp.update(train_examples, sgd=optimizer, losses=losses) |  | ||||||
| 
 | 
 | ||||||
|     # Add a custom rule-based component to mimick NER |     # This shouldn't crash. | ||||||
|     patterns = [ |     entity_linker.predict([example.reference])  # type: ignore | ||||||
|         {"label": "PERSON", "pattern": [{"LOWER": "mahler"}]}, |  | ||||||
|         { |  | ||||||
|             "label": "WORK", |  | ||||||
|             "pattern": [ |  | ||||||
|                 {"LOWER": "symphony"}, |  | ||||||
|                 {"LOWER": "no"}, |  | ||||||
|                 {"LOWER": "."}, |  | ||||||
|                 {"LOWER": "8"}, |  | ||||||
|             ], |  | ||||||
|         }, |  | ||||||
|     ] |  | ||||||
|     ruler = nlp.add_pipe("entity_ruler", before="entity_linker") |  | ||||||
|     ruler.add_patterns(patterns) |  | ||||||
|     # test the trained model - this should not throw E148 |  | ||||||
|     doc = nlp(text) |  | ||||||
|     assert doc |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_no_entities(): | def test_no_entities(): | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user