mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-20 18:54:21 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			391 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			391 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from typing import List, Tuple
 | |
| import torch
 | |
| 
 | |
| from thinc.api import Model, chain, tuplify
 | |
| from thinc.api import PyTorchWrapper, ArgsKwargs
 | |
| from thinc.types import Floats2d, Ints1d, Ints2d
 | |
| from thinc.util import xp2torch, torch2xp
 | |
| 
 | |
| from ...tokens import Doc
 | |
| from ...util import registry
 | |
| from .coref_util import add_dummy, get_sentence_ids
 | |
| 
 | |
| 
 | |
| @registry.architectures("spacy.Coref.v1")
 | |
| def build_wl_coref_model(
 | |
|     tok2vec: Model[List[Doc], List[Floats2d]],
 | |
|     embedding_size: int = 20,
 | |
|     hidden_size: int = 1024,
 | |
|     n_hidden_layers: int = 1,  # TODO rename to "depth"?
 | |
|     dropout: float = 0.3,
 | |
|     # pairs to keep per mention after rough scoring
 | |
|     # TODO change to meaningful name
 | |
|     rough_k: int = 50,
 | |
|     # TODO is this not a training loop setting?
 | |
|     a_scoring_batch_size: int = 512,
 | |
|     # span predictor embeddings
 | |
|     sp_embedding_size: int = 64,
 | |
| ):
 | |
|     # TODO fix this
 | |
|     try:
 | |
|         dim = tok2vec.get_dim("nO")
 | |
|     except ValueError:
 | |
|         # happens with transformer listener
 | |
|         dim = 768
 | |
| 
 | |
|     with Model.define_operators({">>": chain}):
 | |
|         # TODO chain tok2vec with these models
 | |
|         coref_scorer = PyTorchWrapper(
 | |
|             CorefScorer(
 | |
|                 dim,
 | |
|                 embedding_size,
 | |
|                 hidden_size,
 | |
|                 n_hidden_layers,
 | |
|                 dropout,
 | |
|                 rough_k,
 | |
|                 a_scoring_batch_size,
 | |
|             ),
 | |
|             convert_inputs=convert_coref_scorer_inputs,
 | |
|             convert_outputs=convert_coref_scorer_outputs,
 | |
|         )
 | |
|         coref_model = tok2vec >> coref_scorer
 | |
|         # XXX just ignore this until the coref scorer is integrated
 | |
|         # span_predictor = PyTorchWrapper(
 | |
|         #    SpanPredictor(
 | |
|         # TODO this was hardcoded to 1024, check
 | |
|         #        hidden_size,
 | |
|         #        sp_embedding_size,
 | |
|         #    ),
 | |
|         #    convert_inputs=convert_span_predictor_inputs
 | |
|         # )
 | |
|     # TODO combine models so output is uniform (just one forward pass)
 | |
|     # It may be reasonable to have an option to disable span prediction,
 | |
|     # and just return words as spans.
 | |
|     return coref_model
 | |
| 
 | |
| 
 | |
| def convert_coref_scorer_inputs(model: Model, X: List[Floats2d], is_train: bool):
 | |
|     # The input here is List[Floats2d], one for each doc
 | |
|     # just use the first
 | |
|     # TODO real batching
 | |
|     X = X[0]
 | |
|     word_features = xp2torch(X, requires_grad=is_train)
 | |
| 
 | |
|     def backprop(args: ArgsKwargs) -> List[Floats2d]:
 | |
|         # convert to xp and wrap in list
 | |
|         gradients = torch2xp(args.args[0])
 | |
|         return [gradients]
 | |
| 
 | |
|     return ArgsKwargs(args=(word_features,), kwargs={}), backprop
 | |
| 
 | |
| 
 | |
| def convert_coref_scorer_outputs(model: Model, inputs_outputs, is_train: bool):
 | |
|     _, outputs = inputs_outputs
 | |
|     scores, indices = outputs
 | |
| 
 | |
|     def convert_for_torch_backward(dY: Floats2d) -> ArgsKwargs:
 | |
|         dY_t = xp2torch(dY[0])
 | |
|         return ArgsKwargs(
 | |
|             args=([scores],),
 | |
|             kwargs={"grad_tensors": [dY_t]},
 | |
|         )
 | |
| 
 | |
|     scores_xp = torch2xp(scores)
 | |
|     indices_xp = torch2xp(indices)
 | |
|     return (scores_xp, indices_xp), convert_for_torch_backward
 | |
| 
 | |
| 
 | |
| # TODO add docstring for this, maybe move to utils.
 | |
| # This might belong in the component.
 | |
| def _clusterize(model, scores: Floats2d, top_indices: Ints2d):
 | |
