spaCy/spacy/tests/pipeline/test_span_finder.py
2023-06-01 10:19:22 +00:00

265 lines
8.1 KiB
Python

import pytest
from thinc.api import Config
from thinc.types import Ragged
from spacy.language import Language
from spacy.lang.en import English
from spacy.pipeline.span_finder import span_finder_default_config
from spacy.tokens import Doc
from spacy.training import Example
from spacy import util
from spacy.util import registry
from spacy.util import fix_random_seed, make_tempdir
SPANS_KEY = "pytest"
TRAIN_DATA = [
("Who is Shaka Khan?", {"spans": {SPANS_KEY: [(7, 17)]}}),
(
"I like London and Berlin.",
{"spans": {SPANS_KEY: [(7, 13), (18, 24)]}},
),
]
TRAIN_DATA_OVERLAPPING = [
("Who is Shaka Khan?", {"spans": {SPANS_KEY: [(7, 17)]}}),
(
"I like London and Berlin",
{"spans": {SPANS_KEY: [(7, 13), (18, 24), (7, 24)]}},
),
("", {"spans": {SPANS_KEY: []}}),
]
def make_examples(nlp, data=TRAIN_DATA):
train_examples = []
for t in data:
eg = Example.from_dict(nlp.make_doc(t[0]), t[1])
train_examples.append(eg)
return train_examples
@pytest.mark.parametrize(
"tokens_predicted, tokens_reference, reference_truths",
[
(
["Mon", ".", "-", "June", "16"],
["Mon.", "-", "June", "16"],
[(0, 0), (0, 0), (0, 0), (1, 1), (0, 0)],
),
(
["Mon.", "-", "J", "une", "16"],
["Mon.", "-", "June", "16"],
[(0, 0), (0, 0), (1, 0), (0, 1), (0, 0)],
),
(
["Mon", ".", "-", "June", "16"],
["Mon.", "-", "June", "1", "6"],
[(0, 0), (0, 0), (0, 0), (1, 1), (0, 0)],
),
(
["Mon.", "-J", "un", "e 16"],
["Mon.", "-", "June", "16"],
[(0, 0), (0, 0), (0, 0), (0, 0)],
),
pytest.param(
["Mon.-June", "16"],
["Mon.", "-", "June", "16"],
[(0, 1), (0, 0)],
),
pytest.param(
["Mon.-", "June", "16"],
["Mon.", "-", "J", "une", "16"],
[(0, 0), (1, 1), (0, 0)],
),
pytest.param(
["Mon.-", "June 16"],
["Mon.", "-", "June", "16"],
[(0, 0), (1, 0)],
),
],
)
def test_loss_alignment_example(tokens_predicted, tokens_reference, reference_truths):
nlp = Language()
predicted = Doc(
nlp.vocab, words=tokens_predicted, spaces=[False] * len(tokens_predicted)
)
reference = Doc(
nlp.vocab, words=tokens_reference, spaces=[False] * len(tokens_reference)
)
example = Example(predicted, reference)
example.reference.spans[SPANS_KEY] = [example.reference.char_span(5, 9)]
span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY})
nlp.initialize()
ops = span_finder.model.ops
if predicted.text != reference.text:
with pytest.raises(
ValueError, match="must match between reference and predicted"
):
span_finder._get_aligned_truth_scores([example], ops)
return
truth_scores, masks = span_finder._get_aligned_truth_scores([example], ops)
assert len(truth_scores) == len(tokens_predicted)
ops.xp.testing.assert_array_equal(truth_scores, ops.xp.asarray(reference_truths))
def test_span_finder_model():
nlp = Language()
docs = [nlp("This is an example."), nlp("This is the second example.")]
docs[0].spans[SPANS_KEY] = [docs[0][3:4]]
docs[1].spans[SPANS_KEY] = [docs[1][3:5]]
total_tokens = 0
for doc in docs:
total_tokens += len(doc)
config = Config().from_str(span_finder_default_config).interpolate()
model = registry.resolve(config)["model"]
model.initialize(X=docs)
predictions = model.predict(docs)
assert len(predictions) == total_tokens
assert len(predictions[0]) == 2
def test_span_finder_component():
nlp = Language()
docs = [nlp("This is an example."), nlp("This is the second example.")]
docs[0].spans[SPANS_KEY] = [docs[0][3:4]]
docs[1].spans[SPANS_KEY] = [docs[1][3:5]]
span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY})
nlp.initialize()
docs = list(span_finder.pipe(docs))
# TODO: update hard-coded name
assert SPANS_KEY in docs[0].spans
@pytest.mark.parametrize(
"min_length, max_length, span_count",
[(0, 0, 0), (None, None, 8), (2, None, 6), (None, 1, 2), (2, 3, 2)],
)
def test_set_annotations_span_lengths(min_length, max_length, span_count):
nlp = Language()
doc = nlp("Me and Jenny goes together like peas and carrots.")
if min_length == 0 and max_length == 0:
with pytest.raises(ValueError, match="Both 'min_length' and 'max_length'"):
span_finder = nlp.add_pipe(
"span_finder",
config={
"max_length": max_length,
"min_length": min_length,
"spans_key": SPANS_KEY,
},
)
return
span_finder = nlp.add_pipe(
"span_finder",
config={
"max_length": max_length,
"min_length": min_length,
"spans_key": SPANS_KEY,
},
)
nlp.initialize()
# Starts [Me, Jenny, peas]
# Ends [Jenny, peas, carrots]
scores = [
(1, 0),
(0, 0),
(1, 1),
(0, 0),
(0, 0),
(0, 0),
(1, 1),
(0, 0),
(0, 1),
(0, 0),
]
span_finder.set_annotations([doc], scores)
assert doc.spans[SPANS_KEY]
assert len(doc.spans[SPANS_KEY]) == span_count
# Assert below will fail when max_length is set to 0
if max_length is None:
max_length = float("inf")
if min_length is None:
min_length = 1
assert all(
min_length <= len(span) <= max_length
for span in doc.spans[SPANS_KEY]
)
def test_span_finder_suggester():
nlp = Language()
docs = [nlp("This is an example."), nlp("This is the second example.")]
docs[0].spans[SPANS_KEY] = [docs[0][3:4]]
docs[1].spans[SPANS_KEY] = [docs[1][3:5]]
span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY})
nlp.initialize()
span_finder.set_annotations(docs, span_finder.predict(docs))
suggester = registry.misc.get("spacy.span_finder_suggester.v1")(
candidates_key=SPANS_KEY
)
candidates = suggester(docs)
span_length = 0
for doc in docs:
span_length += len(doc.spans[SPANS_KEY])
assert span_length == len(candidates.dataXd)
assert type(candidates) == Ragged
assert len(candidates.dataXd[0]) == 2
def test_overfitting_IO():
# Simple test to try and quickly overfit the spancat component - ensuring the ML models work correctly
fix_random_seed(0)
nlp = English()
span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY})
train_examples = make_examples(nlp)
optimizer = nlp.initialize(get_examples=lambda: train_examples)
assert span_finder.model.get_dim("nO") == 2
for i in range(50):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
assert losses["span_finder"] < 0.001
# test the trained model
test_text = "I like London and Berlin"
doc = nlp(test_text)
spans = doc.spans[span_finder.spans_key]
assert len(spans) == 3
assert set([span.text for span in spans]) == {"London", "Berlin", "London and Berlin"}
# Also test the results are still the same after IO
with make_tempdir() as tmp_dir:
nlp.to_disk(tmp_dir)
nlp2 = util.load_model_from_path(tmp_dir)
doc2 = nlp2(test_text)
spans2 = doc2.spans[span_finder.spans_key]
assert len(spans2) == 3
assert set([span.text for span in spans2]) == {"London", "Berlin", "London and Berlin"}
# Test scoring
scores = nlp.evaluate(train_examples)
sf = nlp.get_pipe("span_finder")
print(sf.spans_key)
assert f"span_finder_{span_finder.spans_key}_f" in scores
# XXX Its not perfect 1.0 F1 because we want it to overgenerate for now.
assert scores[f"span_finder_{span_finder.spans_key}_f"] == 0.4
# also test that the spancat works for just a single entity in a sentence
doc = nlp("London")
assert len(doc.spans[span_finder.spans_key]) == 1