Merge pull request #10782 from kadarakos/feature/coref

Feature/coref
This commit is contained in:
Paul O'Leary McCann 2022-05-11 17:02:21 +09:00 committed by GitHub
commit 57165f9631
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 111 additions and 156 deletions

View File

@ -1,14 +1,14 @@
from typing import List, Tuple from typing import List, Tuple
import torch import torch
from thinc.api import Model, chain, tuplify from thinc.api import Model, chain
from thinc.api import PyTorchWrapper, ArgsKwargs from thinc.api import PyTorchWrapper, ArgsKwargs
from thinc.types import Floats2d, Ints1d, Ints2d from thinc.types import Floats2d, Ints2d, Ints1d
from thinc.util import xp2torch, torch2xp from thinc.util import xp2torch, torch2xp
from ...tokens import Doc from ...tokens import Doc
from ...util import registry from ...util import registry
from .coref_util import add_dummy, get_sentence_ids from .coref_util import add_dummy
@registry.architectures("spacy.Coref.v1") @registry.architectures("spacy.Coref.v1")
@ -19,7 +19,6 @@ def build_wl_coref_model(
n_hidden_layers: int = 1, # TODO rename to "depth"? n_hidden_layers: int = 1, # TODO rename to "depth"?
dropout: float = 0.3, dropout: float = 0.3,
# pairs to keep per mention after rough scoring # pairs to keep per mention after rough scoring
# TODO change to meaningful name
rough_k: int = 50, rough_k: int = 50,
# TODO is this not a training loop setting? # TODO is this not a training loop setting?
a_scoring_batch_size: int = 512, a_scoring_batch_size: int = 512,
@ -34,7 +33,6 @@ def build_wl_coref_model(
dim = 768 dim = 768
with Model.define_operators({">>": chain}): with Model.define_operators({">>": chain}):
# TODO chain tok2vec with these models
coref_scorer = PyTorchWrapper( coref_scorer = PyTorchWrapper(
CorefScorer( CorefScorer(
dim, dim,
@ -49,22 +47,14 @@ def build_wl_coref_model(
convert_outputs=convert_coref_scorer_outputs, convert_outputs=convert_coref_scorer_outputs,
) )
coref_model = tok2vec >> coref_scorer coref_model = tok2vec >> coref_scorer
# XXX just ignore this until the coref scorer is integrated
# span_predictor = PyTorchWrapper(
# SpanPredictor(
# TODO this was hardcoded to 1024, check
# hidden_size,
# sp_embedding_size,
# ),
# convert_inputs=convert_span_predictor_inputs
# )
# TODO combine models so output is uniform (just one forward pass)
# It may be reasonable to have an option to disable span prediction,
# and just return words as spans.
return coref_model return coref_model
def convert_coref_scorer_inputs(model: Model, X: List[Floats2d], is_train: bool): def convert_coref_scorer_inputs(
model: Model,
X: List[Floats2d],
is_train: bool
):
# The input here is List[Floats2d], one for each doc # The input here is List[Floats2d], one for each doc
# just use the first # just use the first
# TODO real batching # TODO real batching
@ -76,7 +66,7 @@ def convert_coref_scorer_inputs(model: Model, X: List[Floats2d], is_train: bool)
gradients = torch2xp(args.args[0]) gradients = torch2xp(args.args[0])
return [gradients] return [gradients]
return ArgsKwargs(args=(word_features,), kwargs={}), backprop return ArgsKwargs(args=(word_features, ), kwargs={}), backprop
def convert_coref_scorer_outputs(model: Model, inputs_outputs, is_train: bool): def convert_coref_scorer_outputs(model: Model, inputs_outputs, is_train: bool):
@ -95,46 +85,13 @@ def convert_coref_scorer_outputs(model: Model, inputs_outputs, is_train: bool):
return (scores_xp, indices_xp), convert_for_torch_backward return (scores_xp, indices_xp), convert_for_torch_backward
# TODO add docstring for this, maybe move to utils.
# This might belong in the component.
def _clusterize(model, scores: Floats2d, top_indices: Ints2d):
xp = model.ops.xp
antecedents = scores.argmax(axis=1) - 1
not_dummy = antecedents >= 0
coref_span_heads = xp.arange(0, len(scores))[not_dummy]
antecedents = top_indices[coref_span_heads, antecedents[not_dummy]]
n_words = scores.shape[0]
nodes = [GraphNode(i) for i in range(n_words)]
for i, j in zip(coref_span_heads.tolist(), antecedents.tolist()):
nodes[i].link(nodes[j])
assert nodes[i] is not nodes[j]
clusters = []
for node in nodes:
if len(node.links) > 0 and not node.visited:
cluster = []
stack = [node]
while stack:
current_node = stack.pop()
current_node.visited = True
cluster.append(current_node.id)
stack.extend(link for link in current_node.links if not link.visited)
assert len(cluster) > 1
clusters.append(sorted(cluster))
return sorted(clusters)
class CorefScorer(torch.nn.Module): class CorefScorer(torch.nn.Module):
"""Combines all coref modules together to find coreferent spans. """
Combines all coref modules together to find coreferent token pairs.
Attributes:
epochs_trained (int): number of epochs the model has been trained for
Submodules (in the order of their usage in the pipeline): Submodules (in the order of their usage in the pipeline):
rough_scorer (RoughScorer) - rough_scorer (RoughScorer) that prunes candidate pairs
pw (PairwiseEncoder) - pw (DistancePairwiseEncoder) that computes pairwise features
a_scorer (AnaphoricityScorer) - a_scorer (AnaphoricityScorer) produces the final scores
sp (SpanPredictor)
""" """
def __init__( def __init__(
@ -149,50 +106,64 @@ class CorefScorer(torch.nn.Module):
): ):
super().__init__() super().__init__()
""" """
A newly created model is set to evaluation mode. dim: Size of the input features.
dist_emb_size: Size of the distance embeddings.
Args: hidden_size: Size of the coreference candidate embeddings.
epochs_trained (int): the number of epochs finished n_layers: Numbers of layers in the AnaphoricityScorer.
(useful for warm start) dropout_rate: Dropout probability to apply across all modules.
roughk: Number of candidates the RoughScorer returns.
batch_size: Internal batch-size for the more expensive scorer.
""" """
self.dropout = torch.nn.Dropout(dropout_rate)
self.batch_size = batch_size
# Modules
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate) self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate)
# TODO clean this up pair_emb = dim * 3 + self.pw.shape
bert_emb = dim self.a_scorer = AnaphoricityScorer(
pair_emb = bert_emb * 3 + self.pw.shape pair_emb,
hidden_size,
n_layers,
dropout_rate
)
self.lstm = torch.nn.LSTM(
input_size=dim,
hidden_size=dim,
batch_first=True,
)
self.rough_scorer = RoughScorer(dim, dropout_rate, roughk)
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate)
pair_emb = dim * 3 + self.pw.shape
self.a_scorer = AnaphoricityScorer( self.a_scorer = AnaphoricityScorer(
pair_emb, hidden_size, n_layers, dropout_rate pair_emb, hidden_size, n_layers, dropout_rate
) )
self.lstm = torch.nn.LSTM(
input_size=bert_emb,
hidden_size=bert_emb,
batch_first=True,
)
self.dropout = torch.nn.Dropout(dropout_rate)
self.rough_scorer = RoughScorer(bert_emb, dropout_rate, roughk)
self.batch_size = batch_size
def forward(self, word_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def forward(
self, word_features: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
This is a massive method, but it made sense to me to not split it into 1. LSTM encodes the incoming word_features.
several ones to let one see the data flow. 2. The RoughScorer scores and prunes the candidates.
3. The DistancePairwiseEncoder embeds the distances between pairs.
4. The AnaphoricityScorer scores all pairs in mini-batches.
Args: word_features: torch.Tensor containing word encodings
word_features: torch.Tensor containing word encodings
Returns: returns:
coreference scores and top indices coref_scores: n_words x roughk floats.
top_indices: n_words x roughk integers.
""" """
# words [n_words, span_emb]
# cluster_ids [n_words]
self.lstm.flatten_parameters() # XXX without this there's a warning self.lstm.flatten_parameters() # XXX without this there's a warning
word_features = torch.unsqueeze(word_features, dim=0) word_features = torch.unsqueeze(word_features, dim=0)
words, _ = self.lstm(word_features) words, _ = self.lstm(word_features)
words = words.squeeze() words = words.squeeze()
# words: n_words x dim
words = self.dropout(words) words = self.dropout(words)
# Obtain bilinear scores and leave only top-k antecedents for each word # Obtain bilinear scores and leave only top-k antecedents for each word
# top_rough_scores [n_words, n_ants] # top_rough_scores: (n_words x roughk)
# top_indices [n_words, n_ants] # top_indices: (n_words x roughk)
top_rough_scores, top_indices = self.rough_scorer(words) top_rough_scores, top_indices = self.rough_scorer(words)
# Get pairwise features [n_words, n_ants, n_pw_features] # Get pairwise features
# (n_words x roughk x n_pw_features)
pw = self.pw(top_indices) pw = self.pw(top_indices)
batch_size = self.batch_size batch_size = self.batch_size
a_scores_lst: List[torch.Tensor] = [] a_scores_lst: List[torch.Tensor] = []
@ -272,13 +243,8 @@ class AnaphoricityScorer(torch.nn.Module):
def _ffnn(self, x: torch.Tensor) -> torch.Tensor: def _ffnn(self, x: torch.Tensor) -> torch.Tensor:
""" """
Calculates anaphoricity scores. x: tensor of shape (batch_size x roughk x n_features
returns: tensor of shape (batch_size x rough_k)
Args:
x: tensor of shape [batch_size, n_ants, n_features]
Returns:
tensor of shape [batch_size, n_ants]
""" """
x = self.out(self.hidden(x)) x = self.out(self.hidden(x))
return x.squeeze(2) return x.squeeze(2)
@ -293,21 +259,18 @@ class AnaphoricityScorer(torch.nn.Module):
""" """
Builds the matrix used as input for AnaphoricityScorer. Builds the matrix used as input for AnaphoricityScorer.
Args: all_mentions: (n_mentions x mention_emb),
all_mentions (torch.Tensor): [n_mentions, mention_emb], all the valid mentions of the document,
all the valid mentions of the document, can be on a different device
can be on a different device mentions_batch: (batch_size x mention_emb),
mentions_batch (torch.Tensor): [batch_size, mention_emb], the mentions of the current batch.
the mentions of the current batch, pw_batch: (batch_size x roughk x pw_emb),
is expected to be on the current device pairwise distance features of the current batch.
pw_batch (torch.Tensor): [batch_size, n_ants, pw_emb], top_indices_batch: (batch_size x n_ants),
pairwise features of the current batch, indices of antecedents of each mention
is expected to be on the current device
top_indices_batch (torch.Tensor): [batch_size, n_ants],
indices of antecedents of each mention
Returns: Returns:
torch.Tensor: [batch_size, n_ants, pair_emb] out: pairwise features (batch_size x n_ants x pair_emb)
""" """
emb_size = mentions_batch.shape[1] emb_size = mentions_batch.shape[1]
n_ants = pw_batch.shape[1] n_ants = pw_batch.shape[1]
@ -322,16 +285,15 @@ class AnaphoricityScorer(torch.nn.Module):
class RoughScorer(torch.nn.Module): class RoughScorer(torch.nn.Module):
""" """
Is needed to give a roughly estimate of the anaphoricity of two candidates, Cheaper module that gives a rough estimate of the anaphoricity of two
only top scoring candidates are considered on later steps to reduce candidates, only top scoring candidates are considered on later
computational complexity. steps to reduce computational cost.
""" """
def __init__(self, features: int, dropout_rate: float, rough_k: float): def __init__(self, features: int, dropout_rate: float, rough_k: float):
super().__init__() super().__init__()
self.dropout = torch.nn.Dropout(dropout_rate) self.dropout = torch.nn.Dropout(dropout_rate)
self.bilinear = torch.nn.Linear(features, features) self.bilinear = torch.nn.Linear(features, features)
self.k = rough_k self.k = rough_k
def forward( def forward(
@ -348,29 +310,27 @@ class RoughScorer(torch.nn.Module):
pair_mask = torch.log((pair_mask > 0).to(torch.float)) pair_mask = torch.log((pair_mask > 0).to(torch.float))
bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T) bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)
rough_scores = pair_mask + bilinear_scores rough_scores = pair_mask + bilinear_scores
return self._prune(rough_scores)
def _prune(self, rough_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Selects top-k rough antecedent scores for each mention.
Args:
rough_scores: tensor of shape [n_mentions, n_mentions], containing
rough antecedent scores of each mention-antecedent pair.
Returns:
FloatTensor of shape [n_mentions, k], top rough scores
LongTensor of shape [n_mentions, k], top indices
"""
top_scores, indices = torch.topk( top_scores, indices = torch.topk(
rough_scores, k=min(self.k, len(rough_scores)), dim=1, sorted=False rough_scores, k=min(self.k, len(rough_scores)), dim=1, sorted=False
) )
return top_scores, indices return top_scores, indices
class DistancePairwiseEncoder(torch.nn.Module): class DistancePairwiseEncoder(torch.nn.Module):
def __init__(self, embedding_size, dropout_rate): def __init__(self, embedding_size, dropout_rate):
"""
Takes the top_indices indicating, which is a ranked
list for each word and its most likely corresponding
anaphora candidates. For each of these pairs it looks
up a distance embedding from a table, where the distance
corresponds to the log-distance.
embedding_size: int,
Dimensionality of the distance-embeddings table.
dropout_rate: float,
Dropout probability.
"""
super().__init__() super().__init__()
emb_size = embedding_size emb_size = embedding_size
self.distance_emb = torch.nn.Embedding(9, emb_size) self.distance_emb = torch.nn.Embedding(9, emb_size)
@ -378,11 +338,12 @@ class DistancePairwiseEncoder(torch.nn.Module):
self.shape = emb_size self.shape = emb_size
def forward( def forward(
self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch self,
top_indices: torch.Tensor, top_indices: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
word_ids = torch.arange(0, top_indices.size(0)) word_ids = torch.arange(0, top_indices.size(0))
distance = (word_ids.unsqueeze(1) - word_ids[top_indices]).clamp_min_(min=1) distance = (word_ids.unsqueeze(1) - word_ids[top_indices]
).clamp_min_(min=1)
log_distance = distance.to(torch.float).log2().floor_() log_distance = distance.to(torch.float).log2().floor_()
log_distance = log_distance.clamp_max_(max=6).to(torch.long) log_distance = log_distance.clamp_max_(max=6).to(torch.long)
distance = torch.where(distance < 5, distance - 1, log_distance + 2) distance = torch.where(distance < 5, distance - 1, log_distance + 2)

View File

@ -3,7 +3,7 @@ import torch
from thinc.api import Model, chain, tuplify from thinc.api import Model, chain, tuplify
from thinc.api import PyTorchWrapper, ArgsKwargs from thinc.api import PyTorchWrapper, ArgsKwargs
from thinc.types import Floats2d, Ints1d, Ints2d from thinc.types import Floats2d, Ints1d
from thinc.util import xp2torch, torch2xp from thinc.util import xp2torch, torch2xp
from ...tokens import Doc from ...tokens import Doc
@ -40,10 +40,9 @@ def convert_span_predictor_inputs(
model: Model, X: Tuple[Ints1d, Floats2d, Ints1d], is_train: bool model: Model, X: Tuple[Ints1d, Floats2d, Ints1d], is_train: bool
): ):
tok2vec, (sent_ids, head_ids) = X tok2vec, (sent_ids, head_ids) = X
# Normally we shoudl use the input is_train, but for these two it's not relevant # Normally we should use the input is_train, but for these two it's not relevant
def backprop(args: ArgsKwargs) -> List[Floats2d]: def backprop(args: ArgsKwargs) -> List[Floats2d]:
# convert to xp and wrap in list
gradients = torch2xp(args.args[1]) gradients = torch2xp(args.args[1])
return [[gradients], None] return [[gradients], None]
@ -55,7 +54,6 @@ def convert_span_predictor_inputs(
head_ids = xp2torch(head_ids[0], requires_grad=False) head_ids = xp2torch(head_ids[0], requires_grad=False)
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={}) argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
# TODO actually support backprop
return argskwargs, backprop return argskwargs, backprop
@ -66,15 +64,13 @@ def predict_span_clusters(
""" """
Predicts span clusters based on the word clusters. Predicts span clusters based on the word clusters.
Args: span_predictor: a SpanPredictor instance
doc (Doc): the document data sent_ids: For each word indicates, which sentence it appears in.
words (torch.Tensor): [n_words, emb_size] matrix containing words: Features for words.
embeddings for each of the words in the text clusters: Clusters inferred by the CorefScorer.
clusters (List[List[int]]): a list of clusters where each cluster
is a list of word indices
Returns: Returns:
List[List[Span]]: span clusters List[List[Tuple[int, int]]: span clusters
""" """
if not clusters: if not clusters:
return [] return []
@ -141,29 +137,29 @@ class SpanPredictor(torch.nn.Module):
# this use of dist_emb_size looks wrong but it was 64...? # this use of dist_emb_size looks wrong but it was 64...?
torch.nn.Linear(256, dist_emb_size), torch.nn.Linear(256, dist_emb_size),
) )
# TODO make the Convs also parametrizeable
self.conv = torch.nn.Sequential( self.conv = torch.nn.Sequential(
torch.nn.Conv1d(64, 4, 3, 1, 1), torch.nn.Conv1d(4, 2, 3, 1, 1) torch.nn.Conv1d(64, 4, 3, 1, 1), torch.nn.Conv1d(4, 2, 3, 1, 1)
) )
# TODO make embeddings size a parameter
self.emb = torch.nn.Embedding(128, dist_emb_size) # [-63, 63] + too_far self.emb = torch.nn.Embedding(128, dist_emb_size) # [-63, 63] + too_far
def forward( def forward(
self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch self,
sent_id, sent_id,
words: torch.Tensor, words: torch.Tensor,
heads_ids: torch.Tensor, heads_ids: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Calculates span start/end scores of words for each span head in Calculates span start/end scores of words for each span
heads_ids for each head.
Args: sent_id: Sentence id of each word.
doc (Doc): the document data words: features for each word in the document.
words (torch.Tensor): contextual embeddings for each word in the heads_ids: word indices of span heads
document, [n_words, emb_size]
heads_ids (torch.Tensor): word indices of span heads
Returns: Returns:
torch.Tensor: span start/end scores, [n_heads, n_words, 2] torch.Tensor: span start/end scores, (n_heads x n_words x 2)
""" """
# If we don't receive heads, return empty # If we don't receive heads, return empty
if heads_ids.nelement() == 0: if heads_ids.nelement() == 0:
@ -176,13 +172,13 @@ class SpanPredictor(torch.nn.Module):
emb_ids = relative_positions + 63 emb_ids = relative_positions + 63
# "too_far" # "too_far"
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127 emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127
# Obtain "same sentence" boolean mask, [n_heads, n_words] # Obtain "same sentence" boolean mask: (n_heads x n_words)
heads_ids = heads_ids.long() heads_ids = heads_ids.long()
same_sent = sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0) same_sent = sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0)
# To save memory, only pass candidates from one sentence for each head # To save memory, only pass candidates from one sentence for each head
# pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb # pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb
# for each candidate among the words in the same sentence as span_head # for each candidate among the words in the same sentence as span_head
# [n_heads, input_size * 2 + distance_emb_size] # (n_heads x input_size * 2 x distance_emb_size)
rows, cols = same_sent.nonzero(as_tuple=True) rows, cols = same_sent.nonzero(as_tuple=True)
pair_matrix = torch.cat( pair_matrix = torch.cat(
( (
@ -194,17 +190,17 @@ class SpanPredictor(torch.nn.Module):
) )
lengths = same_sent.sum(dim=1) lengths = same_sent.sum(dim=1)
padding_mask = torch.arange(0, lengths.max().item()).unsqueeze(0) padding_mask = torch.arange(0, lengths.max().item()).unsqueeze(0)
padding_mask = padding_mask < lengths.unsqueeze(1) # [n_heads, max_sent_len] # (n_heads x max_sent_len)
# [n_heads, max_sent_len, input_size * 2 + distance_emb_size] padding_mask = padding_mask < lengths.unsqueeze(1)
# (n_heads x max_sent_len x input_size * 2 + distance_emb_size)
# This is necessary to allow the convolution layer to look at several # This is necessary to allow the convolution layer to look at several
# word scores # word scores
padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1]) padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1])
padded_pairs[padding_mask] = pair_matrix padded_pairs[padding_mask] = pair_matrix
res = self.ffnn(padded_pairs) # (n_heads x n_candidates x last_layer_output)
res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output]
res = self.conv(res.permute(0, 2, 1)).permute( res = self.conv(res.permute(0, 2, 1)).permute(
0, 2, 1 0, 2, 1
) # [n_heads, n_candidates, 2] ) # (n_heads x n_candidates, 2)
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float("-inf")) scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float("-inf"))
scores[rows, cols] = res[padding_mask] scores[rows, cols] = res[padding_mask]

View File

@ -350,9 +350,7 @@ class CoreferenceResolver(TrainablePipe):
def score(self, examples, **kwargs): def score(self, examples, **kwargs):
"""Score a batch of examples using LEA. """Score a batch of examples using LEA.
For details on how LEA works and why to use it see the paper: For details on how LEA works and why to use it see the paper:
Which Coreference Evaluation Metric Do You Trust? A Proposal for a Link-based Entity Aware Metric Which Coreference Evaluation Metric Do You Trust? A Proposal for a Link-based Entity Aware Metric
Moosavi and Strube, 2016 Moosavi and Strube, 2016
https://api.semanticscholar.org/CorpusID:17606580 https://api.semanticscholar.org/CorpusID:17606580