|     xp = model.ops.xp
 | |
|     antecedents = scores.argmax(axis=1) - 1
 | |
|     not_dummy = antecedents >= 0
 | |
|     coref_span_heads = xp.arange(0, len(scores))[not_dummy]
 | |
|     antecedents = top_indices[coref_span_heads, antecedents[not_dummy]]
 | |
|     n_words = scores.shape[0]
 | |
|     nodes = [GraphNode(i) for i in range(n_words)]
 | |
|     for i, j in zip(coref_span_heads.tolist(), antecedents.tolist()):
 | |
|         nodes[i].link(nodes[j])
 | |
|         assert nodes[i] is not nodes[j]
 | |
| 
 | |
|     clusters = []
 | |
|     for node in nodes:
 | |
|         if len(node.links) > 0 and not node.visited:
 | |
|             cluster = []
 | |
|             stack = [node]
 | |
|             while stack:
 | |
|                 current_node = stack.pop()
 | |
|                 current_node.visited = True
 | |
|                 cluster.append(current_node.id)
 | |
|                 stack.extend(link for link in current_node.links if not link.visited)
 | |
|             assert len(cluster) > 1
 | |
|             clusters.append(sorted(cluster))
 | |
|     return sorted(clusters)
 | |
| 
 | |
| 
 | |
| class CorefScorer(torch.nn.Module):
 | |
|     """Combines all coref modules together to find coreferent spans.
 | |
| 
 | |
|     Attributes:
 | |
|         epochs_trained (int): number of epochs the model has been trained for
 | |
| 
 | |
|     Submodules (in the order of their usage in the pipeline):
 | |
|         rough_scorer (RoughScorer)
 | |
|         pw (PairwiseEncoder)
 | |
|         a_scorer (AnaphoricityScorer)
 | |
|         sp (SpanPredictor)
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         dim: int,  # tok2vec size
 | |
|         dist_emb_size: int,
 | |
|         hidden_size: int,
 | |
|         n_layers: int,
 | |
|         dropout_rate: float,
 | |
|         roughk: int,
 | |
|         batch_size: int,
 | |
|     ):
 | |
|         super().__init__()
 | |
|         """
 | |
|         A newly created model is set to evaluation mode.
 | |
| 
 | |
|         Args:
 | |
|             epochs_trained (int): the number of epochs finished
 | |
|                 (useful for warm start)
 | |
|         """
 | |
|         self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate)
 | |
|         # TODO clean this up
 | |
|         bert_emb = dim
 | |
|         pair_emb = bert_emb * 3 + self.pw.shape
 | |
|         self.a_scorer = AnaphoricityScorer(
 | |
|             pair_emb, hidden_size, n_layers, dropout_rate
 | |
|         )
 | |
|         self.lstm = torch.nn.LSTM(
 | |
|             input_size=bert_emb,
 | |
|             hidden_size=bert_emb,
 | |
|             batch_first=True,
 | |
|         )
 | |
|         self.dropout = torch.nn.Dropout(dropout_rate)
 | |
|         self.rough_scorer = RoughScorer(bert_emb, dropout_rate, roughk)
 | |
|         self.batch_size = batch_size
 | |
| 
 | |
|     def forward(self, word_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
 | |
|         """
 | |
|         This is a massive method, but it made sense to me to not split it into
 | |
|         several ones to let one see the data flow.
 | |
| 
 | |
|         Args:
 | |
|             word_features: torch.Tensor containing word encodings
 | |
|         Returns:
 | |
|             coreference scores and top indices
 | |
