gearing up SpanPredictor for gold-heads

This commit is contained in:
Kádár Ákos 2022-03-24 16:06:20 +01:00
parent 150e7c46d7
commit 706b2e6f25
2 changed files with 35 additions and 19 deletions

View File

@ -53,7 +53,6 @@ def build_wl_coref_model(
convert_inputs=convert_coref_scorer_inputs, convert_inputs=convert_coref_scorer_inputs,
convert_outputs=convert_coref_scorer_outputs convert_outputs=convert_coref_scorer_outputs
) )
coref_model = tok2vec >> coref_scorer coref_model = tok2vec >> coref_scorer
# XXX just ignore this until the coref scorer is integrated # XXX just ignore this until the coref scorer is integrated
span_predictor = PyTorchWrapper( span_predictor = PyTorchWrapper(
@ -62,7 +61,6 @@ def build_wl_coref_model(
hidden_size, hidden_size,
sp_embedding_size, sp_embedding_size,
), ),
convert_inputs=convert_span_predictor_inputs convert_inputs=convert_span_predictor_inputs
) )
# TODO combine models so output is uniform (just one forward pass) # TODO combine models so output is uniform (just one forward pass)
@ -84,14 +82,15 @@ def build_span_predictor(
dim = 768 dim = 768
with Model.define_operators({">>": chain, "&": tuplify}): with Model.define_operators({">>": chain, "&": tuplify}):
# TODO fix device - should be automatic
device = "cuda:0"
span_predictor = PyTorchWrapper( span_predictor = PyTorchWrapper(
SpanPredictor(dim, dist_emb_size, device), SpanPredictor(dim, dist_emb_size),
convert_inputs=convert_span_predictor_inputs convert_inputs=convert_span_predictor_inputs
) )
# TODO use proper parameter for prefix # TODO use proper parameter for prefix
head_info = build_get_head_metadata("coref_head_clusters") head_info = build_get_head_metadata(
"span_coref_head_clusters",
"coref_head_clusters"
)
model = (tok2vec & head_info) >> span_predictor model = (tok2vec & head_info) >> span_predictor
return model return model
@ -148,7 +147,7 @@ def convert_span_predictor_inputs(
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={}) argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
# TODO actually support backprop # TODO actually support backprop
return argskwargs, lambda dX: [] return argskwargs, lambda dX: [[]]
# TODO This probably belongs in the component, not the model. # TODO This probably belongs in the component, not the model.
def predict_span_clusters(span_predictor: Model, def predict_span_clusters(span_predictor: Model,
@ -217,18 +216,27 @@ def _clusterize(
clusters.append(sorted(cluster)) clusters.append(sorted(cluster))
return sorted(clusters) return sorted(clusters)
def build_get_head_metadata(prefix):
def build_get_head_metadata(update_prefix, predict_prefix):
# TODO this name is awful, fix it # TODO this name is awful, fix it
model = Model("HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward) model = Model("HeadDataProvider",
attrs={
"update_prefix": update_prefix,
"predict_prefix": predict_prefix
},
forward=head_data_forward)
return model return model
def head_data_forward(model, docs, is_train): def head_data_forward(model, docs, is_train):
"""A layer to generate the extra data needed for the span predictor. """A layer to generate the extra data needed for the span predictor.
""" """
sent_ids = [] sent_ids = []
head_ids = [] head_ids = []
prefix = model.attrs["prefix"] if is_train:
prefix = model.attrs["update_prefix"]
else:
prefix = model.attrs["predict_prefix"]
for doc in docs: for doc in docs:
sids = model.ops.asarray2i(get_sentence_ids(doc)) sids = model.ops.asarray2i(get_sentence_ids(doc))
sent_ids.append(sids) sent_ids.append(sids)
@ -557,11 +565,9 @@ class SpanPredictor(torch.nn.Module):
words[cols], words[cols],
self.emb(emb_ids[rows, cols]), self.emb(emb_ids[rows, cols]),
), dim=1) ), dim=1)
input(len(heads_ids))
lengths = same_sent.sum(dim=1) lengths = same_sent.sum(dim=1)
padding_mask = torch.arange(0, lengths.max().item(), device=words.device).unsqueeze(0) padding_mask = torch.arange(0, lengths.max().item(), device=words.device).unsqueeze(0)
padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len] padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len]
input(padding_mask.shape)
# [n_heads, max_sent_len, input_size * 2 + distance_emb_size] # [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
# This is necessary to allow the convolution layer to look at several # This is necessary to allow the convolution layer to look at several
# word scores # word scores

