diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index e58afd05b..bb3c4c43c 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -4,7 +4,7 @@ import warnings from thinc.api import Model, Linear, Relu, Dropout from thinc.api import chain, noop, Embed, add, tuplify, concatenate from thinc.api import reduce_first, reduce_last, reduce_mean -from thinc.api import PyTorchWrapper +from thinc.api import PyTorchWrapper, ArgsKwargs from thinc.types import Floats2d, Floats1d, Ints1d, Ints2d, Ragged from typing import List, Callable, Tuple, Any from ...tokens import Doc @@ -455,6 +455,7 @@ def pairwise_product(bilinear, dropout, vecs: Floats2d, is_train): from typing import List, Tuple import torch +from thinc.util import xp2torch, torch2xp # TODO rename this to coref_util from .coref_util_wl import add_dummy @@ -475,6 +476,7 @@ def build_wl_coref_model( # span predictor embeddings sp_embedding_size: int = 64, ): + dim = tok2vec.get_dim("nO") with Model.define_operators({">>": chain}): # TODO chain tok2vec with these models @@ -483,6 +485,7 @@ def build_wl_coref_model( coref_scorer = PyTorchWrapper( CorefScorer( device, + dim, embedding_size, hidden_size, n_hidden_layers, @@ -513,11 +516,20 @@ def build_wl_coref_model( def convert_coref_scorer_inputs( model: Model, - X: Floats2d, + X: List[Floats2d], is_train: bool ): - word_features = xp2torch(X, requires_grad=False) - return ArgsKwargs(args=(word_features, ), kwargs={}), lambda dX: [] + # The input here is List[Floats2d], one for each doc + # just use the first + # TODO real batching + X = X[0] + + word_features = xp2torch(X, requires_grad=is_train) + def backprop(args: ArgsKwargs) -> List[Floats2d]: + # convert to xp and wrap in list + gradients = torch2xp(args.args[0]) + return [gradients] + return ArgsKwargs(args=(word_features, ), kwargs={}), backprop def convert_coref_scorer_outputs( @@ -529,7 +541,7 @@ def convert_coref_scorer_outputs( scores, indices = outputs def convert_for_torch_backward(dY: Floats2d) -> ArgsKwargs: - dY_t = xp2torch(dY) + dY_t = xp2torch(dY[0]) return ArgsKwargs( args=([scores],), kwargs={"grad_tensors": [dY_t]}, @@ -633,6 +645,7 @@ class CorefScorer(torch.nn.Module): def __init__( self, device: str, + dim: int, # tok2vec size dist_emb_size: int, hidden_size: int, n_layers: int, @@ -650,7 +663,8 @@ class CorefScorer(torch.nn.Module): """ # device, dist_emb_size, hidden_size, n_layers, dropout_rate self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate).to(device) - bert_emb = 1024 + #TODO clean this up + bert_emb = dim pair_emb = bert_emb * 3 + self.pw.shape self.a_scorer = AnaphoricityScorer( pair_emb, diff --git a/spacy/ml/models/coref_util.py b/spacy/ml/models/coref_util.py index 88997f5e3..6b4bbc8ba 100644 --- a/spacy/ml/models/coref_util.py +++ b/spacy/ml/models/coref_util.py @@ -193,6 +193,11 @@ def select_non_crossing_spans( # selected.append(selected[0]) # this seems a bit weird? return selected +def create_head_span_idxs(ops, doclen: int): + """Helper function to create single-token span indices.""" + aa = ops.xp.arange(0, doclen) + bb = ops.xp.arange(0, doclen) + 1 + return ops.asarray2i([aa, bb]).T def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]: """Given a Doc, convert the cluster spans to simple int tuple lists.""" @@ -201,7 +206,13 @@ def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]: cluster = [] for span in val: # TODO check that there isn't an off-by-one error here - cluster.append((span.start, span.end)) + #cluster.append((span.start, span.end)) + # TODO This conversion should be happening earlier in processing + head_i = span.root.i + cluster.append( (head_i, head_i + 1) ) + + # don't want duplicates + cluster = list(set(cluster)) out.append(cluster) return out @@ -210,7 +221,11 @@ def create_gold_scores( ments: Ints2d, clusters: List[List[Tuple[int, int]]] ) -> List[List[bool]]: """Given mentions considered for antecedents and gold clusters, - construct a gold score matrix. This does not include the placeholder.""" + construct a gold score matrix. This does not include the placeholder. + + In the gold matrix, the value of a true antecedent is True, and otherwise + it is False. These will be converted to 1/0 values later. + """ # make a mapping of mentions to cluster id # id is not important but equality will be ment2cid = {} diff --git a/spacy/pipeline/coref.py b/spacy/pipeline/coref.py index 94677e2bf..d8b534962 100644 --- a/spacy/pipeline/coref.py +++ b/spacy/pipeline/coref.py @@ -18,6 +18,7 @@ from ..vocab import Vocab from ..ml.models.coref_util import ( create_gold_scores, MentionClusters, + create_head_span_idxs, get_clusters_from_doc, get_predicted_clusters, DEFAULT_CLUSTER_PREFIX, @@ -26,7 +27,8 @@ from ..ml.models.coref_util import ( from ..coref_scorer import Evaluator, get_cluster_info, b_cubed, muc, ceafe -default_config = """ +# TODO remove this - kept for reference for now +old_default_config = """ [model] @architectures = "spacy.Coref.v1" max_span_width = 20 @@ -49,6 +51,35 @@ rows = [2000, 2000, 1000, 1000, 1000, 1000] attrs = ["ORTH", "LOWER", "PREFIX", "SUFFIX", "SHAPE", "ID"] include_static_vectors = false +[model.tok2vec.encode] +@architectures = "spacy.MaxoutWindowEncoder.v2" +width = ${model.tok2vec.embed.width} +window_size = 1 +maxout_pieces = 3 +depth = 2 +""" + +default_config = """ +[model] +@architectures = "spacy.WLCoref.v1" +embedding_size = 20 +hidden_size = 1024 +n_hidden_layers = 1 +dropout = 0.3 +rough_k = 50 +a_scoring_batch_size = 512 +sp_embedding_size = 64 + +[model.tok2vec] +@architectures = "spacy.Tok2Vec.v2" + +[model.tok2vec.embed] +@architectures = "spacy.MultiHashEmbed.v1" +width = 64 +rows = [2000, 2000, 1000, 1000, 1000, 1000] +attrs = ["ORTH", "LOWER", "PREFIX", "SUFFIX", "SHAPE", "ID"] +include_static_vectors = false + [model.tok2vec.encode] @architectures = "spacy.MaxoutWindowEncoder.v2" width = ${model.tok2vec.embed.width} @@ -210,7 +241,9 @@ class CoreferenceResolver(TrainablePipe): inputs = [example.predicted for example in examples] preds, backprop = self.model.begin_update(inputs) score_matrix, mention_idx = preds + loss, d_scores = self.get_loss(examples, score_matrix, mention_idx) + # TODO check shape here backprop((d_scores, mention_idx)) if sgd is not None: @@ -292,15 +325,24 @@ class CoreferenceResolver(TrainablePipe): offset = 0 gradients = [] total_loss = 0 + #TODO change this + # 1. do not handle batching (add it back later) + # 2. don't do index conversion (no mentions, just word indices) + # 3. convert words to spans (if necessary) in gold and predictions + + # massage score matrix to be shaped correctly + score_matrix = [ (score_matrix, None) ] for example, (cscores, cidx) in zip(examples, score_matrix): ll = cscores.shape[0] hi = offset + ll clusters = get_clusters_from_doc(example.reference) - gscores = create_gold_scores(mention_idx[offset:hi], clusters) + span_idxs = create_head_span_idxs(ops, len(example.predicted)) + gscores = create_gold_scores(span_idxs, clusters) gscores = ops.asarray2f(gscores) - top_gscores = xp.take_along_axis(gscores, cidx, axis=1) + #top_gscores = xp.take_along_axis(gscores, cidx, axis=1) + top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1) # now add the placeholder gold_placeholder = ~top_gscores.any(axis=1).T gold_placeholder = xp.expand_dims(gold_placeholder, 1) @@ -319,6 +361,8 @@ class CoreferenceResolver(TrainablePipe): offset = hi + # Undo the wrapping + gradients = gradients[0][0] return total_loss, gradients def initialize(