mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	This follows the pattern used in the Biaffine Parser, which uses an init function to get the size only after the tok2vec is available. This works at first, but serialization fails with an error.
		
			
				
	
	
		
			387 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			387 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from typing import List, Tuple
 | |
| 
 | |
| from thinc.api import Model, chain, get_width
 | |
| from thinc.api import PyTorchWrapper, ArgsKwargs
 | |
| from thinc.types import Floats2d
 | |
| from thinc.util import torch, xp2torch, torch2xp
 | |
| 
 | |
| from ...tokens import Doc
 | |
| from ...util import registry
 | |
| 
 | |
| 
 | |
| EPSILON = 1e-7
 | |
| 
 | |
| 
 | |
| @registry.architectures("spacy.Coref.v1")
 | |
| def build_wl_coref_model(
 | |
|     tok2vec: Model[List[Doc], List[Floats2d]],
 | |
|     distance_embedding_size: int = 20,
 | |
|     hidden_size: int = 1024,
 | |
|     depth: int = 1,
 | |
|     dropout: float = 0.3,
 | |
|     # pairs to keep per mention after rough scoring
 | |
|     antecedent_limit: int = 50,
 | |
|     antecedent_batch_size: int = 512,
 | |
|     tok2vec_size: int = 768,  # tok2vec size
 | |
| ):
 | |
|     # TODO add model return types
 | |
| 
 | |
|     nI = None
 | |
| 
 | |
|     with Model.define_operators({">>": chain}):
 | |
|         coref_clusterer = Model(
 | |
|             "coref_clusterer",
 | |
|             forward=coref_forward,
 | |
|             init=coref_init,
 | |
|             dims={"nI": nI},
 | |
|             attrs={
 | |
|                 "distance_embedding_size": distance_embedding_size,
 | |
|                 "hidden_size": hidden_size,
 | |
|                 "depth": depth,
 | |
|                 "dropout": dropout,
 | |
|                 "antecedent_limit": antecedent_limit,
 | |
|                 "antecedent_batch_size": antecedent_batch_size,
 | |
|             },
 | |
|         )
 | |
| 
 | |
|         coref_model = tok2vec >> coref_clusterer
 | |
|     return coref_model
 | |
| 
 | |
| 
 | |
| def coref_init(model: Model, X=None, Y=None):
 | |
|     if model.layers:
 | |
|         return
 | |
| 
 | |
|     if X is not None and model.has_dim("nI") is None:
 | |
|         model.set_dim("nI", get_width(X))
 | |
| 
 | |
|     hidden_size = model.attrs["hidden_size"]
 | |
|     depth = model.attrs["depth"]
 | |
|     dropout = model.attrs["dropout"]
 | |
|     antecedent_limit = model.attrs["antecedent_limit"]
 | |
|     antecedent_batch_size = model.attrs["antecedent_batch_size"]
 | |
|     distance_embedding_size = model.attrs["distance_embedding_size"]
 | |
| 
 | |
|     PyTorchWrapper = registry.get("layers", "PyTorchWrapper.v2")
 | |
|     model._layers = [
 | |
|         PyTorchWrapper(
 | |
|             CorefClusterer(
 | |
|                 model.get_dim("nI"),
 | |
|                 distance_embedding_size,
 | |
|                 hidden_size,
 | |
|                 depth,
 | |
|                 dropout,
 | |
|                 antecedent_limit,
 | |
|                 antecedent_batch_size,
 | |
|             ),
 | |
|             convert_inputs=convert_coref_clusterer_inputs,
 | |
|             convert_outputs=convert_coref_clusterer_outputs,
 | |
|         )
 | |
|         # TODO maybe we need mixed precision and grad scaling?
 | |
|     ]
 | |
| 
 | |
| 
 | |
| def coref_forward(model: Model, X, is_train: bool):
 | |
|     return model.layers[0](X, is_train)
 | |
| 
 | |
| 
 | |
| def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool):
 | |
|     # The input here is List[Floats2d], one for each doc
 | |
|     # just use the first
 | |
|     # TODO real batching
 | |
|     X = X[0]
 | |
|     word_features = xp2torch(X, requires_grad=is_train)
 | |
| 
 | |
|     # TODO fix or remove type annotations
 | |
|     def backprop(args: ArgsKwargs):  # -> List[Floats2d]:
 | |
|         # convert to xp and wrap in list
 | |
|         gradients = torch2xp(args.args[0])
 | |
|         return [gradients]
 | |
| 
 | |
|     return ArgsKwargs(args=(word_features,), kwargs={}), backprop
 | |
| 
 | |
| 
 | |
