Native coref component (#7243)

* initial coref_er pipe

* matcher more flexible

* base coref component without actual model

* initial setup of coref_er.score

* rename to include_label

* preliminary score_clusters method

* apply scoring in coref component

* IO fix

* return None loss for now

* rename to CoreferenceResolver

* some preliminary unit tests

* use registry as callable
This commit is contained in:
Sofie Van Landeghem 2021-03-03 13:50:14 +01:00 committed by GitHub
parent dd99872bb0
commit e0c45c669a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 829 additions and 11 deletions

View File

@ -1,3 +1,4 @@
from .coref import *
from .entity_linker import * # noqa
from .multi_task import * # noqa
from .parser import * # noqa

18
spacy/ml/models/coref.py Normal file
View File

@ -0,0 +1,18 @@
from typing import List
from thinc.api import Model
from thinc.types import Floats2d
from ...util import registry
from ...tokens import Doc
@registry.architectures("spacy.Coref.v0")
def build_coref_model(
tok2vec: Model[List[Doc], List[Floats2d]]
) -> Model:
"""Build a coref resolution model, using a provided token-to-vector component.
TODO.
tok2vec (Model[List[Doc], List[Floats2d]]): The token-to-vector subnetwork.
"""
return tok2vec

View File

@ -1,4 +1,6 @@
from .attributeruler import AttributeRuler
from .coref import CoreferenceResolver
from .coref_er import CorefEntityRecognizer
from .dep_parser import DependencyParser
from .entity_linker import EntityLinker
from .ner import EntityRecognizer

288
spacy/pipeline/coref.py Normal file
View File

@ -0,0 +1,288 @@
from typing import Iterable, Tuple, Optional, Dict, Callable, Any
from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config
from itertools import islice
from .trainable_pipe import TrainablePipe
from .coref_er import DEFAULT_MENTIONS
from ..language import Language
from ..training import Example, validate_examples, validate_get_examples
from ..errors import Errors
from ..scorer import Scorer
from ..tokens import Doc
from ..vocab import Vocab
default_config = """
[model]
@architectures = "spacy.Coref.v0"
[model.tok2vec]
@architectures = "spacy.Tok2Vec.v2"
[model.tok2vec.embed]
@architectures = "spacy.MultiHashEmbed.v1"
width = 64
rows = [2000, 2000, 1000, 1000, 1000, 1000]
attrs = ["ORTH", "LOWER", "PREFIX", "SUFFIX", "SHAPE", "ID"]
include_static_vectors = false
[model.tok2vec.encode]
@architectures = "spacy.MaxoutWindowEncoder.v2"
width = ${model.tok2vec.embed.width}
window_size = 1
maxout_pieces = 3
depth = 2
"""
DEFAULT_MODEL = Config().from_str(default_config)["model"]
DEFAULT_CLUSTERS_PREFIX = "coref_clusters"
@Language.factory(
"coref",
assigns=[f"doc.spans"],
requires=["doc.spans"],
default_config={
"model": DEFAULT_MODEL,
"span_mentions": DEFAULT_MENTIONS,
"span_cluster_prefix": DEFAULT_CLUSTERS_PREFIX,
},
default_score_weights={"coref_f": 1.0, "coref_p": None, "coref_r": None},
)
def make_coref(
nlp: Language,
name: str,
model,
span_mentions: str,
span_cluster_prefix: str,
) -> "CoreferenceResolver":
"""Create a CoreferenceResolver component. TODO
model (Model[List[Doc], List[Floats2d]]): A model instance that predicts ...
threshold (float): Cutoff to consider a prediction "positive".
"""
return CoreferenceResolver(
nlp.vocab,
model,
name,
span_mentions=span_mentions,
span_cluster_prefix=span_cluster_prefix,
)
class CoreferenceResolver(TrainablePipe):
"""Pipeline component for coreference resolution.
DOCS: https://spacy.io/api/coref (TODO)
"""
def __init__(
self,
vocab: Vocab,
model: Model,
name: str = "coref",
*,
span_mentions: str,
span_cluster_prefix: str,
) -> None:
"""Initialize a coreference resolution component.
vocab (Vocab): The shared vocabulary.
model (thinc.api.Model): The Thinc Model powering the pipeline component.
name (str): The component instance name, used to add entries to the
losses during training.
span_mentions (str): Key in doc.spans where the candidate coref mentions
are stored in.
span_cluster_prefix (str): Prefix for the key in doc.spans to store the
coref clusters in.
DOCS: https://spacy.io/api/coref#init (TODO)
"""
self.vocab = vocab
self.model = model
self.name = name
self.span_mentions = span_mentions
self.span_cluster_prefix = span_cluster_prefix
self._rehearsal_model = None
self.cfg = {}
def predict(self, docs: Iterable[Doc]):
"""Apply the pipeline's model to a batch of docs, without modifying them.
TODO: write actual algorithm
docs (Iterable[Doc]): The documents to predict.
RETURNS: The models prediction for each document.
DOCS: https://spacy.io/api/coref#predict (TODO)
"""
clusters_by_doc = []
for i, doc in enumerate(docs):
clusters = []
for span in doc.spans[self.span_mentions]:
clusters.append([span])
clusters_by_doc.append(clusters)
return clusters_by_doc
def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None:
"""Modify a batch of Doc objects, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify.
clusters: The span clusters, produced by CoreferenceResolver.predict.
DOCS: https://spacy.io/api/coref#set_annotations (TODO)
"""
if len(docs) != len(clusters_by_doc):
raise ValueError("Found coref clusters incompatible with the "
"documents provided to the 'coref' component. "
"This is likely a bug in spaCy.")
for doc, clusters in zip(docs, clusters_by_doc):
index = 0
for cluster in clusters:
key = self.span_cluster_prefix + str(index)
if key in doc.spans:
raise ValueError(f"Couldn't store the results of {self.name}, as the key "
f"{key} already exists in 'doc.spans'.")
doc.spans[key] = cluster
index += 1
def update(
self,
examples: Iterable[Example],
*,
drop: float = 0.0,
sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None,
) -> Dict[str, float]:
"""Learn from a batch of documents and gold-standard information,
updating the pipe's model. Delegates to predict and get_loss.
examples (Iterable[Example]): A batch of Example objects.
drop (float): The dropout rate.
sgd (thinc.api.Optimizer): The optimizer.
losses (Dict[str, float]): Optional record of the loss during training.
Updated using the component name as the key.
RETURNS (Dict[str, float]): The updated losses dictionary.
DOCS: https://spacy.io/api/coref#update (TODO)
"""
if losses is None:
losses = {}
losses.setdefault(self.name, 0.0)
validate_examples(examples, "CoreferenceResolver.update")
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
# Handle cases where there are no tokens in any docs.
return losses
set_dropout_rate(self.model, drop)
scores, bp_scores = self.model.begin_update([eg.predicted for eg in examples])
# TODO below
# loss, d_scores = self.get_loss(examples, scores)
# bp_scores(d_scores)
if sgd is not None:
self.finish_update(sgd)
# losses[self.name] += loss
return losses
def rehearse(
self,
examples: Iterable[Example],
*,
drop: float = 0.0,
sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None,
) -> Dict[str, float]:
"""Perform a "rehearsal" update from a batch of data. Rehearsal updates
teach the current model to make predictions similar to an initial model,
to try to address the "catastrophic forgetting" problem. This feature is
experimental.
examples (Iterable[Example]): A batch of Example objects.
drop (float): The dropout rate.
sgd (thinc.api.Optimizer): The optimizer.
losses (Dict[str, float]): Optional record of the loss during training.
Updated using the component name as the key.
RETURNS (Dict[str, float]): The updated losses dictionary.
DOCS: https://spacy.io/api/coref#rehearse (TODO)
"""
if losses is not None:
losses.setdefault(self.name, 0.0)
if self._rehearsal_model is None:
return losses
validate_examples(examples, "CoreferenceResolver.rehearse")
docs = [eg.predicted for eg in examples]
if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs.
return losses
set_dropout_rate(self.model, drop)
scores, bp_scores = self.model.begin_update(docs)
# TODO below
target = self._rehearsal_model(examples)
gradient = scores - target
bp_scores(gradient)
if sgd is not None:
self.finish_update(sgd)
if losses is not None:
losses[self.name] += (gradient ** 2).sum()
return losses
def add_label(self, label: str) -> int:
"""Technically this method should be implemented from TrainablePipe,
but it is not relevant for the coref component.
"""
raise NotImplementedError(
Errors.E931.format(
parent="CoreferenceResolver", method="add_label", name=self.name
)
)
def get_loss(self, examples: Iterable[Example], scores) -> Tuple[float, float]:
"""Find the loss and gradient of loss for the batch of documents and
their predicted scores.
examples (Iterable[Examples]): The batch of examples.
scores: Scores representing the model's predictions.
RETURNS (Tuple[float, float]): The loss and the gradient.
DOCS: https://spacy.io/api/coref#get_loss (TODO)
"""
validate_examples(examples, "CoreferenceResolver.get_loss")
# TODO
return None
def initialize(
self,
get_examples: Callable[[], Iterable[Example]],
*,
nlp: Optional[Language] = None,
) -> None:
"""Initialize the pipe for training, using a representative set
of data examples.
get_examples (Callable[[], Iterable[Example]]): Function that
returns a representative sample of gold-standard Example objects.
nlp (Language): The current nlp object the component is part of.
DOCS: https://spacy.io/api/coref#initialize (TODO)
"""
validate_get_examples(get_examples, "CoreferenceResolver.initialize")
subbatch = list(islice(get_examples(), 10))
doc_sample = [eg.reference for eg in subbatch]
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
self.model.initialize(X=doc_sample)
def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
"""Score a batch of examples.
examples (Iterable[Example]): The examples to score.
RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_coref.
DOCS: https://spacy.io/api/coref#score (TODO)
"""
def clusters_getter(doc, span_key):
return [spans for name, spans in doc.spans.items() if name.startswith(span_key)]
validate_examples(examples, "CoreferenceResolver.score")
kwargs.setdefault("getter", clusters_getter)
kwargs.setdefault("attr", self.span_cluster_prefix)
kwargs.setdefault("include_label", False)
return Scorer.score_clusters(examples, **kwargs)

227
spacy/pipeline/coref_er.py Normal file
View File

@ -0,0 +1,227 @@
from typing import Optional, Union, Iterable, Callable, List, Dict, Any
from pathlib import Path
import srsly
from .pipe import Pipe
from ..scorer import Scorer
from ..training import Example
from ..language import Language
from ..tokens import Doc, Span, SpanGroup
from ..matcher import Matcher
from ..util import ensure_path, to_disk, from_disk, SimpleFrozenList
DEFAULT_MENTIONS = "coref_mentions"
DEFAULT_MATCHER_KEY = "POS"
DEFAULT_MATCHER_VALUES = ["PROPN", "PRON"]
@Language.factory(
"coref_er",
assigns=[f"doc.spans"],
requires=["doc.ents", "token.ent_iob", "token.ent_type", "token.pos"],
default_config={
"span_mentions": DEFAULT_MENTIONS,
"matcher_key": DEFAULT_MATCHER_KEY,
"matcher_values": DEFAULT_MATCHER_VALUES,
},
default_score_weights={
"coref_mentions_f": None,
"coref_mentions_p": None,
"coref_mentions_r": 1.0, # the mentions data needs to be consistently annotated for precision rates to make sense
},
)
def make_coref_er(nlp: Language, name: str, span_mentions: str, matcher_key: str, matcher_values: List[str]):
return CorefEntityRecognizer(
nlp, name, span_mentions=span_mentions, matcher_key=matcher_key, matcher_values=matcher_values
)
class CorefEntityRecognizer(Pipe):
"""TODO.
DOCS: https://spacy.io/api/coref_er (TODO)
USAGE: https://spacy.io/usage (TODO)
"""
def __init__(
self,
nlp: Language,
name: str = "coref_er",
*,
span_mentions: str,
matcher_key: str,
matcher_values: List[str],
) -> None:
"""Initialize the entity recognizer for coreference mentions. TODO
nlp (Language): The shared nlp object.
name (str): Instance name of the current pipeline component. Typically
passed in automatically from the factory when the component is
added.
span_mentions (str): Key in doc.spans to store the coref mentions in.
matcher_key (List[str]): Field for the matcher to work on (e.g. "POS" or "TAG")
matcher_values (List[str]): Values to match token sequences as
plausible coref mentions
DOCS: https://spacy.io/api/coref_er#init (TODO)
"""
self.nlp = nlp
self.name = name
self.span_mentions = span_mentions
self.matcher_key = matcher_key
self.matcher_values = matcher_values
self.matcher = Matcher(nlp.vocab)
# TODO: allow to specify any matcher patterns instead?
for value in matcher_values:
self.matcher.add(
f"{value}_SEQ", [[{matcher_key: value, "OP": "+"}]], greedy="LONGEST"
)
@staticmethod
def _string_offset(span: Span):
return f"{span.start}-{span.end}"
def __call__(self, doc: Doc) -> Doc:
"""Find relevant coref mentions in the document and add them
to the doc's relevant span container.
doc (Doc): The Doc object in the pipeline.
RETURNS (Doc): The Doc with added entities, if available.
DOCS: https://spacy.io/api/coref_er#call (TODO)
"""
error_handler = self.get_error_handler()
try:
# Add NER
spans = list(doc.ents)
offsets = set()
offsets.update([self._string_offset(e) for e in doc.ents])
# pronouns and proper nouns
try:
matches = self.matcher(doc, as_spans=True)
except ValueError:
raise ValueError(f"Could not run the matcher for 'coref_er'. If {self.matcher_key} tags "
"are not available, change the 'matcher_key' in the config, "
"or set matcher_values to an empty list.")
spans.extend([m for m in matches if self._string_offset(m) not in offsets])
offsets.update([self._string_offset(m) for m in matches])
# noun_chunks - only if implemented and parsing information is available
try:
spans.extend(
[nc for nc in doc.noun_chunks if self._string_offset(nc) not in offsets]
)
offsets.update([self._string_offset(nc) for nc in doc.noun_chunks])
except (NotImplementedError, ValueError):
pass
self.set_annotations(doc, spans)
return doc
except Exception as e:
error_handler(self.name, self, [doc], e)
def set_annotations(self, doc, spans):
"""Modify the document in place"""
group = SpanGroup(doc, name=self.span_mentions, spans=spans)
if self.span_mentions in doc.spans:
raise ValueError(f"Couldn't store the results of {self.name}, as the key "
f"{self.span_mentions} already exists in 'doc.spans'.")
doc.spans[self.span_mentions] = group
def initialize(
self,
get_examples: Callable[[], Iterable[Example]],
*,
nlp: Optional[Language] = None,
):
"""Initialize the pipe for training.
get_examples (Callable[[], Iterable[Example]]): Function that
returns a representative sample of gold-standard Example objects.
nlp (Language): The current nlp object the component is part of.
DOCS: https://spacy.io/api/coref_er#initialize (TODO)
"""
pass
def from_bytes(
self, bytes_data: bytes, *, exclude: Iterable[str] = SimpleFrozenList()
) -> "CorefEntityRecognizer":
"""Load the coreference entity recognizer from a bytestring.
bytes_data (bytes): The bytestring to load.
RETURNS (CorefEntityRecognizer): The loaded coreference entity recognizer.
DOCS: https://spacy.io/api/coref_er#from_bytes
"""
cfg = srsly.msgpack_loads(bytes_data)
self.span_mentions = cfg.get("span_mentions", DEFAULT_MENTIONS)
self.matcher_key = cfg.get("matcher_key", DEFAULT_MATCHER_KEY)
self.matcher_values = cfg.get("matcher_values", DEFAULT_MATCHER_VALUES)
return self
def to_bytes(self, *, exclude: Iterable[str] = SimpleFrozenList()) -> bytes:
"""Serialize the coreference entity recognizer to a bytestring.
RETURNS (bytes): The serialized component.
DOCS: https://spacy.io/api/coref_er#to_bytes (TODO)
"""
serial = {"span_mentions": self.span_mentions}
return srsly.msgpack_dumps(serial)
def from_disk(
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
) -> "CorefEntityRecognizer":
"""Load the coreference entity recognizer from a file.
path (str / Path): The JSONL file to load.
RETURNS (CorefEntityRecognizer): The loaded coreference entity recognizer .
DOCS: https://spacy.io/api/coref_er#from_disk (TODO)
"""
path = ensure_path(path)
cfg = {}
deserializers_cfg = {"cfg": lambda p: cfg.update(srsly.read_json(p))}
from_disk(path, deserializers_cfg, {})
self.span_mentions = cfg.get("span_mentions", DEFAULT_MENTIONS)
self.matcher_key = cfg.get("matcher_key", DEFAULT_MATCHER_KEY)
self.matcher_values = cfg.get("matcher_values", DEFAULT_MATCHER_VALUES)
return self
def to_disk(
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
) -> None:
"""Save the coreference entity recognizer to a directory.
path (str / Path): The JSONL file to save.
DOCS: https://spacy.io/api/coref_er#to_disk (TODO)
"""
path = ensure_path(path)
cfg = {
"span_mentions": self.span_mentions,
"matcher_key": self.matcher_key,
"matcher_values": self.matcher_values,
}
serializers = {"cfg": lambda p: srsly.write_json(p, cfg)}
to_disk(path, serializers, {})
def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
"""Score a batch of examples.
examples (Iterable[Example]): The examples to score.
RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_coref.
DOCS: https://spacy.io/api/coref_er#score (TODO)
"""
def mentions_getter(doc, span_key):
return doc.spans[span_key]
# This will work better once PR 7209 is merged
kwargs.setdefault("getter", mentions_getter)
kwargs.setdefault("attr", self.span_mentions)
kwargs.setdefault("include_label", False)
kwargs.setdefault("allow_overlap", True)
return Scorer.score_spans(examples, **kwargs)

View File

@ -56,8 +56,7 @@ class EntityRuler(Pipe):
"""The EntityRuler lets you add spans to the `Doc.ents` using token-based
rules or exact phrase matches. It can be combined with the statistical
`EntityRecognizer` to boost accuracy, or used on its own to implement a
purely rule-based entity recognition system. After initialization, the
component is typically added to the pipeline using `nlp.add_pipe`.
purely rule-based entity recognition system.
DOCS: https://spacy.io/api/entityruler
USAGE: https://spacy.io/usage/rule-based-matching#entityruler

View File

@ -1,8 +1,8 @@
from itertools import islice
from typing import Iterable, Tuple, Optional, Dict, List, Callable, Any
from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config
from thinc.types import Floats2d
import numpy
from itertools import islice
from .trainable_pipe import TrainablePipe
from ..language import Language

View File

@ -341,7 +341,7 @@ class Scorer:
for label in labels:
if label not in score_per_type:
score_per_type[label] = PRFScore()
# Find all predidate labels, for all and per type
# Find all instances, for all and per type
gold_spans = set()
pred_spans = set()
for span in getter(gold_doc, attr):
@ -373,6 +373,114 @@ class Scorer:
f"{attr}_per_type": None,
}
@staticmethod
def score_clusters(
examples: Iterable[Example],
attr: str,
*,
getter: Callable[[Doc, str], Iterable[Iterable[Span]]] = getattr,
has_annotation: Optional[Callable[[Doc], bool]] = None,
include_label: bool = True,
**cfg,
) -> Dict[str, Any]:
"""Returns PRF scores for clustered spans.
examples (Iterable[Example]): Examples to score
attr (str): The attribute to score.
getter (Callable[[Doc, str], Iterable[Iterable[Span]]]): Defaults to getattr.
If provided, getter(doc, attr) should return the lists of spans for the
individual doc.
has_annotation (Optional[Callable[[Doc], bool]]) should return whether a `Doc`
has annotation for this `attr`. Docs without annotation are skipped for
scoring purposes.
include_label (bool): Whether or not to include label information in
the evaluation. If set to 'False', two spans will be considered
equal if their start and end match, irrespective of their label.
RETURNS (Dict[str, Any]): A dictionary containing the PRF scores under
the keys attr_p/r/f and the per-type PRF scores under attr_per_type.
DOCS: https://spacy.io/api/scorer#score_clusters (TODO)
"""
# Note: the current implementation just scores binary pairs on whether they
# are in the same cluster or not.
# TODO: look at different cluster/coreference scoring techniques.
score = PRFScore()
score_per_type = dict()
for example in examples:
pred_doc = example.predicted
gold_doc = example.reference
# Option to handle docs without sents
if has_annotation is not None:
if not has_annotation(gold_doc):
continue
# Find all labels in gold and doc
gold_clusters = list(getter(gold_doc, attr))
pred_clusters = list(getter(pred_doc, attr))
labels = set(
[span.label_ for span_list in gold_clusters for span in span_list]
+ [span.label_ for span_list in pred_clusters for span in span_list]
)
# Set up all labels for per type scoring and prepare gold per type
for label in labels:
if label not in score_per_type:
score_per_type[label] = PRFScore()
# Find all instances, for all and per type
gold_instances = set()
gold_per_type = {label: set() for label in labels}
for gold_cluster in gold_clusters:
for span1 in gold_cluster:
for span2 in gold_cluster:
# only record pairs where span1 comes before span2
if (span1.start < span2.start) or (span1.start == span2.start and span1.end < span2.end):
if include_label:
gold_rel = (span1.label_, span1.start, span1.end - 1, span2.label_, span2.start, span2.end - 1)
else:
gold_rel = (span1.start, span1.end - 1, span2.start, span2.end - 1)
gold_instances.add(gold_rel)
if span1.label_ == span2.label_:
gold_per_type[span1.label_].add(gold_rel)
pred_instances = set()
pred_per_type = {label: set() for label in labels}
for pred_cluster in pred_clusters:
for span1 in pred_cluster:
for span2 in pred_cluster:
if (span1.start < span2.start) or (span1.start == span2.start and span1.end < span2.end):
if include_label:
pred_rel = (span1.label_, span1.start, span1.end - 1, span2.label_, span2.start, span2.end - 1)
else:
pred_rel = (span1.start, span1.end - 1, span2.start, span2.end - 1)
pred_instances.add(pred_rel)
if span1.label_ == span2.label_:
pred_per_type[span1.label_].add(pred_rel)
# Scores per label
if include_label:
for k, v in score_per_type.items():
if k in pred_per_type:
v.score_set(pred_per_type[k], gold_per_type[k])
# Score for all labels
score.score_set(pred_instances, gold_instances)
# Assemble final result
final_scores = {
f"{attr}_p": None,
f"{attr}_r": None,
f"{attr}_f": None,
}
if include_label:
final_scores[f"{attr}_per_type"] = None
if len(score) > 0:
final_scores[f"{attr}_p"] = score.precision
final_scores[f"{attr}_r"] = score.recall
final_scores[f"{attr}_f"] = score.fscore
return {
f"{attr}_p": None,
f"{attr}_r": None,
f"{attr}_f": None,
f"{attr}_per_type": None,
}
return final_scores
@staticmethod
def score_cats(
examples: Iterable[Example],
@ -722,12 +830,7 @@ def get_ner_prf(examples: Iterable[Example]) -> Dict[str, Any]:
"ents_per_type": {k: v.to_dict() for k, v in score_per_type.items()},
}
else:
return {
"ents_p": None,
"ents_r": None,
"ents_f": None,
"ents_per_type": None,
}
return {"ents_p": None, "ents_r": None, "ents_f": None, "ents_per_type": None}
# The following implementation of roc_auc_score() is adapted from

View File

@ -0,0 +1,180 @@
import pytest
import spacy
from spacy.matcher import PhraseMatcher
from spacy.training import Example
from spacy.lang.en import English
from spacy.tests.util import make_tempdir
from spacy.tokens import Doc
from spacy.pipeline.coref import DEFAULT_CLUSTERS_PREFIX
from spacy.pipeline.coref_er import DEFAULT_MENTIONS
# fmt: off
TRAIN_DATA = [
(
"John Smith told Laura that he was running late and asked her whether she could pick up their kids.",
{
"spans": {
DEFAULT_MENTIONS: [
(0, 10, "MENTION"),
(16, 21, "MENTION"),
(27, 29, "MENTION"),
(57, 60, "MENTION"),
(69, 72, "MENTION"),
(87, 92, "MENTION"),
(87, 97, "MENTION"),
],
f"{DEFAULT_CLUSTERS_PREFIX}_1": [
(0, 10, "MENTION"), # John
(27, 29, "MENTION"),
(87, 92, "MENTION"), # 'their' refers to John and Laur
],
f"{DEFAULT_CLUSTERS_PREFIX}_2": [
(16, 21, "MENTION"), # Laura
(57, 60, "MENTION"),
(69, 72, "MENTION"),
(87, 92, "MENTION"), # 'their' refers to John and Laura
],
}
},
),
(
"Yes, I noticed that many friends around me received it. It seems that almost everyone received this SMS.",
{
"spans": {
DEFAULT_MENTIONS: [
(5, 6, "MENTION"),
(40, 42, "MENTION"),
(52, 54, "MENTION"),
(95, 103, "MENTION"),
],
f"{DEFAULT_CLUSTERS_PREFIX}_1": [
(5, 6, "MENTION"), # I
(40, 42, "MENTION"),
],
f"{DEFAULT_CLUSTERS_PREFIX}_2": [
(52, 54, "MENTION"), # SMS
(95, 103, "MENTION"),
]
}
},
),
]
# fmt: on
@pytest.fixture
def nlp():
return English()
@pytest.fixture
def examples(nlp):
examples = []
for text, annot in TRAIN_DATA:
# eg = Example.from_dict(nlp.make_doc(text), annot)
# if PR #7197 is merged, replace below with above line
ref_doc = nlp.make_doc(text)
for key, span_list in annot["spans"].items():
spans = []
for span_tuple in span_list:
start_char = span_tuple[0]
end_char = span_tuple[1]
label = span_tuple[2]
span = ref_doc.char_span(start_char, end_char, label=label)
spans.append(span)
ref_doc.spans[key] = spans
eg = Example(nlp.make_doc(text), ref_doc)
examples.append(eg)
return examples
def test_coref_er_no_POS(nlp):
doc = nlp("The police woman talked to him.")
coref_er = nlp.add_pipe("coref_er", last=True)
with pytest.raises(ValueError):
coref_er(doc)
def test_coref_er_with_POS(nlp):
words = ["The", "police", "woman", "talked", "to", "him", "."]
pos = ["DET", "NOUN", "NOUN", "VERB", "ADP", "PRON", "PUNCT"]
doc = Doc(nlp.vocab, words=words, pos=pos)
coref_er = nlp.add_pipe("coref_er", last=True)
coref_er(doc)
assert len(doc.spans[coref_er.span_mentions]) == 1
mention = doc.spans[coref_er.span_mentions][0]
assert (mention.text, mention.start, mention.end) == ("him", 5, 6)
def test_coref_er_custom_POS(nlp):
words = ["The", "police", "woman", "talked", "to", "him", "."]
pos = ["DET", "NOUN", "NOUN", "VERB", "ADP", "PRON", "PUNCT"]
doc = Doc(nlp.vocab, words=words, pos=pos)
config = {"matcher_key": "POS", "matcher_values": ["NOUN"]}
coref_er = nlp.add_pipe("coref_er", last=True, config=config)
coref_er(doc)
assert len(doc.spans[coref_er.span_mentions]) == 1
mention = doc.spans[coref_er.span_mentions][0]
assert (mention.text, mention.start, mention.end) == ("police woman", 1, 3)
def test_coref_clusters(nlp, examples):
coref_er = nlp.add_pipe("coref_er", last=True)
coref = nlp.add_pipe("coref", last=True)
coref.initialize(lambda: examples)
words = ["Laura", "walked", "her", "dog", "."]
pos = ["PROPN", "VERB", "PRON", "NOUN", "PUNCT"]
doc = Doc(nlp.vocab, words=words, pos=pos)
coref_er(doc)
coref(doc)
assert len(doc.spans[coref_er.span_mentions]) > 0
found_clusters = 0
for name, spans in doc.spans.items():
if name.startswith(coref.span_cluster_prefix):
found_clusters += 1
assert found_clusters > 0
def test_coref_er_score(nlp, examples):
config = {"matcher_key": "POS", "matcher_values": []}
coref_er = nlp.add_pipe("coref_er", last=True, config=config)
coref = nlp.add_pipe("coref", last=True)
coref.initialize(lambda: examples)
mentions_key = coref_er.span_mentions
cluster_prefix_key = coref.span_cluster_prefix
matcher = PhraseMatcher(nlp.vocab)
terms_1 = ["Laura", "her", "she"]
terms_2 = ["it", "this SMS"]
matcher.add("A", [nlp.make_doc(text) for text in terms_1])
matcher.add("B", [nlp.make_doc(text) for text in terms_2])
for eg in examples:
pred = eg.predicted
matches = matcher(pred, as_spans=True)
pred.set_ents(matches)
coref_er(pred)
coref(pred)
eg.predicted = pred
# TODO: if #7209 is merged, experiment with 'include_label'
scores = coref_er.score([eg])
assert f"{mentions_key}_f" in scores
scores = coref.score([eg])
assert f"{cluster_prefix_key}_f" in scores
def test_coref_serialization(nlp):
# Test that the coref component can be serialized
config_er = {"matcher_key": "TAG", "matcher_values": ["NN"]}
nlp.add_pipe("coref_er", last=True, config=config_er)
nlp.add_pipe("coref", last=True)
assert "coref_er" in nlp.pipe_names
assert "coref" in nlp.pipe_names
with make_tempdir() as tmp_dir:
nlp.to_disk(tmp_dir)
nlp2 = spacy.load(tmp_dir)
assert "coref_er" in nlp2.pipe_names
assert "coref" in nlp2.pipe_names
coref_er_2 = nlp2.get_pipe("coref_er")
assert coref_er_2.matcher_key == "TAG"

View File

@ -13,7 +13,7 @@ if TYPE_CHECKING:
# Why inherit from UserDict instead of dict here?
# Well, the 'dict' class doesn't necessarily delegate everything nicely,
# for performance reasons. The UserDict is slower by better behaved.
# for performance reasons. The UserDict is slower but better behaved.
# See https://treyhunner.com/2019/04/why-you-shouldnt-inherit-from-list-and-dict-in-python/0ww
class SpanGroups(UserDict):
"""A dict-like proxy held by the Doc, to control access to span groups."""