From c4f9c24738b6d609c0994b285931146ebf282764 Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Wed, 9 Mar 2022 19:31:11 +0900 Subject: [PATCH] The coref model is able to be loaded The span predictor component is initialized but not used at all now. Plan is to work on it after the word level clustering part is trainable end-to-end. --- spacy/ml/models/coref.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 6d12ca85f..e58afd05b 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -4,7 +4,8 @@ import warnings from thinc.api import Model, Linear, Relu, Dropout from thinc.api import chain, noop, Embed, add, tuplify, concatenate from thinc.api import reduce_first, reduce_last, reduce_mean -from thinc.types import Floats2d, Floats1d, Ints2d, Ragged +from thinc.api import PyTorchWrapper +from thinc.types import Floats2d, Floats1d, Ints1d, Ints2d, Ragged from typing import List, Callable, Tuple, Any from ...tokens import Doc from ...util import registry @@ -456,7 +457,7 @@ from typing import List, Tuple import torch # TODO rename this to coref_util -import .coref_util_wl as utils +from .coref_util_wl import add_dummy # TODO rename to plain coref @registry.architectures("spacy.WLCoref.v1") @@ -478,20 +479,23 @@ def build_wl_coref_model( with Model.define_operators({">>": chain}): # TODO chain tok2vec with these models # TODO fix device - should be automatic - device = "gpu:0" + device = "cuda:0" coref_scorer = PyTorchWrapper( CorefScorer( device, embedding_size, hidden_size, n_hidden_layers, - dropout_rate, + dropout, rough_k, a_scoring_batch_size ), convert_inputs=convert_coref_scorer_inputs, convert_outputs=convert_coref_scorer_outputs ) + + coref_model = tok2vec >> coref_scorer + # XXX just ignore this until the coref scorer is integrated span_predictor = PyTorchWrapper( SpanPredictor( # TODO this was hardcoded to 1024, check @@ -499,12 +503,13 @@ def build_wl_coref_model( sp_embedding_size, device ), + convert_inputs=convert_span_predictor_inputs ) # TODO combine models so output is uniform (just one forward pass) # It may be reasonable to have an option to disable span prediction, # and just return words as spans. - return coref_scorer + return coref_model def convert_coref_scorer_inputs( model: Model, @@ -534,6 +539,17 @@ def convert_coref_scorer_outputs( indices_xp = torch2xp(indices) return (scores_xp, indices_xp), convert_for_torch_backward +def convert_span_predictor_inputs( + model: Model, + X: Tuple[Ints1d, Floats2d, Ints1d], + is_train: bool +): + sent_id = xp2torch(X[0], requires_grad=False) + word_features = xp2torch(X[1], requires_grad=False) + head_ids = xp2torch(X[2], requires_grad=False) + argskwargs = ArgsKwargs(args=(sent_id, word_features, head_ids), kwargs={}) + return argskwargs, lambda dX: [] + # TODO This probably belongs in the component, not the model. def predict_span_clusters(span_predictor: Model, sent_ids: Ints1d, @@ -747,7 +763,7 @@ class AnaphoricityScorer(torch.nn.Module): # [batch_size, n_ants] scores = top_rough_scores_batch + self._ffnn(pair_matrix) - scores = utils.add_dummy(scores, eps=True) + scores = add_dummy(scores, eps=True) return scores