| def convert_coref_clusterer_outputs(model: Model, inputs_outputs, is_train: bool):
 | |
|     _, outputs = inputs_outputs
 | |
|     scores, indices = outputs
 | |
| 
 | |
|     def convert_for_torch_backward(dY: Floats2d) -> ArgsKwargs:
 | |
|         dY_t = xp2torch(dY[0])
 | |
|         return ArgsKwargs(
 | |
|             args=([scores],),
 | |
|             kwargs={"grad_tensors": [dY_t]},
 | |
|         )
 | |
| 
 | |
|     scores_xp = torch2xp(scores)
 | |
|     indices_xp = torch2xp(indices)
 | |
|     return (scores_xp, indices_xp), convert_for_torch_backward
 | |
| 
 | |
| 
 | |
| class CorefClusterer(torch.nn.Module):
 | |
|     """
 | |
|     Combines all coref modules together to find coreferent token pairs.
 | |
|     Submodules (in the order of their usage in the pipeline):
 | |
|         - rough_scorer (RoughScorer) that prunes candidate pairs
 | |
|         - pw (DistancePairwiseEncoder) that computes pairwise features
 | |
|         - a_scorer (AnaphoricityScorer) produces the final scores
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         dim: int,
 | |
|         dist_emb_size: int,
 | |
|         hidden_size: int,
 | |
|         n_layers: int,
 | |
|         dropout: float,
 | |
|         roughk: int,
 | |
|         batch_size: int,
 | |
|     ):
 | |
|         super().__init__()
 | |
|         """
 | |
|         dim: Size of the input features.
 | |
|         dist_emb_size: Size of the distance embeddings.
 | |
|         hidden_size: Size of the coreference candidate embeddings.
 | |
|         n_layers: Numbers of layers in the AnaphoricityScorer.
 | |
|         dropout: Dropout probability to apply across all modules.
 | |
|         roughk: Number of candidates the RoughScorer returns.
 | |
|         batch_size: Internal batch-size for the more expensive scorer.
 | |
|         """
 | |
|         self.dropout = torch.nn.Dropout(dropout)
 | |
|         self.batch_size = batch_size
 | |
|         self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
 | |
| 
 | |
|         pair_emb = dim * 3 + self.pw.shape
 | |
|         self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout)
 | |
|         self.lstm = torch.nn.LSTM(
 | |
|             input_size=dim,
 | |
|             hidden_size=dim,
 | |
|             batch_first=True,
 | |
|         )
 | |
| 
 | |
|         self.rough_scorer = RoughScorer(dim, dropout, roughk)
 | |
| 
 | |
|     def forward(self, word_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
 | |
|         """
 | |
|         1. LSTM encodes the incoming word_features.
 | |
|         2. The RoughScorer scores and prunes the candidates.
 | |
|         3. The DistancePairwiseEncoder embeds the distances between pairs.
 | |
|         4. The AnaphoricityScorer scores all pairs in mini-batches.
 | |
| 
 | |
|         word_features: torch.Tensor containing word encodings
 | |
| 
 | |
|         returns:
 | |
|             coref_scores: n_words x roughk floats.
 | |
|             top_indices: n_words x roughk integers.
 | |
|         """
 | |
|         self.lstm.flatten_parameters()  # XXX without this there's a warning
 | |
|         word_features = torch.unsqueeze(word_features, dim=0)
 | |
|         words, _ = self.lstm(word_features)
 | |
|         words = words.squeeze()
 | |
|         # words: n_words x dim
 | |
|         words = self.dropout(words)
 | |
|         # Obtain bilinear scores and leave only top-k antecedents for each word
 | |
|         # top_rough_scores: (n_words x roughk)
 | |
|         # top_indices: (n_words x roughk)
 | |
|         top_rough_scores, top_indices = self.rough_scorer(words)
 | |
|         # Get pairwise features
 | |
|         # (n_words x roughk x n_pw_features)
 | |
|         pw = self.pw(top_indices)
 | |
|         batch_size = self.batch_size
 | |
|         a_scores_lst: List[torch.Tensor] = []
 | |
| 
 | |
|         for i in range(0, len(words), batch_size):
 | |
|             pw_batch = pw[i : i + batch_size]
 | |
|             words_batch = words[i : i + batch_size]
 | |
|             top_indices_batch = top_indices[i : i + batch_size]
 | |
|             top_rough_scores_batch = top_rough_scores[i : i + batch_size]
 | |
| 
 | |
|             # a_scores_batch    [batch_size, n_ants]
 | |
|             a_scores_batch = self.a_scorer(
 | |
|                 all_mentions=words,
 | |
|                 mentions_batch=words_batch,
 | |
|                 pw_batch=pw_batch,
 | |
|                 top_indices_batch=top_indices_batch,
 | |
|                 top_rough_scores_batch=top_rough_scores_batch,
 | |
|             )
 | |
|             a_scores_lst.append(a_scores_batch)
 | |
| 
 | |
|         coref_scores = torch.cat(a_scores_lst, dim=0)
 | |
|         return coref_scores, top_indices
 | |
| 
 | |
| 
 | |
| # Note this function is kept here to keep a torch dep out of coref_util.
 | |
| def add_dummy(tensor: torch.Tensor, eps: bool = False):
 | |
|     """Prepends zeros (or a very small value if eps is True)
 | |
