mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-19 20:52:23 +03:00
Remove coref_er.py
The intent of this was that it would be a component pipeline that used entities as input, but that's now covered by the get_mentions function as a pipeline arg.
This commit is contained in:
parent
f6652c9252
commit
0942a0b51b
|
@ -1,227 +0,0 @@
|
|||
from typing import Optional, Union, Iterable, Callable, List, Dict, Any
|
||||
from pathlib import Path
|
||||
import srsly
|
||||
|
||||
from .pipe import Pipe
|
||||
from ..scorer import Scorer
|
||||
from ..training import Example
|
||||
from ..language import Language
|
||||
from ..tokens import Doc, Span, SpanGroup
|
||||
from ..matcher import Matcher
|
||||
from ..util import ensure_path, to_disk, from_disk, SimpleFrozenList
|
||||
|
||||
|
||||
DEFAULT_MENTIONS = "coref_mentions"
|
||||
DEFAULT_MATCHER_KEY = "POS"
|
||||
DEFAULT_MATCHER_VALUES = ["PROPN", "PRON"]
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"coref_er",
|
||||
assigns=[f"doc.spans"],
|
||||
requires=["doc.ents", "token.ent_iob", "token.ent_type", "token.pos"],
|
||||
default_config={
|
||||
"span_mentions": DEFAULT_MENTIONS,
|
||||
"matcher_key": DEFAULT_MATCHER_KEY,
|
||||
"matcher_values": DEFAULT_MATCHER_VALUES,
|
||||
},
|
||||
default_score_weights={
|
||||
"coref_mentions_f": None,
|
||||
"coref_mentions_p": None,
|
||||
"coref_mentions_r": 1.0, # the mentions data needs to be consistently annotated for precision rates to make sense
|
||||
},
|
||||
)
|
||||
def make_coref_er(nlp: Language, name: str, span_mentions: str, matcher_key: str, matcher_values: List[str]):
|
||||
return CorefEntityRecognizer(
|
||||
nlp, name, span_mentions=span_mentions, matcher_key=matcher_key, matcher_values=matcher_values
|
||||
)
|
||||
|
||||
|
||||
class CorefEntityRecognizer(Pipe):
|
||||
"""TODO.
|
||||
|
||||
DOCS: https://spacy.io/api/coref_er (TODO)
|
||||
USAGE: https://spacy.io/usage (TODO)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nlp: Language,
|
||||
name: str = "coref_er",
|
||||
*,
|
||||
span_mentions: str,
|
||||
matcher_key: str,
|
||||
matcher_values: List[str],
|
||||
) -> None:
|
||||
"""Initialize the entity recognizer for coreference mentions. TODO
|
||||
|
||||
nlp (Language): The shared nlp object.
|
||||
name (str): Instance name of the current pipeline component. Typically
|
||||
passed in automatically from the factory when the component is
|
||||
added.
|
||||
span_mentions (str): Key in doc.spans to store the coref mentions in.
|
||||
matcher_key (List[str]): Field for the matcher to work on (e.g. "POS" or "TAG")
|
||||
matcher_values (List[str]): Values to match token sequences as
|
||||
plausible coref mentions
|
||||
|
||||
DOCS: https://spacy.io/api/coref_er#init (TODO)
|
||||
"""
|
||||
self.nlp = nlp
|
||||
self.name = name
|
||||
self.span_mentions = span_mentions
|
||||
self.matcher_key = matcher_key
|
||||
self.matcher_values = matcher_values
|
||||
self.matcher = Matcher(nlp.vocab)
|
||||
# TODO: allow to specify any matcher patterns instead?
|
||||
for value in matcher_values:
|
||||
self.matcher.add(
|
||||
f"{value}_SEQ", [[{matcher_key: value, "OP": "+"}]], greedy="LONGEST"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _string_offset(span: Span):
|
||||
return f"{span.start}-{span.end}"
|
||||
|
||||
def __call__(self, doc: Doc) -> Doc:
|
||||
"""Find relevant coref mentions in the document and add them
|
||||
to the doc's relevant span container.
|
||||
|
||||
doc (Doc): The Doc object in the pipeline.
|
||||
RETURNS (Doc): The Doc with added entities, if available.
|
||||
|
||||
DOCS: https://spacy.io/api/coref_er#call (TODO)
|
||||
"""
|
||||
error_handler = self.get_error_handler()
|
||||
try:
|
||||
# Add NER
|
||||
spans = list(doc.ents)
|
||||
offsets = set()
|
||||
offsets.update([self._string_offset(e) for e in doc.ents])
|
||||
|
||||
# pronouns and proper nouns
|
||||
try:
|
||||
matches = self.matcher(doc, as_spans=True)
|
||||
except ValueError:
|
||||
raise ValueError(f"Could not run the matcher for 'coref_er'. If {self.matcher_key} tags "
|
||||
"are not available, change the 'matcher_key' in the config, "
|
||||
"or set matcher_values to an empty list.")
|
||||
spans.extend([m for m in matches if self._string_offset(m) not in offsets])
|
||||
offsets.update([self._string_offset(m) for m in matches])
|
||||
|
||||
# noun_chunks - only if implemented and parsing information is available
|
||||
try:
|
||||
spans.extend(
|
||||
[nc for nc in doc.noun_chunks if self._string_offset(nc) not in offsets]
|
||||
)
|
||||
offsets.update([self._string_offset(nc) for nc in doc.noun_chunks])
|
||||
except (NotImplementedError, ValueError):
|
||||
pass
|
||||
|
||||
self.set_annotations(doc, spans)
|
||||
return doc
|
||||
except Exception as e:
|
||||
error_handler(self.name, self, [doc], e)
|
||||
|
||||
def set_annotations(self, doc, spans):
|
||||
"""Modify the document in place"""
|
||||
group = SpanGroup(doc, name=self.span_mentions, spans=spans)
|
||||
if self.span_mentions in doc.spans:
|
||||
raise ValueError(f"Couldn't store the results of {self.name}, as the key "
|
||||
f"{self.span_mentions} already exists in 'doc.spans'.")
|
||||
doc.spans[self.span_mentions] = group
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
get_examples: Callable[[], Iterable[Example]],
|
||||
*,
|
||||
nlp: Optional[Language] = None,
|
||||
):
|
||||
"""Initialize the pipe for training.
|
||||
|
||||
get_examples (Callable[[], Iterable[Example]]): Function that
|
||||
returns a representative sample of gold-standard Example objects.
|
||||
nlp (Language): The current nlp object the component is part of.
|
||||
|
||||
DOCS: https://spacy.io/api/coref_er#initialize (TODO)
|
||||
"""
|
||||
pass
|
||||
|
||||
def from_bytes(
|
||||
self, bytes_data: bytes, *, exclude: Iterable[str] = SimpleFrozenList()
|
||||
) -> "CorefEntityRecognizer":
|
||||
"""Load the coreference entity recognizer from a bytestring.
|
||||
|
||||
bytes_data (bytes): The bytestring to load.
|
||||
RETURNS (CorefEntityRecognizer): The loaded coreference entity recognizer.
|
||||
|
||||
DOCS: https://spacy.io/api/coref_er#from_bytes
|
||||
"""
|
||||
cfg = srsly.msgpack_loads(bytes_data)
|
||||
self.span_mentions = cfg.get("span_mentions", DEFAULT_MENTIONS)
|
||||
self.matcher_key = cfg.get("matcher_key", DEFAULT_MATCHER_KEY)
|
||||
self.matcher_values = cfg.get("matcher_values", DEFAULT_MATCHER_VALUES)
|
||||
return self
|
||||
|
||||
def to_bytes(self, *, exclude: Iterable[str] = SimpleFrozenList()) -> bytes:
|
||||
"""Serialize the coreference entity recognizer to a bytestring.
|
||||
|
||||
RETURNS (bytes): The serialized component.
|
||||
|
||||
DOCS: https://spacy.io/api/coref_er#to_bytes (TODO)
|
||||
"""
|
||||
serial = {"span_mentions": self.span_mentions}
|
||||
return srsly.msgpack_dumps(serial)
|
||||
|
||||
def from_disk(
|
||||
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
|
||||
) -> "CorefEntityRecognizer":
|
||||
"""Load the coreference entity recognizer from a file.
|
||||
|
||||
path (str / Path): The JSONL file to load.
|
||||
RETURNS (CorefEntityRecognizer): The loaded coreference entity recognizer .
|
||||
|
||||
DOCS: https://spacy.io/api/coref_er#from_disk (TODO)
|
||||
"""
|
||||
path = ensure_path(path)
|
||||
cfg = {}
|
||||
deserializers_cfg = {"cfg": lambda p: cfg.update(srsly.read_json(p))}
|
||||
from_disk(path, deserializers_cfg, {})
|
||||
self.span_mentions = cfg.get("span_mentions", DEFAULT_MENTIONS)
|
||||
self.matcher_key = cfg.get("matcher_key", DEFAULT_MATCHER_KEY)
|
||||
self.matcher_values = cfg.get("matcher_values", DEFAULT_MATCHER_VALUES)
|
||||
return self
|
||||
|
||||
def to_disk(
|
||||
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
|
||||
) -> None:
|
||||
"""Save the coreference entity recognizer to a directory.
|
||||
|
||||
path (str / Path): The JSONL file to save.
|
||||
|
||||
DOCS: https://spacy.io/api/coref_er#to_disk (TODO)
|
||||
"""
|
||||
path = ensure_path(path)
|
||||
cfg = {
|
||||
"span_mentions": self.span_mentions,
|
||||
"matcher_key": self.matcher_key,
|
||||
"matcher_values": self.matcher_values,
|
||||
}
|
||||
serializers = {"cfg": lambda p: srsly.write_json(p, cfg)}
|
||||
to_disk(path, serializers, {})
|
||||
|
||||
def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
||||
"""Score a batch of examples.
|
||||
|
||||
examples (Iterable[Example]): The examples to score.
|
||||
RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_coref.
|
||||
|
||||
DOCS: https://spacy.io/api/coref_er#score (TODO)
|
||||
"""
|
||||
def mentions_getter(doc, span_key):
|
||||
return doc.spans[span_key]
|
||||
# This will work better once PR 7209 is merged
|
||||
kwargs.setdefault("getter", mentions_getter)
|
||||
kwargs.setdefault("attr", self.span_mentions)
|
||||
kwargs.setdefault("include_label", False)
|
||||
kwargs.setdefault("allow_overlap", True)
|
||||
return Scorer.score_spans(examples, **kwargs)
|
Loading…
Reference in New Issue
Block a user