mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
gearing up SpanPredictor for gold-heads
This commit is contained in:
parent
150e7c46d7
commit
706b2e6f25
|
@ -53,7 +53,6 @@ def build_wl_coref_model(
|
||||||
convert_inputs=convert_coref_scorer_inputs,
|
convert_inputs=convert_coref_scorer_inputs,
|
||||||
convert_outputs=convert_coref_scorer_outputs
|
convert_outputs=convert_coref_scorer_outputs
|
||||||
)
|
)
|
||||||
|
|
||||||
coref_model = tok2vec >> coref_scorer
|
coref_model = tok2vec >> coref_scorer
|
||||||
# XXX just ignore this until the coref scorer is integrated
|
# XXX just ignore this until the coref scorer is integrated
|
||||||
span_predictor = PyTorchWrapper(
|
span_predictor = PyTorchWrapper(
|
||||||
|
@ -62,7 +61,6 @@ def build_wl_coref_model(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
sp_embedding_size,
|
sp_embedding_size,
|
||||||
),
|
),
|
||||||
|
|
||||||
convert_inputs=convert_span_predictor_inputs
|
convert_inputs=convert_span_predictor_inputs
|
||||||
)
|
)
|
||||||
# TODO combine models so output is uniform (just one forward pass)
|
# TODO combine models so output is uniform (just one forward pass)
|
||||||
|
@ -84,14 +82,15 @@ def build_span_predictor(
|
||||||
dim = 768
|
dim = 768
|
||||||
|
|
||||||
with Model.define_operators({">>": chain, "&": tuplify}):
|
with Model.define_operators({">>": chain, "&": tuplify}):
|
||||||
# TODO fix device - should be automatic
|
|
||||||
device = "cuda:0"
|
|
||||||
span_predictor = PyTorchWrapper(
|
span_predictor = PyTorchWrapper(
|
||||||
SpanPredictor(dim, dist_emb_size, device),
|
SpanPredictor(dim, dist_emb_size),
|
||||||
convert_inputs=convert_span_predictor_inputs
|
convert_inputs=convert_span_predictor_inputs
|
||||||
)
|
)
|
||||||
# TODO use proper parameter for prefix
|
# TODO use proper parameter for prefix
|
||||||
head_info = build_get_head_metadata("coref_head_clusters")
|
head_info = build_get_head_metadata(
|
||||||
|
"span_coref_head_clusters",
|
||||||
|
"coref_head_clusters"
|
||||||
|
)
|
||||||
model = (tok2vec & head_info) >> span_predictor
|
model = (tok2vec & head_info) >> span_predictor
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -148,7 +147,7 @@ def convert_span_predictor_inputs(
|
||||||
|
|
||||||
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
|
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
|
||||||
# TODO actually support backprop
|
# TODO actually support backprop
|
||||||
return argskwargs, lambda dX: []
|
return argskwargs, lambda dX: [[]]
|
||||||
|
|
||||||
# TODO This probably belongs in the component, not the model.
|
# TODO This probably belongs in the component, not the model.
|
||||||
def predict_span_clusters(span_predictor: Model,
|
def predict_span_clusters(span_predictor: Model,
|
||||||
|
@ -217,18 +216,27 @@ def _clusterize(
|
||||||
clusters.append(sorted(cluster))
|
clusters.append(sorted(cluster))
|
||||||
return sorted(clusters)
|
return sorted(clusters)
|
||||||
|
|
||||||
def build_get_head_metadata(prefix):
|
|
||||||
|
def build_get_head_metadata(update_prefix, predict_prefix):
|
||||||
# TODO this name is awful, fix it
|
# TODO this name is awful, fix it
|
||||||
model = Model("HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward)
|
model = Model("HeadDataProvider",
|
||||||
|
attrs={
|
||||||
|
"update_prefix": update_prefix,
|
||||||
|
"predict_prefix": predict_prefix
|
||||||
|
},
|
||||||
|
forward=head_data_forward)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def head_data_forward(model, docs, is_train):
|
def head_data_forward(model, docs, is_train):
|
||||||
"""A layer to generate the extra data needed for the span predictor.
|
"""A layer to generate the extra data needed for the span predictor.
|
||||||
"""
|
"""
|
||||||
sent_ids = []
|
sent_ids = []
|
||||||
head_ids = []
|
head_ids = []
|
||||||
prefix = model.attrs["prefix"]
|
if is_train:
|
||||||
|
prefix = model.attrs["update_prefix"]
|
||||||
|
else:
|
||||||
|
prefix = model.attrs["predict_prefix"]
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
sids = model.ops.asarray2i(get_sentence_ids(doc))
|
sids = model.ops.asarray2i(get_sentence_ids(doc))
|
||||||
sent_ids.append(sids)
|
sent_ids.append(sids)
|
||||||
|
@ -557,11 +565,9 @@ class SpanPredictor(torch.nn.Module):
|
||||||
words[cols],
|
words[cols],
|
||||||
self.emb(emb_ids[rows, cols]),
|
self.emb(emb_ids[rows, cols]),
|
||||||
), dim=1)
|
), dim=1)
|
||||||
input(len(heads_ids))
|
|
||||||
lengths = same_sent.sum(dim=1)
|
lengths = same_sent.sum(dim=1)
|
||||||
padding_mask = torch.arange(0, lengths.max().item(), device=words.device).unsqueeze(0)
|
padding_mask = torch.arange(0, lengths.max().item(), device=words.device).unsqueeze(0)
|
||||||
padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len]
|
padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len]
|
||||||
input(padding_mask.shape)
|
|
||||||
# [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
|
# [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
|
||||||
# This is necessary to allow the convolution layer to look at several
|
# This is necessary to allow the convolution layer to look at several
|
||||||
# word scores
|
# word scores
|
||||||
|
|
|
@ -417,6 +417,7 @@ DEFAULT_SPAN_PREDICTOR_MODEL = Config().from_str(default_span_predictor_config)[
|
||||||
default_config={
|
default_config={
|
||||||
"model": DEFAULT_SPAN_PREDICTOR_MODEL,
|
"model": DEFAULT_SPAN_PREDICTOR_MODEL,
|
||||||
"input_prefix": "coref_head_clusters",
|
"input_prefix": "coref_head_clusters",
|
||||||
|
"target_prefix": "span_head_target_clusters",
|
||||||
"output_prefix": "coref_clusters",
|
"output_prefix": "coref_clusters",
|
||||||
},
|
},
|
||||||
default_score_weights={"span_predictor_f": 1.0, "span_predictor_p": None, "span_predictor_r": None},
|
default_score_weights={"span_predictor_f": 1.0, "span_predictor_p": None, "span_predictor_r": None},
|
||||||
|
@ -426,6 +427,7 @@ def make_span_predictor(
|
||||||
name: str,
|
name: str,
|
||||||
model,
|
model,
|
||||||
input_prefix: str = "coref_head_clusters",
|
input_prefix: str = "coref_head_clusters",
|
||||||
|
target_prefix: str = "span_head_target_clusters",
|
||||||
output_prefix: str = "coref_clusters",
|
output_prefix: str = "coref_clusters",
|
||||||
) -> "SpanPredictor":
|
) -> "SpanPredictor":
|
||||||
"""Create a SpanPredictor component."""
|
"""Create a SpanPredictor component."""
|
||||||
|
@ -444,12 +446,14 @@ class SpanPredictor(TrainablePipe):
|
||||||
name: str = "span_predictor",
|
name: str = "span_predictor",
|
||||||
*,
|
*,
|
||||||
input_prefix: str = "coref_head_clusters",
|
input_prefix: str = "coref_head_clusters",
|
||||||
|
target_prefix: str = "span_coref_head_clusters",
|
||||||
output_prefix: str = "coref_clusters",
|
output_prefix: str = "coref_clusters",
|
||||||
) -> None:
|
) -> None:
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.model = model
|
self.model = model
|
||||||
self.name = name
|
self.name = name
|
||||||
self.input_prefix = input_prefix
|
self.input_prefix = input_prefix
|
||||||
|
self.target_prefix = target_prefix
|
||||||
self.output_prefix = output_prefix
|
self.output_prefix = output_prefix
|
||||||
|
|
||||||
self.cfg = {}
|
self.cfg = {}
|
||||||
|
@ -511,13 +515,18 @@ class SpanPredictor(TrainablePipe):
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
|
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
|
docs = [eg.predicted for eg in examples]
|
||||||
for eg in examples:
|
for doc, eg in zip(docs, examples):
|
||||||
span_scores, backprop = self.model.begin_update([eg.predicted])
|
# replicates the EntityLinker's behaviour and
|
||||||
|
# copies annotations over https://bit.ly/3iweDcW
|
||||||
|
for key, sg in eg.reference.spans.items():
|
||||||
|
if key.startswith(self.target_prefix):
|
||||||
|
doc.spans[key] = [doc[span.start:span.end] for span in sg]
|
||||||
|
span_scores, backprop = self.model.begin_update([doc])
|
||||||
loss, d_scores = self.get_loss([eg], span_scores)
|
loss, d_scores = self.get_loss([eg], span_scores)
|
||||||
total_loss += loss
|
total_loss += loss
|
||||||
# TODO check shape here
|
# TODO check shape here
|
||||||
backprop((d_scores, mention_idx))
|
backprop(d_scores)
|
||||||
|
|
||||||
if sgd is not None:
|
if sgd is not None:
|
||||||
self.finish_update(sgd)
|
self.finish_update(sgd)
|
||||||
|
@ -564,7 +573,7 @@ class SpanPredictor(TrainablePipe):
|
||||||
for eg in examples:
|
for eg in examples:
|
||||||
|
|
||||||
# get gold data
|
# get gold data
|
||||||
gold = doc2clusters(eg.reference, self.output_prefix)
|
gold = doc2clusters(eg.predicted, self.target_prefix)
|
||||||
# flatten the gold data
|
# flatten the gold data
|
||||||
starts = []
|
starts = []
|
||||||
ends = []
|
ends = []
|
||||||
|
@ -605,6 +614,7 @@ class SpanPredictor(TrainablePipe):
|
||||||
doc = ex.predicted
|
doc = ex.predicted
|
||||||
assert len(doc) > 2, "Coreference requires at least two tokens"
|
assert len(doc) > 2, "Coreference requires at least two tokens"
|
||||||
doc.spans[f"{self.input_prefix}_0"] = [doc[0:1], doc[1:2]]
|
doc.spans[f"{self.input_prefix}_0"] = [doc[0:1], doc[1:2]]
|
||||||
|
doc.spans[f"{self.target_prefix}_0"] = [doc[0:1], doc[1:2]]
|
||||||
X.append(ex.predicted)
|
X.append(ex.predicted)
|
||||||
Y.append(ex.reference)
|
Y.append(ex.reference)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user