diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 6d12ca85f..e58afd05b 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -4,7 +4,8 @@ import warnings from thinc.api import Model, Linear, Relu, Dropout from thinc.api import chain, noop, Embed, add, tuplify, concatenate from thinc.api import reduce_first, reduce_last, reduce_mean -from thinc.types import Floats2d, Floats1d, Ints2d, Ragged +from thinc.api import PyTorchWrapper +from thinc.types import Floats2d, Floats1d, Ints1d, Ints2d, Ragged from typing import List, Callable, Tuple, Any from ...tokens import Doc from ...util import registry @@ -456,7 +457,7 @@ from typing import List, Tuple import torch # TODO rename this to coref_util -import .coref_util_wl as utils +from .coref_util_wl import add_dummy # TODO rename to plain coref @registry.architectures("spacy.WLCoref.v1") @@ -478,20 +479,23 @@ def build_wl_coref_model( with Model.define_operators({">>": chain}): # TODO chain tok2vec with these models # TODO fix device - should be automatic - device = "gpu:0" + device = "cuda:0" coref_scorer = PyTorchWrapper( CorefScorer( device, embedding_size, hidden_size, n_hidden_layers, - dropout_rate, + dropout, rough_k, a_scoring_batch_size ), convert_inputs=convert_coref_scorer_inputs, convert_outputs=convert_coref_scorer_outputs ) + + coref_model = tok2vec >> coref_scorer + # XXX just ignore this until the coref scorer is integrated span_predictor = PyTorchWrapper( SpanPredictor( # TODO this was hardcoded to 1024, check @@ -499,12 +503,13 @@ def build_wl_coref_model( sp_embedding_size, device ), + convert_inputs=convert_span_predictor_inputs ) # TODO combine models so output is uniform (just one forward pass) # It may be reasonable to have an option to disable span prediction, # and just return words as spans. - return coref_scorer + return coref_model def convert_coref_scorer_inputs( model: Model, @@ -534,6 +539,17 @@ def convert_coref_scorer_outputs( indices_xp = torch2xp(indices) return (scores_xp, indices_xp), convert_for_torch_backward +def convert_span_predictor_inputs( + model: Model, + X: Tuple[Ints1d, Floats2d, Ints1d], + is_train: bool +): + sent_id = xp2torch(X[0], requires_grad=False) + word_features = xp2torch(X[1], requires_grad=False) + head_ids = xp2torch(X[2], requires_grad=False) + argskwargs = ArgsKwargs(args=(sent_id, word_features, head_ids), kwargs={}) + return argskwargs, lambda dX: [] + # TODO This probably belongs in the component, not the model. def predict_span_clusters(span_predictor: Model, sent_ids: Ints1d, @@ -747,7 +763,7 @@ class AnaphoricityScorer(torch.nn.Module): # [batch_size, n_ants] scores = top_rough_scores_batch + self._ffnn(pair_matrix) - scores = utils.add_dummy(scores, eps=True) + scores = add_dummy(scores, eps=True) return scores