mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
commit
57165f9631
|
@ -1,14 +1,14 @@
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from thinc.api import Model, chain, tuplify
|
from thinc.api import Model, chain
|
||||||
from thinc.api import PyTorchWrapper, ArgsKwargs
|
from thinc.api import PyTorchWrapper, ArgsKwargs
|
||||||
from thinc.types import Floats2d, Ints1d, Ints2d
|
from thinc.types import Floats2d, Ints2d, Ints1d
|
||||||
from thinc.util import xp2torch, torch2xp
|
from thinc.util import xp2torch, torch2xp
|
||||||
|
|
||||||
from ...tokens import Doc
|
from ...tokens import Doc
|
||||||
from ...util import registry
|
from ...util import registry
|
||||||
from .coref_util import add_dummy, get_sentence_ids
|
from .coref_util import add_dummy
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures("spacy.Coref.v1")
|
@registry.architectures("spacy.Coref.v1")
|
||||||
|
@ -19,7 +19,6 @@ def build_wl_coref_model(
|
||||||
n_hidden_layers: int = 1, # TODO rename to "depth"?
|
n_hidden_layers: int = 1, # TODO rename to "depth"?
|
||||||
dropout: float = 0.3,
|
dropout: float = 0.3,
|
||||||
# pairs to keep per mention after rough scoring
|
# pairs to keep per mention after rough scoring
|
||||||
# TODO change to meaningful name
|
|
||||||
rough_k: int = 50,
|
rough_k: int = 50,
|
||||||
# TODO is this not a training loop setting?
|
# TODO is this not a training loop setting?
|
||||||
a_scoring_batch_size: int = 512,
|
a_scoring_batch_size: int = 512,
|
||||||
|
@ -34,7 +33,6 @@ def build_wl_coref_model(
|
||||||
dim = 768
|
dim = 768
|
||||||
|
|
||||||
with Model.define_operators({">>": chain}):
|
with Model.define_operators({">>": chain}):
|
||||||
# TODO chain tok2vec with these models
|
|
||||||
coref_scorer = PyTorchWrapper(
|
coref_scorer = PyTorchWrapper(
|
||||||
CorefScorer(
|
CorefScorer(
|
||||||
dim,
|
dim,
|
||||||
|
@ -49,22 +47,14 @@ def build_wl_coref_model(
|
||||||
convert_outputs=convert_coref_scorer_outputs,
|
convert_outputs=convert_coref_scorer_outputs,
|
||||||
)
|
)
|
||||||
coref_model = tok2vec >> coref_scorer
|
coref_model = tok2vec >> coref_scorer
|
||||||
# XXX just ignore this until the coref scorer is integrated
|
|
||||||
# span_predictor = PyTorchWrapper(
|
|
||||||
# SpanPredictor(
|
|
||||||
# TODO this was hardcoded to 1024, check
|
|
||||||
# hidden_size,
|
|
||||||
# sp_embedding_size,
|
|
||||||
# ),
|
|
||||||
# convert_inputs=convert_span_predictor_inputs
|
|
||||||
# )
|
|
||||||
# TODO combine models so output is uniform (just one forward pass)
|
|
||||||
# It may be reasonable to have an option to disable span prediction,
|
|
||||||
# and just return words as spans.
|
|
||||||
return coref_model
|
return coref_model
|
||||||
|
|
||||||
|
|
||||||
def convert_coref_scorer_inputs(model: Model, X: List[Floats2d], is_train: bool):
|
def convert_coref_scorer_inputs(
|
||||||
|
model: Model,
|
||||||
|
X: List[Floats2d],
|
||||||
|
is_train: bool
|
||||||
|
):
|
||||||
# The input here is List[Floats2d], one for each doc
|
# The input here is List[Floats2d], one for each doc
|
||||||
# just use the first
|
# just use the first
|
||||||
# TODO real batching
|
# TODO real batching
|
||||||
|
@ -76,7 +66,7 @@ def convert_coref_scorer_inputs(model: Model, X: List[Floats2d], is_train: bool)
|
||||||
gradients = torch2xp(args.args[0])
|
gradients = torch2xp(args.args[0])
|
||||||
return [gradients]
|
return [gradients]
|
||||||
|
|
||||||
return ArgsKwargs(args=(word_features,), kwargs={}), backprop
|
return ArgsKwargs(args=(word_features, ), kwargs={}), backprop
|
||||||
|
|
||||||
|
|
||||||
def convert_coref_scorer_outputs(model: Model, inputs_outputs, is_train: bool):
|
def convert_coref_scorer_outputs(model: Model, inputs_outputs, is_train: bool):
|
||||||
|
@ -95,46 +85,13 @@ def convert_coref_scorer_outputs(model: Model, inputs_outputs, is_train: bool):
|
||||||
return (scores_xp, indices_xp), convert_for_torch_backward
|
return (scores_xp, indices_xp), convert_for_torch_backward
|
||||||
|
|
||||||
|
|
||||||
# TODO add docstring for this, maybe move to utils.
|
|
||||||
# This might belong in the component.
|
|
||||||
def _clusterize(model, scores: Floats2d, top_indices: Ints2d):
|
|
||||||
xp = model.ops.xp
|
|
||||||
antecedents = scores.argmax(axis=1) - 1
|
|
||||||
not_dummy = antecedents >= 0
|
|
||||||
coref_span_heads = xp.arange(0, len(scores))[not_dummy]
|
|
||||||
antecedents = top_indices[coref_span_heads, antecedents[not_dummy]]
|
|
||||||
n_words = scores.shape[0]
|
|
||||||
nodes = [GraphNode(i) for i in range(n_words)]
|
|
||||||
for i, j in zip(coref_span_heads.tolist(), antecedents.tolist()):
|
|
||||||
nodes[i].link(nodes[j])
|
|
||||||
assert nodes[i] is not nodes[j]
|
|
||||||
|
|
||||||
clusters = []
|
|
||||||
for node in nodes:
|
|
||||||
if len(node.links) > 0 and not node.visited:
|
|
||||||
cluster = []
|
|
||||||
stack = [node]
|
|
||||||
while stack:
|
|
||||||
current_node = stack.pop()
|
|
||||||
current_node.visited = True
|
|
||||||
cluster.append(current_node.id)
|
|
||||||
stack.extend(link for link in current_node.links if not link.visited)
|
|
||||||
assert len(cluster) > 1
|
|
||||||
clusters.append(sorted(cluster))
|
|
||||||
return sorted(clusters)
|
|
||||||
|
|
||||||
|
|
||||||
class CorefScorer(torch.nn.Module):
|
class CorefScorer(torch.nn.Module):
|
||||||
"""Combines all coref modules together to find coreferent spans.
|
"""
|
||||||
|
Combines all coref modules together to find coreferent token pairs.
|
||||||
Attributes:
|
|
||||||
epochs_trained (int): number of epochs the model has been trained for
|
|
||||||
|
|
||||||
Submodules (in the order of their usage in the pipeline):
|
Submodules (in the order of their usage in the pipeline):
|
||||||
rough_scorer (RoughScorer)
|
- rough_scorer (RoughScorer) that prunes candidate pairs
|
||||||
pw (PairwiseEncoder)
|
- pw (DistancePairwiseEncoder) that computes pairwise features
|
||||||
a_scorer (AnaphoricityScorer)
|
- a_scorer (AnaphoricityScorer) produces the final scores
|
||||||
sp (SpanPredictor)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -149,50 +106,64 @@ class CorefScorer(torch.nn.Module):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
"""
|
"""
|
||||||
A newly created model is set to evaluation mode.
|
dim: Size of the input features.
|
||||||
|
dist_emb_size: Size of the distance embeddings.
|
||||||
Args:
|
hidden_size: Size of the coreference candidate embeddings.
|
||||||
epochs_trained (int): the number of epochs finished
|
n_layers: Numbers of layers in the AnaphoricityScorer.
|
||||||
(useful for warm start)
|
dropout_rate: Dropout probability to apply across all modules.
|
||||||
|
roughk: Number of candidates the RoughScorer returns.
|
||||||
|
batch_size: Internal batch-size for the more expensive scorer.
|
||||||
"""
|
"""
|
||||||
|
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||||
|
self.batch_size = batch_size
|
||||||
|
# Modules
|
||||||
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate)
|
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate)
|
||||||
# TODO clean this up
|
pair_emb = dim * 3 + self.pw.shape
|
||||||
bert_emb = dim
|
self.a_scorer = AnaphoricityScorer(
|
||||||
pair_emb = bert_emb * 3 + self.pw.shape
|
pair_emb,
|
||||||
|
hidden_size,
|
||||||
|
n_layers,
|
||||||
|
dropout_rate
|
||||||
|
)
|
||||||
|
self.lstm = torch.nn.LSTM(
|
||||||
|
input_size=dim,
|
||||||
|
hidden_size=dim,
|
||||||
|
batch_first=True,
|
||||||
|
)
|
||||||
|
self.rough_scorer = RoughScorer(dim, dropout_rate, roughk)
|
||||||
|
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate)
|
||||||
|
pair_emb = dim * 3 + self.pw.shape
|
||||||
self.a_scorer = AnaphoricityScorer(
|
self.a_scorer = AnaphoricityScorer(
|
||||||
pair_emb, hidden_size, n_layers, dropout_rate
|
pair_emb, hidden_size, n_layers, dropout_rate
|
||||||
)
|
)
|
||||||
self.lstm = torch.nn.LSTM(
|
|
||||||
input_size=bert_emb,
|
|
||||||
hidden_size=bert_emb,
|
|
||||||
batch_first=True,
|
|
||||||
)
|
|
||||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
|
||||||
self.rough_scorer = RoughScorer(bert_emb, dropout_rate, roughk)
|
|
||||||
self.batch_size = batch_size
|
|
||||||
|
|
||||||
def forward(self, word_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
def forward(
|
||||||
|
self, word_features: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
This is a massive method, but it made sense to me to not split it into
|
1. LSTM encodes the incoming word_features.
|
||||||
several ones to let one see the data flow.
|
2. The RoughScorer scores and prunes the candidates.
|
||||||
|
3. The DistancePairwiseEncoder embeds the distances between pairs.
|
||||||
|
4. The AnaphoricityScorer scores all pairs in mini-batches.
|
||||||
|
|
||||||
Args:
|
word_features: torch.Tensor containing word encodings
|
||||||
word_features: torch.Tensor containing word encodings
|
|
||||||
Returns:
|
returns:
|
||||||
coreference scores and top indices
|
coref_scores: n_words x roughk floats.
|
||||||
|
top_indices: n_words x roughk integers.
|
||||||
"""
|
"""
|
||||||
# words [n_words, span_emb]
|
|
||||||
# cluster_ids [n_words]
|
|
||||||
self.lstm.flatten_parameters() # XXX without this there's a warning
|
self.lstm.flatten_parameters() # XXX without this there's a warning
|
||||||
word_features = torch.unsqueeze(word_features, dim=0)
|
word_features = torch.unsqueeze(word_features, dim=0)
|
||||||
words, _ = self.lstm(word_features)
|
words, _ = self.lstm(word_features)
|
||||||
words = words.squeeze()
|
words = words.squeeze()
|
||||||
|
# words: n_words x dim
|
||||||
words = self.dropout(words)
|
words = self.dropout(words)
|
||||||
# Obtain bilinear scores and leave only top-k antecedents for each word
|
# Obtain bilinear scores and leave only top-k antecedents for each word
|
||||||
# top_rough_scores [n_words, n_ants]
|
# top_rough_scores: (n_words x roughk)
|
||||||
# top_indices [n_words, n_ants]
|
# top_indices: (n_words x roughk)
|
||||||
top_rough_scores, top_indices = self.rough_scorer(words)
|
top_rough_scores, top_indices = self.rough_scorer(words)
|
||||||
# Get pairwise features [n_words, n_ants, n_pw_features]
|
# Get pairwise features
|
||||||
|
# (n_words x roughk x n_pw_features)
|
||||||
pw = self.pw(top_indices)
|
pw = self.pw(top_indices)
|
||||||
batch_size = self.batch_size
|
batch_size = self.batch_size
|
||||||
a_scores_lst: List[torch.Tensor] = []
|
a_scores_lst: List[torch.Tensor] = []
|
||||||
|
@ -272,13 +243,8 @@ class AnaphoricityScorer(torch.nn.Module):
|
||||||
|
|
||||||
def _ffnn(self, x: torch.Tensor) -> torch.Tensor:
|
def _ffnn(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Calculates anaphoricity scores.
|
x: tensor of shape (batch_size x roughk x n_features
|
||||||
|
returns: tensor of shape (batch_size x rough_k)
|
||||||
Args:
|
|
||||||
x: tensor of shape [batch_size, n_ants, n_features]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tensor of shape [batch_size, n_ants]
|
|
||||||
"""
|
"""
|
||||||
x = self.out(self.hidden(x))
|
x = self.out(self.hidden(x))
|
||||||
return x.squeeze(2)
|
return x.squeeze(2)
|
||||||
|
@ -293,21 +259,18 @@ class AnaphoricityScorer(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Builds the matrix used as input for AnaphoricityScorer.
|
Builds the matrix used as input for AnaphoricityScorer.
|
||||||
|
|
||||||
Args:
|
all_mentions: (n_mentions x mention_emb),
|
||||||
all_mentions (torch.Tensor): [n_mentions, mention_emb],
|
all the valid mentions of the document,
|
||||||
all the valid mentions of the document,
|
can be on a different device
|
||||||
can be on a different device
|
mentions_batch: (batch_size x mention_emb),
|
||||||
mentions_batch (torch.Tensor): [batch_size, mention_emb],
|
the mentions of the current batch.
|
||||||
the mentions of the current batch,
|
pw_batch: (batch_size x roughk x pw_emb),
|
||||||
is expected to be on the current device
|
pairwise distance features of the current batch.
|
||||||
pw_batch (torch.Tensor): [batch_size, n_ants, pw_emb],
|
top_indices_batch: (batch_size x n_ants),
|
||||||
pairwise features of the current batch,
|
indices of antecedents of each mention
|
||||||
is expected to be on the current device
|
|
||||||
top_indices_batch (torch.Tensor): [batch_size, n_ants],
|
|
||||||
indices of antecedents of each mention
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: [batch_size, n_ants, pair_emb]
|
out: pairwise features (batch_size x n_ants x pair_emb)
|
||||||
"""
|
"""
|
||||||
emb_size = mentions_batch.shape[1]
|
emb_size = mentions_batch.shape[1]
|
||||||
n_ants = pw_batch.shape[1]
|
n_ants = pw_batch.shape[1]
|
||||||
|
@ -322,16 +285,15 @@ class AnaphoricityScorer(torch.nn.Module):
|
||||||
|
|
||||||
class RoughScorer(torch.nn.Module):
|
class RoughScorer(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Is needed to give a roughly estimate of the anaphoricity of two candidates,
|
Cheaper module that gives a rough estimate of the anaphoricity of two
|
||||||
only top scoring candidates are considered on later steps to reduce
|
candidates, only top scoring candidates are considered on later
|
||||||
computational complexity.
|
steps to reduce computational cost.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, features: int, dropout_rate: float, rough_k: float):
|
def __init__(self, features: int, dropout_rate: float, rough_k: float):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||||
self.bilinear = torch.nn.Linear(features, features)
|
self.bilinear = torch.nn.Linear(features, features)
|
||||||
|
|
||||||
self.k = rough_k
|
self.k = rough_k
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -348,29 +310,27 @@ class RoughScorer(torch.nn.Module):
|
||||||
pair_mask = torch.log((pair_mask > 0).to(torch.float))
|
pair_mask = torch.log((pair_mask > 0).to(torch.float))
|
||||||
bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)
|
bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)
|
||||||
rough_scores = pair_mask + bilinear_scores
|
rough_scores = pair_mask + bilinear_scores
|
||||||
|
|
||||||
return self._prune(rough_scores)
|
|
||||||
|
|
||||||
def _prune(self, rough_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Selects top-k rough antecedent scores for each mention.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rough_scores: tensor of shape [n_mentions, n_mentions], containing
|
|
||||||
rough antecedent scores of each mention-antecedent pair.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
FloatTensor of shape [n_mentions, k], top rough scores
|
|
||||||
LongTensor of shape [n_mentions, k], top indices
|
|
||||||
"""
|
|
||||||
top_scores, indices = torch.topk(
|
top_scores, indices = torch.topk(
|
||||||
rough_scores, k=min(self.k, len(rough_scores)), dim=1, sorted=False
|
rough_scores, k=min(self.k, len(rough_scores)), dim=1, sorted=False
|
||||||
)
|
)
|
||||||
|
|
||||||
return top_scores, indices
|
return top_scores, indices
|
||||||
|
|
||||||
|
|
||||||
class DistancePairwiseEncoder(torch.nn.Module):
|
class DistancePairwiseEncoder(torch.nn.Module):
|
||||||
def __init__(self, embedding_size, dropout_rate):
|
def __init__(self, embedding_size, dropout_rate):
|
||||||
|
"""
|
||||||
|
Takes the top_indices indicating, which is a ranked
|
||||||
|
list for each word and its most likely corresponding
|
||||||
|
anaphora candidates. For each of these pairs it looks
|
||||||
|
up a distance embedding from a table, where the distance
|
||||||
|
corresponds to the log-distance.
|
||||||
|
|
||||||
|
embedding_size: int,
|
||||||
|
Dimensionality of the distance-embeddings table.
|
||||||
|
dropout_rate: float,
|
||||||
|
Dropout probability.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
emb_size = embedding_size
|
emb_size = embedding_size
|
||||||
self.distance_emb = torch.nn.Embedding(9, emb_size)
|
self.distance_emb = torch.nn.Embedding(9, emb_size)
|
||||||
|
@ -378,11 +338,12 @@ class DistancePairwiseEncoder(torch.nn.Module):
|
||||||
self.shape = emb_size
|
self.shape = emb_size
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
self,
|
||||||
top_indices: torch.Tensor,
|
top_indices: torch.Tensor
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
word_ids = torch.arange(0, top_indices.size(0))
|
word_ids = torch.arange(0, top_indices.size(0))
|
||||||
distance = (word_ids.unsqueeze(1) - word_ids[top_indices]).clamp_min_(min=1)
|
distance = (word_ids.unsqueeze(1) - word_ids[top_indices]
|
||||||
|
).clamp_min_(min=1)
|
||||||
log_distance = distance.to(torch.float).log2().floor_()
|
log_distance = distance.to(torch.float).log2().floor_()
|
||||||
log_distance = log_distance.clamp_max_(max=6).to(torch.long)
|
log_distance = log_distance.clamp_max_(max=6).to(torch.long)
|
||||||
distance = torch.where(distance < 5, distance - 1, log_distance + 2)
|
distance = torch.where(distance < 5, distance - 1, log_distance + 2)
|
||||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
||||||
|
|
||||||
from thinc.api import Model, chain, tuplify
|
from thinc.api import Model, chain, tuplify
|
||||||
from thinc.api import PyTorchWrapper, ArgsKwargs
|
from thinc.api import PyTorchWrapper, ArgsKwargs
|
||||||
from thinc.types import Floats2d, Ints1d, Ints2d
|
from thinc.types import Floats2d, Ints1d
|
||||||
from thinc.util import xp2torch, torch2xp
|
from thinc.util import xp2torch, torch2xp
|
||||||
|
|
||||||
from ...tokens import Doc
|
from ...tokens import Doc
|
||||||
|
@ -40,10 +40,9 @@ def convert_span_predictor_inputs(
|
||||||
model: Model, X: Tuple[Ints1d, Floats2d, Ints1d], is_train: bool
|
model: Model, X: Tuple[Ints1d, Floats2d, Ints1d], is_train: bool
|
||||||
):
|
):
|
||||||
tok2vec, (sent_ids, head_ids) = X
|
tok2vec, (sent_ids, head_ids) = X
|
||||||
# Normally we shoudl use the input is_train, but for these two it's not relevant
|
# Normally we should use the input is_train, but for these two it's not relevant
|
||||||
|
|
||||||
def backprop(args: ArgsKwargs) -> List[Floats2d]:
|
def backprop(args: ArgsKwargs) -> List[Floats2d]:
|
||||||
# convert to xp and wrap in list
|
|
||||||
gradients = torch2xp(args.args[1])
|
gradients = torch2xp(args.args[1])
|
||||||
return [[gradients], None]
|
return [[gradients], None]
|
||||||
|
|
||||||
|
@ -55,7 +54,6 @@ def convert_span_predictor_inputs(
|
||||||
head_ids = xp2torch(head_ids[0], requires_grad=False)
|
head_ids = xp2torch(head_ids[0], requires_grad=False)
|
||||||
|
|
||||||
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
|
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
|
||||||
# TODO actually support backprop
|
|
||||||
return argskwargs, backprop
|
return argskwargs, backprop
|
||||||
|
|
||||||
|
|
||||||
|
@ -66,15 +64,13 @@ def predict_span_clusters(
|
||||||
"""
|
"""
|
||||||
Predicts span clusters based on the word clusters.
|
Predicts span clusters based on the word clusters.
|
||||||
|
|
||||||
Args:
|
span_predictor: a SpanPredictor instance
|
||||||
doc (Doc): the document data
|
sent_ids: For each word indicates, which sentence it appears in.
|
||||||
words (torch.Tensor): [n_words, emb_size] matrix containing
|
words: Features for words.
|
||||||
embeddings for each of the words in the text
|
clusters: Clusters inferred by the CorefScorer.
|
||||||
clusters (List[List[int]]): a list of clusters where each cluster
|
|
||||||
is a list of word indices
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[List[Span]]: span clusters
|
List[List[Tuple[int, int]]: span clusters
|
||||||
"""
|
"""
|
||||||
if not clusters:
|
if not clusters:
|
||||||
return []
|
return []
|
||||||
|
@ -141,29 +137,29 @@ class SpanPredictor(torch.nn.Module):
|
||||||
# this use of dist_emb_size looks wrong but it was 64...?
|
# this use of dist_emb_size looks wrong but it was 64...?
|
||||||
torch.nn.Linear(256, dist_emb_size),
|
torch.nn.Linear(256, dist_emb_size),
|
||||||
)
|
)
|
||||||
|
# TODO make the Convs also parametrizeable
|
||||||
self.conv = torch.nn.Sequential(
|
self.conv = torch.nn.Sequential(
|
||||||
torch.nn.Conv1d(64, 4, 3, 1, 1), torch.nn.Conv1d(4, 2, 3, 1, 1)
|
torch.nn.Conv1d(64, 4, 3, 1, 1), torch.nn.Conv1d(4, 2, 3, 1, 1)
|
||||||
)
|
)
|
||||||
|
# TODO make embeddings size a parameter
|
||||||
self.emb = torch.nn.Embedding(128, dist_emb_size) # [-63, 63] + too_far
|
self.emb = torch.nn.Embedding(128, dist_emb_size) # [-63, 63] + too_far
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
self,
|
||||||
sent_id,
|
sent_id,
|
||||||
words: torch.Tensor,
|
words: torch.Tensor,
|
||||||
heads_ids: torch.Tensor,
|
heads_ids: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Calculates span start/end scores of words for each span head in
|
Calculates span start/end scores of words for each span
|
||||||
heads_ids
|
for each head.
|
||||||
|
|
||||||
Args:
|
sent_id: Sentence id of each word.
|
||||||
doc (Doc): the document data
|
words: features for each word in the document.
|
||||||
words (torch.Tensor): contextual embeddings for each word in the
|
heads_ids: word indices of span heads
|
||||||
document, [n_words, emb_size]
|
|
||||||
heads_ids (torch.Tensor): word indices of span heads
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: span start/end scores, [n_heads, n_words, 2]
|
torch.Tensor: span start/end scores, (n_heads x n_words x 2)
|
||||||
"""
|
"""
|
||||||
# If we don't receive heads, return empty
|
# If we don't receive heads, return empty
|
||||||
if heads_ids.nelement() == 0:
|
if heads_ids.nelement() == 0:
|
||||||
|
@ -176,13 +172,13 @@ class SpanPredictor(torch.nn.Module):
|
||||||
emb_ids = relative_positions + 63
|
emb_ids = relative_positions + 63
|
||||||
# "too_far"
|
# "too_far"
|
||||||
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127
|
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127
|
||||||
# Obtain "same sentence" boolean mask, [n_heads, n_words]
|
# Obtain "same sentence" boolean mask: (n_heads x n_words)
|
||||||
heads_ids = heads_ids.long()
|
heads_ids = heads_ids.long()
|
||||||
same_sent = sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0)
|
same_sent = sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0)
|
||||||
# To save memory, only pass candidates from one sentence for each head
|
# To save memory, only pass candidates from one sentence for each head
|
||||||
# pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb
|
# pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb
|
||||||
# for each candidate among the words in the same sentence as span_head
|
# for each candidate among the words in the same sentence as span_head
|
||||||
# [n_heads, input_size * 2 + distance_emb_size]
|
# (n_heads x input_size * 2 x distance_emb_size)
|
||||||
rows, cols = same_sent.nonzero(as_tuple=True)
|
rows, cols = same_sent.nonzero(as_tuple=True)
|
||||||
pair_matrix = torch.cat(
|
pair_matrix = torch.cat(
|
||||||
(
|
(
|
||||||
|
@ -194,17 +190,17 @@ class SpanPredictor(torch.nn.Module):
|
||||||
)
|
)
|
||||||
lengths = same_sent.sum(dim=1)
|
lengths = same_sent.sum(dim=1)
|
||||||
padding_mask = torch.arange(0, lengths.max().item()).unsqueeze(0)
|
padding_mask = torch.arange(0, lengths.max().item()).unsqueeze(0)
|
||||||
padding_mask = padding_mask < lengths.unsqueeze(1) # [n_heads, max_sent_len]
|
# (n_heads x max_sent_len)
|
||||||
# [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
|
padding_mask = padding_mask < lengths.unsqueeze(1)
|
||||||
|
# (n_heads x max_sent_len x input_size * 2 + distance_emb_size)
|
||||||
# This is necessary to allow the convolution layer to look at several
|
# This is necessary to allow the convolution layer to look at several
|
||||||
# word scores
|
# word scores
|
||||||
padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1])
|
padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1])
|
||||||
padded_pairs[padding_mask] = pair_matrix
|
padded_pairs[padding_mask] = pair_matrix
|
||||||
|
res = self.ffnn(padded_pairs) # (n_heads x n_candidates x last_layer_output)
|
||||||
res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output]
|
|
||||||
res = self.conv(res.permute(0, 2, 1)).permute(
|
res = self.conv(res.permute(0, 2, 1)).permute(
|
||||||
0, 2, 1
|
0, 2, 1
|
||||||
) # [n_heads, n_candidates, 2]
|
) # (n_heads x n_candidates, 2)
|
||||||
|
|
||||||
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float("-inf"))
|
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float("-inf"))
|
||||||
scores[rows, cols] = res[padding_mask]
|
scores[rows, cols] = res[padding_mask]
|
||||||
|
|
|
@ -350,9 +350,7 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
|
|
||||||
def score(self, examples, **kwargs):
|
def score(self, examples, **kwargs):
|
||||||
"""Score a batch of examples using LEA.
|
"""Score a batch of examples using LEA.
|
||||||
|
|
||||||
For details on how LEA works and why to use it see the paper:
|
For details on how LEA works and why to use it see the paper:
|
||||||
|
|
||||||
Which Coreference Evaluation Metric Do You Trust? A Proposal for a Link-based Entity Aware Metric
|
Which Coreference Evaluation Metric Do You Trust? A Proposal for a Link-based Entity Aware Metric
|
||||||
Moosavi and Strube, 2016
|
Moosavi and Strube, 2016
|
||||||
https://api.semanticscholar.org/CorpusID:17606580
|
https://api.semanticscholar.org/CorpusID:17606580
|
||||||
|
|
Loading…
Reference in New Issue
Block a user