|     to the first (not zeroth) dimension of tensor.
 | |
|     """
 | |
|     kwargs = dict(device=tensor.device, dtype=tensor.dtype)
 | |
|     shape: List[int] = list(tensor.shape)
 | |
|     shape[1] = 1
 | |
|     if not eps:
 | |
|         dummy = torch.zeros(shape, **kwargs)  # type: ignore
 | |
|     else:
 | |
|         dummy = torch.full(shape, EPSILON, **kwargs)  # type: ignore
 | |
|     output = torch.cat((dummy, tensor), dim=1)
 | |
|     return output
 | |
| 
 | |
| 
 | |
| class AnaphoricityScorer(torch.nn.Module):
 | |
|     """Calculates anaphoricity scores by passing the inputs into a FFNN"""
 | |
| 
 | |
|     def __init__(self, in_features: int, hidden_size, depth, dropout):
 | |
|         super().__init__()
 | |
|         hidden_size = hidden_size
 | |
|         if not depth:
 | |
|             hidden_size = in_features
 | |
|         layers = []
 | |
|         for i in range(depth):
 | |
|             layers.extend(
 | |
|                 [
 | |
|                     torch.nn.Linear(hidden_size if i else in_features, hidden_size),
 | |
|                     torch.nn.LeakyReLU(),
 | |
|                     torch.nn.Dropout(dropout),
 | |
|                 ]
 | |
|             )
 | |
|         self.hidden = torch.nn.Sequential(*layers)
 | |
|         self.out = torch.nn.Linear(hidden_size, out_features=1)
 | |
| 
 | |
|     def forward(
 | |
|         self,
 | |
|         *,  # type: ignore  # pylint: disable=arguments-differ  #35566 in pytorch
 | |
|         all_mentions: torch.Tensor,
 | |
|         mentions_batch: torch.Tensor,
 | |
|         pw_batch: torch.Tensor,
 | |
|         top_indices_batch: torch.Tensor,
 | |
|         top_rough_scores_batch: torch.Tensor,
 | |
|     ) -> torch.Tensor:
 | |
|         """Builds a pairwise matrix, scores the pairs and returns the scores.
 | |
| 
 | |
|         Args:
 | |
|             all_mentions (torch.Tensor): [n_mentions, mention_emb]
 | |
|             mentions_batch (torch.Tensor): [batch_size, mention_emb]
 | |
|             pw_batch (torch.Tensor): [batch_size, n_ants, pw_emb]
 | |
|             top_indices_batch (torch.Tensor): [batch_size, n_ants]
 | |
|             top_rough_scores_batch (torch.Tensor): [batch_size, n_ants]
 | |
| 
 | |
|         Returns:
 | |
|             torch.Tensor [batch_size, n_ants + 1]
 | |
|                 anaphoricity scores for the pairs + a dummy column
 | |
|         """
 | |
|         # [batch_size, n_ants, pair_emb]
 | |
|         pair_matrix = self._get_pair_matrix(
 | |
|             all_mentions, mentions_batch, pw_batch, top_indices_batch
 | |
|         )
 | |
| 
 | |
|         # [batch_size, n_ants]
 | |
|         scores = top_rough_scores_batch + self._ffnn(pair_matrix)
 | |
|         scores = add_dummy(scores, eps=True)
 | |
| 
 | |
|         return scores
 | |
| 
 | |
|     def _ffnn(self, x: torch.Tensor) -> torch.Tensor:
 | |
|         """
 | |
