clean up unused imports + black formatting

This commit is contained in:
svlandeg 2022-05-09 13:34:50 +02:00
parent 683f470852
commit 6b51258a58

View File

@ -1,21 +1,16 @@
from dataclasses import dataclass from typing import List, Tuple
import warnings
from thinc.api import Model, Linear, Relu, Dropout
from thinc.api import chain, noop, Embed, add, tuplify, concatenate
from thinc.api import reduce_first, reduce_last, reduce_mean
from thinc.api import PyTorchWrapper, ArgsKwargs
from thinc.types import Floats2d, Floats1d, Ints1d, Ints2d, Ragged
from typing import List, Callable, Tuple, Any
from ...tokens import Doc
from ...util import registry
from ..extract_spans import extract_spans
import torch import torch
from thinc.api import Model, chain, tuplify
from thinc.api import PyTorchWrapper, ArgsKwargs
from thinc.types import Floats2d, Ints1d, Ints2d
from thinc.util import xp2torch, torch2xp from thinc.util import xp2torch, torch2xp
from ...tokens import Doc
from ...util import registry
from .coref_util import add_dummy, get_sentence_ids from .coref_util import add_dummy, get_sentence_ids
@registry.architectures("spacy.Coref.v1") @registry.architectures("spacy.Coref.v1")
def build_wl_coref_model( def build_wl_coref_model(
tok2vec: Model[List[Doc], List[Floats2d]], tok2vec: Model[List[Doc], List[Floats2d]],
@ -30,7 +25,7 @@ def build_wl_coref_model(
a_scoring_batch_size: int = 512, a_scoring_batch_size: int = 512,
# span predictor embeddings # span predictor embeddings
sp_embedding_size: int = 64, sp_embedding_size: int = 64,
): ):
# TODO fix this # TODO fix this
try: try:
dim = tok2vec.get_dim("nO") dim = tok2vec.get_dim("nO")
@ -48,10 +43,10 @@ def build_wl_coref_model(
n_hidden_layers, n_hidden_layers,
dropout, dropout,
rough_k, rough_k,
a_scoring_batch_size a_scoring_batch_size,
), ),
convert_inputs=convert_coref_scorer_inputs, convert_inputs=convert_coref_scorer_inputs,
convert_outputs=convert_coref_scorer_outputs convert_outputs=convert_coref_scorer_outputs,
) )
coref_model = tok2vec >> coref_scorer coref_model = tok2vec >> coref_scorer
# XXX just ignore this until the coref scorer is integrated # XXX just ignore this until the coref scorer is integrated
@ -68,12 +63,13 @@ def build_wl_coref_model(
# and just return words as spans. # and just return words as spans.
return coref_model return coref_model
@registry.architectures("spacy.SpanPredictor.v1") @registry.architectures("spacy.SpanPredictor.v1")
def build_span_predictor( def build_span_predictor(
tok2vec: Model[List[Doc], List[Floats2d]], tok2vec: Model[List[Doc], List[Floats2d]],
hidden_size: int = 1024, hidden_size: int = 1024,
dist_emb_size: int = 64, dist_emb_size: int = 64,
): ):
# TODO fix this # TODO fix this
try: try:
dim = tok2vec.get_dim("nO") dim = tok2vec.get_dim("nO")
@ -84,22 +80,16 @@ def build_span_predictor(
with Model.define_operators({">>": chain, "&": tuplify}): with Model.define_operators({">>": chain, "&": tuplify}):
span_predictor = PyTorchWrapper( span_predictor = PyTorchWrapper(
SpanPredictor(dim, hidden_size, dist_emb_size), SpanPredictor(dim, hidden_size, dist_emb_size),
convert_inputs=convert_span_predictor_inputs convert_inputs=convert_span_predictor_inputs,
) )
# TODO use proper parameter for prefix # TODO use proper parameter for prefix
head_info = build_get_head_metadata( head_info = build_get_head_metadata("coref_head_clusters")
"coref_head_clusters"
)
model = (tok2vec & head_info) >> span_predictor model = (tok2vec & head_info) >> span_predictor
return model return model
def convert_coref_scorer_inputs( def convert_coref_scorer_inputs(model: Model, X: List[Floats2d], is_train: bool):
model: Model,
X: List[Floats2d],
is_train: bool
):
# The input here is List[Floats2d], one for each doc # The input here is List[Floats2d], one for each doc
# just use the first # just use the first
# TODO real batching # TODO real batching
@ -111,14 +101,10 @@ def convert_coref_scorer_inputs(
gradients = torch2xp(args.args[0]) gradients = torch2xp(args.args[0])
return [gradients] return [gradients]
return ArgsKwargs(args=(word_features, ), kwargs={}), backprop return ArgsKwargs(args=(word_features,), kwargs={}), backprop
def convert_coref_scorer_outputs( def convert_coref_scorer_outputs(model: Model, inputs_outputs, is_train: bool):
model: Model,
inputs_outputs,
is_train: bool
):
_, outputs = inputs_outputs _, outputs = inputs_outputs
scores, indices = outputs scores, indices = outputs
@ -135,9 +121,7 @@ def convert_coref_scorer_outputs(
def convert_span_predictor_inputs( def convert_span_predictor_inputs(
model: Model, model: Model, X: Tuple[Ints1d, Floats2d, Ints1d], is_train: bool
X: Tuple[Ints1d, Floats2d, Ints1d],
is_train: bool
): ):
tok2vec, (sent_ids, head_ids) = X tok2vec, (sent_ids, head_ids) = X
# Normally we shoudl use the input is_train, but for these two it's not relevant # Normally we shoudl use the input is_train, but for these two it's not relevant
@ -160,10 +144,9 @@ def convert_span_predictor_inputs(
# TODO This probably belongs in the component, not the model. # TODO This probably belongs in the component, not the model.
def predict_span_clusters(span_predictor: Model, def predict_span_clusters(
sent_ids: Ints1d, span_predictor: Model, sent_ids: Ints1d, words: Floats2d, clusters: List[Ints1d]
words: Floats2d, ):
clusters: List[Ints1d]):
""" """
Predicts span clusters based on the word clusters. Predicts span clusters based on the word clusters.
@ -187,20 +170,15 @@ def predict_span_clusters(span_predictor: Model,
ends = (scores[:, :, 1].argmax(axis=1) + 1).tolist() ends = (scores[:, :, 1].argmax(axis=1) + 1).tolist()
head2span = { head2span = {
head: (start, end) head: (start, end) for head, start, end in zip(heads_ids.tolist(), starts, ends)
for head, start, end in zip(heads_ids.tolist(), starts, ends)
} }
return [[head2span[head] for head in cluster] return [[head2span[head] for head in cluster] for cluster in clusters]
for cluster in clusters]
# TODO add docstring for this, maybe move to utils. # TODO add docstring for this, maybe move to utils.
# This might belong in the component. # This might belong in the component.
def _clusterize( def _clusterize(model, scores: Floats2d, top_indices: Ints2d):
model,
scores: Floats2d,
top_indices: Ints2d
):
xp = model.ops.xp xp = model.ops.xp
antecedents = scores.argmax(axis=1) - 1 antecedents = scores.argmax(axis=1) - 1
not_dummy = antecedents >= 0 not_dummy = antecedents >= 0
@ -229,15 +207,14 @@ def _clusterize(
def build_get_head_metadata(prefix): def build_get_head_metadata(prefix):
# TODO this name is awful, fix it # TODO this name is awful, fix it
model = Model("HeadDataProvider", model = Model(
attrs={'prefix': prefix}, "HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward
forward=head_data_forward) )
return model return model
def head_data_forward(model, docs, is_train): def head_data_forward(model, docs, is_train):
"""A layer to generate the extra data needed for the span predictor. """A layer to generate the extra data needed for the span predictor."""
"""
sent_ids = [] sent_ids = []
head_ids = [] head_ids = []
prefix = model.attrs["prefix"] prefix = model.attrs["prefix"]
@ -271,6 +248,7 @@ class CorefScorer(torch.nn.Module):
a_scorer (AnaphoricityScorer) a_scorer (AnaphoricityScorer)
sp (SpanPredictor) sp (SpanPredictor)
""" """
def __init__( def __init__(
self, self,
dim: int, # tok2vec size dim: int, # tok2vec size
@ -279,7 +257,7 @@ class CorefScorer(torch.nn.Module):
n_layers: int, n_layers: int,
dropout_rate: float, dropout_rate: float,
roughk: int, roughk: int,
batch_size: int batch_size: int,
): ):
super().__init__() super().__init__()
""" """
@ -290,14 +268,11 @@ class CorefScorer(torch.nn.Module):
(useful for warm start) (useful for warm start)
""" """
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate) self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate)
#TODO clean this up # TODO clean this up
bert_emb = dim bert_emb = dim
pair_emb = bert_emb * 3 + self.pw.shape pair_emb = bert_emb * 3 + self.pw.shape
self.a_scorer = AnaphoricityScorer( self.a_scorer = AnaphoricityScorer(
pair_emb, pair_emb, hidden_size, n_layers, dropout_rate
hidden_size,
n_layers,
dropout_rate
) )
self.lstm = torch.nn.LSTM( self.lstm = torch.nn.LSTM(
input_size=bert_emb, input_size=bert_emb,
@ -305,17 +280,10 @@ class CorefScorer(torch.nn.Module):
batch_first=True, batch_first=True,
) )
self.dropout = torch.nn.Dropout(dropout_rate) self.dropout = torch.nn.Dropout(dropout_rate)
self.rough_scorer = RoughScorer( self.rough_scorer = RoughScorer(bert_emb, dropout_rate, roughk)
bert_emb,
dropout_rate,
roughk
)
self.batch_size = batch_size self.batch_size = batch_size
def forward( def forward(self, word_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
self,
word_features: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
This is a massive method, but it made sense to me to not split it into This is a massive method, but it made sense to me to not split it into
several ones to let one see the data flow. several ones to let one see the data flow.
@ -342,16 +310,18 @@ class CorefScorer(torch.nn.Module):
a_scores_lst: List[torch.Tensor] = [] a_scores_lst: List[torch.Tensor] = []
for i in range(0, len(words), batch_size): for i in range(0, len(words), batch_size):
pw_batch = pw[i:i + batch_size] pw_batch = pw[i : i + batch_size]
words_batch = words[i:i + batch_size] words_batch = words[i : i + batch_size]
top_indices_batch = top_indices[i:i + batch_size] top_indices_batch = top_indices[i : i + batch_size]
top_rough_scores_batch = top_rough_scores[i:i + batch_size] top_rough_scores_batch = top_rough_scores[i : i + batch_size]
# a_scores_batch [batch_size, n_ants] # a_scores_batch [batch_size, n_ants]
a_scores_batch = self.a_scorer( a_scores_batch = self.a_scorer(
all_mentions=words, mentions_batch=words_batch, all_mentions=words,
pw_batch=pw_batch, top_indices_batch=top_indices_batch, mentions_batch=words_batch,
top_rough_scores_batch=top_rough_scores_batch pw_batch=pw_batch,
top_indices_batch=top_indices_batch,
top_rough_scores_batch=top_rough_scores_batch,
) )
a_scores_lst.append(a_scores_batch) a_scores_lst.append(a_scores_batch)
@ -360,33 +330,35 @@ class CorefScorer(torch.nn.Module):
class AnaphoricityScorer(torch.nn.Module): class AnaphoricityScorer(torch.nn.Module):
""" Calculates anaphoricity scores by passing the inputs into a FFNN """ """Calculates anaphoricity scores by passing the inputs into a FFNN"""
def __init__(self,
in_features: int, def __init__(self, in_features: int, hidden_size, n_hidden_layers, dropout_rate):
hidden_size,
n_hidden_layers,
dropout_rate):
super().__init__() super().__init__()
hidden_size = hidden_size hidden_size = hidden_size
if not n_hidden_layers: if not n_hidden_layers:
hidden_size = in_features hidden_size = in_features
layers = [] layers = []
for i in range(n_hidden_layers): for i in range(n_hidden_layers):
layers.extend([torch.nn.Linear(hidden_size if i else in_features, layers.extend(
hidden_size), [
torch.nn.Linear(hidden_size if i else in_features, hidden_size),
torch.nn.LeakyReLU(), torch.nn.LeakyReLU(),
torch.nn.Dropout(dropout_rate)]) torch.nn.Dropout(dropout_rate),
]
)
self.hidden = torch.nn.Sequential(*layers) self.hidden = torch.nn.Sequential(*layers)
self.out = torch.nn.Linear(hidden_size, out_features=1) self.out = torch.nn.Linear(hidden_size, out_features=1)
def forward(self, *, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch def forward(
self,
*, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
all_mentions: torch.Tensor, all_mentions: torch.Tensor,
mentions_batch: torch.Tensor, mentions_batch: torch.Tensor,
pw_batch: torch.Tensor, pw_batch: torch.Tensor,
top_indices_batch: torch.Tensor, top_indices_batch: torch.Tensor,
top_rough_scores_batch: torch.Tensor, top_rough_scores_batch: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
""" Builds a pairwise matrix, scores the pairs and returns the scores. """Builds a pairwise matrix, scores the pairs and returns the scores.
Args: Args:
all_mentions (torch.Tensor): [n_mentions, mention_emb] all_mentions (torch.Tensor): [n_mentions, mention_emb]
@ -401,7 +373,8 @@ class AnaphoricityScorer(torch.nn.Module):
""" """
# [batch_size, n_ants, pair_emb] # [batch_size, n_ants, pair_emb]
pair_matrix = self._get_pair_matrix( pair_matrix = self._get_pair_matrix(
all_mentions, mentions_batch, pw_batch, top_indices_batch) all_mentions, mentions_batch, pw_batch, top_indices_batch
)
# [batch_size, n_ants] # [batch_size, n_ants]
scores = top_rough_scores_batch + self._ffnn(pair_matrix) scores = top_rough_scores_batch + self._ffnn(pair_matrix)
@ -423,7 +396,8 @@ class AnaphoricityScorer(torch.nn.Module):
return x.squeeze(2) return x.squeeze(2)
@staticmethod @staticmethod
def _get_pair_matrix(all_mentions: torch.Tensor, def _get_pair_matrix(
all_mentions: torch.Tensor,
mentions_batch: torch.Tensor, mentions_batch: torch.Tensor,
pw_batch: torch.Tensor, pw_batch: torch.Tensor,
top_indices_batch: torch.Tensor, top_indices_batch: torch.Tensor,
@ -464,12 +438,8 @@ class RoughScorer(torch.nn.Module):
only top scoring candidates are considered on later steps to reduce only top scoring candidates are considered on later steps to reduce
computational complexity. computational complexity.
""" """
def __init__(
self, def __init__(self, features: int, dropout_rate: float, rough_k: float):
features: int,
dropout_rate: float,
rough_k: float
):
super().__init__() super().__init__()
self.dropout = torch.nn.Dropout(dropout_rate) self.dropout = torch.nn.Dropout(dropout_rate)
self.bilinear = torch.nn.Linear(features, features) self.bilinear = torch.nn.Linear(features, features)
@ -478,7 +448,7 @@ class RoughScorer(torch.nn.Module):
def forward( def forward(
self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
mentions: torch.Tensor mentions: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Returns rough anaphoricity scores for candidates, which consist of Returns rough anaphoricity scores for candidates, which consist of
@ -493,9 +463,7 @@ class RoughScorer(torch.nn.Module):
return self._prune(rough_scores) return self._prune(rough_scores)
def _prune(self, def _prune(self, rough_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
rough_scores: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Selects top-k rough antecedent scores for each mention. Selects top-k rough antecedent scores for each mention.
@ -507,9 +475,9 @@ class RoughScorer(torch.nn.Module):
FloatTensor of shape [n_mentions, k], top rough scores FloatTensor of shape [n_mentions, k], top rough scores
LongTensor of shape [n_mentions, k], top indices LongTensor of shape [n_mentions, k], top indices
""" """
top_scores, indices = torch.topk(rough_scores, top_scores, indices = torch.topk(
k=min(self.k, len(rough_scores)), rough_scores, k=min(self.k, len(rough_scores)), dim=1, sorted=False
dim=1, sorted=False) )
return top_scores, indices return top_scores, indices
@ -523,7 +491,7 @@ class SpanPredictor(torch.nn.Module):
torch.nn.Linear(input_size * 2 + dist_emb_size, hidden_size), torch.nn.Linear(input_size * 2 + dist_emb_size, hidden_size),
torch.nn.ReLU(), torch.nn.ReLU(),
torch.nn.Dropout(0.3), torch.nn.Dropout(0.3),
#TODO seems weird the 256 isn't a parameter??? # TODO seems weird the 256 isn't a parameter???
torch.nn.Linear(hidden_size, 256), torch.nn.Linear(hidden_size, 256),
torch.nn.ReLU(), torch.nn.ReLU(),
torch.nn.Dropout(0.3), torch.nn.Dropout(0.3),
@ -531,15 +499,16 @@ class SpanPredictor(torch.nn.Module):
torch.nn.Linear(256, dist_emb_size), torch.nn.Linear(256, dist_emb_size),
) )
self.conv = torch.nn.Sequential( self.conv = torch.nn.Sequential(
torch.nn.Conv1d(64, 4, 3, 1, 1), torch.nn.Conv1d(64, 4, 3, 1, 1), torch.nn.Conv1d(4, 2, 3, 1, 1)
torch.nn.Conv1d(4, 2, 3, 1, 1)
) )
self.emb = torch.nn.Embedding(128, dist_emb_size) # [-63, 63] + too_far self.emb = torch.nn.Embedding(128, dist_emb_size) # [-63, 63] + too_far
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch def forward(
self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
sent_id, sent_id,
words: torch.Tensor, words: torch.Tensor,
heads_ids: torch.Tensor) -> torch.Tensor: heads_ids: torch.Tensor,
) -> torch.Tensor:
""" """
Calculates span start/end scores of words for each span head in Calculates span start/end scores of words for each span head in
heads_ids heads_ids
@ -557,27 +526,32 @@ class SpanPredictor(torch.nn.Module):
if heads_ids.nelement() == 0: if heads_ids.nelement() == 0:
return torch.empty(size=(0,)) return torch.empty(size=(0,))
# Obtain distance embedding indices, [n_heads, n_words] # Obtain distance embedding indices, [n_heads, n_words]
relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0]).unsqueeze(0)) relative_positions = heads_ids.unsqueeze(1) - torch.arange(
words.shape[0]
).unsqueeze(0)
# make all valid distances positive # make all valid distances positive
emb_ids = relative_positions + 63 emb_ids = relative_positions + 63
# "too_far" # "too_far"
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127 emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127
# Obtain "same sentence" boolean mask, [n_heads, n_words] # Obtain "same sentence" boolean mask, [n_heads, n_words]
heads_ids = heads_ids.long() heads_ids = heads_ids.long()
same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0)) same_sent = sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0)
# To save memory, only pass candidates from one sentence for each head # To save memory, only pass candidates from one sentence for each head
# pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb # pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb
# for each candidate among the words in the same sentence as span_head # for each candidate among the words in the same sentence as span_head
# [n_heads, input_size * 2 + distance_emb_size] # [n_heads, input_size * 2 + distance_emb_size]
rows, cols = same_sent.nonzero(as_tuple=True) rows, cols = same_sent.nonzero(as_tuple=True)
pair_matrix = torch.cat(( pair_matrix = torch.cat(
(
words[heads_ids[rows]], words[heads_ids[rows]],
words[cols], words[cols],
self.emb(emb_ids[rows, cols]), self.emb(emb_ids[rows, cols]),
), dim=1) ),
dim=1,
)
lengths = same_sent.sum(dim=1) lengths = same_sent.sum(dim=1)
padding_mask = torch.arange(0, lengths.max().item()).unsqueeze(0) padding_mask = torch.arange(0, lengths.max().item()).unsqueeze(0)
padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len] padding_mask = padding_mask < lengths.unsqueeze(1) # [n_heads, max_sent_len]
# [n_heads, max_sent_len, input_size * 2 + distance_emb_size] # [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
# This is necessary to allow the convolution layer to look at several # This is necessary to allow the convolution layer to look at several
# word scores # word scores
@ -585,9 +559,11 @@ class SpanPredictor(torch.nn.Module):
padded_pairs[padding_mask] = pair_matrix padded_pairs[padding_mask] = pair_matrix
res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output] res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output]
res = self.conv(res.permute(0, 2, 1)).permute(0, 2, 1) # [n_heads, n_candidates, 2] res = self.conv(res.permute(0, 2, 1)).permute(
0, 2, 1
) # [n_heads, n_candidates, 2]
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf')) scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float("-inf"))
scores[rows, cols] = res[padding_mask] scores[rows, cols] = res[padding_mask]
# Make sure that start <= head <= end during inference # Make sure that start <= head <= end during inference
if not self.training: if not self.training:
@ -597,8 +573,8 @@ class SpanPredictor(torch.nn.Module):
return scores + valid_positions return scores + valid_positions
return scores return scores
class DistancePairwiseEncoder(torch.nn.Module):
class DistancePairwiseEncoder(torch.nn.Module):
def __init__(self, embedding_size, dropout_rate): def __init__(self, embedding_size, dropout_rate):
super().__init__() super().__init__()
emb_size = embedding_size emb_size = embedding_size
@ -606,12 +582,12 @@ class DistancePairwiseEncoder(torch.nn.Module):
self.dropout = torch.nn.Dropout(dropout_rate) self.dropout = torch.nn.Dropout(dropout_rate)
self.shape = emb_size self.shape = emb_size
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch def forward(
top_indices: torch.Tensor self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
top_indices: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
word_ids = torch.arange(0, top_indices.size(0)) word_ids = torch.arange(0, top_indices.size(0))
distance = (word_ids.unsqueeze(1) - word_ids[top_indices] distance = (word_ids.unsqueeze(1) - word_ids[top_indices]).clamp_min_(min=1)
).clamp_min_(min=1)
log_distance = distance.to(torch.float).log2().floor_() log_distance = distance.to(torch.float).log2().floor_()
log_distance = log_distance.clamp_max_(max=6).to(torch.long) log_distance = log_distance.clamp_max_(max=6).to(torch.long)
distance = torch.where(distance < 5, distance - 1, log_distance + 2) distance = torch.where(distance < 5, distance - 1, log_distance + 2)