diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 660ef68c5..4967e7f23 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -1,6 +1,6 @@ from typing import List, Tuple -from thinc.api import Model, chain +from thinc.api import Model, chain, get_width from thinc.api import PyTorchWrapper, ArgsKwargs from thinc.types import Floats2d from thinc.util import torch, xp2torch, torch2xp @@ -25,12 +25,48 @@ def build_wl_coref_model( tok2vec_size: int = 768, # tok2vec size ): # TODO add model return types - # dim = tok2vec.maybe_get_dim("n0") + + nI = None with Model.define_operators({">>": chain}): - coref_clusterer = PyTorchWrapper( + coref_clusterer = Model( + "coref_clusterer", + forward=coref_forward, + init=coref_init, + dims={"nI": nI}, + attrs={ + "distance_embedding_size": distance_embedding_size, + "hidden_size": hidden_size, + "depth": depth, + "dropout": dropout, + "antecedent_limit": antecedent_limit, + "antecedent_batch_size": antecedent_batch_size, + }, + ) + + coref_model = tok2vec >> coref_clusterer + return coref_model + + +def coref_init(model: Model, X=None, Y=None): + if model.layers: + return + + if X is not None and model.has_dim("nI") is None: + model.set_dim("nI", get_width(X)) + + hidden_size = model.attrs["hidden_size"] + depth = model.attrs["depth"] + dropout = model.attrs["dropout"] + antecedent_limit = model.attrs["antecedent_limit"] + antecedent_batch_size = model.attrs["antecedent_batch_size"] + distance_embedding_size = model.attrs["distance_embedding_size"] + + PyTorchWrapper = registry.get("layers", "PyTorchWrapper.v2") + model._layers = [ + PyTorchWrapper( CorefClusterer( - tok2vec_size, + model.get_dim("nI"), distance_embedding_size, hidden_size, depth, @@ -41,13 +77,15 @@ def build_wl_coref_model( convert_inputs=convert_coref_clusterer_inputs, convert_outputs=convert_coref_clusterer_outputs, ) - coref_model = tok2vec >> coref_clusterer - return coref_model + # TODO maybe we need mixed precision and grad scaling? + ] -def convert_coref_clusterer_inputs( - model: Model, X: List[Floats2d], is_train: bool -): +def coref_forward(model: Model, X, is_train: bool): + return model.layers[0](X, is_train) + + +def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool): # The input here is List[Floats2d], one for each doc # just use the first # TODO real batching @@ -55,7 +93,7 @@ def convert_coref_clusterer_inputs( word_features = xp2torch(X, requires_grad=is_train) # TODO fix or remove type annotations - def backprop(args: ArgsKwargs): #-> List[Floats2d]: + def backprop(args: ArgsKwargs): # -> List[Floats2d]: # convert to xp and wrap in list gradients = torch2xp(args.args[0]) return [gradients] @@ -63,9 +101,7 @@ def convert_coref_clusterer_inputs( return ArgsKwargs(args=(word_features,), kwargs={}), backprop -def convert_coref_clusterer_outputs( - model: Model, inputs_outputs, is_train: bool -): +def convert_coref_clusterer_outputs(model: Model, inputs_outputs, is_train: bool): _, outputs = inputs_outputs scores, indices = outputs @@ -115,9 +151,7 @@ class CorefClusterer(torch.nn.Module): self.pw = DistancePairwiseEncoder(dist_emb_size, dropout) pair_emb = dim * 3 + self.pw.shape - self.a_scorer = AnaphoricityScorer( - pair_emb, hidden_size, n_layers, dropout - ) + self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout) self.lstm = torch.nn.LSTM( input_size=dim, hidden_size=dim, @@ -156,10 +190,10 @@ class CorefClusterer(torch.nn.Module): a_scores_lst: List[torch.Tensor] = [] for i in range(0, len(words), batch_size): - pw_batch = pw[i:i + batch_size] - words_batch = words[i:i + batch_size] - top_indices_batch = top_indices[i:i + batch_size] - top_rough_scores_batch = top_rough_scores[i:i + batch_size] + pw_batch = pw[i : i + batch_size] + words_batch = words[i : i + batch_size] + top_indices_batch = top_indices[i : i + batch_size] + top_rough_scores_batch = top_rough_scores[i : i + batch_size] # a_scores_batch [batch_size, n_ants] a_scores_batch = self.a_scorer(