mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
commit
e38e84a677
|
@ -15,7 +15,11 @@ from .coref_util import get_sentence_ids
|
||||||
def build_span_predictor(
|
def build_span_predictor(
|
||||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||||
hidden_size: int = 1024,
|
hidden_size: int = 1024,
|
||||||
dist_emb_size: int = 64,
|
distance_embedding_size: int = 64,
|
||||||
|
conv_channels: int = 4,
|
||||||
|
window_size: int = 1,
|
||||||
|
max_distance: int = 128,
|
||||||
|
prefix: str = "coref_head_clusters"
|
||||||
):
|
):
|
||||||
# TODO add model return types
|
# TODO add model return types
|
||||||
# TODO fix this
|
# TODO fix this
|
||||||
|
@ -27,11 +31,18 @@ def build_span_predictor(
|
||||||
|
|
||||||
with Model.define_operators({">>": chain, "&": tuplify}):
|
with Model.define_operators({">>": chain, "&": tuplify}):
|
||||||
span_predictor = PyTorchWrapper(
|
span_predictor = PyTorchWrapper(
|
||||||
SpanPredictor(dim, hidden_size, dist_emb_size),
|
SpanPredictor(
|
||||||
|
dim,
|
||||||
|
hidden_size,
|
||||||
|
distance_embedding_size,
|
||||||
|
conv_channels,
|
||||||
|
window_size,
|
||||||
|
max_distance
|
||||||
|
),
|
||||||
convert_inputs=convert_span_predictor_inputs,
|
convert_inputs=convert_span_predictor_inputs,
|
||||||
)
|
)
|
||||||
# TODO use proper parameter for prefix
|
# TODO use proper parameter for prefix
|
||||||
head_info = build_get_head_metadata("coref_head_clusters")
|
head_info = build_get_head_metadata(prefix)
|
||||||
model = (tok2vec & head_info) >> span_predictor
|
model = (tok2vec & head_info) >> span_predictor
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -122,8 +133,21 @@ def head_data_forward(model, docs, is_train):
|
||||||
|
|
||||||
# TODO this should maybe have a different name from the component
|
# TODO this should maybe have a different name from the component
|
||||||
class SpanPredictor(torch.nn.Module):
|
class SpanPredictor(torch.nn.Module):
|
||||||
def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int):
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
dist_emb_size: int,
|
||||||
|
conv_channels: int,
|
||||||
|
window_size: int,
|
||||||
|
max_distance: int
|
||||||
|
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if max_distance % 2 != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"max_distance has to be an even number"
|
||||||
|
)
|
||||||
# input size = single token size
|
# input size = single token size
|
||||||
# 64 = probably distance emb size
|
# 64 = probably distance emb size
|
||||||
# TODO check that dist_emb_size use is correct
|
# TODO check that dist_emb_size use is correct
|
||||||
|
@ -138,12 +162,15 @@ class SpanPredictor(torch.nn.Module):
|
||||||
# this use of dist_emb_size looks wrong but it was 64...?
|
# this use of dist_emb_size looks wrong but it was 64...?
|
||||||
torch.nn.Linear(256, dist_emb_size),
|
torch.nn.Linear(256, dist_emb_size),
|
||||||
)
|
)
|
||||||
# TODO make the Convs also parametrizeable
|
kernel_size = window_size * 2 + 1
|
||||||
self.conv = torch.nn.Sequential(
|
self.conv = torch.nn.Sequential(
|
||||||
torch.nn.Conv1d(64, 4, 3, 1, 1), torch.nn.Conv1d(4, 2, 3, 1, 1)
|
torch.nn.Conv1d(dist_emb_size, conv_channels, kernel_size, 1, 1),
|
||||||
|
torch.nn.Conv1d(conv_channels, 2, kernel_size, 1, 1)
|
||||||
)
|
)
|
||||||
# TODO make embeddings size a parameter
|
# TODO make embeddings size a parameter
|
||||||
self.emb = torch.nn.Embedding(128, dist_emb_size) # [-63, 63] + too_far
|
self.max_distance = max_distance
|
||||||
|
# handle distances between +-(max_distance - 2 / 2)
|
||||||
|
self.emb = torch.nn.Embedding(max_distance, dist_emb_size)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -169,10 +196,11 @@ class SpanPredictor(torch.nn.Module):
|
||||||
relative_positions = heads_ids.unsqueeze(1) - torch.arange(
|
relative_positions = heads_ids.unsqueeze(1) - torch.arange(
|
||||||
words.shape[0]
|
words.shape[0]
|
||||||
).unsqueeze(0)
|
).unsqueeze(0)
|
||||||
|
md = self.max_distance
|
||||||
# make all valid distances positive
|
# make all valid distances positive
|
||||||
emb_ids = relative_positions + 63
|
emb_ids = relative_positions + (md - 2) // 2
|
||||||
# "too_far"
|
# "too_far"
|
||||||
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127
|
emb_ids[(emb_ids < 0) + (emb_ids > md - 2)] = md - 1
|
||||||
# Obtain "same sentence" boolean mask: (n_heads x n_words)
|
# Obtain "same sentence" boolean mask: (n_heads x n_words)
|
||||||
heads_ids = heads_ids.long()
|
heads_ids = heads_ids.long()
|
||||||
same_sent = sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0)
|
same_sent = sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0)
|
||||||
|
|
|
@ -25,7 +25,11 @@ default_span_predictor_config = """
|
||||||
[model]
|
[model]
|
||||||
@architectures = "spacy.SpanPredictor.v1"
|
@architectures = "spacy.SpanPredictor.v1"
|
||||||
hidden_size = 1024
|
hidden_size = 1024
|
||||||
dist_emb_size = 64
|
distance_embedding_size = 64
|
||||||
|
conv_channels = 4
|
||||||
|
window_size = 1
|
||||||
|
max_distance = 128
|
||||||
|
prefix = coref_head_clusters
|
||||||
|
|
||||||
[model.tok2vec]
|
[model.tok2vec]
|
||||||
@architectures = "spacy.Tok2Vec.v2"
|
@architectures = "spacy.Tok2Vec.v2"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user