diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 33c278b3d..f83014344 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -1,7 +1,7 @@ from dataclasses import dataclass import warnings -from thinc.api import Model, Linear, Relu, Dropout, chain, noop +from thinc.api import Model, Linear, Relu, Dropout, chain, noop, Embed, add from thinc.types import Floats2d, Floats1d, Ints2d, Ragged from typing import List, Callable, Tuple, Any from ...tokens import Doc @@ -27,7 +27,7 @@ def build_coref( span_embedder = build_span_embedder(get_mentions, max_span_width) - with Model.define_operators({">>": chain, "&": tuplify}): + with Model.define_operators({">>": chain, "&": tuplify, "+": add}): mention_scorer = ( Linear(nI=dim, nO=hidden) @@ -37,10 +37,14 @@ def build_coref( ) mention_scorer.initialize() + #TODO make feature_embed_size a param + feature_embed_size = 20 + width_scorer = build_width_scorer(max_span_width, hidden, feature_embed_size) + bilinear = Linear(nI=dim, nO=dim) >> Dropout(dropout) bilinear.initialize() - ms = build_take_vecs() >> mention_scorer + ms = (build_take_vecs() >> mention_scorer) + width_scorer model = ( (tok2vec & noop()) @@ -129,6 +133,38 @@ class SpanEmbeddings: return self +def build_width_scorer(max_span_width, hidden_size, feature_embed_size=20): + span_width_prior = ( + Embed(nV=max_span_width, nO=feature_embed_size) + >> Linear(nI=feature_embed_size, nO=hidden_size) + >> Relu(nI=hidden_size, nO=hidden_size) + >> Dropout() + >> Linear(nI=hidden_size, nO=1) + ) + span_width_prior.initialize() + return Model( + "WidthScorer", + forward=width_score_forward, + layers=[span_width_prior]) + + +def width_score_forward(model, embeds: SpanEmbeddings, is_train) -> Tuple[Floats1d, Callable]: + # calculate widths, subtracting 1 so it's 0-index + w_ffnn = model.layers[0] + idxs = embeds.indices + widths = idxs[:,1] - idxs[:,0] - 1 + wscores, width_b = w_ffnn(widths, is_train) + + lens = embeds.vectors.lengths + + def width_score_backward(d_score: Floats1d) -> SpanEmbeddings: + + dX = width_b(d_score) + vecs = Ragged(dX, lens) + return SpanEmbeddings(idxs, vecs) + + return wscores, width_score_backward + # model converting a Doc/Mention to span embeddings # get_mentions: Callable[Doc, Pairs[int]] def build_span_embedder(