from typing import List, Tuple import torch from thinc.api import Model, chain, tuplify from thinc.api import PyTorchWrapper, ArgsKwargs from thinc.types import Floats2d, Ints1d, Ints2d from thinc.util import xp2torch, torch2xp from ...tokens import Doc from ...util import registry from .coref_util import get_sentence_ids @registry.architectures("spacy.SpanPredictor.v1") def build_span_predictor( tok2vec: Model[List[Doc], List[Floats2d]], hidden_size: int = 1024, dist_emb_size: int = 64, ): # TODO fix this try: dim = tok2vec.get_dim("nO") except ValueError: # happens with transformer listener dim = 768 with Model.define_operators({">>": chain, "&": tuplify}): span_predictor = PyTorchWrapper( SpanPredictor(dim, hidden_size, dist_emb_size), convert_inputs=convert_span_predictor_inputs, ) # TODO use proper parameter for prefix head_info = build_get_head_metadata("coref_head_clusters") model = (tok2vec & head_info) >> span_predictor return model def convert_span_predictor_inputs( model: Model, X: Tuple[Ints1d, Floats2d, Ints1d], is_train: bool ): tok2vec, (sent_ids, head_ids) = X # Normally we shoudl use the input is_train, but for these two it's not relevant def backprop(args: ArgsKwargs) -> List[Floats2d]: # convert to xp and wrap in list gradients = torch2xp(args.args[1]) return [[gradients], None] word_features = xp2torch(tok2vec[0], requires_grad=is_train) sent_ids = xp2torch(sent_ids[0], requires_grad=False) if not head_ids[0].size: head_ids = torch.empty(size=(0,)) else: head_ids = xp2torch(head_ids[0], requires_grad=False) argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={}) # TODO actually support backprop return argskwargs, backprop # TODO This probably belongs in the component, not the model. def predict_span_clusters( span_predictor: Model, sent_ids: Ints1d, words: Floats2d, clusters: List[Ints1d] ): """ Predicts span clusters based on the word clusters. Args: doc (Doc): the document data words (torch.Tensor): [n_words, emb_size] matrix containing embeddings for each of the words in the text clusters (List[List[int]]): a list of clusters where each cluster is a list of word indices Returns: List[List[Span]]: span clusters """ if not clusters: return [] xp = span_predictor.ops.xp heads_ids = xp.asarray(sorted(i for cluster in clusters for i in cluster)) scores = span_predictor.predict((sent_ids, words, heads_ids)) starts = scores[:, :, 0].argmax(axis=1).tolist() ends = (scores[:, :, 1].argmax(axis=1) + 1).tolist() head2span = { head: (start, end) for head, start, end in zip(heads_ids.tolist(), starts, ends) } return [[head2span[head] for head in cluster] for cluster in clusters] # TODO this should maybe have a different name from the component class SpanPredictor(torch.nn.Module): def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int): super().__init__() # input size = single token size # 64 = probably distance emb size # TODO check that dist_emb_size use is correct self.ffnn = torch.nn.Sequential( torch.nn.Linear(input_size * 2 + dist_emb_size, hidden_size), torch.nn.ReLU(), torch.nn.Dropout(0.3), # TODO seems weird the 256 isn't a parameter??? torch.nn.Linear(hidden_size, 256), torch.nn.ReLU(), torch.nn.Dropout(0.3), # this use of dist_emb_size looks wrong but it was 64...? torch.nn.Linear(256, dist_emb_size), ) self.conv = torch.nn.Sequential( torch.nn.Conv1d(64, 4, 3, 1, 1), torch.nn.Conv1d(4, 2, 3, 1, 1) ) self.emb = torch.nn.Embedding(128, dist_emb_size) # [-63, 63] + too_far def forward( self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch sent_id, words: torch.Tensor, heads_ids: torch.Tensor, ) -> torch.Tensor: """ Calculates span start/end scores of words for each span head in heads_ids Args: doc (Doc): the document data words (torch.Tensor): contextual embeddings for each word in the document, [n_words, emb_size] heads_ids (torch.Tensor): word indices of span heads Returns: torch.Tensor: span start/end scores, [n_heads, n_words, 2] """ # If we don't receive heads, return empty if heads_ids.nelement() == 0: return torch.empty(size=(0,)) # Obtain distance embedding indices, [n_heads, n_words] relative_positions = heads_ids.unsqueeze(1) - torch.arange( words.shape[0] ).unsqueeze(0) # make all valid distances positive emb_ids = relative_positions + 63 # "too_far" emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127 # Obtain "same sentence" boolean mask, [n_heads, n_words] heads_ids = heads_ids.long() same_sent = sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0) # To save memory, only pass candidates from one sentence for each head # pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb # for each candidate among the words in the same sentence as span_head # [n_heads, input_size * 2 + distance_emb_size] rows, cols = same_sent.nonzero(as_tuple=True) pair_matrix = torch.cat( ( words[heads_ids[rows]], words[cols], self.emb(emb_ids[rows, cols]), ), dim=1, ) lengths = same_sent.sum(dim=1) padding_mask = torch.arange(0, lengths.max().item()).unsqueeze(0) padding_mask = padding_mask < lengths.unsqueeze(1) # [n_heads, max_sent_len] # [n_heads, max_sent_len, input_size * 2 + distance_emb_size] # This is necessary to allow the convolution layer to look at several # word scores padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1]) padded_pairs[padding_mask] = pair_matrix res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output] res = self.conv(res.permute(0, 2, 1)).permute( 0, 2, 1 ) # [n_heads, n_candidates, 2] scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float("-inf")) scores[rows, cols] = res[padding_mask] # Make sure that start <= head <= end during inference if not self.training: valid_starts = torch.log((relative_positions >= 0).to(torch.float)) valid_ends = torch.log((relative_positions <= 0).to(torch.float)) valid_positions = torch.stack((valid_starts, valid_ends), dim=2) return scores + valid_positions return scores def build_get_head_metadata(prefix): # TODO this name is awful, fix it model = Model( "HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward ) return model def head_data_forward(model, docs, is_train): """A layer to generate the extra data needed for the span predictor.""" sent_ids = [] head_ids = [] prefix = model.attrs["prefix"] for doc in docs: sids = model.ops.asarray2i(get_sentence_ids(doc)) sent_ids.append(sids) heads = [] for key, sg in doc.spans.items(): if not key.startswith(prefix): continue for span in sg: # TODO warn if spans are more than one token heads.append(span[0].i) heads = model.ops.asarray2i(heads) head_ids.append(heads) # each of these is a list with one entry per doc # backprop is just a placeholder # TODO it would probably be better to have a list of tuples than two lists of arrays return (sent_ids, head_ids), lambda x: []