mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	small refactor and docs
This commit is contained in:
		
							parent
							
								
									33f4f90ff0
								
							
						
					
					
						commit
						e512874c80
					
				| 
						 | 
				
			
			@ -1,14 +1,14 @@
 | 
			
		|||
from typing import List, Tuple
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from thinc.api import Model, chain, tuplify
 | 
			
		||||
from thinc.api import Model, chain
 | 
			
		||||
from thinc.api import PyTorchWrapper, ArgsKwargs
 | 
			
		||||
from thinc.types import Floats2d, Ints1d, Ints2d
 | 
			
		||||
from thinc.types import Floats2d, Ints2d
 | 
			
		||||
from thinc.util import xp2torch, torch2xp
 | 
			
		||||
 | 
			
		||||
from ...tokens import Doc
 | 
			
		||||
from ...util import registry
 | 
			
		||||
from .coref_util import add_dummy, get_sentence_ids
 | 
			
		||||
from .coref_util import add_dummy
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.architectures("spacy.Coref.v1")
 | 
			
		||||
| 
						 | 
				
			
			@ -19,7 +19,6 @@ def build_wl_coref_model(
 | 
			
		|||
    n_hidden_layers: int = 1,  # TODO rename to "depth"?
 | 
			
		||||
    dropout: float = 0.3,
 | 
			
		||||
    # pairs to keep per mention after rough scoring
 | 
			
		||||
    # TODO change to meaningful name
 | 
			
		||||
    rough_k: int = 50,
 | 
			
		||||
    # TODO is this not a training loop setting?
 | 
			
		||||
    a_scoring_batch_size: int = 512,
 | 
			
		||||
| 
						 | 
				
			
			@ -34,7 +33,6 @@ def build_wl_coref_model(
 | 
			
		|||
        dim = 768
 | 
			
		||||
 | 
			
		||||
    with Model.define_operators({">>": chain}):
 | 
			
		||||
        # TODO chain tok2vec with these models
 | 
			
		||||
        coref_scorer = PyTorchWrapper(
 | 
			
		||||
            CorefScorer(
 | 
			
		||||
                dim,
 | 
			
		||||
| 
						 | 
				
			
			@ -49,18 +47,6 @@ def build_wl_coref_model(
 | 
			
		|||
            convert_outputs=convert_coref_scorer_outputs,
 | 
			
		||||
        )
 | 
			
		||||
        coref_model = tok2vec >> coref_scorer
 | 
			
		||||
        # XXX just ignore this until the coref scorer is integrated
 | 
			
		||||
        # span_predictor = PyTorchWrapper(
 | 
			
		||||
        #    SpanPredictor(
 | 
			
		||||
        # TODO this was hardcoded to 1024, check
 | 
			
		||||
        #        hidden_size,
 | 
			
		||||
        #        sp_embedding_size,
 | 
			
		||||
        #    ),
 | 
			
		||||
        #    convert_inputs=convert_span_predictor_inputs
 | 
			
		||||
        # )
 | 
			
		||||
    # TODO combine models so output is uniform (just one forward pass)
 | 
			
		||||
    # It may be reasonable to have an option to disable span prediction,
 | 
			
		||||
    # and just return words as spans.
 | 
			
		||||
    return coref_model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -95,46 +81,13 @@ def convert_coref_scorer_outputs(model: Model, inputs_outputs, is_train: bool):
 | 
			
		|||
    return (scores_xp, indices_xp), convert_for_torch_backward
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO add docstring for this, maybe move to utils.
 | 
			
		||||
# This might belong in the component.
 | 
			
		||||
def _clusterize(model, scores: Floats2d, top_indices: Ints2d):
 | 
			
		||||
    xp = model.ops.xp
 | 
			
		||||
    antecedents = scores.argmax(axis=1) - 1
 | 
			
		||||
    not_dummy = antecedents >= 0
 | 
			
		||||
    coref_span_heads = xp.arange(0, len(scores))[not_dummy]
 | 
			
		||||
    antecedents = top_indices[coref_span_heads, antecedents[not_dummy]]
 | 
			
		||||
    n_words = scores.shape[0]
 | 
			
		||||
    nodes = [GraphNode(i) for i in range(n_words)]
 | 
			
		||||
    for i, j in zip(coref_span_heads.tolist(), antecedents.tolist()):
 | 
			
		||||
        nodes[i].link(nodes[j])
 | 
			
		||||
        assert nodes[i] is not nodes[j]
 | 
			
		||||
 | 
			
		||||
    clusters = []
 | 
			
		||||
    for node in nodes:
 | 
			
		||||
        if len(node.links) > 0 and not node.visited:
 | 
			
		||||
            cluster = []
 | 
			
		||||
            stack = [node]
 | 
			
		||||
            while stack:
 | 
			
		||||
                current_node = stack.pop()
 | 
			
		||||
                current_node.visited = True
 | 
			
		||||
                cluster.append(current_node.id)
 | 
			
		||||
                stack.extend(link for link in current_node.links if not link.visited)
 | 
			
		||||
            assert len(cluster) > 1
 | 
			
		||||
            clusters.append(sorted(cluster))
 | 
			
		||||
    return sorted(clusters)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CorefScorer(torch.nn.Module):
 | 
			
		||||
    """Combines all coref modules together to find coreferent spans.
 | 
			
		||||
 | 
			
		||||
    Attributes:
 | 
			
		||||
        epochs_trained (int): number of epochs the model has been trained for
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    Combines all coref modules together to find coreferent token pairs.
 | 
			
		||||
    Submodules (in the order of their usage in the pipeline):
 | 
			
		||||
        rough_scorer (RoughScorer)
 | 
			
		||||
        pw (PairwiseEncoder)
 | 
			
		||||
        a_scorer (AnaphoricityScorer)
 | 
			
		||||
        sp (SpanPredictor)
 | 
			
		||||
        - rough_scorer (RoughScorer) that prunes candidate pairs
 | 
			
		||||
        - pw (DistancePairwiseEncoder) that computes pairwise features
 | 
			
		||||
        - a_scorer (AnaphoricityScorer) produces the final scores
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
| 
						 | 
				
			
			@ -149,50 +102,54 @@ class CorefScorer(torch.nn.Module):
 | 
			
		|||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        """
 | 
			
		||||
        A newly created model is set to evaluation mode.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            epochs_trained (int): the number of epochs finished
 | 
			
		||||
                (useful for warm start)
 | 
			
		||||
        dim: Size of the input features.
 | 
			
		||||
        dist_emb_size: Size of the distance embeddings.
 | 
			
		||||
        hidden_size: Size of the coreference candidate embeddings.
 | 
			
		||||
        n_layers: Numbers of layers in the AnaphoricityScorer.
 | 
			
		||||
        dropout_rate: Dropout probability to apply across all modules.
 | 
			
		||||
        roughk: Number of candidates the RoughScorer returns.
 | 
			
		||||
        batch_size: Internal batch-size for the more expensive AnaphoricityScorer.
 | 
			
		||||
        """
 | 
			
		||||
        self.dropout = torch.nn.Dropout(dropout_rate)
 | 
			
		||||
        self.batch_size = batch_size
 | 
			
		||||
        # Modules
 | 
			
		||||
        self.lstm = torch.nn.LSTM(
 | 
			
		||||
            input_size=dim,
 | 
			
		||||
            hidden_size=dim,
 | 
			
		||||
            batch_first=True,
 | 
			
		||||
        )
 | 
			
		||||
        self.rough_scorer = RoughScorer(dim, dropout_rate, roughk)
 | 
			
		||||
        self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate)
 | 
			
		||||
        # TODO clean this up
 | 
			
		||||
        bert_emb = dim
 | 
			
		||||
        pair_emb = bert_emb * 3 + self.pw.shape
 | 
			
		||||
        pair_emb = dim * 3 + self.pw.shape
 | 
			
		||||
        self.a_scorer = AnaphoricityScorer(
 | 
			
		||||
            pair_emb, hidden_size, n_layers, dropout_rate
 | 
			
		||||
        )
 | 
			
		||||
        self.lstm = torch.nn.LSTM(
 | 
			
		||||
            input_size=bert_emb,
 | 
			
		||||
            hidden_size=bert_emb,
 | 
			
		||||
            batch_first=True,
 | 
			
		||||
        )
 | 
			
		||||
        self.dropout = torch.nn.Dropout(dropout_rate)
 | 
			
		||||
        self.rough_scorer = RoughScorer(bert_emb, dropout_rate, roughk)
 | 
			
		||||
        self.batch_size = batch_size
 | 
			
		||||
 | 
			
		||||
    def forward(self, word_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
        """
 | 
			
		||||
        This is a massive method, but it made sense to me to not split it into
 | 
			
		||||
        several ones to let one see the data flow.
 | 
			
		||||
        1. LSTM encodes the incoming word_features.
 | 
			
		||||
        2. The RoughScorer scores and prunes the candidates.
 | 
			
		||||
        3. The DistancePairwiseEncoder embeds the distance between remaning pairs.
 | 
			
		||||
        4. The AnaphoricityScorer scores all pairs in mini-batches.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            word_features: torch.Tensor containing word encodings
 | 
			
		||||
        Returns:
 | 
			
		||||
            coreference scores and top indices
 | 
			
		||||
        word_features: torch.Tensor containing word encodings
 | 
			
		||||
 | 
			
		||||
        returns:
 | 
			
		||||
            coref_scores: n_words x roughk floats.
 | 
			
		||||
            top_indices: n_words x roughk integers.
 | 
			
		||||
        """
 | 
			
		||||
        # words           [n_words, span_emb]
 | 
			
		||||
        # cluster_ids     [n_words]
 | 
			
		||||
        self.lstm.flatten_parameters()  # XXX without this there's a warning
 | 
			
		||||
        word_features = torch.unsqueeze(word_features, dim=0)
 | 
			
		||||
        words, _ = self.lstm(word_features)
 | 
			
		||||
        words = words.squeeze()
 | 
			
		||||
        # words: n_words x dim
 | 
			
		||||
        words = self.dropout(words)
 | 
			
		||||
        # Obtain bilinear scores and leave only top-k antecedents for each word
 | 
			
		||||
        # top_rough_scores  [n_words, n_ants]
 | 
			
		||||
        # top_indices       [n_words, n_ants]
 | 
			
		||||
        # top_rough_scores: (n_words x roughk)
 | 
			
		||||
        # top_indices: (n_words x roughk)
 | 
			
		||||
        top_rough_scores, top_indices = self.rough_scorer(words)
 | 
			
		||||
        # Get pairwise features [n_words, n_ants, n_pw_features]
 | 
			
		||||
        # Get pairwise features
 | 
			
		||||
        # (n_words x roughk x n_pw_features)
 | 
			
		||||
        pw = self.pw(top_indices)
 | 
			
		||||
        batch_size = self.batch_size
 | 
			
		||||
        a_scores_lst: List[torch.Tensor] = []
 | 
			
		||||
| 
						 | 
				
			
			@ -272,13 +229,8 @@ class AnaphoricityScorer(torch.nn.Module):
 | 
			
		|||
 | 
			
		||||
    def _ffnn(self, x: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Calculates anaphoricity scores.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            x: tensor of shape [batch_size, n_ants, n_features]
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            tensor of shape [batch_size, n_ants]
 | 
			
		||||
        x: tensor of shape (batch_size x roughk x n_features
 | 
			
		||||
        returns: tensor of shape (batch_size x rough_k)
 | 
			
		||||
        """
 | 
			
		||||
        x = self.out(self.hidden(x))
 | 
			
		||||
        return x.squeeze(2)
 | 
			
		||||
| 
						 | 
				
			
			@ -293,21 +245,18 @@ class AnaphoricityScorer(torch.nn.Module):
 | 
			
		|||
        """
 | 
			
		||||
        Builds the matrix used as input for AnaphoricityScorer.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            all_mentions (torch.Tensor): [n_mentions, mention_emb],
 | 
			
		||||
                all the valid mentions of the document,
 | 
			
		||||
                can be on a different device
 | 
			
		||||
            mentions_batch (torch.Tensor): [batch_size, mention_emb],
 | 
			
		||||
                the mentions of the current batch,
 | 
			
		||||
                is expected to be on the current device
 | 
			
		||||
            pw_batch (torch.Tensor): [batch_size, n_ants, pw_emb],
 | 
			
		||||
                pairwise features of the current batch,
 | 
			
		||||
                is expected to be on the current device
 | 
			
		||||
            top_indices_batch (torch.Tensor): [batch_size, n_ants],
 | 
			
		||||
                indices of antecedents of each mention
 | 
			
		||||
        all_mentions: (n_mentions x mention_emb),
 | 
			
		||||
            all the valid mentions of the document,
 | 
			
		||||
            can be on a different device
 | 
			
		||||
        mentions_batch: (batch_size x mention_emb),
 | 
			
		||||
            the mentions of the current batch.
 | 
			
		||||
        pw_batch: (batch_size x roughk x pw_emb),
 | 
			
		||||
            pairwise distance features of the current batch.
 | 
			
		||||
        top_indices_batch: (batch_size x n_ants),
 | 
			
		||||
            indices of antecedents of each mention
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            torch.Tensor: [batch_size, n_ants, pair_emb]
 | 
			
		||||
            out: pairwise features (batch_size x n_ants x pair_emb)
 | 
			
		||||
        """
 | 
			
		||||
        emb_size = mentions_batch.shape[1]
 | 
			
		||||
        n_ants = pw_batch.shape[1]
 | 
			
		||||
| 
						 | 
				
			
			@ -322,16 +271,15 @@ class AnaphoricityScorer(torch.nn.Module):
 | 
			
		|||
 | 
			
		||||
class RoughScorer(torch.nn.Module):
 | 
			
		||||
    """
 | 
			
		||||
    Is needed to give a roughly estimate of the anaphoricity of two candidates,
 | 
			
		||||
    only top scoring candidates are considered on later steps to reduce
 | 
			
		||||
    computational complexity.
 | 
			
		||||
    Cheaper module that gives a rough estimate of the anaphoricity of two
 | 
			
		||||
    candidates, only top scoring candidates are considered on later
 | 
			
		||||
    steps to reduce computational cost.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, features: int, dropout_rate: float, rough_k: float):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.dropout = torch.nn.Dropout(dropout_rate)
 | 
			
		||||
        self.bilinear = torch.nn.Linear(features, features)
 | 
			
		||||
 | 
			
		||||
        self.k = rough_k
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
| 
						 | 
				
			
			@ -348,21 +296,6 @@ class RoughScorer(torch.nn.Module):
 | 
			
		|||
        pair_mask = torch.log((pair_mask > 0).to(torch.float))
 | 
			
		||||
        bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)
 | 
			
		||||
        rough_scores = pair_mask + bilinear_scores
 | 
			
		||||
 | 
			
		||||
        return self._prune(rough_scores)
 | 
			
		||||
 | 
			
		||||
    def _prune(self, rough_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
        """
 | 
			
		||||
        Selects top-k rough antecedent scores for each mention.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            rough_scores: tensor of shape [n_mentions, n_mentions], containing
 | 
			
		||||
                rough antecedent scores of each mention-antecedent pair.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            FloatTensor of shape [n_mentions, k], top rough scores
 | 
			
		||||
            LongTensor of shape [n_mentions, k], top indices
 | 
			
		||||
        """
 | 
			
		||||
        top_scores, indices = torch.topk(
 | 
			
		||||
            rough_scores, k=min(self.k, len(rough_scores)), dim=1, sorted=False
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			@ -371,6 +304,18 @@ class RoughScorer(torch.nn.Module):
 | 
			
		|||
 | 
			
		||||
class DistancePairwiseEncoder(torch.nn.Module):
 | 
			
		||||
    def __init__(self, embedding_size, dropout_rate):
 | 
			
		||||
        """
 | 
			
		||||
        Takes the top_indices indicating, which is a ranked
 | 
			
		||||
        list for each word and its most likely corresponding
 | 
			
		||||
        anaphora candidates. For each of these pairs it looks
 | 
			
		||||
        up a distance embedding from a table, where the distance
 | 
			
		||||
        corresponds to the log-distance.
 | 
			
		||||
 | 
			
		||||
        embedding_size: int,
 | 
			
		||||
            Dimensionality of the distance-embeddings table.
 | 
			
		||||
        dropout_rate: float,
 | 
			
		||||
            Dropout probability.
 | 
			
		||||
        """
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        emb_size = embedding_size
 | 
			
		||||
        self.distance_emb = torch.nn.Embedding(9, emb_size)
 | 
			
		||||
| 
						 | 
				
			
			@ -378,7 +323,7 @@ class DistancePairwiseEncoder(torch.nn.Module):
 | 
			
		|||
        self.shape = emb_size
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,  # type: ignore  # pylint: disable=arguments-differ  #35566 in pytorch
 | 
			
		||||
        self,
 | 
			
		||||
        top_indices: torch.Tensor,
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        word_ids = torch.arange(0, top_indices.size(0))
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user