mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-09 14:44:52 +03:00
Native coref component (#7243)
* initial coref_er pipe * matcher more flexible * base coref component without actual model * initial setup of coref_er.score * rename to include_label * preliminary score_clusters method * apply scoring in coref component * IO fix * return None loss for now * rename to CoreferenceResolver * some preliminary unit tests * use registry as callable
This commit is contained in:
parent
dd99872bb0
commit
e0c45c669a
|
@ -1,3 +1,4 @@
|
|||
from .coref import *
|
||||
from .entity_linker import * # noqa
|
||||
from .multi_task import * # noqa
|
||||
from .parser import * # noqa
|
||||
|
|
18
spacy/ml/models/coref.py
Normal file
18
spacy/ml/models/coref.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
from typing import List
|
||||
from thinc.api import Model
|
||||
from thinc.types import Floats2d
|
||||
|
||||
from ...util import registry
|
||||
from ...tokens import Doc
|
||||
|
||||
|
||||
@registry.architectures("spacy.Coref.v0")
|
||||
def build_coref_model(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]]
|
||||
) -> Model:
|
||||
"""Build a coref resolution model, using a provided token-to-vector component.
|
||||
TODO.
|
||||
|
||||
tok2vec (Model[List[Doc], List[Floats2d]]): The token-to-vector subnetwork.
|
||||
"""
|
||||
return tok2vec
|
|
@ -1,4 +1,6 @@
|
|||
from .attributeruler import AttributeRuler
|
||||
from .coref import CoreferenceResolver
|
||||
from .coref_er import CorefEntityRecognizer
|
||||
from .dep_parser import DependencyParser
|
||||
from .entity_linker import EntityLinker
|
||||
from .ner import EntityRecognizer
|
||||
|
|
288
spacy/pipeline/coref.py
Normal file
288
spacy/pipeline/coref.py
Normal file
|
@ -0,0 +1,288 @@
|
|||
from typing import Iterable, Tuple, Optional, Dict, Callable, Any
|
||||
|
||||
from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config
|
||||
from itertools import islice
|
||||
|
||||
from .trainable_pipe import TrainablePipe
|
||||
from .coref_er import DEFAULT_MENTIONS
|
||||
from ..language import Language
|
||||
from ..training import Example, validate_examples, validate_get_examples
|
||||
from ..errors import Errors
|
||||
from ..scorer import Scorer
|
||||
from ..tokens import Doc
|
||||
from ..vocab import Vocab
|
||||
|
||||
|
||||
default_config = """
|
||||
[model]
|
||||
@architectures = "spacy.Coref.v0"
|
||||
|
||||
[model.tok2vec]
|
||||
@architectures = "spacy.Tok2Vec.v2"
|
||||
|
||||
[model.tok2vec.embed]
|
||||
@architectures = "spacy.MultiHashEmbed.v1"
|
||||
width = 64
|
||||
rows = [2000, 2000, 1000, 1000, 1000, 1000]
|
||||
attrs = ["ORTH", "LOWER", "PREFIX", "SUFFIX", "SHAPE", "ID"]
|
||||
include_static_vectors = false
|
||||
|
||||
[model.tok2vec.encode]
|
||||
@architectures = "spacy.MaxoutWindowEncoder.v2"
|
||||
width = ${model.tok2vec.embed.width}
|
||||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
depth = 2
|
||||
"""
|
||||
DEFAULT_MODEL = Config().from_str(default_config)["model"]
|
||||
|
||||
DEFAULT_CLUSTERS_PREFIX = "coref_clusters"
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"coref",
|
||||
assigns=[f"doc.spans"],
|
||||
requires=["doc.spans"],
|
||||
default_config={
|
||||
"model": DEFAULT_MODEL,
|
||||
"span_mentions": DEFAULT_MENTIONS,
|
||||
"span_cluster_prefix": DEFAULT_CLUSTERS_PREFIX,
|
||||
},
|
||||
default_score_weights={"coref_f": 1.0, "coref_p": None, "coref_r": None},
|
||||
)
|
||||
def make_coref(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
model,
|
||||
span_mentions: str,
|
||||
span_cluster_prefix: str,
|
||||
) -> "CoreferenceResolver":
|
||||
"""Create a CoreferenceResolver component. TODO
|
||||
|
||||
model (Model[List[Doc], List[Floats2d]]): A model instance that predicts ...
|
||||
threshold (float): Cutoff to consider a prediction "positive".
|
||||
"""
|
||||
return CoreferenceResolver(
|
||||
nlp.vocab,
|
||||
model,
|
||||
name,
|
||||
span_mentions=span_mentions,
|
||||
span_cluster_prefix=span_cluster_prefix,
|
||||
)
|
||||
|
||||
|
||||
class CoreferenceResolver(TrainablePipe):
|
||||
"""Pipeline component for coreference resolution.
|
||||
|
||||
DOCS: https://spacy.io/api/coref (TODO)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab: Vocab,
|
||||
model: Model,
|
||||
name: str = "coref",
|
||||
*,
|
||||
span_mentions: str,
|
||||
span_cluster_prefix: str,
|
||||
) -> None:
|
||||
"""Initialize a coreference resolution component.
|
||||
|
||||
vocab (Vocab): The shared vocabulary.
|
||||
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
||||
name (str): The component instance name, used to add entries to the
|
||||
losses during training.
|
||||
span_mentions (str): Key in doc.spans where the candidate coref mentions
|
||||
are stored in.
|
||||
span_cluster_prefix (str): Prefix for the key in doc.spans to store the
|
||||
coref clusters in.
|
||||
|
||||
DOCS: https://spacy.io/api/coref#init (TODO)
|
||||
"""
|
||||
self.vocab = vocab
|
||||
self.model = model
|
||||
self.name = name
|
||||
self.span_mentions = span_mentions
|
||||
self.span_cluster_prefix = span_cluster_prefix
|
||||
self._rehearsal_model = None
|
||||
self.cfg = {}
|
||||
|
||||
def predict(self, docs: Iterable[Doc]):
|
||||
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||
TODO: write actual algorithm
|
||||
|
||||
docs (Iterable[Doc]): The documents to predict.
|
||||
RETURNS: The models prediction for each document.
|
||||
|
||||
DOCS: https://spacy.io/api/coref#predict (TODO)
|
||||
"""
|
||||
clusters_by_doc = []
|
||||
for i, doc in enumerate(docs):
|
||||
clusters = []
|
||||
for span in doc.spans[self.span_mentions]:
|
||||
clusters.append([span])
|
||||
clusters_by_doc.append(clusters)
|
||||
return clusters_by_doc
|
||||
|
||||
def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None:
|
||||
"""Modify a batch of Doc objects, using pre-computed scores.
|
||||
|
||||
docs (Iterable[Doc]): The documents to modify.
|
||||
clusters: The span clusters, produced by CoreferenceResolver.predict.
|
||||
|
||||
DOCS: https://spacy.io/api/coref#set_annotations (TODO)
|
||||
"""
|
||||
if len(docs) != len(clusters_by_doc):
|
||||
raise ValueError("Found coref clusters incompatible with the "
|
||||
"documents provided to the 'coref' component. "
|
||||
"This is likely a bug in spaCy.")
|
||||
for doc, clusters in zip(docs, clusters_by_doc):
|
||||
index = 0
|
||||
for cluster in clusters:
|
||||
key = self.span_cluster_prefix + str(index)
|
||||
if key in doc.spans:
|
||||
raise ValueError(f"Couldn't store the results of {self.name}, as the key "
|
||||
f"{key} already exists in 'doc.spans'.")
|
||||
doc.spans[key] = cluster
|
||||
index += 1
|
||||
|
||||
def update(
|
||||
self,
|
||||
examples: Iterable[Example],
|
||||
*,
|
||||
drop: float = 0.0,
|
||||
sgd: Optional[Optimizer] = None,
|
||||
losses: Optional[Dict[str, float]] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""Learn from a batch of documents and gold-standard information,
|
||||
updating the pipe's model. Delegates to predict and get_loss.
|
||||
|
||||
examples (Iterable[Example]): A batch of Example objects.
|
||||
drop (float): The dropout rate.
|
||||
sgd (thinc.api.Optimizer): The optimizer.
|
||||
losses (Dict[str, float]): Optional record of the loss during training.
|
||||
Updated using the component name as the key.
|
||||
RETURNS (Dict[str, float]): The updated losses dictionary.
|
||||
|
||||
DOCS: https://spacy.io/api/coref#update (TODO)
|
||||
"""
|
||||
if losses is None:
|
||||
losses = {}
|
||||
losses.setdefault(self.name, 0.0)
|
||||
validate_examples(examples, "CoreferenceResolver.update")
|
||||
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
|
||||
# Handle cases where there are no tokens in any docs.
|
||||
return losses
|
||||
set_dropout_rate(self.model, drop)
|
||||
scores, bp_scores = self.model.begin_update([eg.predicted for eg in examples])
|
||||
# TODO below
|
||||
# loss, d_scores = self.get_loss(examples, scores)
|
||||
# bp_scores(d_scores)
|
||||
if sgd is not None:
|
||||
self.finish_update(sgd)
|
||||
# losses[self.name] += loss
|
||||
return losses
|
||||
|
||||
def rehearse(
|
||||
self,
|
||||
examples: Iterable[Example],
|
||||
*,
|
||||
drop: float = 0.0,
|
||||
sgd: Optional[Optimizer] = None,
|
||||
losses: Optional[Dict[str, float]] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""Perform a "rehearsal" update from a batch of data. Rehearsal updates
|
||||
teach the current model to make predictions similar to an initial model,
|
||||
to try to address the "catastrophic forgetting" problem. This feature is
|
||||
experimental.
|
||||
|
||||
examples (Iterable[Example]): A batch of Example objects.
|
||||
drop (float): The dropout rate.
|
||||
sgd (thinc.api.Optimizer): The optimizer.
|
||||
losses (Dict[str, float]): Optional record of the loss during training.
|
||||
Updated using the component name as the key.
|
||||
RETURNS (Dict[str, float]): The updated losses dictionary.
|
||||
|
||||
DOCS: https://spacy.io/api/coref#rehearse (TODO)
|
||||
"""
|
||||
if losses is not None:
|
||||
losses.setdefault(self.name, 0.0)
|
||||
if self._rehearsal_model is None:
|
||||
return losses
|
||||
validate_examples(examples, "CoreferenceResolver.rehearse")
|
||||
docs = [eg.predicted for eg in examples]
|
||||
if not any(len(doc) for doc in docs):
|
||||
# Handle cases where there are no tokens in any docs.
|
||||
return losses
|
||||
set_dropout_rate(self.model, drop)
|
||||
scores, bp_scores = self.model.begin_update(docs)
|
||||
# TODO below
|
||||
target = self._rehearsal_model(examples)
|
||||
gradient = scores - target
|
||||
bp_scores(gradient)
|
||||
if sgd is not None:
|
||||
self.finish_update(sgd)
|
||||
if losses is not None:
|
||||
losses[self.name] += (gradient ** 2).sum()
|
||||
return losses
|
||||
|
||||
def add_label(self, label: str) -> int:
|
||||
"""Technically this method should be implemented from TrainablePipe,
|
||||
but it is not relevant for the coref component.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
Errors.E931.format(
|
||||
parent="CoreferenceResolver", method="add_label", name=self.name
|
||||
)
|
||||
)
|
||||
|
||||
def get_loss(self, examples: Iterable[Example], scores) -> Tuple[float, float]:
|
||||
"""Find the loss and gradient of loss for the batch of documents and
|
||||
their predicted scores.
|
||||
|
||||
examples (Iterable[Examples]): The batch of examples.
|
||||
scores: Scores representing the model's predictions.
|
||||
RETURNS (Tuple[float, float]): The loss and the gradient.
|
||||
|
||||
DOCS: https://spacy.io/api/coref#get_loss (TODO)
|
||||
"""
|
||||
validate_examples(examples, "CoreferenceResolver.get_loss")
|
||||
# TODO
|
||||
return None
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
get_examples: Callable[[], Iterable[Example]],
|
||||
*,
|
||||
nlp: Optional[Language] = None,
|
||||
) -> None:
|
||||
"""Initialize the pipe for training, using a representative set
|
||||
of data examples.
|
||||
|
||||
get_examples (Callable[[], Iterable[Example]]): Function that
|
||||
returns a representative sample of gold-standard Example objects.
|
||||
nlp (Language): The current nlp object the component is part of.
|
||||
|
||||
DOCS: https://spacy.io/api/coref#initialize (TODO)
|
||||
"""
|
||||
validate_get_examples(get_examples, "CoreferenceResolver.initialize")
|
||||
subbatch = list(islice(get_examples(), 10))
|
||||
doc_sample = [eg.reference for eg in subbatch]
|
||||
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
||||
self.model.initialize(X=doc_sample)
|
||||
|
||||
def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
||||
"""Score a batch of examples.
|
||||
|
||||
examples (Iterable[Example]): The examples to score.
|
||||
RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_coref.
|
||||
|
||||
DOCS: https://spacy.io/api/coref#score (TODO)
|
||||
"""
|
||||
def clusters_getter(doc, span_key):
|
||||
return [spans for name, spans in doc.spans.items() if name.startswith(span_key)]
|
||||
validate_examples(examples, "CoreferenceResolver.score")
|
||||
kwargs.setdefault("getter", clusters_getter)
|
||||
kwargs.setdefault("attr", self.span_cluster_prefix)
|
||||
kwargs.setdefault("include_label", False)
|
||||
return Scorer.score_clusters(examples, **kwargs)
|
227
spacy/pipeline/coref_er.py
Normal file
227
spacy/pipeline/coref_er.py
Normal file
|
@ -0,0 +1,227 @@
|
|||
from typing import Optional, Union, Iterable, Callable, List, Dict, Any
|
||||
from pathlib import Path
|
||||
import srsly
|
||||
|
||||
from .pipe import Pipe
|
||||
from ..scorer import Scorer
|
||||
from ..training import Example
|
||||
from ..language import Language
|
||||
from ..tokens import Doc, Span, SpanGroup
|
||||
from ..matcher import Matcher
|
||||
from ..util import ensure_path, to_disk, from_disk, SimpleFrozenList
|
||||
|
||||
|
||||
DEFAULT_MENTIONS = "coref_mentions"
|
||||
DEFAULT_MATCHER_KEY = "POS"
|
||||
DEFAULT_MATCHER_VALUES = ["PROPN", "PRON"]
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"coref_er",
|
||||
assigns=[f"doc.spans"],
|
||||
requires=["doc.ents", "token.ent_iob", "token.ent_type", "token.pos"],
|
||||
default_config={
|
||||
"span_mentions": DEFAULT_MENTIONS,
|
||||
"matcher_key": DEFAULT_MATCHER_KEY,
|
||||
"matcher_values": DEFAULT_MATCHER_VALUES,
|
||||
},
|
||||
default_score_weights={
|
||||
"coref_mentions_f": None,
|
||||
"coref_mentions_p": None,
|
||||
"coref_mentions_r": 1.0, # the mentions data needs to be consistently annotated for precision rates to make sense
|
||||
},
|
||||
)
|
||||
def make_coref_er(nlp: Language, name: str, span_mentions: str, matcher_key: str, matcher_values: List[str]):
|
||||
return CorefEntityRecognizer(
|
||||
nlp, name, span_mentions=span_mentions, matcher_key=matcher_key, matcher_values=matcher_values
|
||||
)
|
||||
|
||||
|
||||
class CorefEntityRecognizer(Pipe):
|
||||
"""TODO.
|
||||
|
||||
DOCS: https://spacy.io/api/coref_er (TODO)
|
||||
USAGE: https://spacy.io/usage (TODO)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nlp: Language,
|
||||
name: str = "coref_er",
|
||||
*,
|
||||
span_mentions: str,
|
||||
matcher_key: str,
|
||||
matcher_values: List[str],
|
||||
) -> None:
|
||||
"""Initialize the entity recognizer for coreference mentions. TODO
|
||||
|
||||
nlp (Language): The shared nlp object.
|
||||
name (str): Instance name of the current pipeline component. Typically
|
||||
passed in automatically from the factory when the component is
|
||||
added.
|
||||
span_mentions (str): Key in doc.spans to store the coref mentions in.
|
||||
matcher_key (List[str]): Field for the matcher to work on (e.g. "POS" or "TAG")
|
||||
matcher_values (List[str]): Values to match token sequences as
|
||||
plausible coref mentions
|
||||
|
||||
DOCS: https://spacy.io/api/coref_er#init (TODO)
|
||||
"""
|
||||
self.nlp = nlp
|
||||
self.name = name
|
||||
self.span_mentions = span_mentions
|
||||
self.matcher_key = matcher_key
|
||||
self.matcher_values = matcher_values
|
||||
self.matcher = Matcher(nlp.vocab)
|
||||
# TODO: allow to specify any matcher patterns instead?
|
||||
for value in matcher_values:
|
||||
self.matcher.add(
|
||||
f"{value}_SEQ", [[{matcher_key: value, "OP": "+"}]], greedy="LONGEST"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _string_offset(span: Span):
|
||||
return f"{span.start}-{span.end}"
|
||||
|
||||
def __call__(self, doc: Doc) -> Doc:
|
||||
"""Find relevant coref mentions in the document and add them
|
||||
to the doc's relevant span container.
|
||||
|
||||
doc (Doc): The Doc object in the pipeline.
|
||||
RETURNS (Doc): The Doc with added entities, if available.
|
||||
|
||||
DOCS: https://spacy.io/api/coref_er#call (TODO)
|
||||
"""
|
||||
error_handler = self.get_error_handler()
|
||||
try:
|
||||
# Add NER
|
||||
spans = list(doc.ents)
|
||||
offsets = set()
|
||||
offsets.update([self._string_offset(e) for e in doc.ents])
|
||||
|
||||
# pronouns and proper nouns
|
||||
try:
|
||||
matches = self.matcher(doc, as_spans=True)
|
||||
except ValueError:
|
||||
raise ValueError(f"Could not run the matcher for 'coref_er'. If {self.matcher_key} tags "
|
||||
"are not available, change the 'matcher_key' in the config, "
|
||||
"or set matcher_values to an empty list.")
|
||||
spans.extend([m for m in matches if self._string_offset(m) not in offsets])
|
||||
offsets.update([self._string_offset(m) for m in matches])
|
||||
|
||||
# noun_chunks - only if implemented and parsing information is available
|
||||
try:
|
||||
spans.extend(
|
||||
[nc for nc in doc.noun_chunks if self._string_offset(nc) not in offsets]
|
||||
)
|
||||
offsets.update([self._string_offset(nc) for nc in doc.noun_chunks])
|
||||
except (NotImplementedError, ValueError):
|
||||
pass
|
||||
|
||||
self.set_annotations(doc, spans)
|
||||
return doc
|
||||
except Exception as e:
|
||||
error_handler(self.name, self, [doc], e)
|
||||
|
||||
def set_annotations(self, doc, spans):
|
||||
"""Modify the document in place"""
|
||||
group = SpanGroup(doc, name=self.span_mentions, spans=spans)
|
||||
if self.span_mentions in doc.spans:
|
||||
raise ValueError(f"Couldn't store the results of {self.name}, as the key "
|
||||
f"{self.span_mentions} already exists in 'doc.spans'.")
|
||||
doc.spans[self.span_mentions] = group
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
get_examples: Callable[[], Iterable[Example]],
|
||||
*,
|
||||
nlp: Optional[Language] = None,
|
||||
):
|
||||
"""Initialize the pipe for training.
|
||||
|
||||
get_examples (Callable[[], Iterable[Example]]): Function that
|
||||
returns a representative sample of gold-standard Example objects.
|
||||
nlp (Language): The current nlp object the component is part of.
|
||||
|
||||
DOCS: https://spacy.io/api/coref_er#initialize (TODO)
|
||||
"""
|
||||
pass
|
||||
|
||||
def from_bytes(
|
||||
self, bytes_data: bytes, *, exclude: Iterable[str] = SimpleFrozenList()
|
||||
) -> "CorefEntityRecognizer":
|
||||
"""Load the coreference entity recognizer from a bytestring.
|
||||
|
||||
bytes_data (bytes): The bytestring to load.
|
||||
RETURNS (CorefEntityRecognizer): The loaded coreference entity recognizer.
|
||||
|
||||
DOCS: https://spacy.io/api/coref_er#from_bytes
|
||||
"""
|
||||
cfg = srsly.msgpack_loads(bytes_data)
|
||||
self.span_mentions = cfg.get("span_mentions", DEFAULT_MENTIONS)
|
||||
self.matcher_key = cfg.get("matcher_key", DEFAULT_MATCHER_KEY)
|
||||
self.matcher_values = cfg.get("matcher_values", DEFAULT_MATCHER_VALUES)
|
||||
return self
|
||||
|
||||
def to_bytes(self, *, exclude: Iterable[str] = SimpleFrozenList()) -> bytes:
|
||||
"""Serialize the coreference entity recognizer to a bytestring.
|
||||
|
||||
RETURNS (bytes): The serialized component.
|
||||
|
||||
DOCS: https://spacy.io/api/coref_er#to_bytes (TODO)
|
||||
"""
|
||||
serial = {"span_mentions": self.span_mentions}
|
||||
return srsly.msgpack_dumps(serial)
|
||||
|
||||
def from_disk(
|
||||
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
|
||||
) -> "CorefEntityRecognizer":
|
||||
"""Load the coreference entity recognizer from a file.
|
||||
|
||||
path (str / Path): The JSONL file to load.
|
||||
RETURNS (CorefEntityRecognizer): The loaded coreference entity recognizer .
|
||||
|
||||
DOCS: https://spacy.io/api/coref_er#from_disk (TODO)
|
||||
"""
|
||||
path = ensure_path(path)
|
||||
cfg = {}
|
||||
deserializers_cfg = {"cfg": lambda p: cfg.update(srsly.read_json(p))}
|
||||
from_disk(path, deserializers_cfg, {})
|
||||
self.span_mentions = cfg.get("span_mentions", DEFAULT_MENTIONS)
|
||||
self.matcher_key = cfg.get("matcher_key", DEFAULT_MATCHER_KEY)
|
||||
self.matcher_values = cfg.get("matcher_values", DEFAULT_MATCHER_VALUES)
|
||||
return self
|
||||
|
||||
def to_disk(
|
||||
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
|
||||
) -> None:
|
||||
"""Save the coreference entity recognizer to a directory.
|
||||
|
||||
path (str / Path): The JSONL file to save.
|
||||
|
||||
DOCS: https://spacy.io/api/coref_er#to_disk (TODO)
|
||||
"""
|
||||
path = ensure_path(path)
|
||||
cfg = {
|
||||
"span_mentions": self.span_mentions,
|
||||
"matcher_key": self.matcher_key,
|
||||
"matcher_values": self.matcher_values,
|
||||
}
|
||||
serializers = {"cfg": lambda p: srsly.write_json(p, cfg)}
|
||||
to_disk(path, serializers, {})
|
||||
|
||||
def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
||||
"""Score a batch of examples.
|
||||
|
||||
examples (Iterable[Example]): The examples to score.
|
||||
RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_coref.
|
||||
|
||||
DOCS: https://spacy.io/api/coref_er#score (TODO)
|
||||
"""
|
||||
def mentions_getter(doc, span_key):
|
||||
return doc.spans[span_key]
|
||||
# This will work better once PR 7209 is merged
|
||||
kwargs.setdefault("getter", mentions_getter)
|
||||
kwargs.setdefault("attr", self.span_mentions)
|
||||
kwargs.setdefault("include_label", False)
|
||||
kwargs.setdefault("allow_overlap", True)
|
||||
return Scorer.score_spans(examples, **kwargs)
|
|
@ -56,8 +56,7 @@ class EntityRuler(Pipe):
|
|||
"""The EntityRuler lets you add spans to the `Doc.ents` using token-based
|
||||
rules or exact phrase matches. It can be combined with the statistical
|
||||
`EntityRecognizer` to boost accuracy, or used on its own to implement a
|
||||
purely rule-based entity recognition system. After initialization, the
|
||||
component is typically added to the pipeline using `nlp.add_pipe`.
|
||||
purely rule-based entity recognition system.
|
||||
|
||||
DOCS: https://spacy.io/api/entityruler
|
||||
USAGE: https://spacy.io/usage/rule-based-matching#entityruler
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from itertools import islice
|
||||
from typing import Iterable, Tuple, Optional, Dict, List, Callable, Any
|
||||
from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config
|
||||
from thinc.types import Floats2d
|
||||
import numpy
|
||||
from itertools import islice
|
||||
|
||||
from .trainable_pipe import TrainablePipe
|
||||
from ..language import Language
|
||||
|
|
117
spacy/scorer.py
117
spacy/scorer.py
|
@ -341,7 +341,7 @@ class Scorer:
|
|||
for label in labels:
|
||||
if label not in score_per_type:
|
||||
score_per_type[label] = PRFScore()
|
||||
# Find all predidate labels, for all and per type
|
||||
# Find all instances, for all and per type
|
||||
gold_spans = set()
|
||||
pred_spans = set()
|
||||
for span in getter(gold_doc, attr):
|
||||
|
@ -373,6 +373,114 @@ class Scorer:
|
|||
f"{attr}_per_type": None,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def score_clusters(
|
||||
examples: Iterable[Example],
|
||||
attr: str,
|
||||
*,
|
||||
getter: Callable[[Doc, str], Iterable[Iterable[Span]]] = getattr,
|
||||
has_annotation: Optional[Callable[[Doc], bool]] = None,
|
||||
include_label: bool = True,
|
||||
**cfg,
|
||||
) -> Dict[str, Any]:
|
||||
"""Returns PRF scores for clustered spans.
|
||||
|
||||
examples (Iterable[Example]): Examples to score
|
||||
attr (str): The attribute to score.
|
||||
getter (Callable[[Doc, str], Iterable[Iterable[Span]]]): Defaults to getattr.
|
||||
If provided, getter(doc, attr) should return the lists of spans for the
|
||||
individual doc.
|
||||
has_annotation (Optional[Callable[[Doc], bool]]) should return whether a `Doc`
|
||||
has annotation for this `attr`. Docs without annotation are skipped for
|
||||
scoring purposes.
|
||||
include_label (bool): Whether or not to include label information in
|
||||
the evaluation. If set to 'False', two spans will be considered
|
||||
equal if their start and end match, irrespective of their label.
|
||||
|
||||
RETURNS (Dict[str, Any]): A dictionary containing the PRF scores under
|
||||
the keys attr_p/r/f and the per-type PRF scores under attr_per_type.
|
||||
|
||||
DOCS: https://spacy.io/api/scorer#score_clusters (TODO)
|
||||
"""
|
||||
# Note: the current implementation just scores binary pairs on whether they
|
||||
# are in the same cluster or not.
|
||||
# TODO: look at different cluster/coreference scoring techniques.
|
||||
score = PRFScore()
|
||||
score_per_type = dict()
|
||||
for example in examples:
|
||||
pred_doc = example.predicted
|
||||
gold_doc = example.reference
|
||||
# Option to handle docs without sents
|
||||
if has_annotation is not None:
|
||||
if not has_annotation(gold_doc):
|
||||
continue
|
||||
# Find all labels in gold and doc
|
||||
gold_clusters = list(getter(gold_doc, attr))
|
||||
pred_clusters = list(getter(pred_doc, attr))
|
||||
labels = set(
|
||||
[span.label_ for span_list in gold_clusters for span in span_list]
|
||||
+ [span.label_ for span_list in pred_clusters for span in span_list]
|
||||
)
|
||||
# Set up all labels for per type scoring and prepare gold per type
|
||||
for label in labels:
|
||||
if label not in score_per_type:
|
||||
score_per_type[label] = PRFScore()
|
||||
# Find all instances, for all and per type
|
||||
gold_instances = set()
|
||||
gold_per_type = {label: set() for label in labels}
|
||||
for gold_cluster in gold_clusters:
|
||||
for span1 in gold_cluster:
|
||||
for span2 in gold_cluster:
|
||||
# only record pairs where span1 comes before span2
|
||||
if (span1.start < span2.start) or (span1.start == span2.start and span1.end < span2.end):
|
||||
if include_label:
|
||||
gold_rel = (span1.label_, span1.start, span1.end - 1, span2.label_, span2.start, span2.end - 1)
|
||||
else:
|
||||
gold_rel = (span1.start, span1.end - 1, span2.start, span2.end - 1)
|
||||
gold_instances.add(gold_rel)
|
||||
if span1.label_ == span2.label_:
|
||||
gold_per_type[span1.label_].add(gold_rel)
|
||||
pred_instances = set()
|
||||
pred_per_type = {label: set() for label in labels}
|
||||
for pred_cluster in pred_clusters:
|
||||
for span1 in pred_cluster:
|
||||
for span2 in pred_cluster:
|
||||
if (span1.start < span2.start) or (span1.start == span2.start and span1.end < span2.end):
|
||||
if include_label:
|
||||
pred_rel = (span1.label_, span1.start, span1.end - 1, span2.label_, span2.start, span2.end - 1)
|
||||
else:
|
||||
pred_rel = (span1.start, span1.end - 1, span2.start, span2.end - 1)
|
||||
pred_instances.add(pred_rel)
|
||||
if span1.label_ == span2.label_:
|
||||
pred_per_type[span1.label_].add(pred_rel)
|
||||
# Scores per label
|
||||
if include_label:
|
||||
for k, v in score_per_type.items():
|
||||
if k in pred_per_type:
|
||||
v.score_set(pred_per_type[k], gold_per_type[k])
|
||||
# Score for all labels
|
||||
score.score_set(pred_instances, gold_instances)
|
||||
# Assemble final result
|
||||
final_scores = {
|
||||
f"{attr}_p": None,
|
||||
f"{attr}_r": None,
|
||||
f"{attr}_f": None,
|
||||
|
||||
}
|
||||
if include_label:
|
||||
final_scores[f"{attr}_per_type"] = None
|
||||
if len(score) > 0:
|
||||
final_scores[f"{attr}_p"] = score.precision
|
||||
final_scores[f"{attr}_r"] = score.recall
|
||||
final_scores[f"{attr}_f"] = score.fscore
|
||||
return {
|
||||
f"{attr}_p": None,
|
||||
f"{attr}_r": None,
|
||||
f"{attr}_f": None,
|
||||
f"{attr}_per_type": None,
|
||||
}
|
||||
return final_scores
|
||||
|
||||
@staticmethod
|
||||
def score_cats(
|
||||
examples: Iterable[Example],
|
||||
|
@ -722,12 +830,7 @@ def get_ner_prf(examples: Iterable[Example]) -> Dict[str, Any]:
|
|||
"ents_per_type": {k: v.to_dict() for k, v in score_per_type.items()},
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"ents_p": None,
|
||||
"ents_r": None,
|
||||
"ents_f": None,
|
||||
"ents_per_type": None,
|
||||
}
|
||||
return {"ents_p": None, "ents_r": None, "ents_f": None, "ents_per_type": None}
|
||||
|
||||
|
||||
# The following implementation of roc_auc_score() is adapted from
|
||||
|
|
180
spacy/tests/pipeline/test_coref.py
Normal file
180
spacy/tests/pipeline/test_coref.py
Normal file
|
@ -0,0 +1,180 @@
|
|||
import pytest
|
||||
import spacy
|
||||
from spacy.matcher import PhraseMatcher
|
||||
from spacy.training import Example
|
||||
from spacy.lang.en import English
|
||||
from spacy.tests.util import make_tempdir
|
||||
from spacy.tokens import Doc
|
||||
from spacy.pipeline.coref import DEFAULT_CLUSTERS_PREFIX
|
||||
from spacy.pipeline.coref_er import DEFAULT_MENTIONS
|
||||
|
||||
|
||||
# fmt: off
|
||||
TRAIN_DATA = [
|
||||
(
|
||||
"John Smith told Laura that he was running late and asked her whether she could pick up their kids.",
|
||||
{
|
||||
"spans": {
|
||||
DEFAULT_MENTIONS: [
|
||||
(0, 10, "MENTION"),
|
||||
(16, 21, "MENTION"),
|
||||
(27, 29, "MENTION"),
|
||||
(57, 60, "MENTION"),
|
||||
(69, 72, "MENTION"),
|
||||
(87, 92, "MENTION"),
|
||||
(87, 97, "MENTION"),
|
||||
],
|
||||
f"{DEFAULT_CLUSTERS_PREFIX}_1": [
|
||||
(0, 10, "MENTION"), # John
|
||||
(27, 29, "MENTION"),
|
||||
(87, 92, "MENTION"), # 'their' refers to John and Laur
|
||||
],
|
||||
f"{DEFAULT_CLUSTERS_PREFIX}_2": [
|
||||
(16, 21, "MENTION"), # Laura
|
||||
(57, 60, "MENTION"),
|
||||
(69, 72, "MENTION"),
|
||||
(87, 92, "MENTION"), # 'their' refers to John and Laura
|
||||
],
|
||||
}
|
||||
},
|
||||
),
|
||||
(
|
||||
"Yes, I noticed that many friends around me received it. It seems that almost everyone received this SMS.",
|
||||
{
|
||||
"spans": {
|
||||
DEFAULT_MENTIONS: [
|
||||
(5, 6, "MENTION"),
|
||||
(40, 42, "MENTION"),
|
||||
(52, 54, "MENTION"),
|
||||
(95, 103, "MENTION"),
|
||||
],
|
||||
f"{DEFAULT_CLUSTERS_PREFIX}_1": [
|
||||
(5, 6, "MENTION"), # I
|
||||
(40, 42, "MENTION"),
|
||||
|
||||
],
|
||||
f"{DEFAULT_CLUSTERS_PREFIX}_2": [
|
||||
(52, 54, "MENTION"), # SMS
|
||||
(95, 103, "MENTION"),
|
||||
]
|
||||
}
|
||||
},
|
||||
),
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nlp():
|
||||
return English()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def examples(nlp):
|
||||
examples = []
|
||||
for text, annot in TRAIN_DATA:
|
||||
# eg = Example.from_dict(nlp.make_doc(text), annot)
|
||||
# if PR #7197 is merged, replace below with above line
|
||||
ref_doc = nlp.make_doc(text)
|
||||
for key, span_list in annot["spans"].items():
|
||||
spans = []
|
||||
for span_tuple in span_list:
|
||||
start_char = span_tuple[0]
|
||||
end_char = span_tuple[1]
|
||||
label = span_tuple[2]
|
||||
span = ref_doc.char_span(start_char, end_char, label=label)
|
||||
spans.append(span)
|
||||
ref_doc.spans[key] = spans
|
||||
eg = Example(nlp.make_doc(text), ref_doc)
|
||||
examples.append(eg)
|
||||
return examples
|
||||
|
||||
|
||||
def test_coref_er_no_POS(nlp):
|
||||
doc = nlp("The police woman talked to him.")
|
||||
coref_er = nlp.add_pipe("coref_er", last=True)
|
||||
with pytest.raises(ValueError):
|
||||
coref_er(doc)
|
||||
|
||||
|
||||
def test_coref_er_with_POS(nlp):
|
||||
words = ["The", "police", "woman", "talked", "to", "him", "."]
|
||||
pos = ["DET", "NOUN", "NOUN", "VERB", "ADP", "PRON", "PUNCT"]
|
||||
doc = Doc(nlp.vocab, words=words, pos=pos)
|
||||
coref_er = nlp.add_pipe("coref_er", last=True)
|
||||
coref_er(doc)
|
||||
assert len(doc.spans[coref_er.span_mentions]) == 1
|
||||
mention = doc.spans[coref_er.span_mentions][0]
|
||||
assert (mention.text, mention.start, mention.end) == ("him", 5, 6)
|
||||
|
||||
|
||||
def test_coref_er_custom_POS(nlp):
|
||||
words = ["The", "police", "woman", "talked", "to", "him", "."]
|
||||
pos = ["DET", "NOUN", "NOUN", "VERB", "ADP", "PRON", "PUNCT"]
|
||||
doc = Doc(nlp.vocab, words=words, pos=pos)
|
||||
config = {"matcher_key": "POS", "matcher_values": ["NOUN"]}
|
||||
coref_er = nlp.add_pipe("coref_er", last=True, config=config)
|
||||
coref_er(doc)
|
||||
assert len(doc.spans[coref_er.span_mentions]) == 1
|
||||
mention = doc.spans[coref_er.span_mentions][0]
|
||||
assert (mention.text, mention.start, mention.end) == ("police woman", 1, 3)
|
||||
|
||||
|
||||
def test_coref_clusters(nlp, examples):
|
||||
coref_er = nlp.add_pipe("coref_er", last=True)
|
||||
coref = nlp.add_pipe("coref", last=True)
|
||||
coref.initialize(lambda: examples)
|
||||
words = ["Laura", "walked", "her", "dog", "."]
|
||||
pos = ["PROPN", "VERB", "PRON", "NOUN", "PUNCT"]
|
||||
doc = Doc(nlp.vocab, words=words, pos=pos)
|
||||
coref_er(doc)
|
||||
coref(doc)
|
||||
assert len(doc.spans[coref_er.span_mentions]) > 0
|
||||
found_clusters = 0
|
||||
for name, spans in doc.spans.items():
|
||||
if name.startswith(coref.span_cluster_prefix):
|
||||
found_clusters += 1
|
||||
assert found_clusters > 0
|
||||
|
||||
|
||||
def test_coref_er_score(nlp, examples):
|
||||
config = {"matcher_key": "POS", "matcher_values": []}
|
||||
coref_er = nlp.add_pipe("coref_er", last=True, config=config)
|
||||
coref = nlp.add_pipe("coref", last=True)
|
||||
coref.initialize(lambda: examples)
|
||||
mentions_key = coref_er.span_mentions
|
||||
cluster_prefix_key = coref.span_cluster_prefix
|
||||
matcher = PhraseMatcher(nlp.vocab)
|
||||
terms_1 = ["Laura", "her", "she"]
|
||||
terms_2 = ["it", "this SMS"]
|
||||
matcher.add("A", [nlp.make_doc(text) for text in terms_1])
|
||||
matcher.add("B", [nlp.make_doc(text) for text in terms_2])
|
||||
for eg in examples:
|
||||
pred = eg.predicted
|
||||
matches = matcher(pred, as_spans=True)
|
||||
pred.set_ents(matches)
|
||||
coref_er(pred)
|
||||
coref(pred)
|
||||
eg.predicted = pred
|
||||
# TODO: if #7209 is merged, experiment with 'include_label'
|
||||
scores = coref_er.score([eg])
|
||||
assert f"{mentions_key}_f" in scores
|
||||
scores = coref.score([eg])
|
||||
assert f"{cluster_prefix_key}_f" in scores
|
||||
|
||||
|
||||
def test_coref_serialization(nlp):
|
||||
# Test that the coref component can be serialized
|
||||
config_er = {"matcher_key": "TAG", "matcher_values": ["NN"]}
|
||||
nlp.add_pipe("coref_er", last=True, config=config_er)
|
||||
nlp.add_pipe("coref", last=True)
|
||||
assert "coref_er" in nlp.pipe_names
|
||||
assert "coref" in nlp.pipe_names
|
||||
|
||||
with make_tempdir() as tmp_dir:
|
||||
nlp.to_disk(tmp_dir)
|
||||
nlp2 = spacy.load(tmp_dir)
|
||||
assert "coref_er" in nlp2.pipe_names
|
||||
assert "coref" in nlp2.pipe_names
|
||||
coref_er_2 = nlp2.get_pipe("coref_er")
|
||||
assert coref_er_2.matcher_key == "TAG"
|
|
@ -13,7 +13,7 @@ if TYPE_CHECKING:
|
|||
|
||||
# Why inherit from UserDict instead of dict here?
|
||||
# Well, the 'dict' class doesn't necessarily delegate everything nicely,
|
||||
# for performance reasons. The UserDict is slower by better behaved.
|
||||
# for performance reasons. The UserDict is slower but better behaved.
|
||||
# See https://treyhunner.com/2019/04/why-you-shouldnt-inherit-from-list-and-dict-in-python/0ww
|
||||
class SpanGroups(UserDict):
|
||||
"""A dict-like proxy held by the Doc, to control access to span groups."""
|
||||
|
|
Loading…
Reference in New Issue
Block a user