mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
Fix spancat training on nested entities (#9007)
* overfitting test on non-overlapping entities * add failing overfitting test for overlapping entities * failing test for list comprehension * remove test that was put in separate PR * bugfix * cleanup
This commit is contained in:
parent
9cc3dc2b67
commit
4d52d7051c
|
@ -398,7 +398,7 @@ class SpanCategorizer(TrainablePipe):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _get_aligned_spans(self, eg: Example):
|
def _get_aligned_spans(self, eg: Example):
|
||||||
return eg.get_aligned_spans_y2x(eg.reference.spans.get(self.key, []))
|
return eg.get_aligned_spans_y2x(eg.reference.spans.get(self.key, []), allow_overlap=True)
|
||||||
|
|
||||||
def _make_span_group(
|
def _make_span_group(
|
||||||
self, doc: Doc, indices: Ints2d, scores: Floats2d, labels: List[str]
|
self, doc: Doc, indices: Ints2d, scores: Floats2d, labels: List[str]
|
||||||
|
|
|
@ -2,11 +2,14 @@ import pytest
|
||||||
import numpy
|
import numpy
|
||||||
from numpy.testing import assert_array_equal, assert_almost_equal
|
from numpy.testing import assert_array_equal, assert_almost_equal
|
||||||
from thinc.api import get_current_ops
|
from thinc.api import get_current_ops
|
||||||
|
|
||||||
|
from spacy import util
|
||||||
|
from spacy.lang.en import English
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
from spacy.tokens.doc import SpanGroups
|
from spacy.tokens.doc import SpanGroups
|
||||||
from spacy.tokens import SpanGroup
|
from spacy.tokens import SpanGroup
|
||||||
from spacy.training import Example
|
from spacy.training import Example
|
||||||
from spacy.util import fix_random_seed, registry
|
from spacy.util import fix_random_seed, registry, make_tempdir
|
||||||
|
|
||||||
OPS = get_current_ops()
|
OPS = get_current_ops()
|
||||||
|
|
||||||
|
@ -20,18 +23,22 @@ TRAIN_DATA = [
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
TRAIN_DATA_OVERLAPPING = [
|
||||||
|
("Who is Shaka Khan?", {"spans": {SPAN_KEY: [(7, 17, "PERSON")]}}),
|
||||||
|
(
|
||||||
|
"I like London and Berlin",
|
||||||
|
{"spans": {SPAN_KEY: [(7, 13, "LOC"), (18, 24, "LOC"), (7, 24, "DOUBLE_LOC")]}},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
def make_get_examples(nlp):
|
|
||||||
|
def make_examples(nlp, data=TRAIN_DATA):
|
||||||
train_examples = []
|
train_examples = []
|
||||||
for t in TRAIN_DATA:
|
for t in data:
|
||||||
eg = Example.from_dict(nlp.make_doc(t[0]), t[1])
|
eg = Example.from_dict(nlp.make_doc(t[0]), t[1])
|
||||||
train_examples.append(eg)
|
train_examples.append(eg)
|
||||||
|
|
||||||
def get_examples():
|
|
||||||
return train_examples
|
return train_examples
|
||||||
|
|
||||||
return get_examples
|
|
||||||
|
|
||||||
|
|
||||||
def test_no_label():
|
def test_no_label():
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
|
@ -57,9 +64,7 @@ def test_implicit_labels():
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
|
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
|
||||||
assert len(spancat.labels) == 0
|
assert len(spancat.labels) == 0
|
||||||
train_examples = []
|
train_examples = make_examples(nlp)
|
||||||
for t in TRAIN_DATA:
|
|
||||||
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
|
||||||
nlp.initialize(get_examples=lambda: train_examples)
|
nlp.initialize(get_examples=lambda: train_examples)
|
||||||
assert spancat.labels == ("PERSON", "LOC")
|
assert spancat.labels == ("PERSON", "LOC")
|
||||||
|
|
||||||
|
@ -140,30 +145,6 @@ def test_make_spangroup(max_positive, nr_results):
|
||||||
assert_almost_equal(0.9, spangroup.attrs["scores"][-1], 5)
|
assert_almost_equal(0.9, spangroup.attrs["scores"][-1], 5)
|
||||||
|
|
||||||
|
|
||||||
def test_simple_train():
|
|
||||||
fix_random_seed(0)
|
|
||||||
nlp = Language()
|
|
||||||
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
|
|
||||||
get_examples = make_get_examples(nlp)
|
|
||||||
nlp.initialize(get_examples)
|
|
||||||
sgd = nlp.create_optimizer()
|
|
||||||
assert len(spancat.labels) != 0
|
|
||||||
for i in range(40):
|
|
||||||
losses = {}
|
|
||||||
nlp.update(list(get_examples()), losses=losses, drop=0.1, sgd=sgd)
|
|
||||||
doc = nlp("I like London and Berlin.")
|
|
||||||
assert doc.spans[spancat.key] == doc.spans[SPAN_KEY]
|
|
||||||
assert len(doc.spans[spancat.key]) == 2
|
|
||||||
assert len(doc.spans[spancat.key].attrs["scores"]) == 2
|
|
||||||
assert doc.spans[spancat.key][0].text == "London"
|
|
||||||
scores = nlp.evaluate(get_examples())
|
|
||||||
assert f"spans_{SPAN_KEY}_f" in scores
|
|
||||||
assert scores[f"spans_{SPAN_KEY}_f"] == 1.0
|
|
||||||
# also test that the spancat works for just a single entity in a sentence
|
|
||||||
doc = nlp("London")
|
|
||||||
assert len(doc.spans[spancat.key]) == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_ngram_suggester(en_tokenizer):
|
def test_ngram_suggester(en_tokenizer):
|
||||||
# test different n-gram lengths
|
# test different n-gram lengths
|
||||||
for size in [1, 2, 3]:
|
for size in [1, 2, 3]:
|
||||||
|
@ -282,3 +263,92 @@ def test_ngram_sizes(en_tokenizer):
|
||||||
range_suggester = suggester_factory(min_size=2, max_size=4)
|
range_suggester = suggester_factory(min_size=2, max_size=4)
|
||||||
ngrams_3 = range_suggester(docs)
|
ngrams_3 = range_suggester(docs)
|
||||||
assert_array_equal(OPS.to_numpy(ngrams_3.lengths), [0, 1, 3, 6, 9])
|
assert_array_equal(OPS.to_numpy(ngrams_3.lengths), [0, 1, 3, 6, 9])
|
||||||
|
|
||||||
|
|
||||||
|
def test_overfitting_IO():
|
||||||
|
# Simple test to try and quickly overfit the spancat component - ensuring the ML models work correctly
|
||||||
|
fix_random_seed(0)
|
||||||
|
nlp = English()
|
||||||
|
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
|
||||||
|
train_examples = make_examples(nlp)
|
||||||
|
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||||
|
assert spancat.model.get_dim("nO") == 2
|
||||||
|
assert set(spancat.labels) == {"LOC", "PERSON"}
|
||||||
|
|
||||||
|
for i in range(50):
|
||||||
|
losses = {}
|
||||||
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
|
assert losses["spancat"] < 0.01
|
||||||
|
|
||||||
|
# test the trained model
|
||||||
|
test_text = "I like London and Berlin"
|
||||||
|
doc = nlp(test_text)
|
||||||
|
assert doc.spans[spancat.key] == doc.spans[SPAN_KEY]
|
||||||
|
spans = doc.spans[SPAN_KEY]
|
||||||
|
assert len(spans) == 2
|
||||||
|
assert len(spans.attrs["scores"]) == 2
|
||||||
|
assert min(spans.attrs["scores"]) > 0.9
|
||||||
|
assert set([span.text for span in spans]) == {"London", "Berlin"}
|
||||||
|
assert set([span.label_ for span in spans]) == {"LOC"}
|
||||||
|
|
||||||
|
# Also test the results are still the same after IO
|
||||||
|
with make_tempdir() as tmp_dir:
|
||||||
|
nlp.to_disk(tmp_dir)
|
||||||
|
nlp2 = util.load_model_from_path(tmp_dir)
|
||||||
|
doc2 = nlp2(test_text)
|
||||||
|
spans2 = doc2.spans[SPAN_KEY]
|
||||||
|
assert len(spans2) == 2
|
||||||
|
assert len(spans2.attrs["scores"]) == 2
|
||||||
|
assert min(spans2.attrs["scores"]) > 0.9
|
||||||
|
assert set([span.text for span in spans2]) == {"London", "Berlin"}
|
||||||
|
assert set([span.label_ for span in spans2]) == {"LOC"}
|
||||||
|
|
||||||
|
# Test scoring
|
||||||
|
scores = nlp.evaluate(train_examples)
|
||||||
|
assert f"spans_{SPAN_KEY}_f" in scores
|
||||||
|
assert scores[f"spans_{SPAN_KEY}_p"] == 1.0
|
||||||
|
assert scores[f"spans_{SPAN_KEY}_r"] == 1.0
|
||||||
|
assert scores[f"spans_{SPAN_KEY}_f"] == 1.0
|
||||||
|
|
||||||
|
# also test that the spancat works for just a single entity in a sentence
|
||||||
|
doc = nlp("London")
|
||||||
|
assert len(doc.spans[spancat.key]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_overfitting_IO_overlapping():
|
||||||
|
# Test for overfitting on overlapping entities
|
||||||
|
fix_random_seed(0)
|
||||||
|
nlp = English()
|
||||||
|
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
|
||||||
|
|
||||||
|
train_examples = make_examples(nlp, data=TRAIN_DATA_OVERLAPPING)
|
||||||
|
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||||
|
assert spancat.model.get_dim("nO") == 3
|
||||||
|
assert set(spancat.labels) == {"PERSON", "LOC", "DOUBLE_LOC"}
|
||||||
|
|
||||||
|
for i in range(50):
|
||||||
|
losses = {}
|
||||||
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
|
assert losses["spancat"] < 0.01
|
||||||
|
|
||||||
|
# test the trained model
|
||||||
|
test_text = "I like London and Berlin"
|
||||||
|
doc = nlp(test_text)
|
||||||
|
spans = doc.spans[SPAN_KEY]
|
||||||
|
assert len(spans) == 3
|
||||||
|
assert len(spans.attrs["scores"]) == 3
|
||||||
|
assert min(spans.attrs["scores"]) > 0.9
|
||||||
|
assert set([span.text for span in spans]) == {"London", "Berlin", "London and Berlin"}
|
||||||
|
assert set([span.label_ for span in spans]) == {"LOC", "DOUBLE_LOC"}
|
||||||
|
|
||||||
|
# Also test the results are still the same after IO
|
||||||
|
with make_tempdir() as tmp_dir:
|
||||||
|
nlp.to_disk(tmp_dir)
|
||||||
|
nlp2 = util.load_model_from_path(tmp_dir)
|
||||||
|
doc2 = nlp2(test_text)
|
||||||
|
spans2 = doc2.spans[SPAN_KEY]
|
||||||
|
assert len(spans2) == 3
|
||||||
|
assert len(spans2.attrs["scores"]) == 3
|
||||||
|
assert min(spans2.attrs["scores"]) > 0.9
|
||||||
|
assert set([span.text for span in spans2]) == {"London", "Berlin", "London and Berlin"}
|
||||||
|
assert set([span.label_ for span in spans2]) == {"LOC", "DOUBLE_LOC"}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user