mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-19 12:42:20 +03:00
Delete all the coref-hoi code
This commit is contained in:
parent
abdc7d87af
commit
d0ae2590db
|
@ -11,457 +11,13 @@ from ...tokens import Doc
|
||||||
from ...util import registry
|
from ...util import registry
|
||||||
from ..extract_spans import extract_spans
|
from ..extract_spans import extract_spans
|
||||||
|
|
||||||
from .coref_util import get_candidate_mentions, select_non_crossing_spans, topk
|
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures("spacy.Coref.v1")
|
|
||||||
def build_coref(
|
|
||||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
|
||||||
get_mentions: Any = get_candidate_mentions,
|
|
||||||
hidden: int = 1000,
|
|
||||||
dropout: float = 0.3,
|
|
||||||
mention_limit: int = 3900,
|
|
||||||
# TODO this needs a better name. It limits the max mentions as a ratio of
|
|
||||||
# the token count.
|
|
||||||
mention_limit_ratio: float = 0.4,
|
|
||||||
max_span_width: int = 20,
|
|
||||||
antecedent_limit: int = 50,
|
|
||||||
):
|
|
||||||
dim = tok2vec.get_dim("nO") * 3
|
|
||||||
|
|
||||||
span_embedder = build_span_embedder(get_mentions, max_span_width)
|
|
||||||
|
|
||||||
with Model.define_operators({">>": chain, "&": tuplify, "+": add}):
|
|
||||||
|
|
||||||
mention_scorer = (
|
|
||||||
Linear(nI=dim, nO=hidden)
|
|
||||||
>> Relu(nI=hidden, nO=hidden)
|
|
||||||
>> Dropout(dropout)
|
|
||||||
>> Linear(nI=hidden, nO=hidden)
|
|
||||||
>> Relu(nI=hidden, nO=hidden)
|
|
||||||
>> Dropout(dropout)
|
|
||||||
>> Linear(nI=hidden, nO=1)
|
|
||||||
)
|
|
||||||
mention_scorer.initialize()
|
|
||||||
|
|
||||||
# TODO make feature_embed_size a param
|
|
||||||
feature_embed_size = 20
|
|
||||||
width_scorer = build_width_scorer(max_span_width, hidden, feature_embed_size)
|
|
||||||
|
|
||||||
bilinear = Linear(nI=dim, nO=dim) >> Dropout(dropout)
|
|
||||||
bilinear.initialize()
|
|
||||||
|
|
||||||
ms = (build_take_vecs() >> mention_scorer) + width_scorer
|
|
||||||
|
|
||||||
model = (
|
|
||||||
(tok2vec & noop())
|
|
||||||
>> span_embedder
|
|
||||||
>> (ms & noop())
|
|
||||||
>> build_coarse_pruner(mention_limit, mention_limit_ratio)
|
|
||||||
>> build_ant_scorer(bilinear, Dropout(dropout), antecedent_limit)
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SpanEmbeddings:
|
|
||||||
indices: Ints2d # Array with 2 columns (for start and end index)
|
|
||||||
vectors: Ragged # Ragged[Floats2d] # One vector per span
|
|
||||||
# NB: We assume that the indices refer to a concatenated Floats2d that
|
|
||||||
# has one row per token in the *batch* of documents. This makes it unambiguous
|
|
||||||
# which row is in which document, because if the lengths are e.g. [10, 5],
|
|
||||||
# a span starting at 11 must be starting at token 2 of doc 1. A bug could
|
|
||||||
# potentially cause you to have a span which crosses a doc boundary though,
|
|
||||||
# which would be bad.
|
|
||||||
# The lengths in the Ragged are not the tokens per doc, but the number of
|
|
||||||
# mentions per doc.
|
|
||||||
|
|
||||||
def __add__(self, right):
|
|
||||||
out = self.vectors.data + right.vectors.data
|
|
||||||
return SpanEmbeddings(self.indices, Ragged(out, self.vectors.lengths))
|
|
||||||
|
|
||||||
def __iadd__(self, right):
|
|
||||||
self.vectors.data += right.vectors.data
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
def build_width_scorer(max_span_width, hidden_size, feature_embed_size=20):
|
|
||||||
span_width_prior = (
|
|
||||||
Embed(nV=max_span_width, nO=feature_embed_size)
|
|
||||||
>> Linear(nI=feature_embed_size, nO=hidden_size)
|
|
||||||
>> Relu(nI=hidden_size, nO=hidden_size)
|
|
||||||
>> Dropout()
|
|
||||||
>> Linear(nI=hidden_size, nO=1)
|
|
||||||
)
|
|
||||||
span_width_prior.initialize()
|
|
||||||
model = Model("WidthScorer", forward=width_score_forward, layers=[span_width_prior])
|
|
||||||
model.set_ref("width_prior", span_width_prior)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def width_score_forward(
|
|
||||||
model, embeds: SpanEmbeddings, is_train
|
|
||||||
) -> Tuple[Floats1d, Callable]:
|
|
||||||
# calculate widths, subtracting 1 so it's 0-index
|
|
||||||
w_ffnn = model.get_ref("width_prior")
|
|
||||||
idxs = embeds.indices
|
|
||||||
widths = idxs[:, 1] - idxs[:, 0] - 1
|
|
||||||
wscores, width_b = w_ffnn(widths, is_train)
|
|
||||||
|
|
||||||
lens = embeds.vectors.lengths
|
|
||||||
|
|
||||||
def width_score_backward(d_score: Floats1d) -> SpanEmbeddings:
|
|
||||||
|
|
||||||
dX = width_b(d_score)
|
|
||||||
vecs = Ragged(dX, lens)
|
|
||||||
return SpanEmbeddings(idxs, vecs)
|
|
||||||
|
|
||||||
return wscores, width_score_backward
|
|
||||||
|
|
||||||
|
|
||||||
# model converting a Doc/Mention to span embeddings
|
|
||||||
# get_mentions: Callable[Doc, Pairs[int]]
|
|
||||||
def build_span_embedder(
|
|
||||||
get_mentions: Callable,
|
|
||||||
max_span_width: int = 20,
|
|
||||||
) -> Model[Tuple[List[Floats2d], List[Doc]], SpanEmbeddings]:
|
|
||||||
|
|
||||||
with Model.define_operators({">>": chain, "|": concatenate}):
|
|
||||||
span_reduce = extract_spans() >> (
|
|
||||||
reduce_first() | reduce_last() | reduce_mean()
|
|
||||||
)
|
|
||||||
model = Model(
|
|
||||||
"SpanEmbedding",
|
|
||||||
forward=span_embeddings_forward,
|
|
||||||
attrs={
|
|
||||||
"get_mentions": get_mentions,
|
|
||||||
# XXX might be better to make this an implicit parameter in the
|
|
||||||
# mention generator
|
|
||||||
"max_span_width": max_span_width,
|
|
||||||
},
|
|
||||||
layers=[span_reduce],
|
|
||||||
)
|
|
||||||
model.set_ref("span_reducer", span_reduce)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def span_embeddings_forward(
|
|
||||||
model, inputs: Tuple[List[Floats2d], List[Doc]], is_train
|
|
||||||
) -> Tuple[SpanEmbeddings, Callable]:
|
|
||||||
ops = model.ops
|
|
||||||
xp = ops.xp
|
|
||||||
|
|
||||||
tokvecs, docs = inputs
|
|
||||||
|
|
||||||
# TODO fix this
|
|
||||||
dim = tokvecs[0].shape[1]
|
|
||||||
|
|
||||||
get_mentions = model.attrs["get_mentions"]
|
|
||||||
max_span_width = model.attrs["max_span_width"]
|
|
||||||
mentions = ops.alloc2i(0, 2)
|
|
||||||
docmenlens = [] # number of mentions per doc
|
|
||||||
|
|
||||||
for doc in docs:
|
|
||||||
starts, ends = get_mentions(doc, max_span_width)
|
|
||||||
docmenlens.append(len(starts))
|
|
||||||
cments = ops.asarray2i([starts, ends]).transpose()
|
|
||||||
|
|
||||||
mentions = xp.concatenate((mentions, cments))
|
|
||||||
|
|
||||||
# TODO support attention here
|
|
||||||
tokvecs = xp.concatenate(tokvecs)
|
|
||||||
doclens = [len(doc) for doc in docs]
|
|
||||||
tokvecs_r = Ragged(tokvecs, doclens)
|
|
||||||
mentions_r = Ragged(mentions, docmenlens)
|
|
||||||
|
|
||||||
span_reduce = model.get_ref("span_reducer")
|
|
||||||
spanvecs, span_reduce_back = span_reduce((tokvecs_r, mentions_r), is_train)
|
|
||||||
|
|
||||||
embeds = Ragged(spanvecs, docmenlens)
|
|
||||||
|
|
||||||
def backprop_span_embed(dY: SpanEmbeddings) -> Tuple[List[Floats2d], List[Doc]]:
|
|
||||||
grad, idxes = span_reduce_back(dY.vectors.data)
|
|
||||||
|
|
||||||
oweights = []
|
|
||||||
offset = 0
|
|
||||||
for doclen in doclens:
|
|
||||||
hi = offset + doclen
|
|
||||||
oweights.append(grad.data[offset:hi])
|
|
||||||
offset = hi
|
|
||||||
|
|
||||||
return oweights, docs
|
|
||||||
|
|
||||||
return SpanEmbeddings(mentions, embeds), backprop_span_embed
|
|
||||||
|
|
||||||
|
|
||||||
def build_coarse_pruner(
|
|
||||||
mention_limit: int,
|
|
||||||
mention_limit_ratio: float,
|
|
||||||
) -> Model[SpanEmbeddings, SpanEmbeddings]:
|
|
||||||
model = Model(
|
|
||||||
"CoarsePruner",
|
|
||||||
forward=coarse_prune,
|
|
||||||
attrs={
|
|
||||||
"mention_limit": mention_limit,
|
|
||||||
"mention_limit_ratio": mention_limit_ratio,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def coarse_prune(
|
|
||||||
model, inputs: Tuple[Floats1d, SpanEmbeddings], is_train
|
|
||||||
) -> Tuple[Tuple[Floats1d, SpanEmbeddings], Callable]:
|
|
||||||
"""Given scores for mention, output the top non-crossing mentions.
|
|
||||||
|
|
||||||
Mentions can contain other mentions, but candidate mentions cannot cross each other.
|
|
||||||
"""
|
|
||||||
rawscores, spanembeds = inputs
|
|
||||||
scores = rawscores.flatten()
|
|
||||||
mention_limit = model.attrs["mention_limit"]
|
|
||||||
mention_limit_ratio = model.attrs["mention_limit_ratio"]
|
|
||||||
# XXX: Issue here. Don't need docs to find crossing spans, but might for the limits.
|
|
||||||
# In old code the limit can be:
|
|
||||||
# - hard number per doc
|
|
||||||
# - ratio of tokens in the doc
|
|
||||||
|
|
||||||
offset = 0
|
|
||||||
selected = []
|
|
||||||
sellens = []
|
|
||||||
for menlen in spanembeds.vectors.lengths:
|
|
||||||
hi = offset + menlen
|
|
||||||
cscores = scores[offset:hi]
|
|
||||||
|
|
||||||
# negate it so highest numbers come first
|
|
||||||
# This is relatively slow but can't be skipped.
|
|
||||||
tops = (model.ops.xp.argsort(-1 * cscores)).tolist()
|
|
||||||
starts = spanembeds.indices[offset:hi, 0].tolist()
|
|
||||||
ends = spanembeds.indices[offset:hi:, 1].tolist()
|
|
||||||
|
|
||||||
# calculate the doc length
|
|
||||||
doclen = ends[-1] - starts[0]
|
|
||||||
# XXX seems to make more sense to use menlen than doclen here?
|
|
||||||
# coref-hoi uses doclen (number of words).
|
|
||||||
mlimit = min(mention_limit, int(mention_limit_ratio * doclen))
|
|
||||||
# csel is a 1d integer list
|
|
||||||
csel = select_non_crossing_spans(tops, starts, ends, mlimit)
|
|
||||||
# add the offset so these indices are absolute
|
|
||||||
csel = [ii + offset for ii in csel]
|
|
||||||
# this should be constant because short choices are padded
|
|
||||||
sellens.append(len(csel))
|
|
||||||
selected += csel
|
|
||||||
offset += menlen
|
|
||||||
|
|
||||||
selected = model.ops.asarray1i(selected)
|
|
||||||
top_spans = spanembeds.indices[selected]
|
|
||||||
top_vecs = spanembeds.vectors.data[selected]
|
|
||||||
|
|
||||||
out = SpanEmbeddings(top_spans, Ragged(top_vecs, sellens))
|
|
||||||
|
|
||||||
# save some variables so the embeds can be garbage collected
|
|
||||||
idxlen = spanembeds.indices.shape[0]
|
|
||||||
vecshape = spanembeds.vectors.data.shape
|
|
||||||
indices = spanembeds.indices
|
|
||||||
veclens = out.vectors.lengths
|
|
||||||
|
|
||||||
def coarse_prune_backprop(
|
|
||||||
dY: Tuple[Floats1d, SpanEmbeddings]
|
|
||||||
) -> Tuple[Floats1d, SpanEmbeddings]:
|
|
||||||
|
|
||||||
dYscores, dYembeds = dY
|
|
||||||
|
|
||||||
dXscores = model.ops.alloc1f(idxlen)
|
|
||||||
dXscores[selected] = dYscores.flatten()
|
|
||||||
|
|
||||||
dXvecs = model.ops.alloc2f(*vecshape)
|
|
||||||
dXvecs[selected] = dYembeds.vectors.data
|
|
||||||
rout = Ragged(dXvecs, veclens)
|
|
||||||
dXembeds = SpanEmbeddings(indices, rout)
|
|
||||||
|
|
||||||
# inflate for mention scorer
|
|
||||||
dXscores = model.ops.xp.expand_dims(dXscores, 1)
|
|
||||||
|
|
||||||
return (dXscores, dXembeds)
|
|
||||||
|
|
||||||
return (scores[selected], out), coarse_prune_backprop
|
|
||||||
|
|
||||||
|
|
||||||
def build_take_vecs() -> Model[SpanEmbeddings, Floats2d]:
|
|
||||||
# this just gets vectors out of spanembeddings
|
|
||||||
# XXX Might be better to convert SpanEmbeddings to a tuple and use with_getitem
|
|
||||||
return Model("TakeVecs", forward=take_vecs_forward)
|
|
||||||
|
|
||||||
|
|
||||||
def take_vecs_forward(model, inputs: SpanEmbeddings, is_train) -> Floats2d:
|
|
||||||
idxs = inputs.indices
|
|
||||||
lens = inputs.vectors.lengths
|
|
||||||
|
|
||||||
def backprop(dY: Floats2d) -> SpanEmbeddings:
|
|
||||||
vecs = Ragged(dY, lens)
|
|
||||||
return SpanEmbeddings(idxs, vecs)
|
|
||||||
|
|
||||||
return inputs.vectors.data, backprop
|
|
||||||
|
|
||||||
|
|
||||||
def build_ant_scorer(
|
|
||||||
bilinear, dropout, ant_limit=50
|
|
||||||
) -> Model[Tuple[Floats1d, SpanEmbeddings], List[Floats2d]]:
|
|
||||||
model = Model(
|
|
||||||
"AntScorer",
|
|
||||||
forward=ant_scorer_forward,
|
|
||||||
layers=[bilinear, dropout],
|
|
||||||
attrs={
|
|
||||||
"ant_limit": ant_limit,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
model.set_ref("bilinear", bilinear)
|
|
||||||
model.set_ref("dropout", dropout)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def ant_scorer_forward(
|
|
||||||
model, inputs: Tuple[Floats1d, SpanEmbeddings], is_train
|
|
||||||
) -> Tuple[Tuple[List[Tuple[Floats2d, Ints2d]], Ints2d], Callable]:
|
|
||||||
ops = model.ops
|
|
||||||
xp = ops.xp
|
|
||||||
|
|
||||||
ant_limit = model.attrs["ant_limit"]
|
|
||||||
# this contains the coarse bilinear in coref-hoi
|
|
||||||
# coarse bilinear is a single layer linear network
|
|
||||||
# TODO make these proper refs
|
|
||||||
bilinear = model.get_ref("bilinear")
|
|
||||||
dropout = model.get_ref("dropout")
|
|
||||||
|
|
||||||
mscores, sembeds = inputs
|
|
||||||
vecs = sembeds.vectors # ragged
|
|
||||||
|
|
||||||
offset = 0
|
|
||||||
backprops = []
|
|
||||||
out = []
|
|
||||||
for ll in vecs.lengths:
|
|
||||||
hi = offset + ll
|
|
||||||
# each iteration is one doc
|
|
||||||
|
|
||||||
# first calculate the pairwise product scores
|
|
||||||
cvecs = vecs.data[offset:hi]
|
|
||||||
pw_prod, prod_back = pairwise_product(bilinear, dropout, cvecs, is_train)
|
|
||||||
|
|
||||||
# now calculate the pairwise mention scores
|
|
||||||
ms = mscores[offset:hi].flatten()
|
|
||||||
pw_sum, pw_sum_back = pairwise_sum(ops, ms)
|
|
||||||
|
|
||||||
# make a mask so antecedents precede referrents
|
|
||||||
ant_range = xp.arange(0, cvecs.shape[0])
|
|
||||||
|
|
||||||
# This will take the log of 0, which causes a warning, but we're doing
|
|
||||||
# it on purpose so we can just ignore the warning.
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
|
||||||
mask = xp.log(
|
|
||||||
(xp.expand_dims(ant_range, 1) - xp.expand_dims(ant_range, 0)) >= 1
|
|
||||||
).astype("f")
|
|
||||||
|
|
||||||
scores = pw_prod + pw_sum + mask
|
|
||||||
|
|
||||||
top_limit = min(ant_limit, len(scores))
|
|
||||||
top_scores, top_scores_idx = topk(xp, scores, top_limit)
|
|
||||||
# now add the placeholder
|
|
||||||
placeholder = ops.alloc2f(scores.shape[0], 1)
|
|
||||||
top_scores = xp.concatenate((placeholder, top_scores), 1)
|
|
||||||
|
|
||||||
out.append((top_scores, top_scores_idx))
|
|
||||||
|
|
||||||
# In the full model these scores can be further refined. In the current
|
|
||||||
# state of this model we're done here, so this pruning is less important,
|
|
||||||
# but it's still helpful for reducing memory usage (since scores can be
|
|
||||||
# garbage collected when the loop exits).
|
|
||||||
|
|
||||||
offset += ll
|
|
||||||
backprops.append((prod_back, pw_sum_back))
|
|
||||||
|
|
||||||
# save vars for gc
|
|
||||||
vecshape = vecs.data.shape
|
|
||||||
veclens = vecs.lengths
|
|
||||||
scoreshape = mscores.shape
|
|
||||||
idxes = sembeds.indices
|
|
||||||
|
|
||||||
def backprop(
|
|
||||||
dYs: Tuple[List[Tuple[Floats2d, Ints2d]], Ints2d]
|
|
||||||
) -> Tuple[Floats2d, SpanEmbeddings]:
|
|
||||||
dYscores, dYembeds = dYs
|
|
||||||
dXembeds = Ragged(ops.alloc2f(*vecshape), veclens)
|
|
||||||
dXscores = ops.alloc1f(*scoreshape)
|
|
||||||
|
|
||||||
offset = 0
|
|
||||||
for dy, (prod_back, pw_sum_back), ll in zip(dYscores, backprops, veclens):
|
|
||||||
hi = offset + ll
|
|
||||||
dyscore, dyidx = dy
|
|
||||||
# remove the placeholder
|
|
||||||
dyscore = dyscore[:, 1:]
|
|
||||||
# the full score grid is square
|
|
||||||
|
|
||||||
fullscore = ops.alloc2f(ll, ll)
|
|
||||||
for ii, (ridx, rscores) in enumerate(zip(dyidx, dyscore)):
|
|
||||||
fullscore[ii][ridx] = rscores
|
|
||||||
|
|
||||||
dXembeds.data[offset:hi] = prod_back(fullscore)
|
|
||||||
dXscores[offset:hi] = pw_sum_back(fullscore)
|
|
||||||
|
|
||||||
offset = hi
|
|
||||||
# make it fit back into the linear
|
|
||||||
dXscores = xp.expand_dims(dXscores, 1)
|
|
||||||
return (dXscores, SpanEmbeddings(idxes, dXembeds))
|
|
||||||
|
|
||||||
return (out, sembeds.indices), backprop
|
|
||||||
|
|
||||||
|
|
||||||
def pairwise_sum(ops, mention_scores: Floats1d) -> Tuple[Floats2d, Callable]:
|
|
||||||
"""Find the most likely mention-antecedent pairs."""
|
|
||||||
# This doesn't use multiplication because two items with low mention scores
|
|
||||||
# don't make a good candidate pair.
|
|
||||||
|
|
||||||
pw_sum = ops.xp.expand_dims(mention_scores, 1) + ops.xp.expand_dims(
|
|
||||||
mention_scores, 0
|
|
||||||
)
|
|
||||||
|
|
||||||
def backward(d_pwsum: Floats2d) -> Floats1d:
|
|
||||||
# For the backward pass, the gradient is distributed over the whole row and
|
|
||||||
# column, so pull it all in.
|
|
||||||
|
|
||||||
out = d_pwsum.sum(axis=0) + d_pwsum.sum(axis=1)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
return pw_sum, backward
|
|
||||||
|
|
||||||
|
|
||||||
def pairwise_product(bilinear, dropout, vecs: Floats2d, is_train):
|
|
||||||
# A neat side effect of this is that we don't have to pass the backprops
|
|
||||||
# around separately because the closure handles them.
|
|
||||||
source, source_b = bilinear(vecs, is_train)
|
|
||||||
target, target_b = dropout(vecs.T, is_train)
|
|
||||||
pw_prod = source @ target
|
|
||||||
|
|
||||||
def backward(d_prod: Floats2d) -> Floats2d:
|
|
||||||
dS = source_b(d_prod @ target.T)
|
|
||||||
dT = target_b(source.T @ d_prod)
|
|
||||||
dX = dS + dT.T
|
|
||||||
return dX
|
|
||||||
|
|
||||||
return pw_prod, backward
|
|
||||||
|
|
||||||
|
|
||||||
# XXX here down is wl-coref
|
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from thinc.util import xp2torch, torch2xp
|
from thinc.util import xp2torch, torch2xp
|
||||||
|
|
||||||
# TODO rename this to coref_util
|
|
||||||
from .coref_util import add_dummy
|
from .coref_util import add_dummy
|
||||||
|
|
||||||
# TODO rename to plain coref
|
# TODO rename to plain coref
|
||||||
@registry.architectures("spacy.WLCoref.v1")
|
@registry.architectures("spacy.Coref.v1")
|
||||||
def build_wl_coref_model(
|
def build_wl_coref_model(
|
||||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||||
embedding_size: int = 20,
|
embedding_size: int = 20,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user