|         """
 | |
|         # words           [n_words, span_emb]
 | |
|         # cluster_ids     [n_words]
 | |
|         self.lstm.flatten_parameters()  # XXX without this there's a warning
 | |
|         word_features = torch.unsqueeze(word_features, dim=0)
 | |
|         words, _ = self.lstm(word_features)
 | |
|         words = words.squeeze()
 | |
|         words = self.dropout(words)
 | |
|         # Obtain bilinear scores and leave only top-k antecedents for each word
 | |
|         # top_rough_scores  [n_words, n_ants]
 | |
|         # top_indices       [n_words, n_ants]
 | |
|         top_rough_scores, top_indices = self.rough_scorer(words)
 | |
|         # Get pairwise features [n_words, n_ants, n_pw_features]
 | |
|         pw = self.pw(top_indices)
 | |
|         batch_size = self.batch_size
 | |
|         a_scores_lst: List[torch.Tensor] = []
 | |
| 
 | |
|         for i in range(0, len(words), batch_size):
 | |
|             pw_batch = pw[i : i + batch_size]
 | |
|             words_batch = words[i : i + batch_size]
 | |
|             top_indices_batch = top_indices[i : i + batch_size]
 | |
|             top_rough_scores_batch = top_rough_scores[i : i + batch_size]
 | |
| 
 | |
|             # a_scores_batch    [batch_size, n_ants]
 | |
|             a_scores_batch = self.a_scorer(
 | |
|                 all_mentions=words,
 | |
|                 mentions_batch=words_batch,
 | |
|                 pw_batch=pw_batch,
 | |
|                 top_indices_batch=top_indices_batch,
 | |
|                 top_rough_scores_batch=top_rough_scores_batch,
 | |
|             )
 | |
|             a_scores_lst.append(a_scores_batch)
 | |
| 
 | |
|         coref_scores = torch.cat(a_scores_lst, dim=0)
 | |
|         return coref_scores, top_indices
 | |
| 
 | |
| 
 | |
| class AnaphoricityScorer(torch.nn.Module):
 | |
|     """Calculates anaphoricity scores by passing the inputs into a FFNN"""
 | |
| 
 | |
|     def __init__(self, in_features: int, hidden_size, n_hidden_layers, dropout_rate):
 | |
|         super().__init__()
 | |
|         hidden_size = hidden_size
 | |
|         if not n_hidden_layers:
 | |
|             hidden_size = in_features
 | |
|         layers = []
 | |
|         for i in range(n_hidden_layers):
 | |
|             layers.extend(
 | |
|                 [
 | |
|                     torch.nn.Linear(hidden_size if i else in_features, hidden_size),
 | |
|                     torch.nn.LeakyReLU(),
 | |
|                     torch.nn.Dropout(dropout_rate),
 | |
|                 ]
 | |
|             )
 | |
|         self.hidden = torch.nn.Sequential(*layers)
 | |
|         self.out = torch.nn.Linear(hidden_size, out_features=1)
 | |
| 
 | |
|     def forward(
 | |
|         self,
 | |
|         *,  # type: ignore  # pylint: disable=arguments-differ  #35566 in pytorch
 | |
|         all_mentions: torch.Tensor,
 | |
|         mentions_batch: torch.Tensor,
 | |
|         pw_batch: torch.Tensor,
 | |
|         top_indices_batch: torch.Tensor,
 | |
|         top_rough_scores_batch: torch.Tensor,
 | |
|     ) -> torch.Tensor:
 | |
|         """Builds a pairwise matrix, scores the pairs and returns the scores.
 | |
| 
 | |
|         Args:
 | |
|             all_mentions (torch.Tensor): [n_mentions, mention_emb]
 | |
|             mentions_batch (torch.Tensor): [batch_size, mention_emb]
 | |
|             pw_batch (torch.Tensor): [batch_size, n_ants, pw_emb]
 | |
|             top_indices_batch (torch.Tensor): [batch_size, n_ants]
 | |
|             top_rough_scores_batch (torch.Tensor): [batch_size, n_ants]
 | |
| 
 | |
|         Returns:
 | |
|             torch.Tensor [batch_size, n_ants + 1]
 | |
|                 anaphoricity scores for the pairs + a dummy column
 | |
|         """
 | |
|         # [batch_size, n_ants, pair_emb]
 | |
|         pair_matrix = self._get_pair_matrix(
 | |
|             all_mentions, mentions_batch, pw_batch, top_indices_batch
 | |
|         )
 | |
| 
 | |
|         # [batch_size, n_ants]
 | |
|         scores = top_rough_scores_batch + self._ffnn(pair_matrix)
 | |
|         scores = add_dummy(scores, eps=True)
 | |
| 
 | |
|         return scores
 | |
| 
 | |
|     def _ffnn(self, x: torch.Tensor) -> torch.Tensor:
 | |
|         """
 | |
|         Calculates anaphoricity scores.
 | |
| 
 | |
|         Args:
 | |
|             x: tensor of shape [batch_size, n_ants, n_features]
 | |
| 
 | |
|         Returns:
 | |
|             tensor of shape [batch_size, n_ants]
 | |
|         """
 | |
|         x = self.out(self.hidden(x))
 | |
|         return x.squeeze(2)
 | |
| 
 | |
|     @staticmethod
 | |
|     def _get_pair_matrix(
 | |
|         all_mentions: torch.Tensor,
 | |
|         mentions_batch: torch.Tensor,
 | |
|         pw_batch: torch.Tensor,
 | |
|         top_indices_batch: torch.Tensor,
 | |
|     ) -> torch.Tensor:
 | |
|         """
 | |
|         Builds the matrix used as input for AnaphoricityScorer.
 | |
| 
 | |
|         Args:
 | |
|             all_mentions (torch.Tensor): [n_mentions, mention_emb],
 | |
