mirror of
https://github.com/explosion/spaCy.git
synced 2025-10-21 19:24:39 +03:00
* initial coref_er pipe * matcher more flexible * base coref component without actual model * initial setup of coref_er.score * rename to include_label * preliminary score_clusters method * apply scoring in coref component * IO fix * return None loss for now * rename to CoreferenceResolver * some preliminary unit tests * use registry as callable
228 lines
8.5 KiB
Python
228 lines
8.5 KiB
Python
from typing import Optional, Union, Iterable, Callable, List, Dict, Any
|
|
from pathlib import Path
|
|
import srsly
|
|
|
|
from .pipe import Pipe
|
|
from ..scorer import Scorer
|
|
from ..training import Example
|
|
from ..language import Language
|
|
from ..tokens import Doc, Span, SpanGroup
|
|
from ..matcher import Matcher
|
|
from ..util import ensure_path, to_disk, from_disk, SimpleFrozenList
|
|
|
|
|
|
DEFAULT_MENTIONS = "coref_mentions"
|
|
DEFAULT_MATCHER_KEY = "POS"
|
|
DEFAULT_MATCHER_VALUES = ["PROPN", "PRON"]
|
|
|
|
|
|
@Language.factory(
|
|
"coref_er",
|
|
assigns=[f"doc.spans"],
|
|
requires=["doc.ents", "token.ent_iob", "token.ent_type", "token.pos"],
|
|
default_config={
|
|
"span_mentions": DEFAULT_MENTIONS,
|
|
"matcher_key": DEFAULT_MATCHER_KEY,
|
|
"matcher_values": DEFAULT_MATCHER_VALUES,
|
|
},
|
|
default_score_weights={
|
|
"coref_mentions_f": None,
|
|
"coref_mentions_p": None,
|
|
"coref_mentions_r": 1.0, # the mentions data needs to be consistently annotated for precision rates to make sense
|
|
},
|
|
)
|
|
def make_coref_er(nlp: Language, name: str, span_mentions: str, matcher_key: str, matcher_values: List[str]):
|
|
return CorefEntityRecognizer(
|
|
nlp, name, span_mentions=span_mentions, matcher_key=matcher_key, matcher_values=matcher_values
|
|
)
|
|
|
|
|
|
class CorefEntityRecognizer(Pipe):
|
|
"""TODO.
|
|
|
|
DOCS: https://spacy.io/api/coref_er (TODO)
|
|
USAGE: https://spacy.io/usage (TODO)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
nlp: Language,
|
|
name: str = "coref_er",
|
|
*,
|
|
span_mentions: str,
|
|
matcher_key: str,
|
|
matcher_values: List[str],
|
|
) -> None:
|
|
"""Initialize the entity recognizer for coreference mentions. TODO
|
|
|
|
nlp (Language): The shared nlp object.
|
|
name (str): Instance name of the current pipeline component. Typically
|
|
passed in automatically from the factory when the component is
|
|
added.
|
|
span_mentions (str): Key in doc.spans to store the coref mentions in.
|
|
matcher_key (List[str]): Field for the matcher to work on (e.g. "POS" or "TAG")
|
|
matcher_values (List[str]): Values to match token sequences as
|
|
plausible coref mentions
|
|
|
|
DOCS: https://spacy.io/api/coref_er#init (TODO)
|
|
"""
|
|
self.nlp = nlp
|
|
self.name = name
|
|
self.span_mentions = span_mentions
|
|
self.matcher_key = matcher_key
|
|
self.matcher_values = matcher_values
|
|
self.matcher = Matcher(nlp.vocab)
|
|
# TODO: allow to specify any matcher patterns instead?
|
|
for value in matcher_values:
|
|
self.matcher.add(
|
|
f"{value}_SEQ", [[{matcher_key: value, "OP": "+"}]], greedy="LONGEST"
|
|
)
|
|
|
|
@staticmethod
|
|
def _string_offset(span: Span):
|
|
return f"{span.start}-{span.end}"
|
|
|
|
def __call__(self, doc: Doc) -> Doc:
|
|
"""Find relevant coref mentions in the document and add them
|
|
to the doc's relevant span container.
|
|
|
|
doc (Doc): The Doc object in the pipeline.
|
|
RETURNS (Doc): The Doc with added entities, if available.
|
|
|
|
DOCS: https://spacy.io/api/coref_er#call (TODO)
|
|
"""
|
|
error_handler = self.get_error_handler()
|
|
try:
|
|
# Add NER
|
|
spans = list(doc.ents)
|
|
offsets = set()
|
|
offsets.update([self._string_offset(e) for e in doc.ents])
|
|
|
|
# pronouns and proper nouns
|
|
try:
|
|
matches = self.matcher(doc, as_spans=True)
|
|
except ValueError:
|
|
raise ValueError(f"Could not run the matcher for 'coref_er'. If {self.matcher_key} tags "
|
|
"are not available, change the 'matcher_key' in the config, "
|
|
"or set matcher_values to an empty list.")
|
|
spans.extend([m for m in matches if self._string_offset(m) not in offsets])
|
|
offsets.update([self._string_offset(m) for m in matches])
|
|
|
|
# noun_chunks - only if implemented and parsing information is available
|
|
try:
|
|
spans.extend(
|
|
[nc for nc in doc.noun_chunks if self._string_offset(nc) not in offsets]
|
|
)
|
|
offsets.update([self._string_offset(nc) for nc in doc.noun_chunks])
|
|
except (NotImplementedError, ValueError):
|
|
pass
|
|
|
|
self.set_annotations(doc, spans)
|
|
return doc
|
|
except Exception as e:
|
|
error_handler(self.name, self, [doc], e)
|
|
|
|
def set_annotations(self, doc, spans):
|
|
"""Modify the document in place"""
|
|
group = SpanGroup(doc, name=self.span_mentions, spans=spans)
|
|
if self.span_mentions in doc.spans:
|
|
raise ValueError(f"Couldn't store the results of {self.name}, as the key "
|
|
f"{self.span_mentions} already exists in 'doc.spans'.")
|
|
doc.spans[self.span_mentions] = group
|
|
|
|
def initialize(
|
|
self,
|
|
get_examples: Callable[[], Iterable[Example]],
|
|
*,
|
|
nlp: Optional[Language] = None,
|
|
):
|
|
"""Initialize the pipe for training.
|
|
|
|
get_examples (Callable[[], Iterable[Example]]): Function that
|
|
returns a representative sample of gold-standard Example objects.
|
|
nlp (Language): The current nlp object the component is part of.
|
|
|
|
DOCS: https://spacy.io/api/coref_er#initialize (TODO)
|
|
"""
|
|
pass
|
|
|
|
def from_bytes(
|
|
self, bytes_data: bytes, *, exclude: Iterable[str] = SimpleFrozenList()
|
|
) -> "CorefEntityRecognizer":
|
|
"""Load the coreference entity recognizer from a bytestring.
|
|
|
|
bytes_data (bytes): The bytestring to load.
|
|
RETURNS (CorefEntityRecognizer): The loaded coreference entity recognizer.
|
|
|
|
DOCS: https://spacy.io/api/coref_er#from_bytes
|
|
"""
|
|
cfg = srsly.msgpack_loads(bytes_data)
|
|
self.span_mentions = cfg.get("span_mentions", DEFAULT_MENTIONS)
|
|
self.matcher_key = cfg.get("matcher_key", DEFAULT_MATCHER_KEY)
|
|
self.matcher_values = cfg.get("matcher_values", DEFAULT_MATCHER_VALUES)
|
|
return self
|
|
|
|
def to_bytes(self, *, exclude: Iterable[str] = SimpleFrozenList()) -> bytes:
|
|
"""Serialize the coreference entity recognizer to a bytestring.
|
|
|
|
RETURNS (bytes): The serialized component.
|
|
|
|
DOCS: https://spacy.io/api/coref_er#to_bytes (TODO)
|
|
"""
|
|
serial = {"span_mentions": self.span_mentions}
|
|
return srsly.msgpack_dumps(serial)
|
|
|
|
def from_disk(
|
|
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
|
|
) -> "CorefEntityRecognizer":
|
|
"""Load the coreference entity recognizer from a file.
|
|
|
|
path (str / Path): The JSONL file to load.
|
|
RETURNS (CorefEntityRecognizer): The loaded coreference entity recognizer .
|
|
|
|
DOCS: https://spacy.io/api/coref_er#from_disk (TODO)
|
|
"""
|
|
path = ensure_path(path)
|
|
cfg = {}
|
|
deserializers_cfg = {"cfg": lambda p: cfg.update(srsly.read_json(p))}
|
|
from_disk(path, deserializers_cfg, {})
|
|
self.span_mentions = cfg.get("span_mentions", DEFAULT_MENTIONS)
|
|
self.matcher_key = cfg.get("matcher_key", DEFAULT_MATCHER_KEY)
|
|
self.matcher_values = cfg.get("matcher_values", DEFAULT_MATCHER_VALUES)
|
|
return self
|
|
|
|
def to_disk(
|
|
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
|
|
) -> None:
|
|
"""Save the coreference entity recognizer to a directory.
|
|
|
|
path (str / Path): The JSONL file to save.
|
|
|
|
DOCS: https://spacy.io/api/coref_er#to_disk (TODO)
|
|
"""
|
|
path = ensure_path(path)
|
|
cfg = {
|
|
"span_mentions": self.span_mentions,
|
|
"matcher_key": self.matcher_key,
|
|
"matcher_values": self.matcher_values,
|
|
}
|
|
serializers = {"cfg": lambda p: srsly.write_json(p, cfg)}
|
|
to_disk(path, serializers, {})
|
|
|
|
def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
|
"""Score a batch of examples.
|
|
|
|
examples (Iterable[Example]): The examples to score.
|
|
RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_coref.
|
|
|
|
DOCS: https://spacy.io/api/coref_er#score (TODO)
|
|
"""
|
|
def mentions_getter(doc, span_key):
|
|
return doc.spans[span_key]
|
|
# This will work better once PR 7209 is merged
|
|
kwargs.setdefault("getter", mentions_getter)
|
|
kwargs.setdefault("attr", self.span_mentions)
|
|
kwargs.setdefault("include_label", False)
|
|
kwargs.setdefault("allow_overlap", True)
|
|
return Scorer.score_spans(examples, **kwargs)
|