mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 04:02:20 +03:00
Split span predictor model into its own file
This commit is contained in:
parent
f852c5cea4
commit
41fc092674
|
@ -1,4 +1,5 @@
|
|||
from .coref import * #noqa
|
||||
from .span_predictor import * #noqa
|
||||
from .entity_linker import * # noqa
|
||||
from .multi_task import * # noqa
|
||||
from .parser import * # noqa
|
||||
|
|
|
@ -64,30 +64,6 @@ def build_wl_coref_model(
|
|||
return coref_model
|
||||
|
||||
|
||||
@registry.architectures("spacy.SpanPredictor.v1")
|
||||
def build_span_predictor(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||
hidden_size: int = 1024,
|
||||
dist_emb_size: int = 64,
|
||||
):
|
||||
# TODO fix this
|
||||
try:
|
||||
dim = tok2vec.get_dim("nO")
|
||||
except ValueError:
|
||||
# happens with transformer listener
|
||||
dim = 768
|
||||
|
||||
with Model.define_operators({">>": chain, "&": tuplify}):
|
||||
span_predictor = PyTorchWrapper(
|
||||
SpanPredictor(dim, hidden_size, dist_emb_size),
|
||||
convert_inputs=convert_span_predictor_inputs,
|
||||
)
|
||||
# TODO use proper parameter for prefix
|
||||
head_info = build_get_head_metadata("coref_head_clusters")
|
||||
model = (tok2vec & head_info) >> span_predictor
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def convert_coref_scorer_inputs(model: Model, X: List[Floats2d], is_train: bool):
|
||||
# The input here is List[Floats2d], one for each doc
|
||||
|
@ -120,61 +96,6 @@ def convert_coref_scorer_outputs(model: Model, inputs_outputs, is_train: bool):
|
|||
return (scores_xp, indices_xp), convert_for_torch_backward
|
||||
|
||||
|
||||
def convert_span_predictor_inputs(
|
||||
model: Model, X: Tuple[Ints1d, Floats2d, Ints1d], is_train: bool
|
||||
):
|
||||
tok2vec, (sent_ids, head_ids) = X
|
||||
# Normally we shoudl use the input is_train, but for these two it's not relevant
|
||||
|
||||
def backprop(args: ArgsKwargs) -> List[Floats2d]:
|
||||
# convert to xp and wrap in list
|
||||
gradients = torch2xp(args.args[1])
|
||||
return [[gradients], None]
|
||||
|
||||
word_features = xp2torch(tok2vec[0], requires_grad=is_train)
|
||||
sent_ids = xp2torch(sent_ids[0], requires_grad=False)
|
||||
if not head_ids[0].size:
|
||||
head_ids = torch.empty(size=(0,))
|
||||
else:
|
||||
head_ids = xp2torch(head_ids[0], requires_grad=False)
|
||||
|
||||
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
|
||||
# TODO actually support backprop
|
||||
return argskwargs, backprop
|
||||
|
||||
|
||||
# TODO This probably belongs in the component, not the model.
|
||||
def predict_span_clusters(
|
||||
span_predictor: Model, sent_ids: Ints1d, words: Floats2d, clusters: List[Ints1d]
|
||||
):
|
||||
"""
|
||||
Predicts span clusters based on the word clusters.
|
||||
|
||||
Args:
|
||||
doc (Doc): the document data
|
||||
words (torch.Tensor): [n_words, emb_size] matrix containing
|
||||
embeddings for each of the words in the text
|
||||
clusters (List[List[int]]): a list of clusters where each cluster
|
||||
is a list of word indices
|
||||
|
||||
Returns:
|
||||
List[List[Span]]: span clusters
|
||||
"""
|
||||
if not clusters:
|
||||
return []
|
||||
|
||||
xp = span_predictor.ops.xp
|
||||
heads_ids = xp.asarray(sorted(i for cluster in clusters for i in cluster))
|
||||
scores = span_predictor.predict((sent_ids, words, heads_ids))
|
||||
starts = scores[:, :, 0].argmax(axis=1).tolist()
|
||||
ends = (scores[:, :, 1].argmax(axis=1) + 1).tolist()
|
||||
|
||||
head2span = {
|
||||
head: (start, end) for head, start, end in zip(heads_ids.tolist(), starts, ends)
|
||||
}
|
||||
|
||||
return [[head2span[head] for head in cluster] for cluster in clusters]
|
||||
|
||||
|
||||
# TODO add docstring for this, maybe move to utils.
|
||||
# This might belong in the component.
|
||||
|
@ -205,36 +126,6 @@ def _clusterize(model, scores: Floats2d, top_indices: Ints2d):
|
|||
return sorted(clusters)
|
||||
|
||||
|
||||
def build_get_head_metadata(prefix):
|
||||
# TODO this name is awful, fix it
|
||||
model = Model(
|
||||
"HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def head_data_forward(model, docs, is_train):
|
||||
"""A layer to generate the extra data needed for the span predictor."""
|
||||
sent_ids = []
|
||||
head_ids = []
|
||||
prefix = model.attrs["prefix"]
|
||||
for doc in docs:
|
||||
sids = model.ops.asarray2i(get_sentence_ids(doc))
|
||||
sent_ids.append(sids)
|
||||
heads = []
|
||||
for key, sg in doc.spans.items():
|
||||
if not key.startswith(prefix):
|
||||
continue
|
||||
for span in sg:
|
||||
# TODO warn if spans are more than one token
|
||||
heads.append(span[0].i)
|
||||
heads = model.ops.asarray2i(heads)
|
||||
head_ids.append(heads)
|
||||
# each of these is a list with one entry per doc
|
||||
# backprop is just a placeholder
|
||||
# TODO it would probably be better to have a list of tuples than two lists of arrays
|
||||
return (sent_ids, head_ids), lambda x: []
|
||||
|
||||
|
||||
class CorefScorer(torch.nn.Module):
|
||||
"""Combines all coref modules together to find coreferent spans.
|
||||
|
@ -481,97 +372,6 @@ class RoughScorer(torch.nn.Module):
|
|||
return top_scores, indices
|
||||
|
||||
|
||||
class SpanPredictor(torch.nn.Module):
|
||||
def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int):
|
||||
super().__init__()
|
||||
# input size = single token size
|
||||
# 64 = probably distance emb size
|
||||
# TODO check that dist_emb_size use is correct
|
||||
self.ffnn = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_size * 2 + dist_emb_size, hidden_size),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Dropout(0.3),
|
||||
# TODO seems weird the 256 isn't a parameter???
|
||||
torch.nn.Linear(hidden_size, 256),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Dropout(0.3),
|
||||
# this use of dist_emb_size looks wrong but it was 64...?
|
||||
torch.nn.Linear(256, dist_emb_size),
|
||||
)
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(64, 4, 3, 1, 1), torch.nn.Conv1d(4, 2, 3, 1, 1)
|
||||
)
|
||||
self.emb = torch.nn.Embedding(128, dist_emb_size) # [-63, 63] + too_far
|
||||
|
||||
def forward(
|
||||
self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
||||
sent_id,
|
||||
words: torch.Tensor,
|
||||
heads_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculates span start/end scores of words for each span head in
|
||||
heads_ids
|
||||
|
||||
Args:
|
||||
doc (Doc): the document data
|
||||
words (torch.Tensor): contextual embeddings for each word in the
|
||||
document, [n_words, emb_size]
|
||||
heads_ids (torch.Tensor): word indices of span heads
|
||||
|
||||
Returns:
|
||||
torch.Tensor: span start/end scores, [n_heads, n_words, 2]
|
||||
"""
|
||||
# If we don't receive heads, return empty
|
||||
if heads_ids.nelement() == 0:
|
||||
return torch.empty(size=(0,))
|
||||
# Obtain distance embedding indices, [n_heads, n_words]
|
||||
relative_positions = heads_ids.unsqueeze(1) - torch.arange(
|
||||
words.shape[0]
|
||||
).unsqueeze(0)
|
||||
# make all valid distances positive
|
||||
emb_ids = relative_positions + 63
|
||||
# "too_far"
|
||||
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127
|
||||
# Obtain "same sentence" boolean mask, [n_heads, n_words]
|
||||
heads_ids = heads_ids.long()
|
||||
same_sent = sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0)
|
||||
# To save memory, only pass candidates from one sentence for each head
|
||||
# pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb
|
||||
# for each candidate among the words in the same sentence as span_head
|
||||
# [n_heads, input_size * 2 + distance_emb_size]
|
||||
rows, cols = same_sent.nonzero(as_tuple=True)
|
||||
pair_matrix = torch.cat(
|
||||
(
|
||||
words[heads_ids[rows]],
|
||||
words[cols],
|
||||
self.emb(emb_ids[rows, cols]),
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
lengths = same_sent.sum(dim=1)
|
||||
padding_mask = torch.arange(0, lengths.max().item()).unsqueeze(0)
|
||||
padding_mask = padding_mask < lengths.unsqueeze(1) # [n_heads, max_sent_len]
|
||||
# [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
|
||||
# This is necessary to allow the convolution layer to look at several
|
||||
# word scores
|
||||
padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1])
|
||||
padded_pairs[padding_mask] = pair_matrix
|
||||
|
||||
res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output]
|
||||
res = self.conv(res.permute(0, 2, 1)).permute(
|
||||
0, 2, 1
|
||||
) # [n_heads, n_candidates, 2]
|
||||
|
||||
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float("-inf"))
|
||||
scores[rows, cols] = res[padding_mask]
|
||||
# Make sure that start <= head <= end during inference
|
||||
if not self.training:
|
||||
valid_starts = torch.log((relative_positions >= 0).to(torch.float))
|
||||
valid_ends = torch.log((relative_positions <= 0).to(torch.float))
|
||||
valid_positions = torch.stack((valid_starts, valid_ends), dim=2)
|
||||
return scores + valid_positions
|
||||
return scores
|
||||
|
||||
|
||||
class DistancePairwiseEncoder(torch.nn.Module):
|
||||
|
|
215
spacy/ml/models/span_predictor.py
Normal file
215
spacy/ml/models/span_predictor.py
Normal file
|
@ -0,0 +1,215 @@
|
|||
from typing import List, Tuple
|
||||
import torch
|
||||
|
||||
from thinc.api import Model, chain, tuplify
|
||||
from thinc.api import PyTorchWrapper, ArgsKwargs
|
||||
from thinc.types import Floats2d, Ints1d, Ints2d
|
||||
from thinc.util import xp2torch, torch2xp
|
||||
|
||||
from ...tokens import Doc
|
||||
from ...util import registry
|
||||
from .coref_util import get_sentence_ids
|
||||
|
||||
@registry.architectures("spacy.SpanPredictor.v1")
|
||||
def build_span_predictor(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||
hidden_size: int = 1024,
|
||||
dist_emb_size: int = 64,
|
||||
):
|
||||
# TODO fix this
|
||||
try:
|
||||
dim = tok2vec.get_dim("nO")
|
||||
except ValueError:
|
||||
# happens with transformer listener
|
||||
dim = 768
|
||||
|
||||
with Model.define_operators({">>": chain, "&": tuplify}):
|
||||
span_predictor = PyTorchWrapper(
|
||||
SpanPredictor(dim, hidden_size, dist_emb_size),
|
||||
convert_inputs=convert_span_predictor_inputs,
|
||||
)
|
||||
# TODO use proper parameter for prefix
|
||||
head_info = build_get_head_metadata("coref_head_clusters")
|
||||
model = (tok2vec & head_info) >> span_predictor
|
||||
|
||||
return model
|
||||
|
||||
def convert_span_predictor_inputs(
|
||||
model: Model, X: Tuple[Ints1d, Floats2d, Ints1d], is_train: bool
|
||||
):
|
||||
tok2vec, (sent_ids, head_ids) = X
|
||||
# Normally we shoudl use the input is_train, but for these two it's not relevant
|
||||
|
||||
def backprop(args: ArgsKwargs) -> List[Floats2d]:
|
||||
# convert to xp and wrap in list
|
||||
gradients = torch2xp(args.args[1])
|
||||
return [[gradients], None]
|
||||
|
||||
word_features = xp2torch(tok2vec[0], requires_grad=is_train)
|
||||
sent_ids = xp2torch(sent_ids[0], requires_grad=False)
|
||||
if not head_ids[0].size:
|
||||
head_ids = torch.empty(size=(0,))
|
||||
else:
|
||||
head_ids = xp2torch(head_ids[0], requires_grad=False)
|
||||
|
||||
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
|
||||
# TODO actually support backprop
|
||||
return argskwargs, backprop
|
||||
|
||||
|
||||
# TODO This probably belongs in the component, not the model.
|
||||
def predict_span_clusters(
|
||||
span_predictor: Model, sent_ids: Ints1d, words: Floats2d, clusters: List[Ints1d]
|
||||
):
|
||||
"""
|
||||
Predicts span clusters based on the word clusters.
|
||||
|
||||
Args:
|
||||
doc (Doc): the document data
|
||||
words (torch.Tensor): [n_words, emb_size] matrix containing
|
||||
embeddings for each of the words in the text
|
||||
clusters (List[List[int]]): a list of clusters where each cluster
|
||||
is a list of word indices
|
||||
|
||||
Returns:
|
||||
List[List[Span]]: span clusters
|
||||
"""
|
||||
if not clusters:
|
||||
return []
|
||||
|
||||
xp = span_predictor.ops.xp
|
||||
heads_ids = xp.asarray(sorted(i for cluster in clusters for i in cluster))
|
||||
scores = span_predictor.predict((sent_ids, words, heads_ids))
|
||||
starts = scores[:, :, 0].argmax(axis=1).tolist()
|
||||
ends = (scores[:, :, 1].argmax(axis=1) + 1).tolist()
|
||||
|
||||
head2span = {
|
||||
head: (start, end) for head, start, end in zip(heads_ids.tolist(), starts, ends)
|
||||
}
|
||||
|
||||
return [[head2span[head] for head in cluster] for cluster in clusters]
|
||||
|
||||
# TODO this should maybe have a different name from the component
|
||||
class SpanPredictor(torch.nn.Module):
|
||||
def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int):
|
||||
super().__init__()
|
||||
# input size = single token size
|
||||
# 64 = probably distance emb size
|
||||
# TODO check that dist_emb_size use is correct
|
||||
self.ffnn = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_size * 2 + dist_emb_size, hidden_size),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Dropout(0.3),
|
||||
# TODO seems weird the 256 isn't a parameter???
|
||||
torch.nn.Linear(hidden_size, 256),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Dropout(0.3),
|
||||
# this use of dist_emb_size looks wrong but it was 64...?
|
||||
torch.nn.Linear(256, dist_emb_size),
|
||||
)
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(64, 4, 3, 1, 1), torch.nn.Conv1d(4, 2, 3, 1, 1)
|
||||
)
|
||||
self.emb = torch.nn.Embedding(128, dist_emb_size) # [-63, 63] + too_far
|
||||
|
||||
def forward(
|
||||
self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
||||
sent_id,
|
||||
words: torch.Tensor,
|
||||
heads_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculates span start/end scores of words for each span head in
|
||||
heads_ids
|
||||
|
||||
Args:
|
||||
doc (Doc): the document data
|
||||
words (torch.Tensor): contextual embeddings for each word in the
|
||||
document, [n_words, emb_size]
|
||||
heads_ids (torch.Tensor): word indices of span heads
|
||||
|
||||
Returns:
|
||||
torch.Tensor: span start/end scores, [n_heads, n_words, 2]
|
||||
"""
|
||||
# If we don't receive heads, return empty
|
||||
if heads_ids.nelement() == 0:
|
||||
return torch.empty(size=(0,))
|
||||
# Obtain distance embedding indices, [n_heads, n_words]
|
||||
relative_positions = heads_ids.unsqueeze(1) - torch.arange(
|
||||
words.shape[0]
|
||||
).unsqueeze(0)
|
||||
# make all valid distances positive
|
||||
emb_ids = relative_positions + 63
|
||||
# "too_far"
|
||||
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127
|
||||
# Obtain "same sentence" boolean mask, [n_heads, n_words]
|
||||
heads_ids = heads_ids.long()
|
||||
same_sent = sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0)
|
||||
# To save memory, only pass candidates from one sentence for each head
|
||||
# pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb
|
||||
# for each candidate among the words in the same sentence as span_head
|
||||
# [n_heads, input_size * 2 + distance_emb_size]
|
||||
rows, cols = same_sent.nonzero(as_tuple=True)
|
||||
pair_matrix = torch.cat(
|
||||
(
|
||||
words[heads_ids[rows]],
|
||||
words[cols],
|
||||
self.emb(emb_ids[rows, cols]),
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
lengths = same_sent.sum(dim=1)
|
||||
padding_mask = torch.arange(0, lengths.max().item()).unsqueeze(0)
|
||||
padding_mask = padding_mask < lengths.unsqueeze(1) # [n_heads, max_sent_len]
|
||||
# [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
|
||||
# This is necessary to allow the convolution layer to look at several
|
||||
# word scores
|
||||
padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1])
|
||||
padded_pairs[padding_mask] = pair_matrix
|
||||
|
||||
res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output]
|
||||
res = self.conv(res.permute(0, 2, 1)).permute(
|
||||
0, 2, 1
|
||||
) # [n_heads, n_candidates, 2]
|
||||
|
||||
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float("-inf"))
|
||||
scores[rows, cols] = res[padding_mask]
|
||||
# Make sure that start <= head <= end during inference
|
||||
if not self.training:
|
||||
valid_starts = torch.log((relative_positions >= 0).to(torch.float))
|
||||
valid_ends = torch.log((relative_positions <= 0).to(torch.float))
|
||||
valid_positions = torch.stack((valid_starts, valid_ends), dim=2)
|
||||
return scores + valid_positions
|
||||
return scores
|
||||
|
||||
|
||||
def build_get_head_metadata(prefix):
|
||||
# TODO this name is awful, fix it
|
||||
model = Model(
|
||||
"HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def head_data_forward(model, docs, is_train):
|
||||
"""A layer to generate the extra data needed for the span predictor."""
|
||||
sent_ids = []
|
||||
head_ids = []
|
||||
prefix = model.attrs["prefix"]
|
||||
for doc in docs:
|
||||
sids = model.ops.asarray2i(get_sentence_ids(doc))
|
||||
sent_ids.append(sids)
|
||||
heads = []
|
||||
for key, sg in doc.spans.items():
|
||||
if not key.startswith(prefix):
|
||||
continue
|
||||
for span in sg:
|
||||
# TODO warn if spans are more than one token
|
||||
heads.append(span[0].i)
|
||||
heads = model.ops.asarray2i(heads)
|
||||
head_ids.append(heads)
|
||||
# each of these is a list with one entry per doc
|
||||
# backprop is just a placeholder
|
||||
# TODO it would probably be better to have a list of tuples than two lists of arrays
|
||||
return (sent_ids, head_ids), lambda x: []
|
||||
|
Loading…
Reference in New Issue
Block a user