mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
c003aac29a
* span finder integrated into spacy from experimental * black * isort * black * default spankey constant * black * Update spacy/pipeline/spancat.py Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * rename * rename * max_length and min_length as Optional[int] and strict checking * black * mypy fix for integer type infinity * revert line order * implement all comparison operators for inf int * avoid two for loops over all docs by not precomputing * interleave thresholding with span creation * black * revert to not interleaving (relized its faster) * black * Update spacy/errors.py Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * update dosctring * enforce that the gold and predicted documents have the same text * new error for ensuring reference and predicted texts are the same * remove todo * adjust test * black * handle misaligned tokenization * return correct variable * failing overfit test * only use a single spans_key like in spancat * black * remove debug lines * typo * remove comment * remove near duplicate reduntant method * use the 'spans_key' variable name everywhere * Update spacy/pipeline/span_finder.py Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * flaky test fix suggestion, hand set bias terms * only test suggester and test result exhaustively * make it clear that the span_finder_suggester is more general (not specific to span_finder) * Update spacy/tests/pipeline/test_span_finder.py Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Apply suggestions from code review * remove question comment * move preset_spans_suggester test to spancat tests * Add docs and unify default configs for spancat and span finder * Add `allow_overlap=True` to span finder scorer * Fix offset bug in set_annotations * Ignore labels in span finder scorer * Format * Add span_finder to quickstart template * Move settings to self.cfg, store min/max unset as None * Remove debugging * Update docstrings and docs * Update spacy/pipeline/span_finder.py Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Fix imports --------- Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
243 lines
7.3 KiB
Python
243 lines
7.3 KiB
Python
import pytest
|
|
from thinc.api import Config
|
|
|
|
from spacy.language import Language
|
|
from spacy.lang.en import English
|
|
from spacy.pipeline.span_finder import span_finder_default_config
|
|
from spacy.tokens import Doc
|
|
from spacy.training import Example
|
|
from spacy import util
|
|
from spacy.util import registry
|
|
from spacy.util import fix_random_seed, make_tempdir
|
|
|
|
|
|
SPANS_KEY = "pytest"
|
|
TRAIN_DATA = [
|
|
("Who is Shaka Khan?", {"spans": {SPANS_KEY: [(7, 17)]}}),
|
|
(
|
|
"I like London and Berlin.",
|
|
{"spans": {SPANS_KEY: [(7, 13), (18, 24)]}},
|
|
),
|
|
]
|
|
|
|
TRAIN_DATA_OVERLAPPING = [
|
|
("Who is Shaka Khan?", {"spans": {SPANS_KEY: [(7, 17)]}}),
|
|
(
|
|
"I like London and Berlin",
|
|
{"spans": {SPANS_KEY: [(7, 13), (18, 24), (7, 24)]}},
|
|
),
|
|
("", {"spans": {SPANS_KEY: []}}),
|
|
]
|
|
|
|
|
|
def make_examples(nlp, data=TRAIN_DATA):
|
|
train_examples = []
|
|
for t in data:
|
|
eg = Example.from_dict(nlp.make_doc(t[0]), t[1])
|
|
train_examples.append(eg)
|
|
return train_examples
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"tokens_predicted, tokens_reference, reference_truths",
|
|
[
|
|
(
|
|
["Mon", ".", "-", "June", "16"],
|
|
["Mon.", "-", "June", "16"],
|
|
[(0, 0), (0, 0), (0, 0), (1, 1), (0, 0)],
|
|
),
|
|
(
|
|
["Mon.", "-", "J", "une", "16"],
|
|
["Mon.", "-", "June", "16"],
|
|
[(0, 0), (0, 0), (1, 0), (0, 1), (0, 0)],
|
|
),
|
|
(
|
|
["Mon", ".", "-", "June", "16"],
|
|
["Mon.", "-", "June", "1", "6"],
|
|
[(0, 0), (0, 0), (0, 0), (1, 1), (0, 0)],
|
|
),
|
|
(
|
|
["Mon.", "-J", "un", "e 16"],
|
|
["Mon.", "-", "June", "16"],
|
|
[(0, 0), (0, 0), (0, 0), (0, 0)],
|
|
),
|
|
pytest.param(
|
|
["Mon.-June", "16"],
|
|
["Mon.", "-", "June", "16"],
|
|
[(0, 1), (0, 0)],
|
|
),
|
|
pytest.param(
|
|
["Mon.-", "June", "16"],
|
|
["Mon.", "-", "J", "une", "16"],
|
|
[(0, 0), (1, 1), (0, 0)],
|
|
),
|
|
pytest.param(
|
|
["Mon.-", "June 16"],
|
|
["Mon.", "-", "June", "16"],
|
|
[(0, 0), (1, 0)],
|
|
),
|
|
],
|
|
)
|
|
def test_loss_alignment_example(tokens_predicted, tokens_reference, reference_truths):
|
|
nlp = Language()
|
|
predicted = Doc(
|
|
nlp.vocab, words=tokens_predicted, spaces=[False] * len(tokens_predicted)
|
|
)
|
|
reference = Doc(
|
|
nlp.vocab, words=tokens_reference, spaces=[False] * len(tokens_reference)
|
|
)
|
|
example = Example(predicted, reference)
|
|
example.reference.spans[SPANS_KEY] = [example.reference.char_span(5, 9)]
|
|
span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY})
|
|
nlp.initialize()
|
|
ops = span_finder.model.ops
|
|
if predicted.text != reference.text:
|
|
with pytest.raises(
|
|
ValueError, match="must match between reference and predicted"
|
|
):
|
|
span_finder._get_aligned_truth_scores([example], ops)
|
|
return
|
|
truth_scores, masks = span_finder._get_aligned_truth_scores([example], ops)
|
|
assert len(truth_scores) == len(tokens_predicted)
|
|
ops.xp.testing.assert_array_equal(truth_scores, ops.xp.asarray(reference_truths))
|
|
|
|
|
|
def test_span_finder_model():
|
|
nlp = Language()
|
|
|
|
docs = [nlp("This is an example."), nlp("This is the second example.")]
|
|
docs[0].spans[SPANS_KEY] = [docs[0][3:4]]
|
|
docs[1].spans[SPANS_KEY] = [docs[1][3:5]]
|
|
|
|
total_tokens = 0
|
|
for doc in docs:
|
|
total_tokens += len(doc)
|
|
|
|
config = Config().from_str(span_finder_default_config).interpolate()
|
|
model = registry.resolve(config)["model"]
|
|
|
|
model.initialize(X=docs)
|
|
predictions = model.predict(docs)
|
|
|
|
assert len(predictions) == total_tokens
|
|
assert len(predictions[0]) == 2
|
|
|
|
|
|
def test_span_finder_component():
|
|
nlp = Language()
|
|
|
|
docs = [nlp("This is an example."), nlp("This is the second example.")]
|
|
docs[0].spans[SPANS_KEY] = [docs[0][3:4]]
|
|
docs[1].spans[SPANS_KEY] = [docs[1][3:5]]
|
|
|
|
span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY})
|
|
nlp.initialize()
|
|
docs = list(span_finder.pipe(docs))
|
|
|
|
assert SPANS_KEY in docs[0].spans
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"min_length, max_length, span_count",
|
|
[(0, 0, 0), (None, None, 8), (2, None, 6), (None, 1, 2), (2, 3, 2)],
|
|
)
|
|
def test_set_annotations_span_lengths(min_length, max_length, span_count):
|
|
nlp = Language()
|
|
doc = nlp("Me and Jenny goes together like peas and carrots.")
|
|
if min_length == 0 and max_length == 0:
|
|
with pytest.raises(ValueError, match="Both 'min_length' and 'max_length'"):
|
|
span_finder = nlp.add_pipe(
|
|
"span_finder",
|
|
config={
|
|
"max_length": max_length,
|
|
"min_length": min_length,
|
|
"spans_key": SPANS_KEY,
|
|
},
|
|
)
|
|
return
|
|
span_finder = nlp.add_pipe(
|
|
"span_finder",
|
|
config={
|
|
"max_length": max_length,
|
|
"min_length": min_length,
|
|
"spans_key": SPANS_KEY,
|
|
},
|
|
)
|
|
nlp.initialize()
|
|
# Starts [Me, Jenny, peas]
|
|
# Ends [Jenny, peas, carrots]
|
|
scores = [
|
|
(1, 0),
|
|
(0, 0),
|
|
(1, 1),
|
|
(0, 0),
|
|
(0, 0),
|
|
(0, 0),
|
|
(1, 1),
|
|
(0, 0),
|
|
(0, 1),
|
|
(0, 0),
|
|
]
|
|
span_finder.set_annotations([doc], scores)
|
|
|
|
assert doc.spans[SPANS_KEY]
|
|
assert len(doc.spans[SPANS_KEY]) == span_count
|
|
|
|
# Assert below will fail when max_length is set to 0
|
|
if max_length is None:
|
|
max_length = float("inf")
|
|
if min_length is None:
|
|
min_length = 1
|
|
|
|
assert all(min_length <= len(span) <= max_length for span in doc.spans[SPANS_KEY])
|
|
|
|
|
|
def test_overfitting_IO():
|
|
# Simple test to try and quickly overfit the span_finder component - ensuring the ML models work correctly
|
|
fix_random_seed(0)
|
|
nlp = English()
|
|
span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY})
|
|
train_examples = make_examples(nlp)
|
|
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
|
assert span_finder.model.get_dim("nO") == 2
|
|
|
|
for i in range(50):
|
|
losses = {}
|
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
|
assert losses["span_finder"] < 0.001
|
|
|
|
# test the trained model
|
|
test_text = "I like London and Berlin"
|
|
doc = nlp(test_text)
|
|
spans = doc.spans[SPANS_KEY]
|
|
assert len(spans) == 3
|
|
assert set([span.text for span in spans]) == {
|
|
"London",
|
|
"Berlin",
|
|
"London and Berlin",
|
|
}
|
|
|
|
# Also test the results are still the same after IO
|
|
with make_tempdir() as tmp_dir:
|
|
nlp.to_disk(tmp_dir)
|
|
nlp2 = util.load_model_from_path(tmp_dir)
|
|
doc2 = nlp2(test_text)
|
|
spans2 = doc2.spans[SPANS_KEY]
|
|
assert len(spans2) == 3
|
|
assert set([span.text for span in spans2]) == {
|
|
"London",
|
|
"Berlin",
|
|
"London and Berlin",
|
|
}
|
|
|
|
# Test scoring
|
|
scores = nlp.evaluate(train_examples)
|
|
assert f"span_finder_{SPANS_KEY}_f" in scores
|
|
# It's not perfect 1.0 F1 because it's designed to overgenerate for now.
|
|
assert scores[f"span_finder_{SPANS_KEY}_p"] == 0.75
|
|
assert scores[f"span_finder_{SPANS_KEY}_r"] == 1.0
|
|
|
|
# also test that the spancat works for just a single entity in a sentence
|
|
doc = nlp("London")
|
|
assert len(doc.spans[SPANS_KEY]) == 1
|