mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Native coref component (#7243)
* initial coref_er pipe * matcher more flexible * base coref component without actual model * initial setup of coref_er.score * rename to include_label * preliminary score_clusters method * apply scoring in coref component * IO fix * return None loss for now * rename to CoreferenceResolver * some preliminary unit tests * use registry as callable
This commit is contained in:
		
							parent
							
								
									dd99872bb0
								
							
						
					
					
						commit
						e0c45c669a
					
				| 
						 | 
					@ -1,3 +1,4 @@
 | 
				
			||||||
 | 
					from .coref import *
 | 
				
			||||||
from .entity_linker import *  # noqa
 | 
					from .entity_linker import *  # noqa
 | 
				
			||||||
from .multi_task import *  # noqa
 | 
					from .multi_task import *  # noqa
 | 
				
			||||||
from .parser import *  # noqa
 | 
					from .parser import *  # noqa
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										18
									
								
								spacy/ml/models/coref.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								spacy/ml/models/coref.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,18 @@
 | 
				
			||||||
 | 
					from typing import List
 | 
				
			||||||
 | 
					from thinc.api import Model
 | 
				
			||||||
 | 
					from thinc.types import Floats2d
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ...util import registry
 | 
				
			||||||
 | 
					from ...tokens import Doc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@registry.architectures("spacy.Coref.v0")
 | 
				
			||||||
 | 
					def build_coref_model(
 | 
				
			||||||
 | 
					    tok2vec: Model[List[Doc], List[Floats2d]]
 | 
				
			||||||
 | 
					) -> Model:
 | 
				
			||||||
 | 
					    """Build a coref resolution model, using a provided token-to-vector component.
 | 
				
			||||||
 | 
					    TODO.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    tok2vec (Model[List[Doc], List[Floats2d]]): The token-to-vector subnetwork.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    return tok2vec
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,6 @@
 | 
				
			||||||
from .attributeruler import AttributeRuler
 | 
					from .attributeruler import AttributeRuler
 | 
				
			||||||
 | 
					from .coref import CoreferenceResolver
 | 
				
			||||||
 | 
					from .coref_er import CorefEntityRecognizer
 | 
				
			||||||
from .dep_parser import DependencyParser
 | 
					from .dep_parser import DependencyParser
 | 
				
			||||||
from .entity_linker import EntityLinker
 | 
					from .entity_linker import EntityLinker
 | 
				
			||||||
from .ner import EntityRecognizer
 | 
					from .ner import EntityRecognizer
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										288
									
								
								spacy/pipeline/coref.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										288
									
								
								spacy/pipeline/coref.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,288 @@
 | 
				
			||||||
 | 
					from typing import Iterable, Tuple, Optional, Dict, Callable, Any
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config
 | 
				
			||||||
 | 
					from itertools import islice
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .trainable_pipe import TrainablePipe
 | 
				
			||||||
 | 
					from .coref_er import DEFAULT_MENTIONS
 | 
				
			||||||
 | 
					from ..language import Language
 | 
				
			||||||
 | 
					from ..training import Example, validate_examples, validate_get_examples
 | 
				
			||||||
 | 
					from ..errors import Errors
 | 
				
			||||||
 | 
					from ..scorer import Scorer
 | 
				
			||||||
 | 
					from ..tokens import Doc
 | 
				
			||||||
 | 
					from ..vocab import Vocab
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					default_config = """
 | 
				
			||||||
 | 
					[model]
 | 
				
			||||||
 | 
					@architectures = "spacy.Coref.v0"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[model.tok2vec]
 | 
				
			||||||
 | 
					@architectures = "spacy.Tok2Vec.v2"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[model.tok2vec.embed]
 | 
				
			||||||
 | 
					@architectures = "spacy.MultiHashEmbed.v1"
 | 
				
			||||||
 | 
					width = 64
 | 
				
			||||||
 | 
					rows = [2000, 2000, 1000, 1000, 1000, 1000]
 | 
				
			||||||
 | 
					attrs = ["ORTH", "LOWER", "PREFIX", "SUFFIX", "SHAPE", "ID"]
 | 
				
			||||||
 | 
					include_static_vectors = false
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[model.tok2vec.encode]
 | 
				
			||||||
 | 
					@architectures = "spacy.MaxoutWindowEncoder.v2"
 | 
				
			||||||
 | 
					width = ${model.tok2vec.embed.width}
 | 
				
			||||||
 | 
					window_size = 1
 | 
				
			||||||
 | 
					maxout_pieces = 3
 | 
				
			||||||
 | 
					depth = 2
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					DEFAULT_MODEL = Config().from_str(default_config)["model"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					DEFAULT_CLUSTERS_PREFIX = "coref_clusters"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@Language.factory(
 | 
				
			||||||
 | 
					    "coref",
 | 
				
			||||||
 | 
					    assigns=[f"doc.spans"],
 | 
				
			||||||
 | 
					    requires=["doc.spans"],
 | 
				
			||||||
 | 
					    default_config={
 | 
				
			||||||
 | 
					        "model": DEFAULT_MODEL,
 | 
				
			||||||
 | 
					        "span_mentions": DEFAULT_MENTIONS,
 | 
				
			||||||
 | 
					        "span_cluster_prefix": DEFAULT_CLUSTERS_PREFIX,
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    default_score_weights={"coref_f": 1.0, "coref_p": None, "coref_r": None},
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					def make_coref(
 | 
				
			||||||
 | 
					    nlp: Language,
 | 
				
			||||||
 | 
					    name: str,
 | 
				
			||||||
 | 
					    model,
 | 
				
			||||||
 | 
					    span_mentions: str,
 | 
				
			||||||
 | 
					    span_cluster_prefix: str,
 | 
				
			||||||
 | 
					) -> "CoreferenceResolver":
 | 
				
			||||||
 | 
					    """Create a CoreferenceResolver component. TODO
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model (Model[List[Doc], List[Floats2d]]): A model instance that predicts ...
 | 
				
			||||||
 | 
					    threshold (float): Cutoff to consider a prediction "positive".
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    return CoreferenceResolver(
 | 
				
			||||||
 | 
					        nlp.vocab,
 | 
				
			||||||
 | 
					        model,
 | 
				
			||||||
 | 
					        name,
 | 
				
			||||||
 | 
					        span_mentions=span_mentions,
 | 
				
			||||||
 | 
					        span_cluster_prefix=span_cluster_prefix,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class CoreferenceResolver(TrainablePipe):
 | 
				
			||||||
 | 
					    """Pipeline component for coreference resolution.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    DOCS: https://spacy.io/api/coref (TODO)
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        vocab: Vocab,
 | 
				
			||||||
 | 
					        model: Model,
 | 
				
			||||||
 | 
					        name: str = "coref",
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        span_mentions: str,
 | 
				
			||||||
 | 
					        span_cluster_prefix: str,
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        """Initialize a coreference resolution component.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        vocab (Vocab): The shared vocabulary.
 | 
				
			||||||
 | 
					        model (thinc.api.Model): The Thinc Model powering the pipeline component.
 | 
				
			||||||
 | 
					        name (str): The component instance name, used to add entries to the
 | 
				
			||||||
 | 
					            losses during training.
 | 
				
			||||||
 | 
					        span_mentions (str): Key in doc.spans where the candidate coref mentions
 | 
				
			||||||
 | 
					            are stored in.
 | 
				
			||||||
 | 
					        span_cluster_prefix (str): Prefix for the key in doc.spans to store the
 | 
				
			||||||
 | 
					            coref clusters in.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        DOCS: https://spacy.io/api/coref#init (TODO)
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        self.vocab = vocab
 | 
				
			||||||
 | 
					        self.model = model
 | 
				
			||||||
 | 
					        self.name = name
 | 
				
			||||||
 | 
					        self.span_mentions = span_mentions
 | 
				
			||||||
 | 
					        self.span_cluster_prefix = span_cluster_prefix
 | 
				
			||||||
 | 
					        self._rehearsal_model = None
 | 
				
			||||||
 | 
					        self.cfg = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def predict(self, docs: Iterable[Doc]):
 | 
				
			||||||
 | 
					        """Apply the pipeline's model to a batch of docs, without modifying them.
 | 
				
			||||||
 | 
					        TODO: write actual algorithm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        docs (Iterable[Doc]): The documents to predict.
 | 
				
			||||||
 | 
					        RETURNS: The models prediction for each document.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        DOCS: https://spacy.io/api/coref#predict (TODO)
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        clusters_by_doc = []
 | 
				
			||||||
 | 
					        for i, doc in enumerate(docs):
 | 
				
			||||||
 | 
					            clusters = []
 | 
				
			||||||
 | 
					            for span in doc.spans[self.span_mentions]:
 | 
				
			||||||
 | 
					                clusters.append([span])
 | 
				
			||||||
 | 
					            clusters_by_doc.append(clusters)
 | 
				
			||||||
 | 
					        return clusters_by_doc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None:
 | 
				
			||||||
 | 
					        """Modify a batch of Doc objects, using pre-computed scores.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        docs (Iterable[Doc]): The documents to modify.
 | 
				
			||||||
 | 
					        clusters: The span clusters, produced by CoreferenceResolver.predict.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        DOCS: https://spacy.io/api/coref#set_annotations (TODO)
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if len(docs) != len(clusters_by_doc):
 | 
				
			||||||
 | 
					            raise ValueError("Found coref clusters incompatible with the "
 | 
				
			||||||
 | 
					                             "documents provided to the 'coref' component. "
 | 
				
			||||||
 | 
					                             "This is likely a bug in spaCy.")
 | 
				
			||||||
 | 
					        for doc, clusters in zip(docs, clusters_by_doc):
 | 
				
			||||||
 | 
					            index = 0
 | 
				
			||||||
 | 
					            for cluster in clusters:
 | 
				
			||||||
 | 
					                key = self.span_cluster_prefix + str(index)
 | 
				
			||||||
 | 
					                if key in doc.spans:
 | 
				
			||||||
 | 
					                    raise ValueError(f"Couldn't store the results of {self.name}, as the key "
 | 
				
			||||||
 | 
					                                     f"{key} already exists in 'doc.spans'.")
 | 
				
			||||||
 | 
					                doc.spans[key] = cluster
 | 
				
			||||||
 | 
					                index += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def update(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        examples: Iterable[Example],
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        drop: float = 0.0,
 | 
				
			||||||
 | 
					        sgd: Optional[Optimizer] = None,
 | 
				
			||||||
 | 
					        losses: Optional[Dict[str, float]] = None,
 | 
				
			||||||
 | 
					    ) -> Dict[str, float]:
 | 
				
			||||||
 | 
					        """Learn from a batch of documents and gold-standard information,
 | 
				
			||||||
 | 
					        updating the pipe's model. Delegates to predict and get_loss.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        examples (Iterable[Example]): A batch of Example objects.
 | 
				
			||||||
 | 
					        drop (float): The dropout rate.
 | 
				
			||||||
 | 
					        sgd (thinc.api.Optimizer): The optimizer.
 | 
				
			||||||
 | 
					        losses (Dict[str, float]): Optional record of the loss during training.
 | 
				
			||||||
 | 
					            Updated using the component name as the key.
 | 
				
			||||||
 | 
					        RETURNS (Dict[str, float]): The updated losses dictionary.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        DOCS: https://spacy.io/api/coref#update (TODO)
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if losses is None:
 | 
				
			||||||
 | 
					            losses = {}
 | 
				
			||||||
 | 
					        losses.setdefault(self.name, 0.0)
 | 
				
			||||||
 | 
					        validate_examples(examples, "CoreferenceResolver.update")
 | 
				
			||||||
 | 
					        if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
 | 
				
			||||||
 | 
					            # Handle cases where there are no tokens in any docs.
 | 
				
			||||||
 | 
					            return losses
 | 
				
			||||||
 | 
					        set_dropout_rate(self.model, drop)
 | 
				
			||||||
 | 
					        scores, bp_scores = self.model.begin_update([eg.predicted for eg in examples])
 | 
				
			||||||
 | 
					        # TODO below
 | 
				
			||||||
 | 
					        # loss, d_scores = self.get_loss(examples, scores)
 | 
				
			||||||
 | 
					        # bp_scores(d_scores)
 | 
				
			||||||
 | 
					        if sgd is not None:
 | 
				
			||||||
 | 
					            self.finish_update(sgd)
 | 
				
			||||||
 | 
					        # losses[self.name] += loss
 | 
				
			||||||
 | 
					        return losses
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def rehearse(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        examples: Iterable[Example],
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        drop: float = 0.0,
 | 
				
			||||||
 | 
					        sgd: Optional[Optimizer] = None,
 | 
				
			||||||
 | 
					        losses: Optional[Dict[str, float]] = None,
 | 
				
			||||||
 | 
					    ) -> Dict[str, float]:
 | 
				
			||||||
 | 
					        """Perform a "rehearsal" update from a batch of data. Rehearsal updates
 | 
				
			||||||
 | 
					        teach the current model to make predictions similar to an initial model,
 | 
				
			||||||
 | 
					        to try to address the "catastrophic forgetting" problem. This feature is
 | 
				
			||||||
 | 
					        experimental.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        examples (Iterable[Example]): A batch of Example objects.
 | 
				
			||||||
 | 
					        drop (float): The dropout rate.
 | 
				
			||||||
 | 
					        sgd (thinc.api.Optimizer): The optimizer.
 | 
				
			||||||
 | 
					        losses (Dict[str, float]): Optional record of the loss during training.
 | 
				
			||||||
 | 
					            Updated using the component name as the key.
 | 
				
			||||||
 | 
					        RETURNS (Dict[str, float]): The updated losses dictionary.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        DOCS: https://spacy.io/api/coref#rehearse (TODO)
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if losses is not None:
 | 
				
			||||||
 | 
					            losses.setdefault(self.name, 0.0)
 | 
				
			||||||
 | 
					        if self._rehearsal_model is None:
 | 
				
			||||||
 | 
					            return losses
 | 
				
			||||||
 | 
					        validate_examples(examples, "CoreferenceResolver.rehearse")
 | 
				
			||||||
 | 
					        docs = [eg.predicted for eg in examples]
 | 
				
			||||||
 | 
					        if not any(len(doc) for doc in docs):
 | 
				
			||||||
 | 
					            # Handle cases where there are no tokens in any docs.
 | 
				
			||||||
 | 
					            return losses
 | 
				
			||||||
 | 
					        set_dropout_rate(self.model, drop)
 | 
				
			||||||
 | 
					        scores, bp_scores = self.model.begin_update(docs)
 | 
				
			||||||
 | 
					        # TODO below
 | 
				
			||||||
 | 
					        target = self._rehearsal_model(examples)
 | 
				
			||||||
 | 
					        gradient = scores - target
 | 
				
			||||||
 | 
					        bp_scores(gradient)
 | 
				
			||||||
 | 
					        if sgd is not None:
 | 
				
			||||||
 | 
					            self.finish_update(sgd)
 | 
				
			||||||
 | 
					        if losses is not None:
 | 
				
			||||||
 | 
					            losses[self.name] += (gradient ** 2).sum()
 | 
				
			||||||
 | 
					        return losses
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_label(self, label: str) -> int:
 | 
				
			||||||
 | 
					        """Technically this method should be implemented from TrainablePipe,
 | 
				
			||||||
 | 
					        but it is not relevant for the coref component.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        raise NotImplementedError(
 | 
				
			||||||
 | 
					            Errors.E931.format(
 | 
				
			||||||
 | 
					                parent="CoreferenceResolver", method="add_label", name=self.name
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_loss(self, examples: Iterable[Example], scores) -> Tuple[float, float]:
 | 
				
			||||||
 | 
					        """Find the loss and gradient of loss for the batch of documents and
 | 
				
			||||||
 | 
					        their predicted scores.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        examples (Iterable[Examples]): The batch of examples.
 | 
				
			||||||
 | 
					        scores: Scores representing the model's predictions.
 | 
				
			||||||
 | 
					        RETURNS (Tuple[float, float]): The loss and the gradient.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        DOCS: https://spacy.io/api/coref#get_loss (TODO)
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        validate_examples(examples, "CoreferenceResolver.get_loss")
 | 
				
			||||||
 | 
					        # TODO
 | 
				
			||||||
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def initialize(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        get_examples: Callable[[], Iterable[Example]],
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        nlp: Optional[Language] = None,
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        """Initialize the pipe for training, using a representative set
 | 
				
			||||||
 | 
					        of data examples.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        get_examples (Callable[[], Iterable[Example]]): Function that
 | 
				
			||||||
 | 
					            returns a representative sample of gold-standard Example objects.
 | 
				
			||||||
 | 
					        nlp (Language): The current nlp object the component is part of.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        DOCS: https://spacy.io/api/coref#initialize (TODO)
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        validate_get_examples(get_examples, "CoreferenceResolver.initialize")
 | 
				
			||||||
 | 
					        subbatch = list(islice(get_examples(), 10))
 | 
				
			||||||
 | 
					        doc_sample = [eg.reference for eg in subbatch]
 | 
				
			||||||
 | 
					        assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
 | 
				
			||||||
 | 
					        self.model.initialize(X=doc_sample)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
 | 
				
			||||||
 | 
					        """Score a batch of examples.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        examples (Iterable[Example]): The examples to score.
 | 
				
			||||||
 | 
					        RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_coref.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        DOCS: https://spacy.io/api/coref#score (TODO)
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        def clusters_getter(doc, span_key):
 | 
				
			||||||
 | 
					            return [spans for name, spans in doc.spans.items() if name.startswith(span_key)]
 | 
				
			||||||
 | 
					        validate_examples(examples, "CoreferenceResolver.score")
 | 
				
			||||||
 | 
					        kwargs.setdefault("getter", clusters_getter)
 | 
				
			||||||
 | 
					        kwargs.setdefault("attr", self.span_cluster_prefix)
 | 
				
			||||||
 | 
					        kwargs.setdefault("include_label", False)
 | 
				
			||||||
 | 
					        return Scorer.score_clusters(examples, **kwargs)
 | 
				
			||||||
							
								
								
									
										227
									
								
								spacy/pipeline/coref_er.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										227
									
								
								spacy/pipeline/coref_er.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,227 @@
 | 
				
			||||||
 | 
					from typing import Optional, Union, Iterable, Callable, List, Dict, Any
 | 
				
			||||||
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					import srsly
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .pipe import Pipe
 | 
				
			||||||
 | 
					from ..scorer import Scorer
 | 
				
			||||||
 | 
					from ..training import Example
 | 
				
			||||||
 | 
					from ..language import Language
 | 
				
			||||||
 | 
					from ..tokens import Doc, Span, SpanGroup
 | 
				
			||||||
 | 
					from ..matcher import Matcher
 | 
				
			||||||
 | 
					from ..util import ensure_path, to_disk, from_disk, SimpleFrozenList
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					DEFAULT_MENTIONS = "coref_mentions"
 | 
				
			||||||
 | 
					DEFAULT_MATCHER_KEY = "POS"
 | 
				
			||||||
 | 
					DEFAULT_MATCHER_VALUES = ["PROPN", "PRON"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@Language.factory(
 | 
				
			||||||
 | 
					    "coref_er",
 | 
				
			||||||
 | 
					    assigns=[f"doc.spans"],
 | 
				
			||||||
 | 
					    requires=["doc.ents", "token.ent_iob", "token.ent_type", "token.pos"],
 | 
				
			||||||
 | 
					    default_config={
 | 
				
			||||||
 | 
					        "span_mentions": DEFAULT_MENTIONS,
 | 
				
			||||||
 | 
					        "matcher_key": DEFAULT_MATCHER_KEY,
 | 
				
			||||||
 | 
					        "matcher_values": DEFAULT_MATCHER_VALUES,
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    default_score_weights={
 | 
				
			||||||
 | 
					        "coref_mentions_f": None,
 | 
				
			||||||
 | 
					        "coref_mentions_p": None,
 | 
				
			||||||
 | 
					        "coref_mentions_r": 1.0,  # the mentions data needs to be consistently annotated for precision rates to make sense
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					def make_coref_er(nlp: Language, name: str, span_mentions: str, matcher_key: str, matcher_values: List[str]):
 | 
				
			||||||
 | 
					    return CorefEntityRecognizer(
 | 
				
			||||||
 | 
					        nlp, name, span_mentions=span_mentions, matcher_key=matcher_key, matcher_values=matcher_values
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class CorefEntityRecognizer(Pipe):
 | 
				
			||||||
 | 
					    """TODO.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        DOCS: https://spacy.io/api/coref_er (TODO)
 | 
				
			||||||
 | 
					        USAGE: https://spacy.io/usage (TODO)
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        nlp: Language,
 | 
				
			||||||
 | 
					        name: str = "coref_er",
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        span_mentions: str,
 | 
				
			||||||
 | 
					        matcher_key: str,
 | 
				
			||||||
 | 
					        matcher_values: List[str],
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        """Initialize the entity recognizer for coreference mentions. TODO
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        nlp (Language): The shared nlp object.
 | 
				
			||||||
 | 
					        name (str): Instance name of the current pipeline component. Typically
 | 
				
			||||||
 | 
					            passed in automatically from the factory when the component is
 | 
				
			||||||
 | 
					            added.
 | 
				
			||||||
 | 
					        span_mentions (str): Key in doc.spans to store the coref mentions in.
 | 
				
			||||||
 | 
					        matcher_key (List[str]): Field for the matcher to work on (e.g. "POS" or "TAG")
 | 
				
			||||||
 | 
					        matcher_values (List[str]): Values to match token sequences as
 | 
				
			||||||
 | 
					            plausible coref mentions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        DOCS: https://spacy.io/api/coref_er#init (TODO)
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        self.nlp = nlp
 | 
				
			||||||
 | 
					        self.name = name
 | 
				
			||||||
 | 
					        self.span_mentions = span_mentions
 | 
				
			||||||
 | 
					        self.matcher_key = matcher_key
 | 
				
			||||||
 | 
					        self.matcher_values = matcher_values
 | 
				
			||||||
 | 
					        self.matcher = Matcher(nlp.vocab)
 | 
				
			||||||
 | 
					        # TODO: allow to specify any matcher patterns instead?
 | 
				
			||||||
 | 
					        for value in matcher_values:
 | 
				
			||||||
 | 
					            self.matcher.add(
 | 
				
			||||||
 | 
					                f"{value}_SEQ", [[{matcher_key: value, "OP": "+"}]], greedy="LONGEST"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def _string_offset(span: Span):
 | 
				
			||||||
 | 
					        return f"{span.start}-{span.end}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, doc: Doc) -> Doc:
 | 
				
			||||||
 | 
					        """Find relevant coref mentions in the document and add them
 | 
				
			||||||
 | 
					        to the doc's relevant span container.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        doc (Doc): The Doc object in the pipeline.
 | 
				
			||||||
 | 
					        RETURNS (Doc): The Doc with added entities, if available.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        DOCS: https://spacy.io/api/coref_er#call (TODO)
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        error_handler = self.get_error_handler()
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            # Add NER
 | 
				
			||||||
 | 
					            spans = list(doc.ents)
 | 
				
			||||||
 | 
					            offsets = set()
 | 
				
			||||||
 | 
					            offsets.update([self._string_offset(e) for e in doc.ents])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # pronouns and proper nouns
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                matches = self.matcher(doc, as_spans=True)
 | 
				
			||||||
 | 
					            except ValueError:
 | 
				
			||||||
 | 
					                raise ValueError(f"Could not run the matcher for 'coref_er'. If {self.matcher_key} tags "
 | 
				
			||||||
 | 
					                                 "are not available, change the 'matcher_key' in the config, "
 | 
				
			||||||
 | 
					                                 "or set matcher_values to an empty list.")
 | 
				
			||||||
 | 
					            spans.extend([m for m in matches if self._string_offset(m) not in offsets])
 | 
				
			||||||
 | 
					            offsets.update([self._string_offset(m) for m in matches])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # noun_chunks - only if implemented and parsing information is available
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                spans.extend(
 | 
				
			||||||
 | 
					                    [nc for nc in doc.noun_chunks if self._string_offset(nc) not in offsets]
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                offsets.update([self._string_offset(nc) for nc in doc.noun_chunks])
 | 
				
			||||||
 | 
					            except (NotImplementedError, ValueError):
 | 
				
			||||||
 | 
					                pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.set_annotations(doc, spans)
 | 
				
			||||||
 | 
					            return doc
 | 
				
			||||||
 | 
					        except Exception as e:
 | 
				
			||||||
 | 
					            error_handler(self.name, self, [doc], e)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def set_annotations(self, doc, spans):
 | 
				
			||||||
 | 
					        """Modify the document in place"""
 | 
				
			||||||
 | 
					        group = SpanGroup(doc, name=self.span_mentions, spans=spans)
 | 
				
			||||||
 | 
					        if self.span_mentions in doc.spans:
 | 
				
			||||||
 | 
					            raise ValueError(f"Couldn't store the results of {self.name}, as the key "
 | 
				
			||||||
 | 
					                             f"{self.span_mentions} already exists in 'doc.spans'.")
 | 
				
			||||||
 | 
					        doc.spans[self.span_mentions] = group
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def initialize(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        get_examples: Callable[[], Iterable[Example]],
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        nlp: Optional[Language] = None,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        """Initialize the pipe for training.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        get_examples (Callable[[], Iterable[Example]]): Function that
 | 
				
			||||||
 | 
					            returns a representative sample of gold-standard Example objects.
 | 
				
			||||||
 | 
					        nlp (Language): The current nlp object the component is part of.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        DOCS: https://spacy.io/api/coref_er#initialize (TODO)
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def from_bytes(
 | 
				
			||||||
 | 
					        self, bytes_data: bytes, *, exclude: Iterable[str] = SimpleFrozenList()
 | 
				
			||||||
 | 
					    ) -> "CorefEntityRecognizer":
 | 
				
			||||||
 | 
					        """Load the coreference entity recognizer from a bytestring.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        bytes_data (bytes): The bytestring to load.
 | 
				
			||||||
 | 
					        RETURNS (CorefEntityRecognizer): The loaded coreference entity recognizer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        DOCS: https://spacy.io/api/coref_er#from_bytes
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        cfg = srsly.msgpack_loads(bytes_data)
 | 
				
			||||||
 | 
					        self.span_mentions = cfg.get("span_mentions", DEFAULT_MENTIONS)
 | 
				
			||||||
 | 
					        self.matcher_key = cfg.get("matcher_key", DEFAULT_MATCHER_KEY)
 | 
				
			||||||
 | 
					        self.matcher_values = cfg.get("matcher_values", DEFAULT_MATCHER_VALUES)
 | 
				
			||||||
 | 
					        return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def to_bytes(self, *, exclude: Iterable[str] = SimpleFrozenList()) -> bytes:
 | 
				
			||||||
 | 
					        """Serialize the coreference entity recognizer to a bytestring.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        RETURNS (bytes): The serialized component.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        DOCS: https://spacy.io/api/coref_er#to_bytes (TODO)
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        serial = {"span_mentions": self.span_mentions}
 | 
				
			||||||
 | 
					        return srsly.msgpack_dumps(serial)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def from_disk(
 | 
				
			||||||
 | 
					        self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
 | 
				
			||||||
 | 
					    ) -> "CorefEntityRecognizer":
 | 
				
			||||||
 | 
					        """Load the coreference entity recognizer  from a file.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        path (str / Path): The JSONL file to load.
 | 
				
			||||||
 | 
					        RETURNS (CorefEntityRecognizer): The loaded coreference entity recognizer .
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        DOCS: https://spacy.io/api/coref_er#from_disk (TODO)
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        path = ensure_path(path)
 | 
				
			||||||
 | 
					        cfg = {}
 | 
				
			||||||
 | 
					        deserializers_cfg = {"cfg": lambda p: cfg.update(srsly.read_json(p))}
 | 
				
			||||||
 | 
					        from_disk(path, deserializers_cfg, {})
 | 
				
			||||||
 | 
					        self.span_mentions = cfg.get("span_mentions", DEFAULT_MENTIONS)
 | 
				
			||||||
 | 
					        self.matcher_key = cfg.get("matcher_key", DEFAULT_MATCHER_KEY)
 | 
				
			||||||
 | 
					        self.matcher_values = cfg.get("matcher_values", DEFAULT_MATCHER_VALUES)
 | 
				
			||||||
 | 
					        return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def to_disk(
 | 
				
			||||||
 | 
					        self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        """Save the coreference entity recognizer to a directory.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        path (str / Path): The JSONL file to save.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        DOCS: https://spacy.io/api/coref_er#to_disk (TODO)
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        path = ensure_path(path)
 | 
				
			||||||
 | 
					        cfg = {
 | 
				
			||||||
 | 
					            "span_mentions": self.span_mentions,
 | 
				
			||||||
 | 
					            "matcher_key": self.matcher_key,
 | 
				
			||||||
 | 
					            "matcher_values": self.matcher_values,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        serializers = {"cfg": lambda p: srsly.write_json(p, cfg)}
 | 
				
			||||||
 | 
					        to_disk(path, serializers, {})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
 | 
				
			||||||
 | 
					        """Score a batch of examples.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        examples (Iterable[Example]): The examples to score.
 | 
				
			||||||
 | 
					        RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_coref.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        DOCS: https://spacy.io/api/coref_er#score (TODO)
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        def mentions_getter(doc, span_key):
 | 
				
			||||||
 | 
					            return doc.spans[span_key]
 | 
				
			||||||
 | 
					        # This will work better once PR 7209 is merged
 | 
				
			||||||
 | 
					        kwargs.setdefault("getter", mentions_getter)
 | 
				
			||||||
 | 
					        kwargs.setdefault("attr", self.span_mentions)
 | 
				
			||||||
 | 
					        kwargs.setdefault("include_label", False)
 | 
				
			||||||
 | 
					        kwargs.setdefault("allow_overlap", True)
 | 
				
			||||||
 | 
					        return Scorer.score_spans(examples, **kwargs)
 | 
				
			||||||
| 
						 | 
					@ -56,8 +56,7 @@ class EntityRuler(Pipe):
 | 
				
			||||||
    """The EntityRuler lets you add spans to the `Doc.ents` using token-based
 | 
					    """The EntityRuler lets you add spans to the `Doc.ents` using token-based
 | 
				
			||||||
    rules or exact phrase matches. It can be combined with the statistical
 | 
					    rules or exact phrase matches. It can be combined with the statistical
 | 
				
			||||||
    `EntityRecognizer` to boost accuracy, or used on its own to implement a
 | 
					    `EntityRecognizer` to boost accuracy, or used on its own to implement a
 | 
				
			||||||
    purely rule-based entity recognition system. After initialization, the
 | 
					    purely rule-based entity recognition system.
 | 
				
			||||||
    component is typically added to the pipeline using `nlp.add_pipe`.
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    DOCS: https://spacy.io/api/entityruler
 | 
					    DOCS: https://spacy.io/api/entityruler
 | 
				
			||||||
    USAGE: https://spacy.io/usage/rule-based-matching#entityruler
 | 
					    USAGE: https://spacy.io/usage/rule-based-matching#entityruler
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,8 +1,8 @@
 | 
				
			||||||
from itertools import islice
 | 
					 | 
				
			||||||
from typing import Iterable, Tuple, Optional, Dict, List, Callable, Any
 | 
					from typing import Iterable, Tuple, Optional, Dict, List, Callable, Any
 | 
				
			||||||
from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config
 | 
					from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config
 | 
				
			||||||
from thinc.types import Floats2d
 | 
					from thinc.types import Floats2d
 | 
				
			||||||
import numpy
 | 
					import numpy
 | 
				
			||||||
 | 
					from itertools import islice
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .trainable_pipe import TrainablePipe
 | 
					from .trainable_pipe import TrainablePipe
 | 
				
			||||||
from ..language import Language
 | 
					from ..language import Language
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										117
									
								
								spacy/scorer.py
									
									
									
									
									
								
							
							
						
						
									
										117
									
								
								spacy/scorer.py
									
									
									
									
									
								
							| 
						 | 
					@ -341,7 +341,7 @@ class Scorer:
 | 
				
			||||||
            for label in labels:
 | 
					            for label in labels:
 | 
				
			||||||
                if label not in score_per_type:
 | 
					                if label not in score_per_type:
 | 
				
			||||||
                    score_per_type[label] = PRFScore()
 | 
					                    score_per_type[label] = PRFScore()
 | 
				
			||||||
            # Find all predidate labels, for all and per type
 | 
					            # Find all instances, for all and per type
 | 
				
			||||||
            gold_spans = set()
 | 
					            gold_spans = set()
 | 
				
			||||||
            pred_spans = set()
 | 
					            pred_spans = set()
 | 
				
			||||||
            for span in getter(gold_doc, attr):
 | 
					            for span in getter(gold_doc, attr):
 | 
				
			||||||
| 
						 | 
					@ -373,6 +373,114 @@ class Scorer:
 | 
				
			||||||
                f"{attr}_per_type": None,
 | 
					                f"{attr}_per_type": None,
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def score_clusters(
 | 
				
			||||||
 | 
					        examples: Iterable[Example],
 | 
				
			||||||
 | 
					        attr: str,
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        getter: Callable[[Doc, str], Iterable[Iterable[Span]]] = getattr,
 | 
				
			||||||
 | 
					        has_annotation: Optional[Callable[[Doc], bool]] = None,
 | 
				
			||||||
 | 
					        include_label: bool = True,
 | 
				
			||||||
 | 
					        **cfg,
 | 
				
			||||||
 | 
					    ) -> Dict[str, Any]:
 | 
				
			||||||
 | 
					        """Returns PRF scores for clustered spans.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        examples (Iterable[Example]): Examples to score
 | 
				
			||||||
 | 
					        attr (str): The attribute to score.
 | 
				
			||||||
 | 
					        getter (Callable[[Doc, str], Iterable[Iterable[Span]]]): Defaults to getattr.
 | 
				
			||||||
 | 
					            If provided, getter(doc, attr) should return the lists of spans for the
 | 
				
			||||||
 | 
					            individual doc.
 | 
				
			||||||
 | 
					        has_annotation (Optional[Callable[[Doc], bool]]) should return whether a `Doc`
 | 
				
			||||||
 | 
					            has annotation for this `attr`. Docs without annotation are skipped for
 | 
				
			||||||
 | 
					            scoring purposes.
 | 
				
			||||||
 | 
					        include_label (bool): Whether or not to include label information in
 | 
				
			||||||
 | 
					            the evaluation. If set to 'False', two spans will be considered
 | 
				
			||||||
 | 
					            equal if their start and end match, irrespective of their label.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        RETURNS (Dict[str, Any]): A dictionary containing the PRF scores under
 | 
				
			||||||
 | 
					            the keys attr_p/r/f and the per-type PRF scores under attr_per_type.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        DOCS: https://spacy.io/api/scorer#score_clusters (TODO)
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        # Note: the current implementation just scores binary pairs on whether they
 | 
				
			||||||
 | 
					        # are in the same cluster or not.
 | 
				
			||||||
 | 
					        # TODO: look at different cluster/coreference scoring techniques.
 | 
				
			||||||
 | 
					        score = PRFScore()
 | 
				
			||||||
 | 
					        score_per_type = dict()
 | 
				
			||||||
 | 
					        for example in examples:
 | 
				
			||||||
 | 
					            pred_doc = example.predicted
 | 
				
			||||||
 | 
					            gold_doc = example.reference
 | 
				
			||||||
 | 
					            # Option to handle docs without sents
 | 
				
			||||||
 | 
					            if has_annotation is not None:
 | 
				
			||||||
 | 
					                if not has_annotation(gold_doc):
 | 
				
			||||||
 | 
					                    continue
 | 
				
			||||||
 | 
					            # Find all labels in gold and doc
 | 
				
			||||||
 | 
					            gold_clusters = list(getter(gold_doc, attr))
 | 
				
			||||||
 | 
					            pred_clusters = list(getter(pred_doc, attr))
 | 
				
			||||||
 | 
					            labels = set(
 | 
				
			||||||
 | 
					                [span.label_ for span_list in gold_clusters for span in span_list]
 | 
				
			||||||
 | 
					                + [span.label_ for span_list in pred_clusters for span in span_list]
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            # Set up all labels for per type scoring and prepare gold per type
 | 
				
			||||||
 | 
					            for label in labels:
 | 
				
			||||||
 | 
					                if label not in score_per_type:
 | 
				
			||||||
 | 
					                    score_per_type[label] = PRFScore()
 | 
				
			||||||
 | 
					            # Find all instances, for all and per type
 | 
				
			||||||
 | 
					            gold_instances = set()
 | 
				
			||||||
 | 
					            gold_per_type = {label: set() for label in labels}
 | 
				
			||||||
 | 
					            for gold_cluster in gold_clusters:
 | 
				
			||||||
 | 
					                for span1 in gold_cluster:
 | 
				
			||||||
 | 
					                    for span2 in gold_cluster:
 | 
				
			||||||
 | 
					                        # only record pairs where span1 comes before span2
 | 
				
			||||||
 | 
					                        if (span1.start < span2.start) or (span1.start == span2.start and span1.end < span2.end):
 | 
				
			||||||
 | 
					                            if include_label:
 | 
				
			||||||
 | 
					                                gold_rel = (span1.label_, span1.start, span1.end - 1, span2.label_, span2.start, span2.end - 1)
 | 
				
			||||||
 | 
					                            else:
 | 
				
			||||||
 | 
					                                gold_rel = (span1.start, span1.end - 1, span2.start, span2.end - 1)
 | 
				
			||||||
 | 
					                            gold_instances.add(gold_rel)
 | 
				
			||||||
 | 
					                            if span1.label_ == span2.label_:
 | 
				
			||||||
 | 
					                                gold_per_type[span1.label_].add(gold_rel)
 | 
				
			||||||
 | 
					            pred_instances = set()
 | 
				
			||||||
 | 
					            pred_per_type = {label: set() for label in labels}
 | 
				
			||||||
 | 
					            for pred_cluster in pred_clusters:
 | 
				
			||||||
 | 
					                for span1 in pred_cluster:
 | 
				
			||||||
 | 
					                    for span2 in pred_cluster:
 | 
				
			||||||
 | 
					                        if (span1.start < span2.start) or (span1.start == span2.start and span1.end < span2.end):
 | 
				
			||||||
 | 
					                            if include_label:
 | 
				
			||||||
 | 
					                                pred_rel = (span1.label_, span1.start, span1.end - 1, span2.label_, span2.start, span2.end - 1)
 | 
				
			||||||
 | 
					                            else:
 | 
				
			||||||
 | 
					                                pred_rel = (span1.start, span1.end - 1, span2.start, span2.end - 1)
 | 
				
			||||||
 | 
					                            pred_instances.add(pred_rel)
 | 
				
			||||||
 | 
					                            if span1.label_ == span2.label_:
 | 
				
			||||||
 | 
					                                pred_per_type[span1.label_].add(pred_rel)
 | 
				
			||||||
 | 
					            # Scores per label
 | 
				
			||||||
 | 
					            if include_label:
 | 
				
			||||||
 | 
					                for k, v in score_per_type.items():
 | 
				
			||||||
 | 
					                    if k in pred_per_type:
 | 
				
			||||||
 | 
					                        v.score_set(pred_per_type[k], gold_per_type[k])
 | 
				
			||||||
 | 
					            # Score for all labels
 | 
				
			||||||
 | 
					            score.score_set(pred_instances, gold_instances)
 | 
				
			||||||
 | 
					        # Assemble final result
 | 
				
			||||||
 | 
					        final_scores = {
 | 
				
			||||||
 | 
					            f"{attr}_p": None,
 | 
				
			||||||
 | 
					            f"{attr}_r": None,
 | 
				
			||||||
 | 
					            f"{attr}_f": None,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        if include_label:
 | 
				
			||||||
 | 
					            final_scores[f"{attr}_per_type"] = None
 | 
				
			||||||
 | 
					        if len(score) > 0:
 | 
				
			||||||
 | 
					            final_scores[f"{attr}_p"] = score.precision
 | 
				
			||||||
 | 
					            final_scores[f"{attr}_r"] = score.recall
 | 
				
			||||||
 | 
					            final_scores[f"{attr}_f"] = score.fscore
 | 
				
			||||||
 | 
					            return {
 | 
				
			||||||
 | 
					                f"{attr}_p": None,
 | 
				
			||||||
 | 
					                f"{attr}_r": None,
 | 
				
			||||||
 | 
					                f"{attr}_f": None,
 | 
				
			||||||
 | 
					                f"{attr}_per_type": None,
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        return final_scores
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    def score_cats(
 | 
					    def score_cats(
 | 
				
			||||||
        examples: Iterable[Example],
 | 
					        examples: Iterable[Example],
 | 
				
			||||||
| 
						 | 
					@ -722,12 +830,7 @@ def get_ner_prf(examples: Iterable[Example]) -> Dict[str, Any]:
 | 
				
			||||||
            "ents_per_type": {k: v.to_dict() for k, v in score_per_type.items()},
 | 
					            "ents_per_type": {k: v.to_dict() for k, v in score_per_type.items()},
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        return {
 | 
					        return {"ents_p": None, "ents_r": None, "ents_f": None, "ents_per_type": None}
 | 
				
			||||||
            "ents_p": None,
 | 
					 | 
				
			||||||
            "ents_r": None,
 | 
					 | 
				
			||||||
            "ents_f": None,
 | 
					 | 
				
			||||||
            "ents_per_type": None,
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# The following implementation of roc_auc_score() is adapted from
 | 
					# The following implementation of roc_auc_score() is adapted from
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										180
									
								
								spacy/tests/pipeline/test_coref.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										180
									
								
								spacy/tests/pipeline/test_coref.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,180 @@
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					import spacy
 | 
				
			||||||
 | 
					from spacy.matcher import PhraseMatcher
 | 
				
			||||||
 | 
					from spacy.training import Example
 | 
				
			||||||
 | 
					from spacy.lang.en import English
 | 
				
			||||||
 | 
					from spacy.tests.util import make_tempdir
 | 
				
			||||||
 | 
					from spacy.tokens import Doc
 | 
				
			||||||
 | 
					from spacy.pipeline.coref import DEFAULT_CLUSTERS_PREFIX
 | 
				
			||||||
 | 
					from spacy.pipeline.coref_er import DEFAULT_MENTIONS
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# fmt: off
 | 
				
			||||||
 | 
					TRAIN_DATA = [
 | 
				
			||||||
 | 
					    (
 | 
				
			||||||
 | 
					        "John Smith told Laura that he was running late and asked her whether she could pick up their kids.",
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "spans": {
 | 
				
			||||||
 | 
					                DEFAULT_MENTIONS: [
 | 
				
			||||||
 | 
					                    (0, 10, "MENTION"),
 | 
				
			||||||
 | 
					                    (16, 21, "MENTION"),
 | 
				
			||||||
 | 
					                    (27, 29, "MENTION"),
 | 
				
			||||||
 | 
					                    (57, 60, "MENTION"),
 | 
				
			||||||
 | 
					                    (69, 72, "MENTION"),
 | 
				
			||||||
 | 
					                    (87, 92, "MENTION"),
 | 
				
			||||||
 | 
					                    (87, 97, "MENTION"),
 | 
				
			||||||
 | 
					                ],
 | 
				
			||||||
 | 
					                f"{DEFAULT_CLUSTERS_PREFIX}_1": [
 | 
				
			||||||
 | 
					                    (0, 10, "MENTION"),     # John
 | 
				
			||||||
 | 
					                    (27, 29, "MENTION"),
 | 
				
			||||||
 | 
					                    (87, 92, "MENTION"),    # 'their' refers to John and Laur
 | 
				
			||||||
 | 
					                ],
 | 
				
			||||||
 | 
					                f"{DEFAULT_CLUSTERS_PREFIX}_2": [
 | 
				
			||||||
 | 
					                    (16, 21, "MENTION"),     # Laura
 | 
				
			||||||
 | 
					                    (57, 60, "MENTION"),
 | 
				
			||||||
 | 
					                    (69, 72, "MENTION"),
 | 
				
			||||||
 | 
					                    (87, 92, "MENTION"),     # 'their' refers to John and Laura
 | 
				
			||||||
 | 
					                ],
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					    ),
 | 
				
			||||||
 | 
					    (
 | 
				
			||||||
 | 
					        "Yes, I noticed that many friends around me received it. It seems that almost everyone received this SMS.",
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "spans": {
 | 
				
			||||||
 | 
					                DEFAULT_MENTIONS: [
 | 
				
			||||||
 | 
					                    (5, 6, "MENTION"),
 | 
				
			||||||
 | 
					                    (40, 42, "MENTION"),
 | 
				
			||||||
 | 
					                    (52, 54, "MENTION"),
 | 
				
			||||||
 | 
					                    (95, 103, "MENTION"),
 | 
				
			||||||
 | 
					                ],
 | 
				
			||||||
 | 
					                f"{DEFAULT_CLUSTERS_PREFIX}_1": [
 | 
				
			||||||
 | 
					                    (5, 6, "MENTION"),      # I
 | 
				
			||||||
 | 
					                    (40, 42, "MENTION"),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                ],
 | 
				
			||||||
 | 
					                f"{DEFAULT_CLUSTERS_PREFIX}_2": [
 | 
				
			||||||
 | 
					                    (52, 54, "MENTION"),     # SMS
 | 
				
			||||||
 | 
					                    (95, 103, "MENTION"),
 | 
				
			||||||
 | 
					                ]
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					    ),
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					# fmt: on
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture
 | 
				
			||||||
 | 
					def nlp():
 | 
				
			||||||
 | 
					    return English()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture
 | 
				
			||||||
 | 
					def examples(nlp):
 | 
				
			||||||
 | 
					    examples = []
 | 
				
			||||||
 | 
					    for text, annot in TRAIN_DATA:
 | 
				
			||||||
 | 
					        # eg = Example.from_dict(nlp.make_doc(text), annot)
 | 
				
			||||||
 | 
					        # if PR #7197 is merged, replace below with above line
 | 
				
			||||||
 | 
					        ref_doc = nlp.make_doc(text)
 | 
				
			||||||
 | 
					        for key, span_list in annot["spans"].items():
 | 
				
			||||||
 | 
					            spans = []
 | 
				
			||||||
 | 
					            for span_tuple in span_list:
 | 
				
			||||||
 | 
					                start_char = span_tuple[0]
 | 
				
			||||||
 | 
					                end_char = span_tuple[1]
 | 
				
			||||||
 | 
					                label = span_tuple[2]
 | 
				
			||||||
 | 
					                span = ref_doc.char_span(start_char, end_char, label=label)
 | 
				
			||||||
 | 
					                spans.append(span)
 | 
				
			||||||
 | 
					            ref_doc.spans[key] = spans
 | 
				
			||||||
 | 
					        eg = Example(nlp.make_doc(text), ref_doc)
 | 
				
			||||||
 | 
					        examples.append(eg)
 | 
				
			||||||
 | 
					    return examples
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_coref_er_no_POS(nlp):
 | 
				
			||||||
 | 
					    doc = nlp("The police woman talked to him.")
 | 
				
			||||||
 | 
					    coref_er = nlp.add_pipe("coref_er", last=True)
 | 
				
			||||||
 | 
					    with pytest.raises(ValueError):
 | 
				
			||||||
 | 
					        coref_er(doc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_coref_er_with_POS(nlp):
 | 
				
			||||||
 | 
					    words = ["The", "police", "woman", "talked", "to", "him", "."]
 | 
				
			||||||
 | 
					    pos = ["DET", "NOUN", "NOUN", "VERB", "ADP", "PRON", "PUNCT"]
 | 
				
			||||||
 | 
					    doc = Doc(nlp.vocab, words=words, pos=pos)
 | 
				
			||||||
 | 
					    coref_er = nlp.add_pipe("coref_er", last=True)
 | 
				
			||||||
 | 
					    coref_er(doc)
 | 
				
			||||||
 | 
					    assert len(doc.spans[coref_er.span_mentions]) == 1
 | 
				
			||||||
 | 
					    mention = doc.spans[coref_er.span_mentions][0]
 | 
				
			||||||
 | 
					    assert (mention.text, mention.start, mention.end) == ("him", 5, 6)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_coref_er_custom_POS(nlp):
 | 
				
			||||||
 | 
					    words = ["The", "police", "woman", "talked", "to", "him", "."]
 | 
				
			||||||
 | 
					    pos = ["DET", "NOUN", "NOUN", "VERB", "ADP", "PRON", "PUNCT"]
 | 
				
			||||||
 | 
					    doc = Doc(nlp.vocab, words=words, pos=pos)
 | 
				
			||||||
 | 
					    config = {"matcher_key": "POS", "matcher_values": ["NOUN"]}
 | 
				
			||||||
 | 
					    coref_er = nlp.add_pipe("coref_er", last=True, config=config)
 | 
				
			||||||
 | 
					    coref_er(doc)
 | 
				
			||||||
 | 
					    assert len(doc.spans[coref_er.span_mentions]) == 1
 | 
				
			||||||
 | 
					    mention = doc.spans[coref_er.span_mentions][0]
 | 
				
			||||||
 | 
					    assert (mention.text, mention.start, mention.end) == ("police woman", 1, 3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_coref_clusters(nlp, examples):
 | 
				
			||||||
 | 
					    coref_er = nlp.add_pipe("coref_er", last=True)
 | 
				
			||||||
 | 
					    coref = nlp.add_pipe("coref", last=True)
 | 
				
			||||||
 | 
					    coref.initialize(lambda: examples)
 | 
				
			||||||
 | 
					    words = ["Laura", "walked", "her", "dog", "."]
 | 
				
			||||||
 | 
					    pos = ["PROPN", "VERB", "PRON", "NOUN", "PUNCT"]
 | 
				
			||||||
 | 
					    doc = Doc(nlp.vocab, words=words, pos=pos)
 | 
				
			||||||
 | 
					    coref_er(doc)
 | 
				
			||||||
 | 
					    coref(doc)
 | 
				
			||||||
 | 
					    assert len(doc.spans[coref_er.span_mentions]) > 0
 | 
				
			||||||
 | 
					    found_clusters = 0
 | 
				
			||||||
 | 
					    for name, spans in doc.spans.items():
 | 
				
			||||||
 | 
					        if name.startswith(coref.span_cluster_prefix):
 | 
				
			||||||
 | 
					            found_clusters += 1
 | 
				
			||||||
 | 
					    assert found_clusters > 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_coref_er_score(nlp, examples):
 | 
				
			||||||
 | 
					    config = {"matcher_key": "POS", "matcher_values": []}
 | 
				
			||||||
 | 
					    coref_er = nlp.add_pipe("coref_er", last=True, config=config)
 | 
				
			||||||
 | 
					    coref = nlp.add_pipe("coref", last=True)
 | 
				
			||||||
 | 
					    coref.initialize(lambda: examples)
 | 
				
			||||||
 | 
					    mentions_key = coref_er.span_mentions
 | 
				
			||||||
 | 
					    cluster_prefix_key = coref.span_cluster_prefix
 | 
				
			||||||
 | 
					    matcher = PhraseMatcher(nlp.vocab)
 | 
				
			||||||
 | 
					    terms_1 = ["Laura", "her", "she"]
 | 
				
			||||||
 | 
					    terms_2 = ["it", "this SMS"]
 | 
				
			||||||
 | 
					    matcher.add("A", [nlp.make_doc(text) for text in terms_1])
 | 
				
			||||||
 | 
					    matcher.add("B", [nlp.make_doc(text) for text in terms_2])
 | 
				
			||||||
 | 
					    for eg in examples:
 | 
				
			||||||
 | 
					        pred = eg.predicted
 | 
				
			||||||
 | 
					        matches = matcher(pred, as_spans=True)
 | 
				
			||||||
 | 
					        pred.set_ents(matches)
 | 
				
			||||||
 | 
					        coref_er(pred)
 | 
				
			||||||
 | 
					        coref(pred)
 | 
				
			||||||
 | 
					        eg.predicted = pred
 | 
				
			||||||
 | 
					        # TODO: if #7209 is merged, experiment with 'include_label'
 | 
				
			||||||
 | 
					        scores = coref_er.score([eg])
 | 
				
			||||||
 | 
					        assert f"{mentions_key}_f" in scores
 | 
				
			||||||
 | 
					        scores = coref.score([eg])
 | 
				
			||||||
 | 
					        assert f"{cluster_prefix_key}_f" in scores
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_coref_serialization(nlp):
 | 
				
			||||||
 | 
					    # Test that the coref component can be serialized
 | 
				
			||||||
 | 
					    config_er = {"matcher_key": "TAG", "matcher_values": ["NN"]}
 | 
				
			||||||
 | 
					    nlp.add_pipe("coref_er", last=True, config=config_er)
 | 
				
			||||||
 | 
					    nlp.add_pipe("coref", last=True)
 | 
				
			||||||
 | 
					    assert "coref_er" in nlp.pipe_names
 | 
				
			||||||
 | 
					    assert "coref" in nlp.pipe_names
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with make_tempdir() as tmp_dir:
 | 
				
			||||||
 | 
					        nlp.to_disk(tmp_dir)
 | 
				
			||||||
 | 
					        nlp2 = spacy.load(tmp_dir)
 | 
				
			||||||
 | 
					        assert "coref_er" in nlp2.pipe_names
 | 
				
			||||||
 | 
					        assert "coref" in nlp2.pipe_names
 | 
				
			||||||
 | 
					        coref_er_2 = nlp2.get_pipe("coref_er")
 | 
				
			||||||
 | 
					        assert coref_er_2.matcher_key == "TAG"
 | 
				
			||||||
| 
						 | 
					@ -13,7 +13,7 @@ if TYPE_CHECKING:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Why inherit from UserDict instead of dict here?
 | 
					# Why inherit from UserDict instead of dict here?
 | 
				
			||||||
# Well, the 'dict' class doesn't necessarily delegate everything nicely,
 | 
					# Well, the 'dict' class doesn't necessarily delegate everything nicely,
 | 
				
			||||||
# for performance reasons. The UserDict is slower by better behaved.
 | 
					# for performance reasons. The UserDict is slower but better behaved.
 | 
				
			||||||
# See https://treyhunner.com/2019/04/why-you-shouldnt-inherit-from-list-and-dict-in-python/0ww
 | 
					# See https://treyhunner.com/2019/04/why-you-shouldnt-inherit-from-list-and-dict-in-python/0ww
 | 
				
			||||||
class SpanGroups(UserDict):
 | 
					class SpanGroups(UserDict):
 | 
				
			||||||
    """A dict-like proxy held by the Doc, to control access to span groups."""
 | 
					    """A dict-like proxy held by the Doc, to control access to span groups."""
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user