|         x: tensor of shape (batch_size x roughk x n_features
 | |
|         returns: tensor of shape (batch_size x antecedent_limit)
 | |
|         """
 | |
|         x = self.out(self.hidden(x))
 | |
|         return x.squeeze(2)
 | |
| 
 | |
|     @staticmethod
 | |
|     def _get_pair_matrix(
 | |
|         all_mentions: torch.Tensor,
 | |
|         mentions_batch: torch.Tensor,
 | |
|         pw_batch: torch.Tensor,
 | |
|         top_indices_batch: torch.Tensor,
 | |
|     ) -> torch.Tensor:
 | |
|         """
 | |
|         Builds the matrix used as input for AnaphoricityScorer.
 | |
| 
 | |
|         all_mentions: (n_mentions x mention_emb),
 | |
|             all the valid mentions of the document,
 | |
|             can be on a different device
 | |
|         mentions_batch: (batch_size x mention_emb),
 | |
|             the mentions of the current batch.
 | |
|         pw_batch: (batch_size x roughk x pw_emb),
 | |
|             pairwise distance features of the current batch.
 | |
|         top_indices_batch: (batch_size x n_ants),
 | |
|             indices of antecedents of each mention
 | |
| 
 | |
|         Returns:
 | |
|             out: pairwise features (batch_size x n_ants x pair_emb)
 | |
|         """
 | |
|         emb_size = mentions_batch.shape[1]
 | |
|         n_ants = pw_batch.shape[1]
 | |
| 
 | |
|         a_mentions = mentions_batch.unsqueeze(1).expand(-1, n_ants, emb_size)
 | |
|         b_mentions = all_mentions[top_indices_batch]
 | |
|         similarity = a_mentions * b_mentions
 | |
| 
 | |
|         out = torch.cat((a_mentions, b_mentions, similarity, pw_batch), dim=2)
 | |
|         return out
 | |
| 
 | |
| 
 | |
| class RoughScorer(torch.nn.Module):
 | |
|     """
 | |
|     Cheaper module that gives a rough estimate of the anaphoricity of two
 | |
|     candidates, only top scoring candidates are considered on later
 | |
|     steps to reduce computational cost.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, features: int, dropout: float, antecedent_limit: int):
 | |
|         super().__init__()
 | |
|         self.dropout = torch.nn.Dropout(dropout)
 | |
|         self.bilinear = torch.nn.Linear(features, features)
 | |
|         self.k = antecedent_limit
 | |
| 
 | |
|     def forward(
 | |
|         self,  # type: ignore
 | |
|         mentions: torch.Tensor,
 | |
|     ) -> Tuple[torch.Tensor, torch.Tensor]:
 | |
|         """
 | |
|         Returns rough anaphoricity scores for candidates, which consist of
 | |
|         the bilinear output of the current model summed with mention scores.
 | |
|         """
 | |
|         # [n_mentions, n_mentions]
 | |
|         pair_mask = torch.arange(mentions.shape[0])
 | |
|         pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0)
 | |
|         pair_mask = torch.log((pair_mask > 0).to(torch.float))
 | |
|         pair_mask = pair_mask.to(mentions.device)
 | |
|         bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)
 | |
|         rough_scores = pair_mask + bilinear_scores
 | |
|         top_scores, indices = torch.topk(
 | |
|             rough_scores, k=min(self.k, len(rough_scores)), dim=1, sorted=False
 | |
|         )
 | |
| 
 | |
|         return top_scores, indices
 | |
| 
 | |
| 
 | |
| class DistancePairwiseEncoder(torch.nn.Module):
 | |
|     def __init__(self, distance_embedding_size, dropout):
 | |
|         """
 | |
|         Takes the top_indices indicating, which is a ranked
 | |
|         list for each word and its most likely corresponding
 | |
|         anaphora candidates. For each of these pairs it looks
 | |
|         up a distance embedding from a table, where the distance
 | |
|         corresponds to the log-distance.
 | |
| 
 | |
|         distance_embedding_size: int,
 | |
|             Dimensionality of the distance-embeddings table.
 | |
|         dropout: float,
 | |
|             Dropout probability.
 | |
|         """
 | |
|         super().__init__()
 | |
|         emb_size = distance_embedding_size
 | |
|         self.distance_emb = torch.nn.Embedding(9, emb_size)
 | |
|         self.dropout = torch.nn.Dropout(dropout)
 | |
|         self.shape = emb_size
 | |
| 
 | |
|     def forward(self, top_indices: torch.Tensor) -> torch.Tensor:
 | |
|         word_ids = torch.arange(0, top_indices.size(0))
 | |
|         distance = (word_ids.unsqueeze(1) - word_ids[top_indices]).clamp_min_(min=1)
 | |
|         log_distance = distance.to(torch.float).log2().floor_()
 | |
|         log_distance = log_distance.clamp_max_(max=6).to(torch.long)
 | |
|         distance = torch.where(distance < 5, distance - 1, log_distance + 2)
 | |
|         distance = distance.to(top_indices.device)
 | |
|         distance = self.distance_emb(distance)
 | |
|         return self.dropout(distance)
 |