mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	* initial coref_er pipe * matcher more flexible * base coref component without actual model * initial setup of coref_er.score * rename to include_label * preliminary score_clusters method * apply scoring in coref component * IO fix * return None loss for now * rename to CoreferenceResolver * some preliminary unit tests * use registry as callable
		
			
				
	
	
		
			228 lines
		
	
	
		
			8.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			228 lines
		
	
	
		
			8.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from typing import Optional, Union, Iterable, Callable, List, Dict, Any
 | 
						|
from pathlib import Path
 | 
						|
import srsly
 | 
						|
 | 
						|
from .pipe import Pipe
 | 
						|
from ..scorer import Scorer
 | 
						|
from ..training import Example
 | 
						|
from ..language import Language
 | 
						|
from ..tokens import Doc, Span, SpanGroup
 | 
						|
from ..matcher import Matcher
 | 
						|
from ..util import ensure_path, to_disk, from_disk, SimpleFrozenList
 | 
						|
 | 
						|
 | 
						|
DEFAULT_MENTIONS = "coref_mentions"
 | 
						|
DEFAULT_MATCHER_KEY = "POS"
 | 
						|
DEFAULT_MATCHER_VALUES = ["PROPN", "PRON"]
 | 
						|
 | 
						|
 | 
						|
@Language.factory(
 | 
						|
    "coref_er",
 | 
						|
    assigns=[f"doc.spans"],
 | 
						|
    requires=["doc.ents", "token.ent_iob", "token.ent_type", "token.pos"],
 | 
						|
    default_config={
 | 
						|
        "span_mentions": DEFAULT_MENTIONS,
 | 
						|
        "matcher_key": DEFAULT_MATCHER_KEY,
 | 
						|
        "matcher_values": DEFAULT_MATCHER_VALUES,
 | 
						|
    },
 | 
						|
    default_score_weights={
 | 
						|
        "coref_mentions_f": None,
 | 
						|
        "coref_mentions_p": None,
 | 
						|
        "coref_mentions_r": 1.0,  # the mentions data needs to be consistently annotated for precision rates to make sense
 | 
						|
    },
 | 
						|
)
 | 
						|
