mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-01 00:17:44 +03:00 
			
		
		
		
	Fix spancat training on nested entities (#9007)
* overfitting test on non-overlapping entities * add failing overfitting test for overlapping entities * failing test for list comprehension * remove test that was put in separate PR * bugfix * cleanup
This commit is contained in:
		
							parent
							
								
									9cc3dc2b67
								
							
						
					
					
						commit
						4d52d7051c
					
				|  | @ -398,7 +398,7 @@ class SpanCategorizer(TrainablePipe): | ||||||
|         pass |         pass | ||||||
| 
 | 
 | ||||||
|     def _get_aligned_spans(self, eg: Example): |     def _get_aligned_spans(self, eg: Example): | ||||||
|         return eg.get_aligned_spans_y2x(eg.reference.spans.get(self.key, [])) |         return eg.get_aligned_spans_y2x(eg.reference.spans.get(self.key, []), allow_overlap=True) | ||||||
| 
 | 
 | ||||||
|     def _make_span_group( |     def _make_span_group( | ||||||
|         self, doc: Doc, indices: Ints2d, scores: Floats2d, labels: List[str] |         self, doc: Doc, indices: Ints2d, scores: Floats2d, labels: List[str] | ||||||
|  |  | ||||||
|  | @ -2,11 +2,14 @@ import pytest | ||||||
| import numpy | import numpy | ||||||
| from numpy.testing import assert_array_equal, assert_almost_equal | from numpy.testing import assert_array_equal, assert_almost_equal | ||||||
| from thinc.api import get_current_ops | from thinc.api import get_current_ops | ||||||
|  | 
 | ||||||
|  | from spacy import util | ||||||
|  | from spacy.lang.en import English | ||||||
| from spacy.language import Language | from spacy.language import Language | ||||||
| from spacy.tokens.doc import SpanGroups | from spacy.tokens.doc import SpanGroups | ||||||
| from spacy.tokens import SpanGroup | from spacy.tokens import SpanGroup | ||||||
| from spacy.training import Example | from spacy.training import Example | ||||||
| from spacy.util import fix_random_seed, registry | from spacy.util import fix_random_seed, registry, make_tempdir | ||||||
| 
 | 
 | ||||||
| OPS = get_current_ops() | OPS = get_current_ops() | ||||||
| 
 | 
 | ||||||
|  | @ -20,17 +23,21 @@ TRAIN_DATA = [ | ||||||
|     ), |     ), | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
|  | TRAIN_DATA_OVERLAPPING = [ | ||||||
|  |     ("Who is Shaka Khan?", {"spans": {SPAN_KEY: [(7, 17, "PERSON")]}}), | ||||||
|  |     ( | ||||||
|  |         "I like London and Berlin", | ||||||
|  |         {"spans": {SPAN_KEY: [(7, 13, "LOC"), (18, 24, "LOC"), (7, 24, "DOUBLE_LOC")]}}, | ||||||
|  |     ), | ||||||
|  | ] | ||||||
| 
 | 
 | ||||||
| def make_get_examples(nlp): | 
 | ||||||
|  | def make_examples(nlp, data=TRAIN_DATA): | ||||||
|     train_examples = [] |     train_examples = [] | ||||||
|     for t in TRAIN_DATA: |     for t in data: | ||||||
|         eg = Example.from_dict(nlp.make_doc(t[0]), t[1]) |         eg = Example.from_dict(nlp.make_doc(t[0]), t[1]) | ||||||
|         train_examples.append(eg) |         train_examples.append(eg) | ||||||
| 
 |     return train_examples | ||||||
|     def get_examples(): |  | ||||||
|         return train_examples |  | ||||||
| 
 |  | ||||||
|     return get_examples |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_no_label(): | def test_no_label(): | ||||||
|  | @ -57,9 +64,7 @@ def test_implicit_labels(): | ||||||
|     nlp = Language() |     nlp = Language() | ||||||
|     spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) |     spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) | ||||||
|     assert len(spancat.labels) == 0 |     assert len(spancat.labels) == 0 | ||||||
|     train_examples = [] |     train_examples = make_examples(nlp) | ||||||
|     for t in TRAIN_DATA: |  | ||||||
|         train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1])) |  | ||||||
|     nlp.initialize(get_examples=lambda: train_examples) |     nlp.initialize(get_examples=lambda: train_examples) | ||||||
|     assert spancat.labels == ("PERSON", "LOC") |     assert spancat.labels == ("PERSON", "LOC") | ||||||
| 
 | 
 | ||||||
