diff --git a/spacy/ml/models/__init__.py b/spacy/ml/models/__init__.py index 85497559c..608f36393 100644 --- a/spacy/ml/models/__init__.py +++ b/spacy/ml/models/__init__.py @@ -1,4 +1,4 @@ -from .coref import * +from .coref import * #noqa from .entity_linker import * # noqa from .multi_task import * # noqa from .parser import * # noqa diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 511e44476..2e291aa2b 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -448,3 +448,434 @@ def pairwise_product(bilinear, dropout, vecs: Floats2d, is_train): return dX return pw_prod, backward + + +# XXX here down is wl-coref +from typing import List, Tuple + +import torch + +# TODO rename this to coref_util +import .coref_util_wl as utils + +# TODO rename to plain coref +@registry.architectures("spacy.WLCoref.v1") +def build_wl_coref_model( + #TODO add other hyperparams + tok2vec: Model[List[Doc], List[Floats2d]], + ): + + # TODO change to use passed in values for config + config = utils._load_config("/dev/null") + with Model.define_operators({">>": chain}): + + coref_scorer, span_predictor = configure_pytorch_modules(config) + # TODO chain tok2vec with these models + coref_scorer = PyTorchWrapper( + CorefScorer( + config.device, + config.embedding_size, + config.hidden_size, + config.n_hidden_layers, + config.dropout_rate, + config.rough_k, + config.a_scoring_batch_size + ), + convert_inputs=convert_coref_scorer_inputs, + convert_outputs=convert_coref_scorer_outputs + ) + span_predictor = PyTorchWrapper( + SpanPredictor( + 1024, + config.sp_embedding_size, + config.device + ), + convert_inputs=convert_span_predictor_inputs + ) + # TODO combine models so output is uniform (just one forward pass) + # It may be reasonable to have an option to disable span prediction, + # and just return words as spans. + return coref_scorer + +def convert_coref_scorer_inputs( + model: Model, + X: Floats2d, + is_train: bool +): + word_features = xp2torch(X, requires_grad=False) + return ArgsKwargs(args=(word_features, ), kwargs={}), lambda dX: [] + + +def convert_coref_scorer_outputs( + model: Model, + inputs_outputs, + is_train: bool +): + _, outputs = inputs_outputs + scores, indices = outputs + + def convert_for_torch_backward(dY: Floats2d) -> ArgsKwargs: + dY_t = xp2torch(dY) + return ArgsKwargs( + args=([scores],), + kwargs={"grad_tensors": [dY_t]}, + ) + + scores_xp = torch2xp(scores) + indices_xp = torch2xp(indices) + return (scores_xp, indices_xp), convert_for_torch_backward + +# TODO This probably belongs in the component, not the model. +def predict_span_clusters(span_predictor: Model, + sent_ids: Ints1d, + words: Floats2d, + clusters: List[Ints1d]): + """ + Predicts span clusters based on the word clusters. + + Args: + doc (Doc): the document data + words (torch.Tensor): [n_words, emb_size] matrix containing + embeddings for each of the words in the text + clusters (List[List[int]]): a list of clusters where each cluster + is a list of word indices + + Returns: + List[List[Span]]: span clusters + """ + if not clusters: + return [] + + xp = span_predictor.ops.xp + heads_ids = xp.asarray(sorted(i for cluster in clusters for i in cluster)) + scores = span_predictor.predict((sent_ids, words, heads_ids)) + starts = scores[:, :, 0].argmax(axis=1).tolist() + ends = (scores[:, :, 1].argmax(axis=1) + 1).tolist() + + head2span = { + head: (start, end) + for head, start, end in zip(heads_ids.tolist(), starts, ends) + } + + return [[head2span[head] for head in cluster] + for cluster in clusters] + +# TODO add docstring for this, maybe move to utils. +# This might belong in the component. +def _clusterize( + model, + scores: Floats2d, + top_indices: Ints2d +): + xp = model.ops.xp + antecedents = scores.argmax(axis=1) - 1 + not_dummy = antecedents >= 0 + coref_span_heads = xp.arange(0, len(scores))[not_dummy] + antecedents = top_indices[coref_span_heads, antecedents[not_dummy]] + n_words = scores.shape[0] + nodes = [GraphNode(i) for i in range(n_words)] + for i, j in zip(coref_span_heads.tolist(), antecedents.tolist()): + nodes[i].link(nodes[j]) + assert nodes[i] is not nodes[j] + + clusters = [] + for node in nodes: + if len(node.links) > 0 and not node.visited: + cluster = [] + stack = [node] + while stack: + current_node = stack.pop() + current_node.visited = True + cluster.append(current_node.id) + stack.extend(link for link in current_node.links if not link.visited) + assert len(cluster) > 1 + clusters.append(sorted(cluster)) + return sorted(clusters) + + +class CorefScorer(torch.nn.Module): + """Combines all coref modules together to find coreferent spans. + + Attributes: + config (coref.config.Config): the model's configuration, + see config.toml for the details + epochs_trained (int): number of epochs the model has been trained for + + Submodules (in the order of their usage in the pipeline): + rough_scorer (RoughScorer) + pw (PairwiseEncoder) + a_scorer (AnaphoricityScorer) + sp (SpanPredictor) + """ + def __init__( + self, + device: str, + dist_emb_size: int, + hidden_size: int, + n_layers: int, + dropout_rate: float, + roughk: int, + batch_size: int + ): + super().__init__() + """ + A newly created model is set to evaluation mode. + + Args: + config_path (str): the path to the toml file with the configuration + section (str): the selected section of the config file + epochs_trained (int): the number of epochs finished + (useful for warm start) + """ + # device, dist_emb_size, hidden_size, n_layers, dropout_rate + self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate).to(device) + bert_emb = 1024 + pair_emb = bert_emb * 3 + self.pw.shape + self.a_scorer = AnaphoricityScorer( + pair_emb, + hidden_size, + n_layers, + dropout_rate + ).to(device) + self.lstm = torch.nn.LSTM( + input_size=bert_emb, + hidden_size=bert_emb, + batch_first=True, + ) + self.dropout = torch.nn.Dropout(dropout_rate) + self.rough_scorer = RoughScorer( + bert_emb, + dropout_rate, + roughk + ).to(device) + self.batch_size = batch_size + + def forward( + self, + word_features: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + This is a massive method, but it made sense to me to not split it into + several ones to let one see the data flow. + + Args: + word_features: torch.Tensor containing word encodings + Returns: + coreference scores and top indices + """ + # words [n_words, span_emb] + # cluster_ids [n_words] + word_features = torch.unsqueeze(word_features, dim=0) + words, _ = self.lstm(word_features) + words = words.squeeze() + words = self.dropout(words) + # Obtain bilinear scores and leave only top-k antecedents for each word + # top_rough_scores [n_words, n_ants] + # top_indices [n_words, n_ants] + top_rough_scores, top_indices = self.rough_scorer(words) + # Get pairwise features [n_words, n_ants, n_pw_features] + pw = self.pw(top_indices) + batch_size = self.batch_size + a_scores_lst: List[torch.Tensor] = [] + + for i in range(0, len(words), batch_size): + pw_batch = pw[i:i + batch_size] + words_batch = words[i:i + batch_size] + top_indices_batch = top_indices[i:i + batch_size] + top_rough_scores_batch = top_rough_scores[i:i + batch_size] + + # a_scores_batch [batch_size, n_ants] + a_scores_batch = self.a_scorer( + all_mentions=words, mentions_batch=words_batch, + pw_batch=pw_batch, top_indices_batch=top_indices_batch, + top_rough_scores_batch=top_rough_scores_batch + ) + a_scores_lst.append(a_scores_batch) + + coref_scores = torch.cat(a_scores_lst, dim=0) + return coref_scores, top_indices + + +class AnaphoricityScorer(torch.nn.Module): + """ Calculates anaphoricity scores by passing the inputs into a FFNN """ + def __init__(self, + in_features: int, + hidden_size, + n_hidden_layers, + dropout_rate): + super().__init__() + hidden_size = hidden_size + if not n_hidden_layers: + hidden_size = in_features + layers = [] + for i in range(n_hidden_layers): + layers.extend([torch.nn.Linear(hidden_size if i else in_features, + hidden_size), + torch.nn.LeakyReLU(), + torch.nn.Dropout(dropout_rate)]) + self.hidden = torch.nn.Sequential(*layers) + self.out = torch.nn.Linear(hidden_size, out_features=1) + + def forward(self, *, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch + all_mentions: torch.Tensor, + mentions_batch: torch.Tensor, + pw_batch: torch.Tensor, + top_indices_batch: torch.Tensor, + top_rough_scores_batch: torch.Tensor, + ) -> torch.Tensor: + """ Builds a pairwise matrix, scores the pairs and returns the scores. + + Args: + all_mentions (torch.Tensor): [n_mentions, mention_emb] + mentions_batch (torch.Tensor): [batch_size, mention_emb] + pw_batch (torch.Tensor): [batch_size, n_ants, pw_emb] + top_indices_batch (torch.Tensor): [batch_size, n_ants] + top_rough_scores_batch (torch.Tensor): [batch_size, n_ants] + + Returns: + torch.Tensor [batch_size, n_ants + 1] + anaphoricity scores for the pairs + a dummy column + """ + # [batch_size, n_ants, pair_emb] + pair_matrix = self._get_pair_matrix( + all_mentions, mentions_batch, pw_batch, top_indices_batch) + + # [batch_size, n_ants] + scores = top_rough_scores_batch + self._ffnn(pair_matrix) + scores = utils.add_dummy(scores, eps=True) + + return scores + + def _ffnn(self, x: torch.Tensor) -> torch.Tensor: + """ + Calculates anaphoricity scores. + + Args: + x: tensor of shape [batch_size, n_ants, n_features] + + Returns: + tensor of shape [batch_size, n_ants] + """ + x = self.out(self.hidden(x)) + return x.squeeze(2) + + @staticmethod + def _get_pair_matrix(all_mentions: torch.Tensor, + mentions_batch: torch.Tensor, + pw_batch: torch.Tensor, + top_indices_batch: torch.Tensor, + ) -> torch.Tensor: + """ + Builds the matrix used as input for AnaphoricityScorer. + + Args: + all_mentions (torch.Tensor): [n_mentions, mention_emb], + all the valid mentions of the document, + can be on a different device + mentions_batch (torch.Tensor): [batch_size, mention_emb], + the mentions of the current batch, + is expected to be on the current device + pw_batch (torch.Tensor): [batch_size, n_ants, pw_emb], + pairwise features of the current batch, + is expected to be on the current device + top_indices_batch (torch.Tensor): [batch_size, n_ants], + indices of antecedents of each mention + + Returns: + torch.Tensor: [batch_size, n_ants, pair_emb] + """ + emb_size = mentions_batch.shape[1] + n_ants = pw_batch.shape[1] + + a_mentions = mentions_batch.unsqueeze(1).expand(-1, n_ants, emb_size) + b_mentions = all_mentions[top_indices_batch] + similarity = a_mentions * b_mentions + + out = torch.cat((a_mentions, b_mentions, similarity, pw_batch), dim=2) + return out + + + +class RoughScorer(torch.nn.Module): + """ + Is needed to give a roughly estimate of the anaphoricity of two candidates, + only top scoring candidates are considered on later steps to reduce + computational complexity. + """ + def __init__( + self, + features: int, + dropout_rate: float, + rough_k: float + ): + super().__init__() + self.dropout = torch.nn.Dropout(dropout_rate) + self.bilinear = torch.nn.Linear(features, features) + + self.k = rough_k + + def forward( + self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch + mentions: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Returns rough anaphoricity scores for candidates, which consist of + the bilinear output of the current model summed with mention scores. + """ + # [n_mentions, n_mentions] + pair_mask = torch.arange(mentions.shape[0]) + pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0) + pair_mask = torch.log((pair_mask > 0).to(torch.float)) + pair_mask = pair_mask.to(mentions.device) + bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T) + rough_scores = pair_mask + bilinear_scores + + return self._prune(rough_scores) + + def _prune(self, + rough_scores: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Selects top-k rough antecedent scores for each mention. + + Args: + rough_scores: tensor of shape [n_mentions, n_mentions], containing + rough antecedent scores of each mention-antecedent pair. + + Returns: + FloatTensor of shape [n_mentions, k], top rough scores + LongTensor of shape [n_mentions, k], top indices + """ + top_scores, indices = torch.topk(rough_scores, + k=min(self.k, len(rough_scores)), + dim=1, sorted=False) + return top_scores, indices + + +class DistancePairwiseEncoder(torch.nn.Module): + + def __init__(self, embedding_size, dropout_rate): + super().__init__() + emb_size = embedding_size + self.distance_emb = torch.nn.Embedding(9, emb_size) + self.dropout = torch.nn.Dropout(dropout_rate) + self.shape = emb_size + + @property + def device(self) -> torch.device: + """ A workaround to get current device (which is assumed to be the + device of the first parameter of one of the submodules) """ + return next(self.distance_emb.parameters()).device + + + def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch + top_indices: torch.Tensor + ) -> torch.Tensor: + word_ids = torch.arange(0, top_indices.size(0), device=self.device) + distance = (word_ids.unsqueeze(1) - word_ids[top_indices] + ).clamp_min_(min=1) + log_distance = distance.to(torch.float).log2().floor_() + log_distance = log_distance.clamp_max_(max=6).to(torch.long) + distance = torch.where(distance < 5, distance - 1, log_distance + 2) + distance = self.distance_emb(distance) + return self.dropout(distance) diff --git a/spacy/tests/test_models.py b/spacy/tests/test_models.py index 2306cabb7..ce074fe42 100644 --- a/spacy/tests/test_models.py +++ b/spacy/tests/test_models.py @@ -6,7 +6,7 @@ from numpy.testing import assert_array_equal, assert_array_almost_equal import numpy from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder from spacy.ml.models import build_bow_text_classifier, build_simple_cnn_text_classifier -from spacy.ml.models import build_spancat_model +from spacy.ml.models import build_spancat_model, build_wl_coref_model from spacy.ml.staticvectors import StaticVectors from spacy.ml.extract_spans import extract_spans, _get_span_indices from spacy.lang.en import English @@ -269,3 +269,8 @@ def test_spancat_model_forward_backward(nO=5): Y, backprop = model((docs, spans), is_train=True) assert Y.shape == (spans.dataXd.shape[0], nO) backprop(Y) + +#TODO expand this +def test_coref_model_init(): + tok2vec = build_Tok2Vec_model(**get_tok2vec_kwargs()) + model = build_wl_coref_model(tok2vec)