View File

@ -417,6 +417,7 @@ DEFAULT_SPAN_PREDICTOR_MODEL = Config().from_str(default_span_predictor_config)[
default_config={ default_config={
"model": DEFAULT_SPAN_PREDICTOR_MODEL, "model": DEFAULT_SPAN_PREDICTOR_MODEL,
"input_prefix": "coref_head_clusters", "input_prefix": "coref_head_clusters",
"target_prefix": "span_head_target_clusters",
"output_prefix": "coref_clusters", "output_prefix": "coref_clusters",
}, },
default_score_weights={"span_predictor_f": 1.0, "span_predictor_p": None, "span_predictor_r": None}, default_score_weights={"span_predictor_f": 1.0, "span_predictor_p": None, "span_predictor_r": None},
@ -426,6 +427,7 @@ def make_span_predictor(
name: str, name: str,
model, model,
input_prefix: str = "coref_head_clusters", input_prefix: str = "coref_head_clusters",
target_prefix: str = "span_head_target_clusters",
output_prefix: str = "coref_clusters", output_prefix: str = "coref_clusters",
) -> "SpanPredictor": ) -> "SpanPredictor":
"""Create a SpanPredictor component.""" """Create a SpanPredictor component."""
@ -444,12 +446,14 @@ class SpanPredictor(TrainablePipe):
name: str = "span_predictor", name: str = "span_predictor",
*, *,
input_prefix: str = "coref_head_clusters", input_prefix: str = "coref_head_clusters",
target_prefix: str = "span_coref_head_clusters",
output_prefix: str = "coref_clusters", output_prefix: str = "coref_clusters",
) -> None: ) -> None:
self.vocab = vocab self.vocab = vocab
self.model = model self.model = model
self.name = name self.name = name
self.input_prefix = input_prefix self.input_prefix = input_prefix
self.target_prefix = target_prefix
self.output_prefix = output_prefix self.output_prefix = output_prefix
self.cfg = {} self.cfg = {}
@ -511,13 +515,18 @@ class SpanPredictor(TrainablePipe):
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
total_loss = 0 total_loss = 0
docs = [eg.predicted for eg in examples]
for eg in examples: for doc, eg in zip(docs, examples):
span_scores, backprop = self.model.begin_update([eg.predicted]) # replicates the EntityLinker's behaviour and
# copies annotations over https://bit.ly/3iweDcW
for key, sg in eg.reference.spans.items():
if key.startswith(self.target_prefix):
doc.spans[key] = [doc[span.start:span.end] for span in sg]
span_scores, backprop = self.model.begin_update([doc])
loss, d_scores = self.get_loss([eg], span_scores) loss, d_scores = self.get_loss([eg], span_scores)
total_loss += loss total_loss += loss
# TODO check shape here # TODO check shape here
backprop((d_scores, mention_idx)) backprop(d_scores)
if sgd is not None: if sgd is not None:
self.finish_update(sgd) self.finish_update(sgd)
@ -564,7 +573,7 @@ class SpanPredictor(TrainablePipe):
for eg in examples: for eg in examples:
# get gold data # get gold data
gold = doc2clusters(eg.reference, self.output_prefix) gold = doc2clusters(eg.predicted, self.target_prefix)
# flatten the gold data # flatten the gold data
starts = [] starts = []
ends = [] ends = []
@ -605,6 +614,7 @@ class SpanPredictor(TrainablePipe):
doc = ex.predicted doc = ex.predicted
assert len(doc) > 2, "Coreference requires at least two tokens" assert len(doc) > 2, "Coreference requires at least two tokens"
doc.spans[f"{self.input_prefix}_0"] = [doc[0:1], doc[1:2]] doc.spans[f"{self.input_prefix}_0"] = [doc[0:1], doc[1:2]]
doc.spans[f"{self.target_prefix}_0"] = [doc[0:1], doc[1:2]]
X.append(ex.predicted) X.append(ex.predicted)
Y.append(ex.reference) Y.append(ex.reference)