mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
clean up unused imports + black formatting
This commit is contained in:
parent
683f470852
commit
6b51258a58
|
@ -1,27 +1,22 @@
|
|||
from dataclasses import dataclass
|
||||
import warnings
|
||||
|
||||
from thinc.api import Model, Linear, Relu, Dropout
|
||||
from thinc.api import chain, noop, Embed, add, tuplify, concatenate
|
||||
from thinc.api import reduce_first, reduce_last, reduce_mean
|
||||
from thinc.api import PyTorchWrapper, ArgsKwargs
|
||||
from thinc.types import Floats2d, Floats1d, Ints1d, Ints2d, Ragged
|
||||
from typing import List, Callable, Tuple, Any
|
||||
from ...tokens import Doc
|
||||
from ...util import registry
|
||||
from ..extract_spans import extract_spans
|
||||
|
||||
from typing import List, Tuple
|
||||
import torch
|
||||
|
||||
from thinc.api import Model, chain, tuplify
|
||||
from thinc.api import PyTorchWrapper, ArgsKwargs
|
||||
from thinc.types import Floats2d, Ints1d, Ints2d
|
||||
from thinc.util import xp2torch, torch2xp
|
||||
|
||||
from ...tokens import Doc
|
||||
from ...util import registry
|
||||
from .coref_util import add_dummy, get_sentence_ids
|
||||
|
||||
|
||||
@registry.architectures("spacy.Coref.v1")
|
||||
def build_wl_coref_model(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||
embedding_size: int = 20,
|
||||
hidden_size: int = 1024,
|
||||
n_hidden_layers: int = 1, # TODO rename to "depth"?
|
||||
n_hidden_layers: int = 1, # TODO rename to "depth"?
|
||||
dropout: float = 0.3,
|
||||
# pairs to keep per mention after rough scoring
|
||||
# TODO change to meaningful name
|
||||
|
@ -30,7 +25,7 @@ def build_wl_coref_model(
|
|||
a_scoring_batch_size: int = 512,
|
||||
# span predictor embeddings
|
||||
sp_embedding_size: int = 64,
|
||||
):
|
||||
):
|
||||
# TODO fix this
|
||||
try:
|
||||
dim = tok2vec.get_dim("nO")
|
||||
|
@ -48,10 +43,10 @@ def build_wl_coref_model(
|
|||
n_hidden_layers,
|
||||
dropout,
|
||||
rough_k,
|
||||
a_scoring_batch_size
|
||||
a_scoring_batch_size,
|
||||
),
|
||||
convert_inputs=convert_coref_scorer_inputs,
|
||||
convert_outputs=convert_coref_scorer_outputs
|
||||
convert_outputs=convert_coref_scorer_outputs,
|
||||
)
|
||||
coref_model = tok2vec >> coref_scorer
|
||||
# XXX just ignore this until the coref scorer is integrated
|
||||
|
@ -68,12 +63,13 @@ def build_wl_coref_model(
|
|||
# and just return words as spans.
|
||||
return coref_model
|
||||
|
||||
|
||||
@registry.architectures("spacy.SpanPredictor.v1")
|
||||
def build_span_predictor(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||
hidden_size: int = 1024,
|
||||
dist_emb_size: int = 64,
|
||||
):
|
||||
):
|
||||
# TODO fix this
|
||||
try:
|
||||
dim = tok2vec.get_dim("nO")
|
||||
|
@ -84,22 +80,16 @@ def build_span_predictor(
|
|||
with Model.define_operators({">>": chain, "&": tuplify}):
|
||||
span_predictor = PyTorchWrapper(
|
||||
SpanPredictor(dim, hidden_size, dist_emb_size),
|
||||
convert_inputs=convert_span_predictor_inputs
|
||||
convert_inputs=convert_span_predictor_inputs,
|
||||
)
|
||||
# TODO use proper parameter for prefix
|
||||
head_info = build_get_head_metadata(
|
||||
"coref_head_clusters"
|
||||
)
|
||||
head_info = build_get_head_metadata("coref_head_clusters")
|
||||
model = (tok2vec & head_info) >> span_predictor
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def convert_coref_scorer_inputs(
|
||||
model: Model,
|
||||
X: List[Floats2d],
|
||||
is_train: bool
|
||||
):
|
||||
def convert_coref_scorer_inputs(model: Model, X: List[Floats2d], is_train: bool):
|
||||
# The input here is List[Floats2d], one for each doc
|
||||
# just use the first
|
||||
# TODO real batching
|
||||
|
@ -111,14 +101,10 @@ def convert_coref_scorer_inputs(
|
|||
gradients = torch2xp(args.args[0])
|
||||
return [gradients]
|
||||
|
||||
return ArgsKwargs(args=(word_features, ), kwargs={}), backprop
|
||||
return ArgsKwargs(args=(word_features,), kwargs={}), backprop
|
||||
|
||||
|
||||
def convert_coref_scorer_outputs(
|
||||
model: Model,
|
||||
inputs_outputs,
|
||||
is_train: bool
|
||||
):
|
||||
def convert_coref_scorer_outputs(model: Model, inputs_outputs, is_train: bool):
|
||||
_, outputs = inputs_outputs
|
||||
scores, indices = outputs
|
||||
|
||||
|
@ -135,9 +121,7 @@ def convert_coref_scorer_outputs(
|
|||
|
||||
|
||||
def convert_span_predictor_inputs(
|
||||
model: Model,
|
||||
X: Tuple[Ints1d, Floats2d, Ints1d],
|
||||
is_train: bool
|
||||
model: Model, X: Tuple[Ints1d, Floats2d, Ints1d], is_train: bool
|
||||
):
|
||||
tok2vec, (sent_ids, head_ids) = X
|
||||
# Normally we shoudl use the input is_train, but for these two it's not relevant
|
||||
|
@ -160,10 +144,9 @@ def convert_span_predictor_inputs(
|
|||
|
||||
|
||||
# TODO This probably belongs in the component, not the model.
|
||||
def predict_span_clusters(span_predictor: Model,
|
||||
sent_ids: Ints1d,
|
||||
words: Floats2d,
|
||||
clusters: List[Ints1d]):
|
||||
def predict_span_clusters(
|
||||
span_predictor: Model, sent_ids: Ints1d, words: Floats2d, clusters: List[Ints1d]
|
||||
):
|
||||
"""
|
||||
Predicts span clusters based on the word clusters.
|
||||
|
||||
|
@ -187,20 +170,15 @@ def predict_span_clusters(span_predictor: Model,
|
|||
ends = (scores[:, :, 1].argmax(axis=1) + 1).tolist()
|
||||
|
||||
head2span = {
|
||||
head: (start, end)
|
||||
for head, start, end in zip(heads_ids.tolist(), starts, ends)
|
||||
head: (start, end) for head, start, end in zip(heads_ids.tolist(), starts, ends)
|
||||
}
|
||||
|
||||
return [[head2span[head] for head in cluster]
|
||||
for cluster in clusters]
|
||||
return [[head2span[head] for head in cluster] for cluster in clusters]
|
||||
|
||||
|
||||
# TODO add docstring for this, maybe move to utils.
|
||||
# This might belong in the component.
|
||||
def _clusterize(
|
||||
model,
|
||||
scores: Floats2d,
|
||||
top_indices: Ints2d
|
||||
):
|
||||
def _clusterize(model, scores: Floats2d, top_indices: Ints2d):
|
||||
xp = model.ops.xp
|
||||
antecedents = scores.argmax(axis=1) - 1
|
||||
not_dummy = antecedents >= 0
|
||||
|
@ -229,15 +207,14 @@ def _clusterize(
|
|||
|
||||
def build_get_head_metadata(prefix):
|
||||
# TODO this name is awful, fix it
|
||||
model = Model("HeadDataProvider",
|
||||
attrs={'prefix': prefix},
|
||||
forward=head_data_forward)
|
||||
model = Model(
|
||||
"HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def head_data_forward(model, docs, is_train):
|
||||
"""A layer to generate the extra data needed for the span predictor.
|
||||
"""
|
||||
"""A layer to generate the extra data needed for the span predictor."""
|
||||
sent_ids = []
|
||||
head_ids = []
|
||||
prefix = model.attrs["prefix"]
|
||||
|
@ -271,15 +248,16 @@ class CorefScorer(torch.nn.Module):
|
|||
a_scorer (AnaphoricityScorer)
|
||||
sp (SpanPredictor)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int, # tok2vec size
|
||||
dim: int, # tok2vec size
|
||||
dist_emb_size: int,
|
||||
hidden_size: int,
|
||||
n_layers: int,
|
||||
dropout_rate: float,
|
||||
roughk: int,
|
||||
batch_size: int
|
||||
batch_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
"""
|
||||
|
@ -290,14 +268,11 @@ class CorefScorer(torch.nn.Module):
|
|||
(useful for warm start)
|
||||
"""
|
||||
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate)
|
||||
#TODO clean this up
|
||||
# TODO clean this up
|
||||
bert_emb = dim
|
||||
pair_emb = bert_emb * 3 + self.pw.shape
|
||||
self.a_scorer = AnaphoricityScorer(
|
||||
pair_emb,
|
||||
hidden_size,
|
||||
n_layers,
|
||||
dropout_rate
|
||||
pair_emb, hidden_size, n_layers, dropout_rate
|
||||
)
|
||||
self.lstm = torch.nn.LSTM(
|
||||
input_size=bert_emb,
|
||||
|
@ -305,17 +280,10 @@ class CorefScorer(torch.nn.Module):
|
|||
batch_first=True,
|
||||
)
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
self.rough_scorer = RoughScorer(
|
||||
bert_emb,
|
||||
dropout_rate,
|
||||
roughk
|
||||
)
|
||||
self.rough_scorer = RoughScorer(bert_emb, dropout_rate, roughk)
|
||||
self.batch_size = batch_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
word_features: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def forward(self, word_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This is a massive method, but it made sense to me to not split it into
|
||||
several ones to let one see the data flow.
|
||||
|
@ -327,7 +295,7 @@ class CorefScorer(torch.nn.Module):
|
|||
"""
|
||||
# words [n_words, span_emb]
|
||||
# cluster_ids [n_words]
|
||||
self.lstm.flatten_parameters() # XXX without this there's a warning
|
||||
self.lstm.flatten_parameters() # XXX without this there's a warning
|
||||
word_features = torch.unsqueeze(word_features, dim=0)
|
||||
words, _ = self.lstm(word_features)
|
||||
words = words.squeeze()
|
||||
|
@ -342,16 +310,18 @@ class CorefScorer(torch.nn.Module):
|
|||
a_scores_lst: List[torch.Tensor] = []
|
||||
|
||||
for i in range(0, len(words), batch_size):
|
||||
pw_batch = pw[i:i + batch_size]
|
||||
words_batch = words[i:i + batch_size]
|
||||
top_indices_batch = top_indices[i:i + batch_size]
|
||||
top_rough_scores_batch = top_rough_scores[i:i + batch_size]
|
||||
pw_batch = pw[i : i + batch_size]
|
||||
words_batch = words[i : i + batch_size]
|
||||
top_indices_batch = top_indices[i : i + batch_size]
|
||||
top_rough_scores_batch = top_rough_scores[i : i + batch_size]
|
||||
|
||||
# a_scores_batch [batch_size, n_ants]
|
||||
a_scores_batch = self.a_scorer(
|
||||
all_mentions=words, mentions_batch=words_batch,
|
||||
pw_batch=pw_batch, top_indices_batch=top_indices_batch,
|
||||
top_rough_scores_batch=top_rough_scores_batch
|
||||
all_mentions=words,
|
||||
mentions_batch=words_batch,
|
||||
pw_batch=pw_batch,
|
||||
top_indices_batch=top_indices_batch,
|
||||
top_rough_scores_batch=top_rough_scores_batch,
|
||||
)
|
||||
a_scores_lst.append(a_scores_batch)
|
||||
|
||||
|
@ -360,33 +330,35 @@ class CorefScorer(torch.nn.Module):
|
|||
|
||||
|
||||
class AnaphoricityScorer(torch.nn.Module):
|
||||
""" Calculates anaphoricity scores by passing the inputs into a FFNN """
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
hidden_size,
|
||||
n_hidden_layers,
|
||||
dropout_rate):
|
||||
"""Calculates anaphoricity scores by passing the inputs into a FFNN"""
|
||||
|
||||
def __init__(self, in_features: int, hidden_size, n_hidden_layers, dropout_rate):
|
||||
super().__init__()
|
||||
hidden_size = hidden_size
|
||||
if not n_hidden_layers:
|
||||
hidden_size = in_features
|
||||
layers = []
|
||||
for i in range(n_hidden_layers):
|
||||
layers.extend([torch.nn.Linear(hidden_size if i else in_features,
|
||||
hidden_size),
|
||||
torch.nn.LeakyReLU(),
|
||||
torch.nn.Dropout(dropout_rate)])
|
||||
layers.extend(
|
||||
[
|
||||
torch.nn.Linear(hidden_size if i else in_features, hidden_size),
|
||||
torch.nn.LeakyReLU(),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
]
|
||||
)
|
||||
self.hidden = torch.nn.Sequential(*layers)
|
||||
self.out = torch.nn.Linear(hidden_size, out_features=1)
|
||||
|
||||
def forward(self, *, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
||||
all_mentions: torch.Tensor,
|
||||
mentions_batch: torch.Tensor,
|
||||
pw_batch: torch.Tensor,
|
||||
top_indices_batch: torch.Tensor,
|
||||
top_rough_scores_batch: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
""" Builds a pairwise matrix, scores the pairs and returns the scores.
|
||||
def forward(
|
||||
self,
|
||||
*, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
||||
all_mentions: torch.Tensor,
|
||||
mentions_batch: torch.Tensor,
|
||||
pw_batch: torch.Tensor,
|
||||
top_indices_batch: torch.Tensor,
|
||||
top_rough_scores_batch: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Builds a pairwise matrix, scores the pairs and returns the scores.
|
||||
|
||||
Args:
|
||||
all_mentions (torch.Tensor): [n_mentions, mention_emb]
|
||||
|
@ -401,7 +373,8 @@ class AnaphoricityScorer(torch.nn.Module):
|
|||
"""
|
||||
# [batch_size, n_ants, pair_emb]
|
||||
pair_matrix = self._get_pair_matrix(
|
||||
all_mentions, mentions_batch, pw_batch, top_indices_batch)
|
||||
all_mentions, mentions_batch, pw_batch, top_indices_batch
|
||||
)
|
||||
|
||||
# [batch_size, n_ants]
|
||||
scores = top_rough_scores_batch + self._ffnn(pair_matrix)
|
||||
|
@ -423,11 +396,12 @@ class AnaphoricityScorer(torch.nn.Module):
|
|||
return x.squeeze(2)
|
||||
|
||||
@staticmethod
|
||||
def _get_pair_matrix(all_mentions: torch.Tensor,
|
||||
mentions_batch: torch.Tensor,
|
||||
pw_batch: torch.Tensor,
|
||||
top_indices_batch: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
def _get_pair_matrix(
|
||||
all_mentions: torch.Tensor,
|
||||
mentions_batch: torch.Tensor,
|
||||
pw_batch: torch.Tensor,
|
||||
top_indices_batch: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Builds the matrix used as input for AnaphoricityScorer.
|
||||
|
||||
|
@ -464,12 +438,8 @@ class RoughScorer(torch.nn.Module):
|
|||
only top scoring candidates are considered on later steps to reduce
|
||||
computational complexity.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
features: int,
|
||||
dropout_rate: float,
|
||||
rough_k: float
|
||||
):
|
||||
|
||||
def __init__(self, features: int, dropout_rate: float, rough_k: float):
|
||||
super().__init__()
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
self.bilinear = torch.nn.Linear(features, features)
|
||||
|
@ -478,7 +448,7 @@ class RoughScorer(torch.nn.Module):
|
|||
|
||||
def forward(
|
||||
self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
||||
mentions: torch.Tensor
|
||||
mentions: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Returns rough anaphoricity scores for candidates, which consist of
|
||||
|
@ -493,9 +463,7 @@ class RoughScorer(torch.nn.Module):
|
|||
|
||||
return self._prune(rough_scores)
|
||||
|
||||
def _prune(self,
|
||||
rough_scores: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _prune(self, rough_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Selects top-k rough antecedent scores for each mention.
|
||||
|
||||
|
@ -507,9 +475,9 @@ class RoughScorer(torch.nn.Module):
|
|||
FloatTensor of shape [n_mentions, k], top rough scores
|
||||
LongTensor of shape [n_mentions, k], top indices
|
||||
"""
|
||||
top_scores, indices = torch.topk(rough_scores,
|
||||
k=min(self.k, len(rough_scores)),
|
||||
dim=1, sorted=False)
|
||||
top_scores, indices = torch.topk(
|
||||
rough_scores, k=min(self.k, len(rough_scores)), dim=1, sorted=False
|
||||
)
|
||||
return top_scores, indices
|
||||
|
||||
|
||||
|
@ -523,7 +491,7 @@ class SpanPredictor(torch.nn.Module):
|
|||
torch.nn.Linear(input_size * 2 + dist_emb_size, hidden_size),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Dropout(0.3),
|
||||
#TODO seems weird the 256 isn't a parameter???
|
||||
# TODO seems weird the 256 isn't a parameter???
|
||||
torch.nn.Linear(hidden_size, 256),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Dropout(0.3),
|
||||
|
@ -531,15 +499,16 @@ class SpanPredictor(torch.nn.Module):
|
|||
torch.nn.Linear(256, dist_emb_size),
|
||||
)
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(64, 4, 3, 1, 1),
|
||||
torch.nn.Conv1d(4, 2, 3, 1, 1)
|
||||
torch.nn.Conv1d(64, 4, 3, 1, 1), torch.nn.Conv1d(4, 2, 3, 1, 1)
|
||||
)
|
||||
self.emb = torch.nn.Embedding(128, dist_emb_size) # [-63, 63] + too_far
|
||||
self.emb = torch.nn.Embedding(128, dist_emb_size) # [-63, 63] + too_far
|
||||
|
||||
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
||||
sent_id,
|
||||
words: torch.Tensor,
|
||||
heads_ids: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
||||
sent_id,
|
||||
words: torch.Tensor,
|
||||
heads_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculates span start/end scores of words for each span head in
|
||||
heads_ids
|
||||
|
@ -557,37 +526,44 @@ class SpanPredictor(torch.nn.Module):
|
|||
if heads_ids.nelement() == 0:
|
||||
return torch.empty(size=(0,))
|
||||
# Obtain distance embedding indices, [n_heads, n_words]
|
||||
relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0]).unsqueeze(0))
|
||||
relative_positions = heads_ids.unsqueeze(1) - torch.arange(
|
||||
words.shape[0]
|
||||
).unsqueeze(0)
|
||||
# make all valid distances positive
|
||||
emb_ids = relative_positions + 63
|
||||
# "too_far"
|
||||
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127
|
||||
# Obtain "same sentence" boolean mask, [n_heads, n_words]
|
||||
heads_ids = heads_ids.long()
|
||||
same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0))
|
||||
same_sent = sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0)
|
||||
# To save memory, only pass candidates from one sentence for each head
|
||||
# pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb
|
||||
# for each candidate among the words in the same sentence as span_head
|
||||
# [n_heads, input_size * 2 + distance_emb_size]
|
||||
rows, cols = same_sent.nonzero(as_tuple=True)
|
||||
pair_matrix = torch.cat((
|
||||
words[heads_ids[rows]],
|
||||
words[cols],
|
||||
self.emb(emb_ids[rows, cols]),
|
||||
), dim=1)
|
||||
pair_matrix = torch.cat(
|
||||
(
|
||||
words[heads_ids[rows]],
|
||||
words[cols],
|
||||
self.emb(emb_ids[rows, cols]),
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
lengths = same_sent.sum(dim=1)
|
||||
padding_mask = torch.arange(0, lengths.max().item()).unsqueeze(0)
|
||||
padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len]
|
||||
padding_mask = padding_mask < lengths.unsqueeze(1) # [n_heads, max_sent_len]
|
||||
# [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
|
||||
# This is necessary to allow the convolution layer to look at several
|
||||
# word scores
|
||||
padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1])
|
||||
padded_pairs[padding_mask] = pair_matrix
|
||||
|
||||
res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output]
|
||||
res = self.conv(res.permute(0, 2, 1)).permute(0, 2, 1) # [n_heads, n_candidates, 2]
|
||||
res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output]
|
||||
res = self.conv(res.permute(0, 2, 1)).permute(
|
||||
0, 2, 1
|
||||
) # [n_heads, n_candidates, 2]
|
||||
|
||||
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf'))
|
||||
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float("-inf"))
|
||||
scores[rows, cols] = res[padding_mask]
|
||||
# Make sure that start <= head <= end during inference
|
||||
if not self.training:
|
||||
|
@ -597,8 +573,8 @@ class SpanPredictor(torch.nn.Module):
|
|||
return scores + valid_positions
|
||||
return scores
|
||||
|
||||
class DistancePairwiseEncoder(torch.nn.Module):
|
||||
|
||||
class DistancePairwiseEncoder(torch.nn.Module):
|
||||
def __init__(self, embedding_size, dropout_rate):
|
||||
super().__init__()
|
||||
emb_size = embedding_size
|
||||
|
@ -606,12 +582,12 @@ class DistancePairwiseEncoder(torch.nn.Module):
|
|||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
self.shape = emb_size
|
||||
|
||||
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
||||
top_indices: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
def forward(
|
||||
self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
||||
top_indices: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
word_ids = torch.arange(0, top_indices.size(0))
|
||||
distance = (word_ids.unsqueeze(1) - word_ids[top_indices]
|
||||
).clamp_min_(min=1)
|
||||
distance = (word_ids.unsqueeze(1) - word_ids[top_indices]).clamp_min_(min=1)
|
||||
log_distance = distance.to(torch.float).log2().floor_()
|
||||
log_distance = log_distance.clamp_max_(max=6).to(torch.long)
|
||||
distance = torch.where(distance < 5, distance - 1, log_distance + 2)
|
||||
|
|
Loading…
Reference in New Issue
Block a user