Split span predictor model into its own file

This commit is contained in:
Paul O'Leary McCann 2022-05-10 19:08:21 +09:00
parent f852c5cea4
commit 41fc092674
3 changed files with 216 additions and 200 deletions

View File

@ -1,4 +1,5 @@
from .coref import * #noqa
from .span_predictor import * #noqa
from .entity_linker import * # noqa
from .multi_task import * # noqa
from .parser import * # noqa

View File

@ -64,30 +64,6 @@ def build_wl_coref_model(
return coref_model
@registry.architectures("spacy.SpanPredictor.v1")
def build_span_predictor(
tok2vec: Model[List[Doc], List[Floats2d]],
hidden_size: int = 1024,
dist_emb_size: int = 64,
):
# TODO fix this
try:
dim = tok2vec.get_dim("nO")
except ValueError:
# happens with transformer listener
dim = 768
with Model.define_operators({">>": chain, "&": tuplify}):
span_predictor = PyTorchWrapper(
SpanPredictor(dim, hidden_size, dist_emb_size),
convert_inputs=convert_span_predictor_inputs,
)
# TODO use proper parameter for prefix
head_info = build_get_head_metadata("coref_head_clusters")
model = (tok2vec & head_info) >> span_predictor
return model
def convert_coref_scorer_inputs(model: Model, X: List[Floats2d], is_train: bool):
# The input here is List[Floats2d], one for each doc
@ -120,61 +96,6 @@ def convert_coref_scorer_outputs(model: Model, inputs_outputs, is_train: bool):
return (scores_xp, indices_xp), convert_for_torch_backward
def convert_span_predictor_inputs(
model: Model, X: Tuple[Ints1d, Floats2d, Ints1d], is_train: bool
):
tok2vec, (sent_ids, head_ids) = X
# Normally we shoudl use the input is_train, but for these two it's not relevant
def backprop(args: ArgsKwargs) -> List[Floats2d]:
# convert to xp and wrap in list
gradients = torch2xp(args.args[1])
return [[gradients], None]
word_features = xp2torch(tok2vec[0], requires_grad=is_train)
sent_ids = xp2torch(sent_ids[0], requires_grad=False)
if not head_ids[0].size:
head_ids = torch.empty(size=(0,))
else:
head_ids = xp2torch(head_ids[0], requires_grad=False)
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
# TODO actually support backprop
return argskwargs, backprop
# TODO This probably belongs in the component, not the model.
def predict_span_clusters(
span_predictor: Model, sent_ids: Ints1d, words: Floats2d, clusters: List[Ints1d]
):
"""
Predicts span clusters based on the word clusters.
Args:
doc (Doc): the document data
words (torch.Tensor): [n_words, emb_size] matrix containing
embeddings for each of the words in the text
clusters (List[List[int]]): a list of clusters where each cluster
is a list of word indices
Returns:
List[List[Span]]: span clusters
"""
if not clusters:
return []
xp = span_predictor.ops.xp
heads_ids = xp.asarray(sorted(i for cluster in clusters for i in cluster))
scores = span_predictor.predict((sent_ids, words, heads_ids))
starts = scores[:, :, 0].argmax(axis=1).tolist()
ends = (scores[:, :, 1].argmax(axis=1) + 1).tolist()
head2span = {
head: (start, end) for head, start, end in zip(heads_ids.tolist(), starts, ends)
}
return [[head2span[head] for head in cluster] for cluster in clusters]
# TODO add docstring for this, maybe move to utils.
# This might belong in the component.
@ -205,36 +126,6 @@ def _clusterize(model, scores: Floats2d, top_indices: Ints2d):
return sorted(clusters)
def build_get_head_metadata(prefix):
# TODO this name is awful, fix it
model = Model(
"HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward
)
return model
def head_data_forward(model, docs, is_train):
"""A layer to generate the extra data needed for the span predictor."""
sent_ids = []
head_ids = []
prefix = model.attrs["prefix"]
for doc in docs:
sids = model.ops.asarray2i(get_sentence_ids(doc))
sent_ids.append(sids)
heads = []
for key, sg in doc.spans.items():
if not key.startswith(prefix):
continue
for span in sg:
# TODO warn if spans are more than one token
heads.append(span[0].i)
heads = model.ops.asarray2i(heads)
head_ids.append(heads)
# each of these is a list with one entry per doc
# backprop is just a placeholder
# TODO it would probably be better to have a list of tuples than two lists of arrays
return (sent_ids, head_ids), lambda x: []
class CorefScorer(torch.nn.Module):
"""Combines all coref modules together to find coreferent spans.
@ -481,97 +372,6 @@ class RoughScorer(torch.nn.Module):
return top_scores, indices
class SpanPredictor(torch.nn.Module):
def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int):
super().__init__()
# input size = single token size
# 64 = probably distance emb size
# TODO check that dist_emb_size use is correct
self.ffnn = torch.nn.Sequential(
torch.nn.Linear(input_size * 2 + dist_emb_size, hidden_size),
torch.nn.ReLU(),
torch.nn.Dropout(0.3),
# TODO seems weird the 256 isn't a parameter???
torch.nn.Linear(hidden_size, 256),
torch.nn.ReLU(),
torch.nn.Dropout(0.3),
# this use of dist_emb_size looks wrong but it was 64...?
torch.nn.Linear(256, dist_emb_size),
)
self.conv = torch.nn.Sequential(
torch.nn.Conv1d(64, 4, 3, 1, 1), torch.nn.Conv1d(4, 2, 3, 1, 1)
)
self.emb = torch.nn.Embedding(128, dist_emb_size) # [-63, 63] + too_far
def forward(
self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
sent_id,
words: torch.Tensor,
heads_ids: torch.Tensor,
) -> torch.Tensor:
"""
Calculates span start/end scores of words for each span head in
heads_ids
Args:
doc (Doc): the document data
words (torch.Tensor): contextual embeddings for each word in the
document, [n_words, emb_size]
heads_ids (torch.Tensor): word indices of span heads
Returns:
torch.Tensor: span start/end scores, [n_heads, n_words, 2]
"""
# If we don't receive heads, return empty
if heads_ids.nelement() == 0:
return torch.empty(size=(0,))
# Obtain distance embedding indices, [n_heads, n_words]
relative_positions = heads_ids.unsqueeze(1) - torch.arange(
words.shape[0]
).unsqueeze(0)
# make all valid distances positive
emb_ids = relative_positions + 63
# "too_far"
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127
# Obtain "same sentence" boolean mask, [n_heads, n_words]
heads_ids = heads_ids.long()
same_sent = sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0)
# To save memory, only pass candidates from one sentence for each head
# pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb
# for each candidate among the words in the same sentence as span_head
# [n_heads, input_size * 2 + distance_emb_size]
rows, cols = same_sent.nonzero(as_tuple=True)
pair_matrix = torch.cat(
(
words[heads_ids[rows]],
words[cols],
self.emb(emb_ids[rows, cols]),
),
dim=1,
)
lengths = same_sent.sum(dim=1)
padding_mask = torch.arange(0, lengths.max().item()).unsqueeze(0)
padding_mask = padding_mask < lengths.unsqueeze(1) # [n_heads, max_sent_len]
# [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
# This is necessary to allow the convolution layer to look at several
# word scores
padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1])
padded_pairs[padding_mask] = pair_matrix
res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output]
res = self.conv(res.permute(0, 2, 1)).permute(
0, 2, 1
) # [n_heads, n_candidates, 2]
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float("-inf"))
scores[rows, cols] = res[padding_mask]
# Make sure that start <= head <= end during inference
if not self.training:
valid_starts = torch.log((relative_positions >= 0).to(torch.float))
valid_ends = torch.log((relative_positions <= 0).to(torch.float))
valid_positions = torch.stack((valid_starts, valid_ends), dim=2)
return scores + valid_positions
return scores
class DistancePairwiseEncoder(torch.nn.Module):

View File

@ -0,0 +1,215 @@
from typing import List, Tuple
import torch
from thinc.api import Model, chain, tuplify
from thinc.api import PyTorchWrapper, ArgsKwargs
from thinc.types import Floats2d, Ints1d, Ints2d
from thinc.util import xp2torch, torch2xp
from ...tokens import Doc
from ...util import registry
from .coref_util import get_sentence_ids
@registry.architectures("spacy.SpanPredictor.v1")
def build_span_predictor(
tok2vec: Model[List[Doc], List[Floats2d]],
hidden_size: int = 1024,
dist_emb_size: int = 64,
):
# TODO fix this
try:
dim = tok2vec.get_dim("nO")
except ValueError:
# happens with transformer listener
dim = 768
with Model.define_operators({">>": chain, "&": tuplify}):
span_predictor = PyTorchWrapper(
SpanPredictor(dim, hidden_size, dist_emb_size),
convert_inputs=convert_span_predictor_inputs,
)
# TODO use proper parameter for prefix
head_info = build_get_head_metadata("coref_head_clusters")
model = (tok2vec & head_info) >> span_predictor
return model
def convert_span_predictor_inputs(
model: Model, X: Tuple[Ints1d, Floats2d, Ints1d], is_train: bool
):
tok2vec, (sent_ids, head_ids) = X
# Normally we shoudl use the input is_train, but for these two it's not relevant
def backprop(args: ArgsKwargs) -> List[Floats2d]:
# convert to xp and wrap in list
gradients = torch2xp(args.args[1])
return [[gradients], None]
word_features = xp2torch(tok2vec[0], requires_grad=is_train)
sent_ids = xp2torch(sent_ids[0], requires_grad=False)
if not head_ids[0].size:
head_ids = torch.empty(size=(0,))
else:
head_ids = xp2torch(head_ids[0], requires_grad=False)
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
# TODO actually support backprop
return argskwargs, backprop
# TODO This probably belongs in the component, not the model.
def predict_span_clusters(
span_predictor: Model, sent_ids: Ints1d, words: Floats2d, clusters: List[Ints1d]
):
"""
Predicts span clusters based on the word clusters.
Args:
doc (Doc): the document data
words (torch.Tensor): [n_words, emb_size] matrix containing
embeddings for each of the words in the text
clusters (List[List[int]]): a list of clusters where each cluster
is a list of word indices
Returns:
List[List[Span]]: span clusters
"""
if not clusters:
return []
xp = span_predictor.ops.xp
heads_ids = xp.asarray(sorted(i for cluster in clusters for i in cluster))
scores = span_predictor.predict((sent_ids, words, heads_ids))
starts = scores[:, :, 0].argmax(axis=1).tolist()
ends = (scores[:, :, 1].argmax(axis=1) + 1).tolist()
head2span = {
head: (start, end) for head, start, end in zip(heads_ids.tolist(), starts, ends)
}
return [[head2span[head] for head in cluster] for cluster in clusters]
# TODO this should maybe have a different name from the component
class SpanPredictor(torch.nn.Module):
def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int):
super().__init__()
# input size = single token size
# 64 = probably distance emb size
# TODO check that dist_emb_size use is correct
self.ffnn = torch.nn.Sequential(
torch.nn.Linear(input_size * 2 + dist_emb_size, hidden_size),
torch.nn.ReLU(),
torch.nn.Dropout(0.3),
# TODO seems weird the 256 isn't a parameter???
torch.nn.Linear(hidden_size, 256),
torch.nn.ReLU(),
torch.nn.Dropout(0.3),
# this use of dist_emb_size looks wrong but it was 64...?
torch.nn.Linear(256, dist_emb_size),
)
self.conv = torch.nn.Sequential(
torch.nn.Conv1d(64, 4, 3, 1, 1), torch.nn.Conv1d(4, 2, 3, 1, 1)
)
self.emb = torch.nn.Embedding(128, dist_emb_size) # [-63, 63] + too_far
def forward(
self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
sent_id,
words: torch.Tensor,
heads_ids: torch.Tensor,
) -> torch.Tensor:
"""
Calculates span start/end scores of words for each span head in
heads_ids
Args:
doc (Doc): the document data
words (torch.Tensor): contextual embeddings for each word in the
document, [n_words, emb_size]
heads_ids (torch.Tensor): word indices of span heads
Returns:
torch.Tensor: span start/end scores, [n_heads, n_words, 2]
"""
# If we don't receive heads, return empty
if heads_ids.nelement() == 0:
return torch.empty(size=(0,))
# Obtain distance embedding indices, [n_heads, n_words]
relative_positions = heads_ids.unsqueeze(1) - torch.arange(
words.shape[0]
).unsqueeze(0)
# make all valid distances positive
emb_ids = relative_positions + 63
# "too_far"
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127
# Obtain "same sentence" boolean mask, [n_heads, n_words]
heads_ids = heads_ids.long()
same_sent = sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0)
# To save memory, only pass candidates from one sentence for each head
# pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb
# for each candidate among the words in the same sentence as span_head
# [n_heads, input_size * 2 + distance_emb_size]
rows, cols = same_sent.nonzero(as_tuple=True)
pair_matrix = torch.cat(
(
words[heads_ids[rows]],
words[cols],
self.emb(emb_ids[rows, cols]),
),
dim=1,
)
lengths = same_sent.sum(dim=1)
padding_mask = torch.arange(0, lengths.max().item()).unsqueeze(0)
padding_mask = padding_mask < lengths.unsqueeze(1) # [n_heads, max_sent_len]
# [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
# This is necessary to allow the convolution layer to look at several
# word scores
padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1])
padded_pairs[padding_mask] = pair_matrix
res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output]
res = self.conv(res.permute(0, 2, 1)).permute(
0, 2, 1
) # [n_heads, n_candidates, 2]
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float("-inf"))
scores[rows, cols] = res[padding_mask]
# Make sure that start <= head <= end during inference
if not self.training:
valid_starts = torch.log((relative_positions >= 0).to(torch.float))
valid_ends = torch.log((relative_positions <= 0).to(torch.float))
valid_positions = torch.stack((valid_starts, valid_ends), dim=2)
return scores + valid_positions
return scores
def build_get_head_metadata(prefix):
# TODO this name is awful, fix it
model = Model(
"HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward
)
return model
def head_data_forward(model, docs, is_train):
"""A layer to generate the extra data needed for the span predictor."""
sent_ids = []
head_ids = []
prefix = model.attrs["prefix"]
for doc in docs:
sids = model.ops.asarray2i(get_sentence_ids(doc))
sent_ids.append(sids)
heads = []
for key, sg in doc.spans.items():
if not key.startswith(prefix):
continue
for span in sg:
# TODO warn if spans are more than one token
heads.append(span[0].i)
heads = model.ops.asarray2i(heads)
head_ids.append(heads)
# each of these is a list with one entry per doc
# backprop is just a placeholder
# TODO it would probably be better to have a list of tuples than two lists of arrays
return (sent_ids, head_ids), lambda x: []