mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	clean up unused imports + black formatting
This commit is contained in:
		
							parent
							
								
									683f470852
								
							
						
					
					
						commit
						6b51258a58
					
				|  | @ -1,27 +1,22 @@ | |||
| from dataclasses import dataclass | ||||
| import warnings | ||||
| 
 | ||||
| from thinc.api import Model, Linear, Relu, Dropout | ||||
| from thinc.api import chain, noop, Embed, add, tuplify, concatenate | ||||
| from thinc.api import reduce_first, reduce_last, reduce_mean | ||||
| from thinc.api import PyTorchWrapper, ArgsKwargs | ||||
| from thinc.types import Floats2d, Floats1d, Ints1d, Ints2d, Ragged | ||||
| from typing import List, Callable, Tuple, Any | ||||
| from ...tokens import Doc | ||||
| from ...util import registry | ||||
| from ..extract_spans import extract_spans | ||||
| 
 | ||||
| from typing import List, Tuple | ||||
| import torch | ||||
| 
 | ||||
| from thinc.api import Model, chain, tuplify | ||||
| from thinc.api import PyTorchWrapper, ArgsKwargs | ||||
| from thinc.types import Floats2d, Ints1d, Ints2d | ||||
| from thinc.util import xp2torch, torch2xp | ||||
| 
 | ||||
| from ...tokens import Doc | ||||
| from ...util import registry | ||||
| from .coref_util import add_dummy, get_sentence_ids | ||||
| 
 | ||||
| 
 | ||||
