mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	Fix EL failure with sentence-crossing entities (#12398)
* Add test reproducing EL failure in sentence-crossing entities. * Format. * Draft fix. * Format. * Fix case for len(ent.sents) == 1. * Format. * Format. * Format. * Fix mypy error. * Merge EL sentence crossing tests. * Remove unneeded sentencizer component. * Fix or ignore mypy issues in test. * Simplify ent.sents handling. * Format. Update assert in ent.sents handling. * Small rewrite --------- Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									2ce9a220db
								
							
						
					
					
						commit
						96b61d0671
					
				|  | @ -474,18 +474,24 @@ class EntityLinker(TrainablePipe): | |||
| 
 | ||||
|                 # Looping through each entity in batch (TODO: rewrite) | ||||
|                 for j, ent in enumerate(ent_batch): | ||||
|                     sent_index = sentences.index(ent.sent) | ||||
|                     assert sent_index >= 0 | ||||
|                     assert hasattr(ent, "sents") | ||||
|                     sents = list(ent.sents) | ||||
|                     sent_indices = ( | ||||
|                         sentences.index(sents[0]), | ||||
|                         sentences.index(sents[-1]), | ||||
|                     ) | ||||
|                     assert sent_indices[1] >= sent_indices[0] >= 0 | ||||
| 
 | ||||
|                     if self.incl_context: | ||||
|                         # get n_neighbour sentences, clipped to the length of the document | ||||
|                         start_sentence = max(0, sent_index - self.n_sents) | ||||
|                         start_sentence = max(0, sent_indices[0] - self.n_sents) | ||||
|                         end_sentence = min( | ||||
|                             len(sentences) - 1, sent_index + self.n_sents | ||||
|                             len(sentences) - 1, sent_indices[1] + self.n_sents | ||||
|                         ) | ||||
|                         start_token = sentences[start_sentence].start | ||||
|                         end_token = sentences[end_sentence].end | ||||
|                         sent_doc = doc[start_token:end_token].as_doc() | ||||
| 
 | ||||
|                         # currently, the context is the same for each entity in a sentence (should be refined) | ||||
|                         sentence_encoding = self.model.predict([sent_doc])[0] | ||||
|                         sentence_encoding_t = sentence_encoding.T | ||||
|  |  | |||
|  | @ -1,9 +1,9 @@ | |||
| from typing import Callable, Iterable, Dict, Any | ||||
| from typing import Callable, Iterable, Dict, Any, Tuple | ||||
| 
 | ||||
| import pytest | ||||
| from numpy.testing import assert_equal | ||||
| 
 | ||||
| from spacy import registry, util | ||||
| from spacy import registry, util, Language | ||||
| from spacy.attrs import ENT_KB_ID | ||||
| from spacy.compat import pickle | ||||
| from spacy.kb import Candidate, InMemoryLookupKB, get_candidates, KnowledgeBase | ||||
|  | @ -108,18 +108,23 @@ def test_issue7065(): | |||
| 
 | ||||
| 
 | ||||
| @pytest.mark.issue(7065) | ||||
| def test_issue7065_b(): | ||||
| @pytest.mark.parametrize("entity_in_first_sentence", [True, False]) | ||||
| def test_sentence_crossing_ents(entity_in_first_sentence: bool): | ||||
|     """Tests if NEL crashes if entities cross sentence boundaries and the first associated sentence doesn't have an | ||||
|     entity. | ||||
|     entity_in_prior_sentence (bool): Whether to include an entity in the first sentence associated with the | ||||
|     sentence-crossing entity. | ||||
|     """ | ||||
|     # Test that the NEL doesn't crash when an entity crosses a sentence boundary | ||||
|     nlp = English() | ||||
|     vector_length = 3 | ||||
|     nlp.add_pipe("sentencizer") | ||||
|     text = "Mahler 's Symphony No. 8 was beautiful." | ||||
|     entities = [(0, 6, "PERSON"), (10, 24, "WORK")] | ||||
|     links = { | ||||
|         (0, 6): {"Q7304": 1.0, "Q270853": 0.0}, | ||||
|         (10, 24): {"Q7304": 0.0, "Q270853": 1.0}, | ||||
|     } | ||||
|     sent_starts = [1, -1, 0, 0, 0, 0, 0, 0, 0] | ||||
|     entities = [(10, 24, "WORK")] | ||||
|     links = {(10, 24): {"Q7304": 0.0, "Q270853": 1.0}} | ||||
|     if entity_in_first_sentence: | ||||
|         entities.append((0, 6, "PERSON")) | ||||
|         links[(0, 6)] = {"Q7304": 1.0, "Q270853": 0.0} | ||||
|     sent_starts = [1, -1, 0, 0, 0, 1, 0, 0, 0] | ||||
|     doc = nlp(text) | ||||
|     example = Example.from_dict( | ||||
|         doc, {"entities": entities, "links": links, "sent_starts": sent_starts} | ||||
|  | @ -145,31 +150,14 @@ def test_issue7065_b(): | |||
| 
 | ||||
|     # Create the Entity Linker component and add it to the pipeline | ||||
|     entity_linker = nlp.add_pipe("entity_linker", last=True) | ||||
|     entity_linker.set_kb(create_kb) | ||||
|     entity_linker.set_kb(create_kb)  # type: ignore | ||||
|     # train the NEL pipe | ||||
|     optimizer = nlp.initialize(get_examples=lambda: train_examples) | ||||
|     for i in range(2): | ||||
|         losses = {} | ||||
|         nlp.update(train_examples, sgd=optimizer, losses=losses) | ||||
|         nlp.update(train_examples, sgd=optimizer) | ||||
| 
 | ||||
|     # Add a custom rule-based component to mimick NER | ||||
|     patterns = [ | ||||
|         {"label": "PERSON", "pattern": [{"LOWER": "mahler"}]}, | ||||
|         { | ||||
|             "label": "WORK", | ||||
|             "pattern": [ | ||||
|                 {"LOWER": "symphony"}, | ||||
|                 {"LOWER": "no"}, | ||||
|                 {"LOWER": "."}, | ||||
|                 {"LOWER": "8"}, | ||||
|             ], | ||||
|         }, | ||||
|     ] | ||||
|     ruler = nlp.add_pipe("entity_ruler", before="entity_linker") | ||||
|     ruler.add_patterns(patterns) | ||||
|     # test the trained model - this should not throw E148 | ||||
|     doc = nlp(text) | ||||
|     assert doc | ||||
|     # This shouldn't crash. | ||||
|     entity_linker.predict([example.reference])  # type: ignore | ||||
| 
 | ||||
| 
 | ||||
| def test_no_entities(): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user