diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 4e8e604d8..24b5500a2 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -1,14 +1,14 @@ from typing import List, Tuple import torch -from thinc.api import Model, chain, tuplify +from thinc.api import Model, chain from thinc.api import PyTorchWrapper, ArgsKwargs -from thinc.types import Floats2d, Ints1d, Ints2d +from thinc.types import Floats2d, Ints2d, Ints1d from thinc.util import xp2torch, torch2xp from ...tokens import Doc from ...util import registry -from .coref_util import add_dummy, get_sentence_ids +from .coref_util import add_dummy @registry.architectures("spacy.Coref.v1") @@ -19,7 +19,6 @@ def build_wl_coref_model( n_hidden_layers: int = 1, # TODO rename to "depth"? dropout: float = 0.3, # pairs to keep per mention after rough scoring - # TODO change to meaningful name rough_k: int = 50, # TODO is this not a training loop setting? a_scoring_batch_size: int = 512, @@ -34,7 +33,6 @@ def build_wl_coref_model( dim = 768 with Model.define_operators({">>": chain}): - # TODO chain tok2vec with these models coref_scorer = PyTorchWrapper( CorefScorer( dim, @@ -49,22 +47,14 @@ def build_wl_coref_model( convert_outputs=convert_coref_scorer_outputs, ) coref_model = tok2vec >> coref_scorer - # XXX just ignore this until the coref scorer is integrated - # span_predictor = PyTorchWrapper( - # SpanPredictor( - # TODO this was hardcoded to 1024, check - # hidden_size, - # sp_embedding_size, - # ), - # convert_inputs=convert_span_predictor_inputs - # ) - # TODO combine models so output is uniform (just one forward pass) - # It may be reasonable to have an option to disable span prediction, - # and just return words as spans. return coref_model -def convert_coref_scorer_inputs(model: Model, X: List[Floats2d], is_train: bool): +def convert_coref_scorer_inputs( + model: Model, + X: List[Floats2d], + is_train: bool +): # The input here is List[Floats2d], one for each doc # just use the first # TODO real batching @@ -76,7 +66,7 @@ def convert_coref_scorer_inputs(model: Model, X: List[Floats2d], is_train: bool) gradients = torch2xp(args.args[0]) return [gradients] - return ArgsKwargs(args=(word_features,), kwargs={}), backprop + return ArgsKwargs(args=(word_features, ), kwargs={}), backprop def convert_coref_scorer_outputs(model: Model, inputs_outputs, is_train: bool): @@ -95,46 +85,13 @@ def convert_coref_scorer_outputs(model: Model, inputs_outputs, is_train: bool): return (scores_xp, indices_xp), convert_for_torch_backward -# TODO add docstring for this, maybe move to utils. -# This might belong in the component. -def _clusterize(model, scores: Floats2d, top_indices: Ints2d): - xp = model.ops.xp - antecedents = scores.argmax(axis=1) - 1 - not_dummy = antecedents >= 0 - coref_span_heads = xp.arange(0, len(scores))[not_dummy] - antecedents = top_indices[coref_span_heads, antecedents[not_dummy]] - n_words = scores.shape[0] - nodes = [GraphNode(i) for i in range(n_words)] - for i, j in zip(coref_span_heads.tolist(), antecedents.tolist()): - nodes[i].link(nodes[j]) - assert nodes[i] is not nodes[j] - - clusters = [] - for node in nodes: - if len(node.links) > 0 and not node.visited: - cluster = [] - stack = [node] - while stack: - current_node = stack.pop() - current_node.visited = True - cluster.append(current_node.id) - stack.extend(link for link in current_node.links if not link.visited) - assert len(cluster) > 1 - clusters.append(sorted(cluster)) - return sorted(clusters) - - class CorefScorer(torch.nn.Module): - """Combines all coref modules together to find coreferent spans. - - Attributes: - epochs_trained (int): number of epochs the model has been trained for - + """ + Combines all coref modules together to find coreferent token pairs. Submodules (in the order of their usage in the pipeline): - rough_scorer (RoughScorer) - pw (PairwiseEncoder) - a_scorer (AnaphoricityScorer) - sp (SpanPredictor) + - rough_scorer (RoughScorer) that prunes candidate pairs + - pw (DistancePairwiseEncoder) that computes pairwise features + - a_scorer (AnaphoricityScorer) produces the final scores """ def __init__( @@ -149,50 +106,64 @@ class CorefScorer(torch.nn.Module): ): super().__init__() """ - A newly created model is set to evaluation mode. - - Args: - epochs_trained (int): the number of epochs finished - (useful for warm start) + dim: Size of the input features. + dist_emb_size: Size of the distance embeddings. + hidden_size: Size of the coreference candidate embeddings. + n_layers: Numbers of layers in the AnaphoricityScorer. + dropout_rate: Dropout probability to apply across all modules. + roughk: Number of candidates the RoughScorer returns. + batch_size: Internal batch-size for the more expensive scorer. """ + self.dropout = torch.nn.Dropout(dropout_rate) + self.batch_size = batch_size + # Modules self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate) - # TODO clean this up - bert_emb = dim - pair_emb = bert_emb * 3 + self.pw.shape + pair_emb = dim * 3 + self.pw.shape + self.a_scorer = AnaphoricityScorer( + pair_emb, + hidden_size, + n_layers, + dropout_rate + ) + self.lstm = torch.nn.LSTM( + input_size=dim, + hidden_size=dim, + batch_first=True, + ) + self.rough_scorer = RoughScorer(dim, dropout_rate, roughk) + self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate) + pair_emb = dim * 3 + self.pw.shape self.a_scorer = AnaphoricityScorer( pair_emb, hidden_size, n_layers, dropout_rate ) - self.lstm = torch.nn.LSTM( - input_size=bert_emb, - hidden_size=bert_emb, - batch_first=True, - ) - self.dropout = torch.nn.Dropout(dropout_rate) - self.rough_scorer = RoughScorer(bert_emb, dropout_rate, roughk) - self.batch_size = batch_size - def forward(self, word_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward( + self, word_features: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: """ - This is a massive method, but it made sense to me to not split it into - several ones to let one see the data flow. + 1. LSTM encodes the incoming word_features. + 2. The RoughScorer scores and prunes the candidates. + 3. The DistancePairwiseEncoder embeds the distances between pairs. + 4. The AnaphoricityScorer scores all pairs in mini-batches. - Args: - word_features: torch.Tensor containing word encodings - Returns: - coreference scores and top indices + word_features: torch.Tensor containing word encodings + + returns: + coref_scores: n_words x roughk floats. + top_indices: n_words x roughk integers. """ - # words [n_words, span_emb] - # cluster_ids [n_words] self.lstm.flatten_parameters() # XXX without this there's a warning word_features = torch.unsqueeze(word_features, dim=0) words, _ = self.lstm(word_features) words = words.squeeze() + # words: n_words x dim words = self.dropout(words) # Obtain bilinear scores and leave only top-k antecedents for each word - # top_rough_scores [n_words, n_ants] - # top_indices [n_words, n_ants] + # top_rough_scores: (n_words x roughk) + # top_indices: (n_words x roughk) top_rough_scores, top_indices = self.rough_scorer(words) - # Get pairwise features [n_words, n_ants, n_pw_features] + # Get pairwise features + # (n_words x roughk x n_pw_features) pw = self.pw(top_indices) batch_size = self.batch_size a_scores_lst: List[torch.Tensor] = [] @@ -272,13 +243,8 @@ class AnaphoricityScorer(torch.nn.Module): def _ffnn(self, x: torch.Tensor) -> torch.Tensor: """ - Calculates anaphoricity scores. - - Args: - x: tensor of shape [batch_size, n_ants, n_features] - - Returns: - tensor of shape [batch_size, n_ants] + x: tensor of shape (batch_size x roughk x n_features + returns: tensor of shape (batch_size x rough_k) """ x = self.out(self.hidden(x)) return x.squeeze(2) @@ -293,21 +259,18 @@ class AnaphoricityScorer(torch.nn.Module): """ Builds the matrix used as input for AnaphoricityScorer. - Args: - all_mentions (torch.Tensor): [n_mentions, mention_emb], - all the valid mentions of the document, - can be on a different device - mentions_batch (torch.Tensor): [batch_size, mention_emb], - the mentions of the current batch, - is expected to be on the current device - pw_batch (torch.Tensor): [batch_size, n_ants, pw_emb], - pairwise features of the current batch, - is expected to be on the current device - top_indices_batch (torch.Tensor): [batch_size, n_ants], - indices of antecedents of each mention + all_mentions: (n_mentions x mention_emb), + all the valid mentions of the document, + can be on a different device + mentions_batch: (batch_size x mention_emb), + the mentions of the current batch. + pw_batch: (batch_size x roughk x pw_emb), + pairwise distance features of the current batch. + top_indices_batch: (batch_size x n_ants), + indices of antecedents of each mention Returns: - torch.Tensor: [batch_size, n_ants, pair_emb] + out: pairwise features (batch_size x n_ants x pair_emb) """ emb_size = mentions_batch.shape[1] n_ants = pw_batch.shape[1] @@ -322,16 +285,15 @@ class AnaphoricityScorer(torch.nn.Module): class RoughScorer(torch.nn.Module): """ - Is needed to give a roughly estimate of the anaphoricity of two candidates, - only top scoring candidates are considered on later steps to reduce - computational complexity. + Cheaper module that gives a rough estimate of the anaphoricity of two + candidates, only top scoring candidates are considered on later + steps to reduce computational cost. """ def __init__(self, features: int, dropout_rate: float, rough_k: float): super().__init__() self.dropout = torch.nn.Dropout(dropout_rate) self.bilinear = torch.nn.Linear(features, features) - self.k = rough_k def forward( @@ -348,29 +310,27 @@ class RoughScorer(torch.nn.Module): pair_mask = torch.log((pair_mask > 0).to(torch.float)) bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T) rough_scores = pair_mask + bilinear_scores - - return self._prune(rough_scores) - - def _prune(self, rough_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Selects top-k rough antecedent scores for each mention. - - Args: - rough_scores: tensor of shape [n_mentions, n_mentions], containing - rough antecedent scores of each mention-antecedent pair. - - Returns: - FloatTensor of shape [n_mentions, k], top rough scores - LongTensor of shape [n_mentions, k], top indices - """ top_scores, indices = torch.topk( rough_scores, k=min(self.k, len(rough_scores)), dim=1, sorted=False ) + return top_scores, indices class DistancePairwiseEncoder(torch.nn.Module): def __init__(self, embedding_size, dropout_rate): + """ + Takes the top_indices indicating, which is a ranked + list for each word and its most likely corresponding + anaphora candidates. For each of these pairs it looks + up a distance embedding from a table, where the distance + corresponds to the log-distance. + + embedding_size: int, + Dimensionality of the distance-embeddings table. + dropout_rate: float, + Dropout probability. + """ super().__init__() emb_size = embedding_size self.distance_emb = torch.nn.Embedding(9, emb_size) @@ -378,11 +338,12 @@ class DistancePairwiseEncoder(torch.nn.Module): self.shape = emb_size def forward( - self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch - top_indices: torch.Tensor, + self, + top_indices: torch.Tensor ) -> torch.Tensor: word_ids = torch.arange(0, top_indices.size(0)) - distance = (word_ids.unsqueeze(1) - word_ids[top_indices]).clamp_min_(min=1) + distance = (word_ids.unsqueeze(1) - word_ids[top_indices] + ).clamp_min_(min=1) log_distance = distance.to(torch.float).log2().floor_() log_distance = log_distance.clamp_max_(max=6).to(torch.long) distance = torch.where(distance < 5, distance - 1, log_distance + 2) diff --git a/spacy/ml/models/span_predictor.py b/spacy/ml/models/span_predictor.py index 779aa8c1e..b990b4019 100644 --- a/spacy/ml/models/span_predictor.py +++ b/spacy/ml/models/span_predictor.py @@ -3,7 +3,7 @@ import torch from thinc.api import Model, chain, tuplify from thinc.api import PyTorchWrapper, ArgsKwargs -from thinc.types import Floats2d, Ints1d, Ints2d +from thinc.types import Floats2d, Ints1d from thinc.util import xp2torch, torch2xp from ...tokens import Doc @@ -40,10 +40,9 @@ def convert_span_predictor_inputs( model: Model, X: Tuple[Ints1d, Floats2d, Ints1d], is_train: bool ): tok2vec, (sent_ids, head_ids) = X - # Normally we shoudl use the input is_train, but for these two it's not relevant + # Normally we should use the input is_train, but for these two it's not relevant def backprop(args: ArgsKwargs) -> List[Floats2d]: - # convert to xp and wrap in list gradients = torch2xp(args.args[1]) return [[gradients], None] @@ -55,7 +54,6 @@ def convert_span_predictor_inputs( head_ids = xp2torch(head_ids[0], requires_grad=False) argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={}) - # TODO actually support backprop return argskwargs, backprop @@ -66,15 +64,13 @@ def predict_span_clusters( """ Predicts span clusters based on the word clusters. - Args: - doc (Doc): the document data - words (torch.Tensor): [n_words, emb_size] matrix containing - embeddings for each of the words in the text - clusters (List[List[int]]): a list of clusters where each cluster - is a list of word indices + span_predictor: a SpanPredictor instance + sent_ids: For each word indicates, which sentence it appears in. + words: Features for words. + clusters: Clusters inferred by the CorefScorer. Returns: - List[List[Span]]: span clusters + List[List[Tuple[int, int]]: span clusters """ if not clusters: return [] @@ -141,29 +137,29 @@ class SpanPredictor(torch.nn.Module): # this use of dist_emb_size looks wrong but it was 64...? torch.nn.Linear(256, dist_emb_size), ) + # TODO make the Convs also parametrizeable self.conv = torch.nn.Sequential( torch.nn.Conv1d(64, 4, 3, 1, 1), torch.nn.Conv1d(4, 2, 3, 1, 1) ) + # TODO make embeddings size a parameter self.emb = torch.nn.Embedding(128, dist_emb_size) # [-63, 63] + too_far def forward( - self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch + self, sent_id, words: torch.Tensor, heads_ids: torch.Tensor, ) -> torch.Tensor: """ - Calculates span start/end scores of words for each span head in - heads_ids + Calculates span start/end scores of words for each span + for each head. - Args: - doc (Doc): the document data - words (torch.Tensor): contextual embeddings for each word in the - document, [n_words, emb_size] - heads_ids (torch.Tensor): word indices of span heads + sent_id: Sentence id of each word. + words: features for each word in the document. + heads_ids: word indices of span heads Returns: - torch.Tensor: span start/end scores, [n_heads, n_words, 2] + torch.Tensor: span start/end scores, (n_heads x n_words x 2) """ # If we don't receive heads, return empty if heads_ids.nelement() == 0: @@ -176,13 +172,13 @@ class SpanPredictor(torch.nn.Module): emb_ids = relative_positions + 63 # "too_far" emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127 - # Obtain "same sentence" boolean mask, [n_heads, n_words] + # Obtain "same sentence" boolean mask: (n_heads x n_words) heads_ids = heads_ids.long() same_sent = sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0) # To save memory, only pass candidates from one sentence for each head # pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb # for each candidate among the words in the same sentence as span_head - # [n_heads, input_size * 2 + distance_emb_size] + # (n_heads x input_size * 2 x distance_emb_size) rows, cols = same_sent.nonzero(as_tuple=True) pair_matrix = torch.cat( ( @@ -194,17 +190,17 @@ class SpanPredictor(torch.nn.Module): ) lengths = same_sent.sum(dim=1) padding_mask = torch.arange(0, lengths.max().item()).unsqueeze(0) - padding_mask = padding_mask < lengths.unsqueeze(1) # [n_heads, max_sent_len] - # [n_heads, max_sent_len, input_size * 2 + distance_emb_size] + # (n_heads x max_sent_len) + padding_mask = padding_mask < lengths.unsqueeze(1) + # (n_heads x max_sent_len x input_size * 2 + distance_emb_size) # This is necessary to allow the convolution layer to look at several # word scores padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1]) padded_pairs[padding_mask] = pair_matrix - - res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output] + res = self.ffnn(padded_pairs) # (n_heads x n_candidates x last_layer_output) res = self.conv(res.permute(0, 2, 1)).permute( 0, 2, 1 - ) # [n_heads, n_candidates, 2] + ) # (n_heads x n_candidates, 2) scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float("-inf")) scores[rows, cols] = res[padding_mask] diff --git a/spacy/pipeline/coref.py b/spacy/pipeline/coref.py index dcc4434ca..5237788cc 100644 --- a/spacy/pipeline/coref.py +++ b/spacy/pipeline/coref.py @@ -350,9 +350,7 @@ class CoreferenceResolver(TrainablePipe): def score(self, examples, **kwargs): """Score a batch of examples using LEA. - For details on how LEA works and why to use it see the paper: - Which Coreference Evaluation Metric Do You Trust? A Proposal for a Link-based Entity Aware Metric Moosavi and Strube, 2016 https://api.semanticscholar.org/CorpusID:17606580