small refactor and docs

This commit is contained in:
kadarakos 2022-05-10 16:40:31 +00:00
parent 33f4f90ff0
commit e512874c80

View File

@ -1,14 +1,14 @@
from typing import List, Tuple
import torch
from thinc.api import Model, chain, tuplify
from thinc.api import Model, chain
from thinc.api import PyTorchWrapper, ArgsKwargs
from thinc.types import Floats2d, Ints1d, Ints2d
from thinc.types import Floats2d, Ints2d
from thinc.util import xp2torch, torch2xp
from ...tokens import Doc
from ...util import registry
from .coref_util import add_dummy, get_sentence_ids
from .coref_util import add_dummy
@registry.architectures("spacy.Coref.v1")
@ -19,7 +19,6 @@ def build_wl_coref_model(
n_hidden_layers: int = 1, # TODO rename to "depth"?
dropout: float = 0.3,
# pairs to keep per mention after rough scoring
# TODO change to meaningful name
rough_k: int = 50,
# TODO is this not a training loop setting?
a_scoring_batch_size: int = 512,
@ -34,7 +33,6 @@ def build_wl_coref_model(
dim = 768
with Model.define_operators({">>": chain}):
# TODO chain tok2vec with these models
coref_scorer = PyTorchWrapper(
CorefScorer(
dim,
@ -49,18 +47,6 @@ def build_wl_coref_model(
convert_outputs=convert_coref_scorer_outputs,
)
coref_model = tok2vec >> coref_scorer
# XXX just ignore this until the coref scorer is integrated
# span_predictor = PyTorchWrapper(
# SpanPredictor(
# TODO this was hardcoded to 1024, check
# hidden_size,
# sp_embedding_size,
# ),
# convert_inputs=convert_span_predictor_inputs
# )
# TODO combine models so output is uniform (just one forward pass)
# It may be reasonable to have an option to disable span prediction,
# and just return words as spans.
return coref_model
@ -95,46 +81,13 @@ def convert_coref_scorer_outputs(model: Model, inputs_outputs, is_train: bool):
return (scores_xp, indices_xp), convert_for_torch_backward
# TODO add docstring for this, maybe move to utils.
# This might belong in the component.
def _clusterize(model, scores: Floats2d, top_indices: Ints2d):
xp = model.ops.xp
antecedents = scores.argmax(axis=1) - 1
not_dummy = antecedents >= 0
coref_span_heads = xp.arange(0, len(scores))[not_dummy]
antecedents = top_indices[coref_span_heads, antecedents[not_dummy]]
n_words = scores.shape[0]
nodes = [GraphNode(i) for i in range(n_words)]
for i, j in zip(coref_span_heads.tolist(), antecedents.tolist()):
nodes[i].link(nodes[j])
assert nodes[i] is not nodes[j]
clusters = []
for node in nodes:
if len(node.links) > 0 and not node.visited:
cluster = []
stack = [node]
while stack:
current_node = stack.pop()
current_node.visited = True
cluster.append(current_node.id)
stack.extend(link for link in current_node.links if not link.visited)
assert len(cluster) > 1
clusters.append(sorted(cluster))
return sorted(clusters)
class CorefScorer(torch.nn.Module):
"""Combines all coref modules together to find coreferent spans.
Attributes:
epochs_trained (int): number of epochs the model has been trained for
"""
Combines all coref modules together to find coreferent token pairs.
Submodules (in the order of their usage in the pipeline):
rough_scorer (RoughScorer)
pw (PairwiseEncoder)
a_scorer (AnaphoricityScorer)
sp (SpanPredictor)
- rough_scorer (RoughScorer) that prunes candidate pairs
- pw (DistancePairwiseEncoder) that computes pairwise features
- a_scorer (AnaphoricityScorer) produces the final scores
"""
def __init__(
@ -149,50 +102,54 @@ class CorefScorer(torch.nn.Module):
):
super().__init__()
"""
A newly created model is set to evaluation mode.
Args:
epochs_trained (int): the number of epochs finished
(useful for warm start)
dim: Size of the input features.
dist_emb_size: Size of the distance embeddings.
hidden_size: Size of the coreference candidate embeddings.
n_layers: Numbers of layers in the AnaphoricityScorer.
dropout_rate: Dropout probability to apply across all modules.
roughk: Number of candidates the RoughScorer returns.
batch_size: Internal batch-size for the more expensive AnaphoricityScorer.
"""
self.dropout = torch.nn.Dropout(dropout_rate)
self.batch_size = batch_size
# Modules
self.lstm = torch.nn.LSTM(
input_size=dim,
hidden_size=dim,
batch_first=True,
)
self.rough_scorer = RoughScorer(dim, dropout_rate, roughk)
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate)
# TODO clean this up
bert_emb = dim
pair_emb = bert_emb * 3 + self.pw.shape
pair_emb = dim * 3 + self.pw.shape
self.a_scorer = AnaphoricityScorer(
pair_emb, hidden_size, n_layers, dropout_rate
)
self.lstm = torch.nn.LSTM(
input_size=bert_emb,
hidden_size=bert_emb,
batch_first=True,
)
self.dropout = torch.nn.Dropout(dropout_rate)
self.rough_scorer = RoughScorer(bert_emb, dropout_rate, roughk)
self.batch_size = batch_size
def forward(self, word_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
This is a massive method, but it made sense to me to not split it into
several ones to let one see the data flow.
1. LSTM encodes the incoming word_features.
2. The RoughScorer scores and prunes the candidates.
3. The DistancePairwiseEncoder embeds the distance between remaning pairs.
4. The AnaphoricityScorer scores all pairs in mini-batches.
Args:
word_features: torch.Tensor containing word encodings
Returns:
coreference scores and top indices
returns:
coref_scores: n_words x roughk floats.
top_indices: n_words x roughk integers.
"""
# words [n_words, span_emb]
# cluster_ids [n_words]
self.lstm.flatten_parameters() # XXX without this there's a warning
word_features = torch.unsqueeze(word_features, dim=0)
words, _ = self.lstm(word_features)
words = words.squeeze()
# words: n_words x dim
words = self.dropout(words)
# Obtain bilinear scores and leave only top-k antecedents for each word
# top_rough_scores [n_words, n_ants]
# top_indices [n_words, n_ants]
# top_rough_scores: (n_words x roughk)
# top_indices: (n_words x roughk)
top_rough_scores, top_indices = self.rough_scorer(words)
# Get pairwise features [n_words, n_ants, n_pw_features]
# Get pairwise features
# (n_words x roughk x n_pw_features)
pw = self.pw(top_indices)
batch_size = self.batch_size
a_scores_lst: List[torch.Tensor] = []
@ -272,13 +229,8 @@ class AnaphoricityScorer(torch.nn.Module):
def _ffnn(self, x: torch.Tensor) -> torch.Tensor:
"""
Calculates anaphoricity scores.
Args:
x: tensor of shape [batch_size, n_ants, n_features]
Returns:
tensor of shape [batch_size, n_ants]
x: tensor of shape (batch_size x roughk x n_features
returns: tensor of shape (batch_size x rough_k)
"""
x = self.out(self.hidden(x))
return x.squeeze(2)
@ -293,21 +245,18 @@ class AnaphoricityScorer(torch.nn.Module):
"""
Builds the matrix used as input for AnaphoricityScorer.
Args:
all_mentions (torch.Tensor): [n_mentions, mention_emb],
all_mentions: (n_mentions x mention_emb),
all the valid mentions of the document,
can be on a different device
mentions_batch (torch.Tensor): [batch_size, mention_emb],
the mentions of the current batch,
is expected to be on the current device
pw_batch (torch.Tensor): [batch_size, n_ants, pw_emb],
pairwise features of the current batch,
is expected to be on the current device
top_indices_batch (torch.Tensor): [batch_size, n_ants],
mentions_batch: (batch_size x mention_emb),
the mentions of the current batch.
pw_batch: (batch_size x roughk x pw_emb),
pairwise distance features of the current batch.
top_indices_batch: (batch_size x n_ants),
indices of antecedents of each mention
Returns:
torch.Tensor: [batch_size, n_ants, pair_emb]
out: pairwise features (batch_size x n_ants x pair_emb)
"""
emb_size = mentions_batch.shape[1]
n_ants = pw_batch.shape[1]
@ -322,16 +271,15 @@ class AnaphoricityScorer(torch.nn.Module):
class RoughScorer(torch.nn.Module):
"""
Is needed to give a roughly estimate of the anaphoricity of two candidates,
only top scoring candidates are considered on later steps to reduce
computational complexity.
Cheaper module that gives a rough estimate of the anaphoricity of two
candidates, only top scoring candidates are considered on later
steps to reduce computational cost.
"""
def __init__(self, features: int, dropout_rate: float, rough_k: float):
super().__init__()
self.dropout = torch.nn.Dropout(dropout_rate)
self.bilinear = torch.nn.Linear(features, features)
self.k = rough_k
def forward(
@ -348,21 +296,6 @@ class RoughScorer(torch.nn.Module):
pair_mask = torch.log((pair_mask > 0).to(torch.float))
bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)
rough_scores = pair_mask + bilinear_scores
return self._prune(rough_scores)
def _prune(self, rough_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Selects top-k rough antecedent scores for each mention.
Args:
rough_scores: tensor of shape [n_mentions, n_mentions], containing
rough antecedent scores of each mention-antecedent pair.
Returns:
FloatTensor of shape [n_mentions, k], top rough scores
LongTensor of shape [n_mentions, k], top indices
"""
top_scores, indices = torch.topk(
rough_scores, k=min(self.k, len(rough_scores)), dim=1, sorted=False
)
@ -371,6 +304,18 @@ class RoughScorer(torch.nn.Module):
class DistancePairwiseEncoder(torch.nn.Module):
def __init__(self, embedding_size, dropout_rate):
"""
Takes the top_indices indicating, which is a ranked
list for each word and its most likely corresponding
anaphora candidates. For each of these pairs it looks
up a distance embedding from a table, where the distance
corresponds to the log-distance.
embedding_size: int,
Dimensionality of the distance-embeddings table.
dropout_rate: float,
Dropout probability.
"""
super().__init__()
emb_size = embedding_size
self.distance_emb = torch.nn.Embedding(9, emb_size)
@ -378,7 +323,7 @@ class DistancePairwiseEncoder(torch.nn.Module):
self.shape = emb_size
def forward(
self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
self,
top_indices: torch.Tensor,
) -> torch.Tensor:
word_ids = torch.arange(0, top_indices.size(0))