|  | @ -140,30 +145,6 @@ def test_make_spangroup(max_positive, nr_results): | ||||||
|     assert_almost_equal(0.9, spangroup.attrs["scores"][-1], 5) |     assert_almost_equal(0.9, spangroup.attrs["scores"][-1], 5) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_simple_train(): |  | ||||||
|     fix_random_seed(0) |  | ||||||
|     nlp = Language() |  | ||||||
|     spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) |  | ||||||
|     get_examples = make_get_examples(nlp) |  | ||||||
|     nlp.initialize(get_examples) |  | ||||||
|     sgd = nlp.create_optimizer() |  | ||||||
|     assert len(spancat.labels) != 0 |  | ||||||
|     for i in range(40): |  | ||||||
|         losses = {} |  | ||||||
|         nlp.update(list(get_examples()), losses=losses, drop=0.1, sgd=sgd) |  | ||||||
|     doc = nlp("I like London and Berlin.") |  | ||||||
|     assert doc.spans[spancat.key] == doc.spans[SPAN_KEY] |  | ||||||
|     assert len(doc.spans[spancat.key]) == 2 |  | ||||||
|     assert len(doc.spans[spancat.key].attrs["scores"]) == 2 |  | ||||||
|     assert doc.spans[spancat.key][0].text == "London" |  | ||||||
|     scores = nlp.evaluate(get_examples()) |  | ||||||
|     assert f"spans_{SPAN_KEY}_f" in scores |  | ||||||
|     assert scores[f"spans_{SPAN_KEY}_f"] == 1.0 |  | ||||||
|     # also test that the spancat works for just a single entity in a sentence |  | ||||||
|     doc = nlp("London") |  | ||||||
|     assert len(doc.spans[spancat.key]) == 1 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def test_ngram_suggester(en_tokenizer): | def test_ngram_suggester(en_tokenizer): | ||||||
|     # test different n-gram lengths |     # test different n-gram lengths | ||||||
|     for size in [1, 2, 3]: |     for size in [1, 2, 3]: | ||||||
|  | @ -282,3 +263,92 @@ def test_ngram_sizes(en_tokenizer): | ||||||
|     range_suggester = suggester_factory(min_size=2, max_size=4) |     range_suggester = suggester_factory(min_size=2, max_size=4) | ||||||
|     ngrams_3 = range_suggester(docs) |     ngrams_3 = range_suggester(docs) | ||||||
|     assert_array_equal(OPS.to_numpy(ngrams_3.lengths), [0, 1, 3, 6, 9]) |     assert_array_equal(OPS.to_numpy(ngrams_3.lengths), [0, 1, 3, 6, 9]) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_overfitting_IO(): | ||||||
|  |     # Simple test to try and quickly overfit the spancat component - ensuring the ML models work correctly | ||||||
|  |     fix_random_seed(0) | ||||||
|  |     nlp = English() | ||||||
|  |     spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) | ||||||
|  |     train_examples = make_examples(nlp) | ||||||
|  |     optimizer = nlp.initialize(get_examples=lambda: train_examples) | ||||||
|  |     assert spancat.model.get_dim("nO") == 2 | ||||||
|  |     assert set(spancat.labels) == {"LOC", "PERSON"} | ||||||
|  | 
 | ||||||
|  |     for i in range(50): | ||||||
|  |         losses = {} | ||||||
|  |         nlp.update(train_examples, sgd=optimizer, losses=losses) | ||||||
|  |     assert losses["spancat"] < 0.01 | ||||||
|  | 
 | ||||||
|  |     # test the trained model | ||||||
|  |     test_text = "I like London and Berlin" | ||||||
|  |     doc = nlp(test_text) | ||||||
|  |     assert doc.spans[spancat.key] == doc.spans[SPAN_KEY] | ||||||
|  |     spans = doc.spans[SPAN_KEY] | ||||||
|  |     assert len(spans) == 2 | ||||||
|  |     assert len(spans.attrs["scores"]) == 2 | ||||||
|  |     assert min(spans.attrs["scores"]) > 0.9 | ||||||
|  |     assert set([span.text for span in spans]) == {"London", "Berlin"} | ||||||
|  |     assert set([span.label_ for span in spans]) == {"LOC"} | ||||||
|  | 
 | ||||||