|                 all the valid mentions of the document,
 | |
|                 can be on a different device
 | |
|             mentions_batch (torch.Tensor): [batch_size, mention_emb],
 | |
|                 the mentions of the current batch,
 | |
|                 is expected to be on the current device
 | |
|             pw_batch (torch.Tensor): [batch_size, n_ants, pw_emb],
 | |
|                 pairwise features of the current batch,
 | |
|                 is expected to be on the current device
 | |
|             top_indices_batch (torch.Tensor): [batch_size, n_ants],
 | |
|                 indices of antecedents of each mention
 | |
| 
 | |
|         Returns:
 | |
|             torch.Tensor: [batch_size, n_ants, pair_emb]
 | |
|         """
 | |
|         emb_size = mentions_batch.shape[1]
 | |
|         n_ants = pw_batch.shape[1]
 | |
| 
 | |
|         a_mentions = mentions_batch.unsqueeze(1).expand(-1, n_ants, emb_size)
 | |
|         b_mentions = all_mentions[top_indices_batch]
 | |
|         similarity = a_mentions * b_mentions
 | |
| 
 | |
|         out = torch.cat((a_mentions, b_mentions, similarity, pw_batch), dim=2)
 | |
|         return out
 | |
| 
 | |
| 
 | |
| class RoughScorer(torch.nn.Module):
 | |
|     """
 | |
|     Is needed to give a roughly estimate of the anaphoricity of two candidates,
 | |
|     only top scoring candidates are considered on later steps to reduce
 | |
|     computational complexity.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, features: int, dropout_rate: float, rough_k: float):
 | |
|         super().__init__()
 | |
|         self.dropout = torch.nn.Dropout(dropout_rate)
 | |
|         self.bilinear = torch.nn.Linear(features, features)
 | |
| 
 | |
|         self.k = rough_k
 | |
| 
 | |
|     def forward(
 | |
|         self,  # type: ignore  # pylint: disable=arguments-differ  #35566 in pytorch
 | |
|         mentions: torch.Tensor,
 | |
|     ) -> Tuple[torch.Tensor, torch.Tensor]:
 | |
|         """
 | |
|         Returns rough anaphoricity scores for candidates, which consist of
 | |
|         the bilinear output of the current model summed with mention scores.
 | |
|         """
 | |
|         # [n_mentions, n_mentions]
 | |
|         pair_mask = torch.arange(mentions.shape[0])
 | |
|         pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0)
 | |
|         pair_mask = torch.log((pair_mask > 0).to(torch.float))
 | |
|         bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)
 | |
|         rough_scores = pair_mask + bilinear_scores
 | |
| 
 | |
|         return self._prune(rough_scores)
 | |
| 
 | |
|     def _prune(self, rough_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
 | |
|         """
 | |
|         Selects top-k rough antecedent scores for each mention.
 | |
| 
 | |
|         Args:
 | |
|             rough_scores: tensor of shape [n_mentions, n_mentions], containing
 | |
|                 rough antecedent scores of each mention-antecedent pair.
 | |
| 
 | |
|         Returns:
 | |
|             FloatTensor of shape [n_mentions, k], top rough scores
 | |
|             LongTensor of shape [n_mentions, k], top indices
 | |
|         """
 | |
|         top_scores, indices = torch.topk(
 | |
|             rough_scores, k=min(self.k, len(rough_scores)), dim=1, sorted=False
 | |
|         )
 | |
|         return top_scores, indices
 | |
| 
 | |
| 
 | |
| class DistancePairwiseEncoder(torch.nn.Module):
 | |
|     def __init__(self, embedding_size, dropout_rate):
 | |
|         super().__init__()
 | |
|         emb_size = embedding_size
 | |
|         self.distance_emb = torch.nn.Embedding(9, emb_size)
 | |
|         self.dropout = torch.nn.Dropout(dropout_rate)
 | |
|         self.shape = emb_size
 | |
| 
 | |
|     def forward(
 | |
|         self,  # type: ignore  # pylint: disable=arguments-differ  #35566 in pytorch
 | |
|         top_indices: torch.Tensor,
 | |
|     ) -> torch.Tensor:
 | |
|         word_ids = torch.arange(0, top_indices.size(0))
 | |
|         distance = (word_ids.unsqueeze(1) - word_ids[top_indices]).clamp_min_(min=1)
 | |
|         log_distance = distance.to(torch.float).log2().floor_()
 | |
|         log_distance = log_distance.clamp_max_(max=6).to(torch.long)
 | |
|         distance = torch.where(distance < 5, distance - 1, log_distance + 2)
 | |
|         distance = self.distance_emb(distance)
 | |
|         return self.dropout(distance)
 |