The coref model is able to be loaded

The span predictor component is initialized but not used at all now.
Plan is to work on it after the word level clustering part is trainable
end-to-end.
This commit is contained in:
Paul O'Leary McCann 2022-03-09 19:31:11 +09:00
parent 35cc2b138f
commit c4f9c24738

View File

@ -4,7 +4,8 @@ import warnings
from thinc.api import Model, Linear, Relu, Dropout from thinc.api import Model, Linear, Relu, Dropout
from thinc.api import chain, noop, Embed, add, tuplify, concatenate from thinc.api import chain, noop, Embed, add, tuplify, concatenate
from thinc.api import reduce_first, reduce_last, reduce_mean from thinc.api import reduce_first, reduce_last, reduce_mean
from thinc.types import Floats2d, Floats1d, Ints2d, Ragged from thinc.api import PyTorchWrapper
from thinc.types import Floats2d, Floats1d, Ints1d, Ints2d, Ragged
from typing import List, Callable, Tuple, Any from typing import List, Callable, Tuple, Any
from ...tokens import Doc from ...tokens import Doc
from ...util import registry from ...util import registry
@ -456,7 +457,7 @@ from typing import List, Tuple
import torch import torch
# TODO rename this to coref_util # TODO rename this to coref_util
import .coref_util_wl as utils from .coref_util_wl import add_dummy
# TODO rename to plain coref # TODO rename to plain coref
@registry.architectures("spacy.WLCoref.v1") @registry.architectures("spacy.WLCoref.v1")
@ -478,20 +479,23 @@ def build_wl_coref_model(
with Model.define_operators({">>": chain}): with Model.define_operators({">>": chain}):
# TODO chain tok2vec with these models # TODO chain tok2vec with these models
# TODO fix device - should be automatic # TODO fix device - should be automatic
device = "gpu:0" device = "cuda:0"
coref_scorer = PyTorchWrapper( coref_scorer = PyTorchWrapper(
CorefScorer( CorefScorer(
device, device,
embedding_size, embedding_size,
hidden_size, hidden_size,
n_hidden_layers, n_hidden_layers,
dropout_rate, dropout,
rough_k, rough_k,
a_scoring_batch_size a_scoring_batch_size
), ),
convert_inputs=convert_coref_scorer_inputs, convert_inputs=convert_coref_scorer_inputs,
convert_outputs=convert_coref_scorer_outputs convert_outputs=convert_coref_scorer_outputs
) )
coref_model = tok2vec >> coref_scorer
# XXX just ignore this until the coref scorer is integrated
span_predictor = PyTorchWrapper( span_predictor = PyTorchWrapper(
SpanPredictor( SpanPredictor(
# TODO this was hardcoded to 1024, check # TODO this was hardcoded to 1024, check
@ -499,12 +503,13 @@ def build_wl_coref_model(
sp_embedding_size, sp_embedding_size,
device device
), ),
convert_inputs=convert_span_predictor_inputs convert_inputs=convert_span_predictor_inputs
) )
# TODO combine models so output is uniform (just one forward pass) # TODO combine models so output is uniform (just one forward pass)
# It may be reasonable to have an option to disable span prediction, # It may be reasonable to have an option to disable span prediction,
# and just return words as spans. # and just return words as spans.
return coref_scorer return coref_model
def convert_coref_scorer_inputs( def convert_coref_scorer_inputs(
model: Model, model: Model,
@ -534,6 +539,17 @@ def convert_coref_scorer_outputs(
indices_xp = torch2xp(indices) indices_xp = torch2xp(indices)
return (scores_xp, indices_xp), convert_for_torch_backward return (scores_xp, indices_xp), convert_for_torch_backward
def convert_span_predictor_inputs(
model: Model,
X: Tuple[Ints1d, Floats2d, Ints1d],
is_train: bool
):
sent_id = xp2torch(X[0], requires_grad=False)
word_features = xp2torch(X[1], requires_grad=False)
head_ids = xp2torch(X[2], requires_grad=False)
argskwargs = ArgsKwargs(args=(sent_id, word_features, head_ids), kwargs={})
return argskwargs, lambda dX: []
# TODO This probably belongs in the component, not the model. # TODO This probably belongs in the component, not the model.
def predict_span_clusters(span_predictor: Model, def predict_span_clusters(span_predictor: Model,
sent_ids: Ints1d, sent_ids: Ints1d,
@ -747,7 +763,7 @@ class AnaphoricityScorer(torch.nn.Module):
# [batch_size, n_ants] # [batch_size, n_ants]
scores = top_rough_scores_batch + self._ffnn(pair_matrix) scores = top_rough_scores_batch + self._ffnn(pair_matrix)
scores = utils.add_dummy(scores, eps=True) scores = add_dummy(scores, eps=True)
return scores return scores