|  |     # Also test the results are still the same after IO | ||||||
|  |     with make_tempdir() as tmp_dir: | ||||||
|  |         nlp.to_disk(tmp_dir) | ||||||
|  |         nlp2 = util.load_model_from_path(tmp_dir) | ||||||
|  |         doc2 = nlp2(test_text) | ||||||
|  |         spans2 = doc2.spans[SPAN_KEY] | ||||||
|  |         assert len(spans2) == 2 | ||||||
|  |         assert len(spans2.attrs["scores"]) == 2 | ||||||
|  |         assert min(spans2.attrs["scores"]) > 0.9 | ||||||
|  |         assert set([span.text for span in spans2]) == {"London", "Berlin"} | ||||||
|  |         assert set([span.label_ for span in spans2]) == {"LOC"} | ||||||
|  | 
 | ||||||
|  |     # Test scoring | ||||||
|  |     scores = nlp.evaluate(train_examples) | ||||||
|  |     assert f"spans_{SPAN_KEY}_f" in scores | ||||||
|  |     assert scores[f"spans_{SPAN_KEY}_p"] == 1.0 | ||||||
|  |     assert scores[f"spans_{SPAN_KEY}_r"] == 1.0 | ||||||
|  |     assert scores[f"spans_{SPAN_KEY}_f"] == 1.0 | ||||||
|  | 
 | ||||||
|  |     # also test that the spancat works for just a single entity in a sentence | ||||||
|  |     doc = nlp("London") | ||||||
|  |     assert len(doc.spans[spancat.key]) == 1 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_overfitting_IO_overlapping(): | ||||||
|  |     # Test for overfitting on overlapping entities | ||||||
|  |     fix_random_seed(0) | ||||||
|  |     nlp = English() | ||||||
|  |     spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) | ||||||
|  | 
 | ||||||
|  |     train_examples = make_examples(nlp, data=TRAIN_DATA_OVERLAPPING) | ||||||
|  |     optimizer = nlp.initialize(get_examples=lambda: train_examples) | ||||||
|  |     assert spancat.model.get_dim("nO") == 3 | ||||||
|  |     assert set(spancat.labels) == {"PERSON", "LOC", "DOUBLE_LOC"} | ||||||
|  | 
 | ||||||
|  |     for i in range(50): | ||||||
|  |         losses = {} | ||||||
|  |         nlp.update(train_examples, sgd=optimizer, losses=losses) | ||||||
|  |     assert losses["spancat"] < 0.01 | ||||||
|  | 
 | ||||||
|  |     # test the trained model | ||||||
|  |     test_text = "I like London and Berlin" | ||||||
|  |     doc = nlp(test_text) | ||||||
|  |     spans = doc.spans[SPAN_KEY] | ||||||
|  |     assert len(spans) == 3 | ||||||
|  |     assert len(spans.attrs["scores"]) == 3 | ||||||
|  |     assert min(spans.attrs["scores"]) > 0.9 | ||||||
|  |     assert set([span.text for span in spans]) == {"London", "Berlin", "London and Berlin"} | ||||||
|  |     assert set([span.label_ for span in spans]) == {"LOC", "DOUBLE_LOC"} | ||||||
|  | 
 | ||||||
|  |     # Also test the results are still the same after IO | ||||||
|  |     with make_tempdir() as tmp_dir: | ||||||
|  |         nlp.to_disk(tmp_dir) | ||||||
|  |         nlp2 = util.load_model_from_path(tmp_dir) | ||||||
|  |         doc2 = nlp2(test_text) | ||||||
|  |         spans2 = doc2.spans[SPAN_KEY] | ||||||
|  |         assert len(spans2) == 3 | ||||||
|  |         assert len(spans2.attrs["scores"]) == 3 | ||||||
|  |         assert min(spans2.attrs["scores"]) > 0.9 | ||||||
|  |         assert set([span.text for span in spans2]) == {"London", "Berlin", "London and Berlin"} | ||||||
|  |         assert set([span.label_ for span in spans2]) == {"LOC", "DOUBLE_LOC"} | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user