def make_coref_er(nlp: Language, name: str, span_mentions: str, matcher_key: str, matcher_values: List[str]):
 | 
						|
    return CorefEntityRecognizer(
 | 
						|
        nlp, name, span_mentions=span_mentions, matcher_key=matcher_key, matcher_values=matcher_values
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
class CorefEntityRecognizer(Pipe):
 | 
						|
    """TODO.
 | 
						|
 | 
						|
        DOCS: https://spacy.io/api/coref_er (TODO)
 | 
						|
        USAGE: https://spacy.io/usage (TODO)
 | 
						|
        """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        nlp: Language,
 | 
						|
        name: str = "coref_er",
 | 
						|
        *,
 | 
						|
        span_mentions: str,
 | 
						|
        matcher_key: str,
 | 
						|
        matcher_values: List[str],
 | 
						|
    ) -> None:
 | 
						|
        """Initialize the entity recognizer for coreference mentions. TODO
 | 
						|
 | 
						|
        nlp (Language): The shared nlp object.
 | 
						|
        name (str): Instance name of the current pipeline component. Typically
 | 
						|
            passed in automatically from the factory when the component is
 | 
						|
            added.
 | 
						|
        span_mentions (str): Key in doc.spans to store the coref mentions in.
 | 
						|
        matcher_key (List[str]): Field for the matcher to work on (e.g. "POS" or "TAG")
 | 
						|
        matcher_values (List[str]): Values to match token sequences as
 | 
						|
            plausible coref mentions
 | 
						|
 | 
						|
        DOCS: https://spacy.io/api/coref_er#init (TODO)
 | 
						|
        """
 | 
						|
        self.nlp = nlp
 | 
						|
        self.name = name
 | 
						|
        self.span_mentions = span_mentions
 | 
						|
        self.matcher_key = matcher_key
 | 
						|
        self.matcher_values = matcher_values
 | 
						|
        self.matcher = Matcher(nlp.vocab)
 | 
						|
        # TODO: allow to specify any matcher patterns instead?
 | 
						|
        for value in matcher_values:
 | 
						|
            self.matcher.add(
 | 
						|
                f"{value}_SEQ", [[{matcher_key: value, "OP": "+"}]], greedy="LONGEST"
 | 
						|
            )
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def _string_offset(span: Span):
 | 
						|
        return f"{span.start}-{span.end}"
 | 
						|
 | 
						|
    def __call__(self, doc: Doc) -> Doc:
 | 
						|
        """Find relevant coref mentions in the document and add them
 | 
						|
        to the doc's relevant span container.
 | 
						|
 | 
						|
        doc (Doc): The Doc object in the pipeline.
 | 
						|
        RETURNS (Doc): The Doc with added entities, if available.
 | 
						|
 | 
						|
        DOCS: https://spacy.io/api/coref_er#call (TODO)
 | 
						|
        """
 | 
						|
        error_handler = self.get_error_handler()
 | 
						|
        try:
 | 
						|
            # Add NER
 | 
						|
            spans = list(doc.ents)
 | 
						|
            offsets = set()
 | 
						|
            offsets.update([self._string_offset(e) for e in doc.ents])
 | 
						|
 | 
						|
            # pronouns and proper nouns
 | 
						|
            try:
 | 
						|
                matches = self.matcher(doc, as_spans=True)
 | 
						|
            except ValueError:
 | 
						|
                raise ValueError(f"Could not run the matcher for 'coref_er'. If {self.matcher_key} tags "
 | 
						|
                                 "are not available, change the 'matcher_key' in the config, "
 | 
						|
                                 "or set matcher_values to an empty list.")
 | 
						|
            spans.extend([m for m in matches if self._string_offset(m) not in offsets])
 | 
						|
            offsets.update([self._string_offset(m) for m in matches])
 | 
						|
 | 
						|
            # noun_chunks - only if implemented and parsing information is available
 | 
						|
            try:
 | 
						|
                spans.extend(
 | 
						|
                    [nc for nc in doc.noun_chunks if self._string_offset(nc) not in offsets]
 | 
						|
                )
 | 
						|
                offsets.update([self._string_offset(nc) for nc in doc.noun_chunks])
 | 
						|
            except (NotImplementedError, ValueError):
 | 
						|
                pass
 | 
						|
 | 
						|
            self.set_annotations(doc, spans)
 | 
						|
            return doc
 | 
						|
        except Exception as e:
 | 
						|
            error_handler(self.name, self, [doc], e)
 | 
						|
 | 
						|
    def set_annotations(self, doc, spans):
 | 
						|
        """Modify the document in place"""
 | 
						|
        group = SpanGroup(doc, name=self.span_mentions, spans=spans)
 | 
						|
        if self.span_mentions in doc.spans:
 | 
						|
            raise ValueError(f"Couldn't store the results of {self.name}, as the key "
 | 
						|
                             f"{self.span_mentions} already exists in 'doc.spans'.")
 | 
						|
        doc.spans[self.span_mentions] = group
 | 
						|
 | 
						|
    def initialize(
 | 
						|
        self,
 | 
						|
        get_examples: Callable[[], Iterable[Example]],
 | 
						|
        *,
 | 
						|
        nlp: Optional[Language] = None,
 | 
						|
    ):
 | 
						|
        """Initialize the pipe for training.
 | 
						|
 | 
						|
        get_examples (Callable[[], Iterable[Example]]): Function that
 | 
						|
            returns a representative sample of gold-standard Example objects.
 | 
						|
        nlp (Language): The current nlp object the component is part of.
 | 
						|
 | 
						|
        DOCS: https://spacy.io/api/coref_er#initialize (TODO)
 | 
						|
        """
 | 
						|
        pass
 | 
						|
 | 
						|
    def from_bytes(
 | 
						|
        self, bytes_data: bytes, *, exclude: Iterable[str] = SimpleFrozenList()
 | 
						|
    ) -> "CorefEntityRecognizer":
 | 
						|
        """Load the coreference entity recognizer from a bytestring.
 | 
						|
 | 
						|
        bytes_data (bytes): The bytestring to load.
 | 
						|
        RETURNS (CorefEntityRecognizer): The loaded coreference entity recognizer.
 | 
						|
 | 
						|
        DOCS: https://spacy.io/api/coref_er#from_bytes
 | 
						|
        """
 | 
						|
        cfg = srsly.msgpack_loads(bytes_data)
 | 
						|
        self.span_mentions = cfg.get("span_mentions", DEFAULT_MENTIONS)
 | 
						|
        self.matcher_key = cfg.get("matcher_key", DEFAULT_MATCHER_KEY)
 | 
						|
        self.matcher_values = cfg.get("matcher_values", DEFAULT_MATCHER_VALUES)
 | 
						|
        return self
 | 
						|
 | 
						|
    def to_bytes(self, *, exclude: Iterable[str] = SimpleFrozenList()) -> bytes:
 | 
						|
        """Serialize the coreference entity recognizer to a bytestring.
 | 
						|
 | 
						|
        RETURNS (bytes): The serialized component.
 | 
						|
 | 
						|
        DOCS: https://spacy.io/api/coref_er#to_bytes (TODO)
 | 
						|
        """
 | 
						|
        serial = {"span_mentions": self.span_mentions}
 | 
						|
        return srsly.msgpack_dumps(serial)
 | 
						|
 | 
						|
    def from_disk(
 | 
						|
        self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
 | 
						|
    ) -> "CorefEntityRecognizer":
 | 
						|
        """Load the coreference entity recognizer  from a file.
 | 
						|
 | 
						|
        path (str / Path): The JSONL file to load.
 | 
						|
        RETURNS (CorefEntityRecognizer): The loaded coreference entity recognizer .
 | 
						|
 | 
						|
        DOCS: https://spacy.io/api/coref_er#from_disk (TODO)
 | 
						|
        """
 | 
						|
        path = ensure_path(path)
 | 
						|
        cfg = {}
 | 
						|
        deserializers_cfg = {"cfg": lambda p: cfg.update(srsly.read_json(p))}
 | 
						|
        from_disk(path, deserializers_cfg, {})
 | 
						|
        self.span_mentions = cfg.get("span_mentions", DEFAULT_MENTIONS)
 | 
						|
        self.matcher_key = cfg.get("matcher_key", DEFAULT_MATCHER_KEY)
 | 
						|
        self.matcher_values = cfg.get("matcher_values", DEFAULT_MATCHER_VALUES)
 | 
						|
        return self
 | 
						|
 | 
						|
    def to_disk(
 | 
						|
        self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
 | 
						|
    ) -> None:
 | 
						|
        """Save the coreference entity recognizer to a directory.
 | 
						|
 | 
						|
        path (str / Path): The JSONL file to save.
 | 
						|
 | 
						|
        DOCS: https://spacy.io/api/coref_er#to_disk (TODO)
 | 
						|
        """
 | 
						|
        path = ensure_path(path)
 | 
						|
        cfg = {
 | 
						|
            "span_mentions": self.span_mentions,
 | 
						|
            "matcher_key": self.matcher_key,
 | 
						|
            "matcher_values": self.matcher_values,
 | 
						|
        }
 | 
						|
        serializers = {"cfg": lambda p: srsly.write_json(p, cfg)}
 | 
						|
        to_disk(path, serializers, {})
 | 
						|
 | 
						|
    def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
 | 
						|
        """Score a batch of examples.
 | 
						|
 | 
						|
        examples (Iterable[Example]): The examples to score.
 | 
						|
        RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_coref.
 | 
						|
 | 
						|
        DOCS: https://spacy.io/api/coref_er#score (TODO)
 | 
						|
        """
 | 
						|
        def mentions_getter(doc, span_key):
 | 
						|
            return doc.spans[span_key]
 | 
						|
        # This will work better once PR 7209 is merged
 | 
						|
        kwargs.setdefault("getter", mentions_getter)
 | 
						|
        kwargs.setdefault("attr", self.span_mentions)
 | 
						|
        kwargs.setdefault("include_label", False)
 | 
						|
        kwargs.setdefault("allow_overlap", True)
 | 
						|
        return Scorer.score_spans(examples, **kwargs)
 |