The coref model is able to be loaded

The span predictor component is initialized but not used at all now.
Plan is to work on it after the word level clustering part is trainable
end-to-end.
This commit is contained in:
Paul O'Leary McCann 2022-03-09 19:31:11 +09:00
parent 35cc2b138f
commit c4f9c24738

View File

@ -4,7 +4,8 @@ import warnings
from thinc.api import Model, Linear, Relu, Dropout
from thinc.api import chain, noop, Embed, add, tuplify, concatenate
from thinc.api import reduce_first, reduce_last, reduce_mean
from thinc.types import Floats2d, Floats1d, Ints2d, Ragged
from thinc.api import PyTorchWrapper
from thinc.types import Floats2d, Floats1d, Ints1d, Ints2d, Ragged
from typing import List, Callable, Tuple, Any
from ...tokens import Doc
from ...util import registry
@ -456,7 +457,7 @@ from typing import List, Tuple
import torch
# TODO rename this to coref_util
import .coref_util_wl as utils
from .coref_util_wl import add_dummy
# TODO rename to plain coref
@registry.architectures("spacy.WLCoref.v1")
@ -478,20 +479,23 @@ def build_wl_coref_model(
with Model.define_operators({">>": chain}):
# TODO chain tok2vec with these models
# TODO fix device - should be automatic
device = "gpu:0"
device = "cuda:0"
coref_scorer = PyTorchWrapper(
CorefScorer(
device,
embedding_size,
hidden_size,
n_hidden_layers,
dropout_rate,
dropout,
rough_k,
a_scoring_batch_size
),
convert_inputs=convert_coref_scorer_inputs,
convert_outputs=convert_coref_scorer_outputs
)
coref_model = tok2vec >> coref_scorer
# XXX just ignore this until the coref scorer is integrated
span_predictor = PyTorchWrapper(
SpanPredictor(
# TODO this was hardcoded to 1024, check
@ -499,12 +503,13 @@ def build_wl_coref_model(
sp_embedding_size,
device
),
convert_inputs=convert_span_predictor_inputs
)
# TODO combine models so output is uniform (just one forward pass)
# It may be reasonable to have an option to disable span prediction,
# and just return words as spans.
return coref_scorer
return coref_model
def convert_coref_scorer_inputs(
model: Model,
@ -534,6 +539,17 @@ def convert_coref_scorer_outputs(
indices_xp = torch2xp(indices)
return (scores_xp, indices_xp), convert_for_torch_backward
def convert_span_predictor_inputs(
model: Model,
X: Tuple[Ints1d, Floats2d, Ints1d],
is_train: bool
):
sent_id = xp2torch(X[0], requires_grad=False)
word_features = xp2torch(X[1], requires_grad=False)
head_ids = xp2torch(X[2], requires_grad=False)
argskwargs = ArgsKwargs(args=(sent_id, word_features, head_ids), kwargs={})
return argskwargs, lambda dX: []
# TODO This probably belongs in the component, not the model.
def predict_span_clusters(span_predictor: Model,
sent_ids: Ints1d,
@ -747,7 +763,7 @@ class AnaphoricityScorer(torch.nn.Module):
# [batch_size, n_ants]
scores = top_rough_scores_batch + self._ffnn(pair_matrix)
scores = utils.add_dummy(scores, eps=True)
scores = add_dummy(scores, eps=True)
return scores