mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
Add basic span predictor tests
This commit is contained in:
parent
201731df2d
commit
1a4dbb702d
129
spacy/tests/pipeline/test_span_predictor.py
Normal file
129
spacy/tests/pipeline/test_span_predictor.py
Normal file
|
@ -0,0 +1,129 @@
|
|||
import pytest
|
||||
import spacy
|
||||
|
||||
from spacy import util
|
||||
from spacy.training import Example
|
||||
from spacy.lang.en import English
|
||||
from spacy.tests.util import make_tempdir
|
||||
from spacy.ml.models.coref_util import (
|
||||
DEFAULT_CLUSTER_PREFIX,
|
||||
select_non_crossing_spans,
|
||||
get_sentence_ids,
|
||||
spans2ints,
|
||||
)
|
||||
|
||||
from thinc.util import has_torch
|
||||
|
||||
# fmt: off
|
||||
TRAIN_DATA = [
|
||||
(
|
||||
"John Smith picked up the red ball and he threw it away.",
|
||||
{
|
||||
"spans": {
|
||||
f"{DEFAULT_CLUSTER_PREFIX}_1": [
|
||||
(0, 11, "MENTION"), # John Smith
|
||||
(38, 41, "MENTION"), # he
|
||||
|
||||
],
|
||||
f"{DEFAULT_CLUSTER_PREFIX}_2": [
|
||||
(25, 33, "MENTION"), # red ball
|
||||
(47, 50, "MENTION"), # it
|
||||
],
|
||||
f"coref_head_clusters_1": [
|
||||
(5, 11, "MENTION"), # Smith
|
||||
(38, 41, "MENTION"), # he
|
||||
|
||||
],
|
||||
f"coref_head_clusters_2": [
|
||||
(29, 33, "MENTION"), # red ball
|
||||
(47, 50, "MENTION"), # it
|
||||
]
|
||||
}
|
||||
},
|
||||
),
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nlp():
|
||||
return English()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def snlp():
|
||||
en = English()
|
||||
en.add_pipe("sentencizer")
|
||||
return en
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_add_pipe(nlp):
|
||||
nlp.add_pipe("span_predictor")
|
||||
assert nlp.pipe_names == ["span_predictor"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_not_initialized(nlp):
|
||||
nlp.add_pipe("span_predictor")
|
||||
text = "She gave me her pen."
|
||||
with pytest.raises(ValueError, match="E109"):
|
||||
nlp(text)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_span_predictor_serialization(nlp):
|
||||
# Test that the span predictor component can be serialized
|
||||
nlp.add_pipe("span_predictor", last=True)
|
||||
nlp.initialize()
|
||||
assert nlp.pipe_names == ["span_predictor"]
|
||||
text = "She gave me her pen."
|
||||
doc = nlp(text)
|
||||
|
||||
with make_tempdir() as tmp_dir:
|
||||
nlp.to_disk(tmp_dir)
|
||||
nlp2 = spacy.load(tmp_dir)
|
||||
assert nlp2.pipe_names == ["span_predictor"]
|
||||
doc2 = nlp2(text)
|
||||
|
||||
assert spans2ints(doc) == spans2ints(doc2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_overfitting_IO(nlp):
|
||||
# Simple test to try and quickly overfit - ensuring the ML models work correctly
|
||||
train_examples = []
|
||||
for text, annot in TRAIN_DATA:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annot))
|
||||
|
||||
nlp.add_pipe("span_predictor")
|
||||
optimizer = nlp.initialize()
|
||||
test_text = TRAIN_DATA[0][0]
|
||||
doc = nlp(test_text)
|
||||
|
||||
# Needs ~12 epochs to converge
|
||||
for i in range(15):
|
||||
losses = {}
|
||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||
doc = nlp(test_text)
|
||||
|
||||
# test the trained model
|
||||
doc = nlp(test_text)
|
||||
|
||||
# Also test the results are still the same after IO
|
||||
with make_tempdir() as tmp_dir:
|
||||
nlp.to_disk(tmp_dir)
|
||||
nlp2 = util.load_model_from_path(tmp_dir)
|
||||
doc2 = nlp2(test_text)
|
||||
|
||||
# Make sure that running pipe twice, or comparing to call, always amounts to the same predictions
|
||||
texts = [
|
||||
test_text,
|
||||
"I noticed many friends around me",
|
||||
"They received it. They received the SMS.",
|
||||
]
|
||||
docs1 = list(nlp.pipe(texts))
|
||||
docs2 = list(nlp.pipe(texts))
|
||||
docs3 = [nlp(text) for text in texts]
|
||||
assert spans2ints(docs1[0]) == spans2ints(docs2[0])
|
||||
assert spans2ints(docs1[0]) == spans2ints(docs3[0])
|
Loading…
Reference in New Issue
Block a user