| @registry.architectures("spacy.Coref.v1") | ||||
| def build_wl_coref_model( | ||||
|     tok2vec: Model[List[Doc], List[Floats2d]], | ||||
|     embedding_size: int = 20, | ||||
|     hidden_size: int = 1024, | ||||
|     n_hidden_layers: int = 1, # TODO rename to "depth"? | ||||
|     n_hidden_layers: int = 1,  # TODO rename to "depth"? | ||||
|     dropout: float = 0.3, | ||||
|     # pairs to keep per mention after rough scoring | ||||
|     # TODO change to meaningful name | ||||
|  | @ -30,7 +25,7 @@ def build_wl_coref_model( | |||
|     a_scoring_batch_size: int = 512, | ||||
|     # span predictor embeddings | ||||
|     sp_embedding_size: int = 64, | ||||
|     ): | ||||
| ): | ||||
|     # TODO fix this | ||||
|     try: | ||||
|         dim = tok2vec.get_dim("nO") | ||||
|  | @ -48,10 +43,10 @@ def build_wl_coref_model( | |||
|                 n_hidden_layers, | ||||
|                 dropout, | ||||
|                 rough_k, | ||||
|                 a_scoring_batch_size | ||||
|                 a_scoring_batch_size, | ||||
|             ), | ||||
|             convert_inputs=convert_coref_scorer_inputs, | ||||
|             convert_outputs=convert_coref_scorer_outputs | ||||
|             convert_outputs=convert_coref_scorer_outputs, | ||||
|         ) | ||||
|         coref_model = tok2vec >> coref_scorer | ||||
|         # XXX just ignore this until the coref scorer is integrated | ||||
|  | @ -68,12 +63,13 @@ def build_wl_coref_model( | |||
|     # and just return words as spans. | ||||
|     return coref_model | ||||
| 
 | ||||
| 
 | ||||
| @registry.architectures("spacy.SpanPredictor.v1") | ||||
| def build_span_predictor( | ||||
|     tok2vec: Model[List[Doc], List[Floats2d]], | ||||
|     hidden_size: int = 1024, | ||||
|     dist_emb_size: int = 64, | ||||
|     ): | ||||
| ): | ||||
|     # TODO fix this | ||||
|     try: | ||||
|         dim = tok2vec.get_dim("nO") | ||||
|  | @ -84,22 +80,16 @@ def build_span_predictor( | |||
|     with Model.define_operators({">>": chain, "&": tuplify}): | ||||
|         span_predictor = PyTorchWrapper( | ||||
|             SpanPredictor(dim, hidden_size, dist_emb_size), | ||||
|             convert_inputs=convert_span_predictor_inputs | ||||
|             convert_inputs=convert_span_predictor_inputs, | ||||
|         ) | ||||
|         # TODO use proper parameter for prefix | ||||
|         head_info = build_get_head_metadata( | ||||
|             "coref_head_clusters" | ||||
|         ) | ||||
|         head_info = build_get_head_metadata("coref_head_clusters") | ||||
|         model = (tok2vec & head_info) >> span_predictor | ||||
| 
 | ||||
|     return model | ||||
| 
 | ||||
| 
 | ||||
| def convert_coref_scorer_inputs( | ||||
|     model: Model, | ||||
|     X: List[Floats2d], | ||||
|     is_train: bool | ||||
| ): | ||||
| def convert_coref_scorer_inputs(model: Model, X: List[Floats2d], is_train: bool): | ||||
|     # The input here is List[Floats2d], one for each doc | ||||
|     # just use the first | ||||
|     # TODO real batching | ||||
|  | @ -111,14 +101,10 @@ def convert_coref_scorer_inputs( | |||
|         gradients = torch2xp(args.args[0]) | ||||
|         return [gradients] | ||||
| 
 | ||||
|     return ArgsKwargs(args=(word_features, ), kwargs={}), backprop | ||||
|     return ArgsKwargs(args=(word_features,), kwargs={}), backprop | ||||
| 
 | ||||
| 
 | ||||
| def convert_coref_scorer_outputs( | ||||
|     model: Model, | ||||
|     inputs_outputs, | ||||
|     is_train: bool | ||||
| ): | ||||
| def convert_coref_scorer_outputs(model: Model, inputs_outputs, is_train: bool): | ||||
|     _, outputs = inputs_outputs | ||||
|     scores, indices = outputs | ||||
| 
 | ||||
|  | @ -135,9 +121,7 @@ def convert_coref_scorer_outputs( | |||
| 
 | ||||
| 
 | ||||
| def convert_span_predictor_inputs( | ||||
|     model: Model, | ||||
|     X: Tuple[Ints1d, Floats2d, Ints1d], | ||||
|     is_train: bool | ||||
|     model: Model, X: Tuple[Ints1d, Floats2d, Ints1d], is_train: bool | ||||
| ): | ||||
|     tok2vec, (sent_ids, head_ids) = X | ||||
|     # Normally we shoudl use the input is_train, but for these two it's not relevant | ||||
|  | @ -160,10 +144,9 @@ def convert_span_predictor_inputs( | |||
| 
 | ||||
| 
 | ||||
| # TODO This probably belongs in the component, not the model. | ||||
| def predict_span_clusters(span_predictor: Model, | ||||
|                           sent_ids: Ints1d, | ||||
|                           words: Floats2d, | ||||
|                           clusters: List[Ints1d]): | ||||
| def predict_span_clusters( | ||||
|     span_predictor: Model, sent_ids: Ints1d, words: Floats2d, clusters: List[Ints1d] | ||||
| ): | ||||
|     """ | ||||
|     Predicts span clusters based on the word clusters. | ||||
| 
 | ||||
|  | @ -187,20 +170,15 @@ def predict_span_clusters(span_predictor: Model, | |||
|     ends = (scores[:, :, 1].argmax(axis=1) + 1).tolist() | ||||
| 
 | ||||
|     head2span = { | ||||
|         head: (start, end) | ||||
|         for head, start, end in zip(heads_ids.tolist(), starts, ends) | ||||
|         head: (start, end) for head, start, end in zip(heads_ids.tolist(), starts, ends) | ||||
|     } | ||||
| 
 | ||||
|     return [[head2span[head] for head in cluster] | ||||
|             for cluster in clusters] | ||||
|     return [[head2span[head] for head in cluster] for cluster in clusters] | ||||
| 
 | ||||
| 
 | ||||
| # TODO add docstring for this, maybe move to utils. | ||||
| # This might belong in the component. | ||||
| def _clusterize( | ||||
|         model, | ||||
|         scores: Floats2d, | ||||
|         top_indices: Ints2d | ||||
| ): | ||||
| def _clusterize(model, scores: Floats2d, top_indices: Ints2d): | ||||
|     xp = model.ops.xp | ||||
|     antecedents = scores.argmax(axis=1) - 1 | ||||
|     not_dummy = antecedents >= 0 | ||||
|  | @ -229,15 +207,14 @@ def _clusterize( | |||
| 
 | ||||
| def build_get_head_metadata(prefix): | ||||
|     # TODO this name is awful, fix it | ||||
|     model = Model("HeadDataProvider", | ||||
|                   attrs={'prefix': prefix}, | ||||
|                   forward=head_data_forward) | ||||
|     model = Model( | ||||
|         "HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward | ||||
|     ) | ||||
|     return model | ||||
| 
 | ||||
| 
 | ||||
| def head_data_forward(model, docs, is_train): | ||||
|     """A layer to generate the extra data needed for the span predictor. | ||||
|     """ | ||||
|     """A layer to generate the extra data needed for the span predictor.""" | ||||
|     sent_ids = [] | ||||
|     head_ids = [] | ||||
|     prefix = model.attrs["prefix"] | ||||
|  | @ -271,15 +248,16 @@ class CorefScorer(torch.nn.Module): | |||
|         a_scorer (AnaphoricityScorer) | ||||
|         sp (SpanPredictor) | ||||
|     """ | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         dim: int, # tok2vec size | ||||
|         dim: int,  # tok2vec size | ||||
|         dist_emb_size: int, | ||||
|         hidden_size: int, | ||||
|         n_layers: int, | ||||
|         dropout_rate: float, | ||||
|         roughk: int, | ||||
|         batch_size: int | ||||
|         batch_size: int, | ||||
|     ): | ||||
|         super().__init__() | ||||
|         """ | ||||
|  | @ -290,14 +268,11 @@ class CorefScorer(torch.nn.Module): | |||
|                 (useful for warm start) | ||||
|         """ | ||||
|         self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate) | ||||
|         #TODO clean this up | ||||
|         # TODO clean this up | ||||
|         bert_emb = dim | ||||
|         pair_emb = bert_emb * 3 + self.pw.shape | ||||
|         self.a_scorer = AnaphoricityScorer( | ||||
|             pair_emb, | ||||
|             hidden_size, | ||||
|             n_layers, | ||||
|             dropout_rate | ||||
|             pair_emb, hidden_size, n_layers, dropout_rate | ||||
|         ) | ||||
|         self.lstm = torch.nn.LSTM( | ||||
|             input_size=bert_emb, | ||||
|  | @ -305,17 +280,10 @@ class CorefScorer(torch.nn.Module): | |||
|             batch_first=True, | ||||
|         ) | ||||
|         self.dropout = torch.nn.Dropout(dropout_rate) | ||||
|         self.rough_scorer = RoughScorer( | ||||
|             bert_emb, | ||||
|             dropout_rate, | ||||
|             roughk | ||||
|         ) | ||||
|         self.rough_scorer = RoughScorer(bert_emb, dropout_rate, roughk) | ||||
|         self.batch_size = batch_size | ||||
| 
 | ||||
|     def forward( | ||||
|         self, | ||||
|         word_features: torch.Tensor | ||||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||||
|     def forward(self, word_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||||
|         """ | ||||
|         This is a massive method, but it made sense to me to not split it into | ||||
|         several ones to let one see the data flow. | ||||
|  | @ -327,7 +295,7 @@ class CorefScorer(torch.nn.Module): | |||
|         """ | ||||
|         # words           [n_words, span_emb] | ||||
|         # cluster_ids     [n_words] | ||||
|         self.lstm.flatten_parameters() # XXX without this there's a warning | ||||
|         self.lstm.flatten_parameters()  # XXX without this there's a warning | ||||
|         word_features = torch.unsqueeze(word_features, dim=0) | ||||
|         words, _ = self.lstm(word_features) | ||||
|         words = words.squeeze() | ||||
|  | @ -342,16 +310,18 @@ class CorefScorer(torch.nn.Module): | |||
|         a_scores_lst: List[torch.Tensor] = [] | ||||
| 
 | ||||
|         for i in range(0, len(words), batch_size): | ||||
|             pw_batch = pw[i:i + batch_size] | ||||
|             words_batch = words[i:i + batch_size] | ||||
|             top_indices_batch = top_indices[i:i + batch_size] | ||||
|             top_rough_scores_batch = top_rough_scores[i:i + batch_size] | ||||
|             pw_batch = pw[i : i + batch_size] | ||||
|             words_batch = words[i : i + batch_size] | ||||
|             top_indices_batch = top_indices[i : i + batch_size] | ||||
|             top_rough_scores_batch = top_rough_scores[i : i + batch_size] | ||||
| 
 | ||||
|             # a_scores_batch    [batch_size, n_ants] | ||||
|             a_scores_batch = self.a_scorer( | ||||
|                 all_mentions=words, mentions_batch=words_batch, | ||||
|                 pw_batch=pw_batch, top_indices_batch=top_indices_batch, | ||||
|                 top_rough_scores_batch=top_rough_scores_batch | ||||
|                 all_mentions=words, | ||||
|                 mentions_batch=words_batch, | ||||
|                 pw_batch=pw_batch, | ||||
|                 top_indices_batch=top_indices_batch, | ||||
|                 top_rough_scores_batch=top_rough_scores_batch, | ||||
|             ) | ||||
|             a_scores_lst.append(a_scores_batch) | ||||
| 
 | ||||
|  | @ -360,33 +330,35 @@ class CorefScorer(torch.nn.Module): | |||
| 
 | ||||
| 
 | ||||
| class AnaphoricityScorer(torch.nn.Module): | ||||
|     """ Calculates anaphoricity scores by passing the inputs into a FFNN """ | ||||
|     def __init__(self, | ||||
|                  in_features: int, | ||||
|                  hidden_size, | ||||
|                  n_hidden_layers, | ||||
|                  dropout_rate): | ||||
|     """Calculates anaphoricity scores by passing the inputs into a FFNN""" | ||||
| 
 | ||||
|     def __init__(self, in_features: int, hidden_size, n_hidden_layers, dropout_rate): | ||||
|         super().__init__() | ||||
|         hidden_size = hidden_size | ||||
|         if not n_hidden_layers: | ||||
|             hidden_size = in_features | ||||
|         layers = [] | ||||
|         for i in range(n_hidden_layers): | ||||
|             layers.extend([torch.nn.Linear(hidden_size if i else in_features, | ||||
|                                            hidden_size), | ||||
|                            torch.nn.LeakyReLU(), | ||||
|                            torch.nn.Dropout(dropout_rate)]) | ||||
|             layers.extend( | ||||
|                 [ | ||||
|                     torch.nn.Linear(hidden_size if i else in_features, hidden_size), | ||||
|                     torch.nn.LeakyReLU(), | ||||
|                     torch.nn.Dropout(dropout_rate), | ||||
|                 ] | ||||
|             ) | ||||
|         self.hidden = torch.nn.Sequential(*layers) | ||||
|         self.out = torch.nn.Linear(hidden_size, out_features=1) | ||||
| 
 | ||||
|     def forward(self, *,  # type: ignore  # pylint: disable=arguments-differ  #35566 in pytorch | ||||
|                 all_mentions: torch.Tensor, | ||||
|                 mentions_batch: torch.Tensor, | ||||
|                 pw_batch: torch.Tensor, | ||||
|                 top_indices_batch: torch.Tensor, | ||||
|                 top_rough_scores_batch: torch.Tensor, | ||||
|                 ) -> torch.Tensor: | ||||
|         """ Builds a pairwise matrix, scores the pairs and returns the scores. | ||||
|     def forward( | ||||
|         self, | ||||
|         *,  # type: ignore  # pylint: disable=arguments-differ  #35566 in pytorch | ||||
|         all_mentions: torch.Tensor, | ||||
|         mentions_batch: torch.Tensor, | ||||
|         pw_batch: torch.Tensor, | ||||
|         top_indices_batch: torch.Tensor, | ||||
|         top_rough_scores_batch: torch.Tensor, | ||||
|     ) -> torch.Tensor: | ||||
|         """Builds a pairwise matrix, scores the pairs and returns the scores. | ||||
| 
 | ||||
|         Args: | ||||
|             all_mentions (torch.Tensor): [n_mentions, mention_emb] | ||||
|  | @ -401,7 +373,8 @@ class AnaphoricityScorer(torch.nn.Module): | |||
|         """ | ||||
|         # [batch_size, n_ants, pair_emb] | ||||
|         pair_matrix = self._get_pair_matrix( | ||||
|             all_mentions, mentions_batch, pw_batch, top_indices_batch) | ||||
|             all_mentions, mentions_batch, pw_batch, top_indices_batch | ||||
|         ) | ||||
| 
 | ||||
|         # [batch_size, n_ants] | ||||
|         scores = top_rough_scores_batch + self._ffnn(pair_matrix) | ||||
|  | @ -423,11 +396,12 @@ class AnaphoricityScorer(torch.nn.Module): | |||
|         return x.squeeze(2) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _get_pair_matrix(all_mentions: torch.Tensor, | ||||
|                          mentions_batch: torch.Tensor, | ||||
|                          pw_batch: torch.Tensor, | ||||
|                          top_indices_batch: torch.Tensor, | ||||
|                          ) -> torch.Tensor: | ||||
|     def _get_pair_matrix( | ||||
|         all_mentions: torch.Tensor, | ||||
|         mentions_batch: torch.Tensor, | ||||
|         pw_batch: torch.Tensor, | ||||
|         top_indices_batch: torch.Tensor, | ||||
|     ) -> torch.Tensor: | ||||
|         """ | ||||
|         Builds the matrix used as input for AnaphoricityScorer. | ||||
| 
 | ||||
|  | @ -464,12 +438,8 @@ class RoughScorer(torch.nn.Module): | |||
|     only top scoring candidates are considered on later steps to reduce | ||||
|     computational complexity. | ||||
|     """ | ||||
|     def __init__( | ||||
|             self, | ||||
|             features: int,  | ||||
|             dropout_rate: float,  | ||||
|             rough_k: float | ||||
|     ): | ||||
| 
 | ||||
|     def __init__(self, features: int, dropout_rate: float, rough_k: float): | ||||
|         super().__init__() | ||||
|         self.dropout = torch.nn.Dropout(dropout_rate) | ||||
|         self.bilinear = torch.nn.Linear(features, features) | ||||
|  | @ -478,7 +448,7 @@ class RoughScorer(torch.nn.Module): | |||
| 
 | ||||
|     def forward( | ||||
|         self,  # type: ignore  # pylint: disable=arguments-differ  #35566 in pytorch | ||||
|         mentions: torch.Tensor | ||||
|         mentions: torch.Tensor, | ||||
|     ) -> Tuple[torch.Tensor, torch.Tensor]: | ||||
|         """ | ||||
|         Returns rough anaphoricity scores for candidates, which consist of | ||||
|  | @ -493,9 +463,7 @@ class RoughScorer(torch.nn.Module): | |||
| 
 | ||||
|         return self._prune(rough_scores) | ||||
| 
 | ||||
|     def _prune(self, | ||||
|                rough_scores: torch.Tensor | ||||
|                ) -> Tuple[torch.Tensor, torch.Tensor]: | ||||
|     def _prune(self, rough_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||||
|         """ | ||||
|         Selects top-k rough antecedent scores for each mention. | ||||
| 
 | ||||
|  | @ -507,9 +475,9 @@ class RoughScorer(torch.nn.Module): | |||
|             FloatTensor of shape [n_mentions, k], top rough scores | ||||
|             LongTensor of shape [n_mentions, k], top indices | ||||
|         """ | ||||
|         top_scores, indices = torch.topk(rough_scores, | ||||
|                                          k=min(self.k, len(rough_scores)), | ||||
|                                          dim=1, sorted=False) | ||||
|         top_scores, indices = torch.topk( | ||||
|             rough_scores, k=min(self.k, len(rough_scores)), dim=1, sorted=False | ||||
|         ) | ||||
|         return top_scores, indices | ||||
| 
 | ||||
| 
 | ||||
|  | @ -523,7 +491,7 @@ class SpanPredictor(torch.nn.Module): | |||
|             torch.nn.Linear(input_size * 2 + dist_emb_size, hidden_size), | ||||
|             torch.nn.ReLU(), | ||||
|             torch.nn.Dropout(0.3), | ||||
|             #TODO seems weird the 256 isn't a parameter??? | ||||
|             # TODO seems weird the 256 isn't a parameter??? | ||||
|             torch.nn.Linear(hidden_size, 256), | ||||
|             torch.nn.ReLU(), | ||||
|             torch.nn.Dropout(0.3), | ||||
|  | @ -531,15 +499,16 @@ class SpanPredictor(torch.nn.Module): | |||
|             torch.nn.Linear(256, dist_emb_size), | ||||
|         ) | ||||
|         self.conv = torch.nn.Sequential( | ||||
|             torch.nn.Conv1d(64, 4, 3, 1, 1), | ||||
|             torch.nn.Conv1d(4, 2, 3, 1, 1) | ||||
|             torch.nn.Conv1d(64, 4, 3, 1, 1), torch.nn.Conv1d(4, 2, 3, 1, 1) | ||||
|         ) | ||||
|         self.emb = torch.nn.Embedding(128, dist_emb_size) # [-63, 63] + too_far | ||||
|         self.emb = torch.nn.Embedding(128, dist_emb_size)  # [-63, 63] + too_far | ||||
| 
 | ||||
|     def forward(self,  # type: ignore  # pylint: disable=arguments-differ  #35566 in pytorch | ||||
|                 sent_id, | ||||
|                 words: torch.Tensor, | ||||
|                 heads_ids: torch.Tensor) -> torch.Tensor: | ||||
|     def forward( | ||||
|         self,  # type: ignore  # pylint: disable=arguments-differ  #35566 in pytorch | ||||
|         sent_id, | ||||
|         words: torch.Tensor, | ||||
|         heads_ids: torch.Tensor, | ||||
|     ) -> torch.Tensor: | ||||
|         """ | ||||
|         Calculates span start/end scores of words for each span head in | ||||
|         heads_ids | ||||
|  | @ -557,37 +526,44 @@ class SpanPredictor(torch.nn.Module): | |||
|         if heads_ids.nelement() == 0: | ||||
|             return torch.empty(size=(0,)) | ||||
|         # Obtain distance embedding indices, [n_heads, n_words] | ||||
|         relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0]).unsqueeze(0)) | ||||
|         relative_positions = heads_ids.unsqueeze(1) - torch.arange( | ||||
|             words.shape[0] | ||||
|         ).unsqueeze(0) | ||||
|         # make all valid distances positive | ||||
|         emb_ids = relative_positions + 63 | ||||
|         # "too_far" | ||||
|         emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127 | ||||
|         # Obtain "same sentence" boolean mask, [n_heads, n_words] | ||||
|         heads_ids = heads_ids.long() | ||||
|         same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0)) | ||||
|         same_sent = sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0) | ||||
|         # To save memory, only pass candidates from one sentence for each head | ||||
|         # pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb | ||||
|         # for each candidate among the words in the same sentence as span_head | ||||
|         # [n_heads, input_size * 2 + distance_emb_size] | ||||
|         rows, cols = same_sent.nonzero(as_tuple=True) | ||||
|         pair_matrix = torch.cat(( | ||||
|             words[heads_ids[rows]], | ||||
|             words[cols], | ||||
|             self.emb(emb_ids[rows, cols]), | ||||
|         ), dim=1) | ||||
|         pair_matrix = torch.cat( | ||||
|             ( | ||||
|                 words[heads_ids[rows]], | ||||
|                 words[cols], | ||||
|                 self.emb(emb_ids[rows, cols]), | ||||
|             ), | ||||
|             dim=1, | ||||
|         ) | ||||
|         lengths = same_sent.sum(dim=1) | ||||
|         padding_mask = torch.arange(0, lengths.max().item()).unsqueeze(0) | ||||
|         padding_mask = (padding_mask < lengths.unsqueeze(1))  # [n_heads, max_sent_len] | ||||
|         padding_mask = padding_mask < lengths.unsqueeze(1)  # [n_heads, max_sent_len] | ||||
|         # [n_heads, max_sent_len, input_size * 2 + distance_emb_size] | ||||
|         # This is necessary to allow the convolution layer to look at several | ||||
|         # word scores | ||||
|         padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1]) | ||||
|         padded_pairs[padding_mask] = pair_matrix | ||||
| 
 | ||||
|         res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output] | ||||
|         res = self.conv(res.permute(0, 2, 1)).permute(0, 2, 1) # [n_heads, n_candidates, 2] | ||||
|         res = self.ffnn(padded_pairs)  # [n_heads, n_candidates, last_layer_output] | ||||
|         res = self.conv(res.permute(0, 2, 1)).permute( | ||||
|             0, 2, 1 | ||||
|         )  # [n_heads, n_candidates, 2] | ||||
| 
 | ||||
|         scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf')) | ||||
|         scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float("-inf")) | ||||
|         scores[rows, cols] = res[padding_mask] | ||||
|         # Make sure that start <= head <= end during inference | ||||
|         if not self.training: | ||||
|  | @ -597,8 +573,8 @@ class SpanPredictor(torch.nn.Module): | |||
|             return scores + valid_positions | ||||
|         return scores | ||||
| 
 | ||||
| class DistancePairwiseEncoder(torch.nn.Module): | ||||
| 
 | ||||
| class DistancePairwiseEncoder(torch.nn.Module): | ||||
|     def __init__(self, embedding_size, dropout_rate): | ||||
|         super().__init__() | ||||
|         emb_size = embedding_size | ||||
|  | @ -606,12 +582,12 @@ class DistancePairwiseEncoder(torch.nn.Module): | |||
|         self.dropout = torch.nn.Dropout(dropout_rate) | ||||
|         self.shape = emb_size | ||||
| 
 | ||||
|     def forward(self,  # type: ignore  # pylint: disable=arguments-differ  #35566 in pytorch | ||||
|                 top_indices: torch.Tensor | ||||
|         ) -> torch.Tensor: | ||||
|     def forward( | ||||
|         self,  # type: ignore  # pylint: disable=arguments-differ  #35566 in pytorch | ||||
|         top_indices: torch.Tensor, | ||||
|     ) -> torch.Tensor: | ||||
|         word_ids = torch.arange(0, top_indices.size(0)) | ||||
|         distance = (word_ids.unsqueeze(1) - word_ids[top_indices] | ||||
|                     ).clamp_min_(min=1) | ||||
|         distance = (word_ids.unsqueeze(1) - word_ids[top_indices]).clamp_min_(min=1) | ||||
|         log_distance = distance.to(torch.float).log2().floor_() | ||||
|         log_distance = log_distance.clamp_max_(max=6).to(torch.long) | ||||
|         distance = torch.where(distance < 5, distance - 1, log_distance + 2) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user