mirror of
https://github.com/explosion/spaCy.git
synced 2025-10-22 03:34:15 +03:00
* initial coref_er pipe * matcher more flexible * base coref component without actual model * initial setup of coref_er.score * rename to include_label * preliminary score_clusters method * apply scoring in coref component * IO fix * return None loss for now * rename to CoreferenceResolver * some preliminary unit tests * use registry as callable
181 lines
6.2 KiB
Python
181 lines
6.2 KiB
Python
import pytest
|
|
import spacy
|
|
from spacy.matcher import PhraseMatcher
|
|
from spacy.training import Example
|
|
from spacy.lang.en import English
|
|
from spacy.tests.util import make_tempdir
|
|
from spacy.tokens import Doc
|
|
from spacy.pipeline.coref import DEFAULT_CLUSTERS_PREFIX
|
|
from spacy.pipeline.coref_er import DEFAULT_MENTIONS
|
|
|
|
|
|
# fmt: off
|
|
TRAIN_DATA = [
|
|
(
|
|
"John Smith told Laura that he was running late and asked her whether she could pick up their kids.",
|
|
{
|
|
"spans": {
|
|
DEFAULT_MENTIONS: [
|
|
(0, 10, "MENTION"),
|
|
(16, 21, "MENTION"),
|
|
(27, 29, "MENTION"),
|
|
(57, 60, "MENTION"),
|
|
(69, 72, "MENTION"),
|
|
(87, 92, "MENTION"),
|
|
(87, 97, "MENTION"),
|
|
],
|
|
f"{DEFAULT_CLUSTERS_PREFIX}_1": [
|
|
(0, 10, "MENTION"), # John
|
|
(27, 29, "MENTION"),
|
|
(87, 92, "MENTION"), # 'their' refers to John and Laur
|
|
],
|
|
f"{DEFAULT_CLUSTERS_PREFIX}_2": [
|
|
(16, 21, "MENTION"), # Laura
|
|
(57, 60, "MENTION"),
|
|
(69, 72, "MENTION"),
|
|
(87, 92, "MENTION"), # 'their' refers to John and Laura
|
|
],
|
|
}
|
|
},
|
|
),
|
|
(
|
|
"Yes, I noticed that many friends around me received it. It seems that almost everyone received this SMS.",
|
|
{
|
|
"spans": {
|
|
DEFAULT_MENTIONS: [
|
|
(5, 6, "MENTION"),
|
|
(40, 42, "MENTION"),
|
|
(52, 54, "MENTION"),
|
|
(95, 103, "MENTION"),
|
|
],
|
|
f"{DEFAULT_CLUSTERS_PREFIX}_1": [
|
|
(5, 6, "MENTION"), # I
|
|
(40, 42, "MENTION"),
|
|
|
|
],
|
|
f"{DEFAULT_CLUSTERS_PREFIX}_2": [
|
|
(52, 54, "MENTION"), # SMS
|
|
(95, 103, "MENTION"),
|
|
]
|
|
}
|
|
},
|
|
),
|
|
]
|
|
# fmt: on
|
|
|
|
|
|
@pytest.fixture
|
|
def nlp():
|
|
return English()
|
|
|
|
|
|
@pytest.fixture
|
|
def examples(nlp):
|
|
examples = []
|
|
for text, annot in TRAIN_DATA:
|
|
# eg = Example.from_dict(nlp.make_doc(text), annot)
|
|
# if PR #7197 is merged, replace below with above line
|
|
ref_doc = nlp.make_doc(text)
|
|
for key, span_list in annot["spans"].items():
|
|
spans = []
|
|
for span_tuple in span_list:
|
|
start_char = span_tuple[0]
|
|
end_char = span_tuple[1]
|
|
label = span_tuple[2]
|
|
span = ref_doc.char_span(start_char, end_char, label=label)
|
|
spans.append(span)
|
|
ref_doc.spans[key] = spans
|
|
eg = Example(nlp.make_doc(text), ref_doc)
|
|
examples.append(eg)
|
|
return examples
|
|
|
|
|
|
def test_coref_er_no_POS(nlp):
|
|
doc = nlp("The police woman talked to him.")
|
|
coref_er = nlp.add_pipe("coref_er", last=True)
|
|
with pytest.raises(ValueError):
|
|
coref_er(doc)
|
|
|
|
|
|
def test_coref_er_with_POS(nlp):
|
|
words = ["The", "police", "woman", "talked", "to", "him", "."]
|
|
pos = ["DET", "NOUN", "NOUN", "VERB", "ADP", "PRON", "PUNCT"]
|
|
doc = Doc(nlp.vocab, words=words, pos=pos)
|
|
coref_er = nlp.add_pipe("coref_er", last=True)
|
|
coref_er(doc)
|
|
assert len(doc.spans[coref_er.span_mentions]) == 1
|
|
mention = doc.spans[coref_er.span_mentions][0]
|
|
assert (mention.text, mention.start, mention.end) == ("him", 5, 6)
|
|
|
|
|
|
def test_coref_er_custom_POS(nlp):
|
|
words = ["The", "police", "woman", "talked", "to", "him", "."]
|
|
pos = ["DET", "NOUN", "NOUN", "VERB", "ADP", "PRON", "PUNCT"]
|
|
doc = Doc(nlp.vocab, words=words, pos=pos)
|
|
config = {"matcher_key": "POS", "matcher_values": ["NOUN"]}
|
|
coref_er = nlp.add_pipe("coref_er", last=True, config=config)
|
|
coref_er(doc)
|
|
assert len(doc.spans[coref_er.span_mentions]) == 1
|
|
mention = doc.spans[coref_er.span_mentions][0]
|
|
assert (mention.text, mention.start, mention.end) == ("police woman", 1, 3)
|
|
|
|
|
|
def test_coref_clusters(nlp, examples):
|
|
coref_er = nlp.add_pipe("coref_er", last=True)
|
|
coref = nlp.add_pipe("coref", last=True)
|
|
coref.initialize(lambda: examples)
|
|
words = ["Laura", "walked", "her", "dog", "."]
|
|
pos = ["PROPN", "VERB", "PRON", "NOUN", "PUNCT"]
|
|
doc = Doc(nlp.vocab, words=words, pos=pos)
|
|
coref_er(doc)
|
|
coref(doc)
|
|
assert len(doc.spans[coref_er.span_mentions]) > 0
|
|
found_clusters = 0
|
|
for name, spans in doc.spans.items():
|
|
if name.startswith(coref.span_cluster_prefix):
|
|
found_clusters += 1
|
|
assert found_clusters > 0
|
|
|
|
|
|
def test_coref_er_score(nlp, examples):
|
|
config = {"matcher_key": "POS", "matcher_values": []}
|
|
coref_er = nlp.add_pipe("coref_er", last=True, config=config)
|
|
coref = nlp.add_pipe("coref", last=True)
|
|
coref.initialize(lambda: examples)
|
|
mentions_key = coref_er.span_mentions
|
|
cluster_prefix_key = coref.span_cluster_prefix
|
|
matcher = PhraseMatcher(nlp.vocab)
|
|
terms_1 = ["Laura", "her", "she"]
|
|
terms_2 = ["it", "this SMS"]
|
|
matcher.add("A", [nlp.make_doc(text) for text in terms_1])
|
|
matcher.add("B", [nlp.make_doc(text) for text in terms_2])
|
|
for eg in examples:
|
|
pred = eg.predicted
|
|
matches = matcher(pred, as_spans=True)
|
|
pred.set_ents(matches)
|
|
coref_er(pred)
|
|
coref(pred)
|
|
eg.predicted = pred
|
|
# TODO: if #7209 is merged, experiment with 'include_label'
|
|
scores = coref_er.score([eg])
|
|
assert f"{mentions_key}_f" in scores
|
|
scores = coref.score([eg])
|
|
assert f"{cluster_prefix_key}_f" in scores
|
|
|
|
|
|
def test_coref_serialization(nlp):
|
|
# Test that the coref component can be serialized
|
|
config_er = {"matcher_key": "TAG", "matcher_values": ["NN"]}
|
|
nlp.add_pipe("coref_er", last=True, config=config_er)
|
|
nlp.add_pipe("coref", last=True)
|
|
assert "coref_er" in nlp.pipe_names
|
|
assert "coref" in nlp.pipe_names
|
|
|
|
with make_tempdir() as tmp_dir:
|
|
nlp.to_disk(tmp_dir)
|
|
nlp2 = spacy.load(tmp_dir)
|
|
assert "coref_er" in nlp2.pipe_names
|
|
assert "coref" in nlp2.pipe_names
|
|
coref_er_2 = nlp2.get_pipe("coref_er")
|
|
assert coref_er_2.matcher_key == "TAG"
|