diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 68ce51bbb..ba522b1f2 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -1,18 +1,402 @@ -from typing import List -from thinc.api import Model -from thinc.types import Floats2d +from dataclasses import dataclass -from ...util import registry +from thinc.api import Model, Linear, Relu, Dropout, chain, noop +from thinc.types import Floats2d, Floats1d, Ints2d, Ragged +from typing import List, Callable, Tuple from ...tokens import Doc +from ...util import registry + +from .coref_util import ( + get_predicted_clusters, + get_candidate_mentions, + select_non_crossing_spans, + make_clean_doc, + create_gold_scores, + logsumexp, + topk, +) @registry.architectures("spacy.Coref.v0") -def build_coref_model( - tok2vec: Model[List[Doc], List[Floats2d]] -) -> Model: - """Build a coref resolution model, using a provided token-to-vector component. - TODO. +def build_coref( + tok2vec: Model[List[Doc], List[Floats2d]], + get_mentions: Callable = get_candidate_mentions, + hidden: int = 1000, + dropout: float = 0.3, + mention_limit: int = 3900, + max_span_width: int = 20, +): + dim = tok2vec.get_dim("nO") * 3 - tok2vec (Model[List[Doc], List[Floats2d]]): The token-to-vector subnetwork. + span_embedder = build_span_embedder(get_mentions, max_span_width) + + with Model.define_operators({">>": chain, "&": tuplify}): + + mention_scorer = ( + Linear(nI=dim, nO=hidden) + >> Relu(nI=hidden, nO=hidden) + >> Dropout(dropout) + >> Linear(nI=hidden, nO=1) + ) + mention_scorer.initialize() + + bilinear = Linear(nI=dim, nO=dim) >> Dropout(dropout) + bilinear.initialize() + + ms = build_take_vecs() >> mention_scorer + + model = ( + (tok2vec & noop()) + >> span_embedder + >> (ms & noop()) + >> build_coarse_pruner(mention_limit) + >> build_ant_scorer(bilinear, Dropout(dropout)) + ) + return model + + +# TODO replace this with thinc version once PR is in +def tuplify(layer1: Model, layer2: Model, *layers) -> Model: + layers = (layer1, layer2) + layers + names = [layer.name for layer in layers] + return Model( + "tuple(" + ", ".join(names) + ")", + tuplify_forward, + layers=layers, + ) + + +def tuplify_forward(model, X, is_train): + Ys = [] + backprops = [] + for layer in model.layers: + Y, backprop = layer(X, is_train) + Ys.append(Y) + backprops.append(backprop) + + def backprop_tuplify(dYs): + dXs = [bp(dY) for bp, dY in zip(backprops, dYs)] + dX = dXs[0] + for dx in dXs[1:]: + dX += dx + return dX + + return tuple(Ys), backprop_tuplify + + +@dataclass +class SpanEmbeddings: + indices: Ints2d # Array with 2 columns (for start and end index) + vectors: Ragged # Ragged[Floats2d] # One vector per span + # NB: We assume that the indices refer to a concatenated Floats2d that + # has one row per token in the *batch* of documents. This makes it unambiguous + # which row is in which document, because if the lengths are e.g. [10, 5], + # a span starting at 11 must be starting at token 2 of doc 1. A bug could + # potentially cause you to have a span which crosses a doc boundary though, + # which would be bad. + # The lengths in the Ragged are not the tokens per doc, but the number of + # mentions per doc. + + def __add__(self, right): + out = self.vectors.data + right.vectors.data + return SpanEmbeddings(self.indices, Ragged(out, self.vectors.lengths)) + + def __iadd__(self, right): + self.vectors.data += right.vectors.data + return self + + +# model converting a Doc/Mention to span embeddings +# get_mentions: Callable[Doc, Pairs[int]] +def build_span_embedder( + get_mentions: Callable, + max_span_width: int = 20, +) -> Model[Tuple[List[Floats2d], List[Doc]], SpanEmbeddings]: + + return Model( + "SpanEmbedding", + forward=span_embeddings_forward, + attrs={ + "get_mentions": get_mentions, + # XXX might be better to make this an implicit parameter in the + # mention generator + "max_span_width": max_span_width, + }, + ) + + +def span_embeddings_forward( + model, inputs: Tuple[List[Floats2d], List[Doc]], is_train +) -> SpanEmbeddings: + ops = model.ops + xp = ops.xp + + tokvecs, docs = inputs + + dim = tokvecs[0].shape[1] + + get_mentions = model.attrs["get_mentions"] + max_span_width = model.attrs["max_span_width"] + mentions = ops.alloc2i(0, 2) + total_length = 0 + docmenlens = [] # number of mentions per doc + for doc in docs: + starts, ends = get_mentions(doc, max_span_width) + docmenlens.append(len(starts)) + cments = ops.asarray2i([starts, ends]).transpose() + + mentions = xp.concatenate((mentions, cments + total_length)) + total_length += len(doc) + + # TODO support attention here + tokvecs = xp.concatenate(tokvecs) + spans = [tokvecs[ii:jj] for ii, jj in mentions.tolist()] + avgs = [xp.mean(ss, axis=0) for ss in spans] + spanvecs = ops.asarray2f(avgs) + + # first and last token embeds + starts = [tokvecs[ii] for ii in mentions[:, 0]] + ends = [tokvecs[jj] for jj in mentions[:, 1]] + + starts = ops.asarray2f(starts) + ends = ops.asarray2f(ends) + concat = xp.concatenate((starts, ends, spanvecs), 1) + embeds = Ragged(concat, docmenlens) + + def backprop_span_embed(dY: SpanEmbeddings) -> Tuple[List[Floats2d], List[Doc]]: + + oweights = [] + odocs = [] + offset = 0 + tokoffset = 0 + for indoc, mlen in zip(docs, dY.vectors.lengths): + hi = offset + mlen + hitok = tokoffset + len(indoc) + odocs.append(indoc) # no change + vecs = dY.vectors.data[offset:hi] + + starts = vecs[:, :dim] + ends = vecs[:, dim : 2 * dim] + spanvecs = vecs[:, 2 * dim :] + + out = model.ops.alloc2f(len(indoc), dim) + + for ii, (start, end) in enumerate(dY.indices[offset:hi]): + # adjust indexes to align with doc + start -= tokoffset + end -= tokoffset + + out[start] += starts[ii] + out[end] += ends[ii] + out[start:end] += spanvecs[ii] + oweights.append(out) + + offset = hi + tokoffset = hitok + return oweights, odocs + + return SpanEmbeddings(mentions, embeds), backprop_span_embed + + +def build_coarse_pruner( + mention_limit: int, +) -> Model[SpanEmbeddings, SpanEmbeddings]: + model = Model( + "CoarsePruner", + forward=coarse_prune, + attrs={ + "mention_limit": mention_limit, + }, + ) + return model + + +def coarse_prune( + model, inputs: Tuple[Floats1d, SpanEmbeddings], is_train +) -> SpanEmbeddings: + """Given scores for mention, output the top non-crossing mentions. + + Mentions can contain other mentions, but candidate mentions cannot cross each other. """ - return tok2vec + rawscores, spanembeds = inputs + scores = rawscores.squeeze() + mention_limit = model.attrs["mention_limit"] + # XXX: Issue here. Don't need docs to find crossing spans, but might for the limits. + # In old code the limit can be: + # - hard number per doc + # - ratio of tokens in the doc + + offset = 0 + selected = [] + sellens = [] + for menlen in spanembeds.vectors.lengths: + hi = offset + menlen + cscores = scores[offset:hi] + + # negate it so highest numbers come first + tops = (model.ops.xp.argsort(-1 * cscores)).tolist() + starts = spanembeds.indices[offset:hi, 0].tolist() + ends = spanembeds.indices[offset:hi:, 1].tolist() + + # csel is a 1d integer list + csel = select_non_crossing_spans(tops, starts, ends, mention_limit) + # add the offset so these indices are absolute + csel = [ii + offset for ii in csel] + # this should be constant because short choices are padded + sellens.append(len(csel)) + selected += csel + offset += menlen + + selected = model.ops.asarray1i(selected) + top_spans = spanembeds.indices[selected] + top_vecs = spanembeds.vectors.data[selected] + + out = SpanEmbeddings(top_spans, Ragged(top_vecs, sellens)) + + def coarse_prune_backprop( + dY: Tuple[Floats1d, SpanEmbeddings] + ) -> Tuple[Floats1d, SpanEmbeddings]: + ll = spanembeds.indices.shape[0] + + dYscores, dYembeds = dY + + dXscores = model.ops.alloc1f(ll) + dXscores[selected] = dYscores.squeeze() + + dXvecs = model.ops.alloc2f(*spanembeds.vectors.data.shape) + dXvecs[selected] = dYembeds.vectors.data + rout = Ragged(dXvecs, out.vectors.lengths) + dXembeds = SpanEmbeddings(spanembeds.indices, rout) + + # inflate for mention scorer + dXscores = model.ops.xp.expand_dims(dXscores, 1) + + return (dXscores, dXembeds) + + return (scores[selected], out), coarse_prune_backprop + + +def build_take_vecs() -> Model[SpanEmbeddings, Floats2d]: + # this just gets vectors out of spanembeddings + # XXX Might be better to convert SpanEmbeddings to a tuple and use with_getitem + return Model("TakeVecs", forward=take_vecs_forward) + + +def take_vecs_forward(model, inputs: SpanEmbeddings, is_train) -> Floats2d: + def backprop(dY: Floats2d) -> SpanEmbeddings: + vecs = Ragged(dY, inputs.vectors.lengths) + return SpanEmbeddings(inputs.indices, vecs) + + return inputs.vectors.data, backprop + + +def build_ant_scorer( + bilinear, dropout, ant_limit=50 +) -> Model[Tuple[Floats1d, SpanEmbeddings], List[Floats2d]]: + return Model( + "AntScorer", + forward=ant_scorer_forward, + layers=[bilinear, dropout], + attrs={ + "ant_limit": ant_limit, + }, + ) + + +def ant_scorer_forward( + model, inputs: Tuple[Floats1d, SpanEmbeddings], is_train +) -> Tuple[List[Tuple[Floats2d, Ints2d]], Ints2d]: + ops = model.ops + xp = ops.xp + + ant_limit = model.attrs["ant_limit"] + # this contains the coarse bilinear in coref-hoi + # coarse bilinear is a single layer linear network + # TODO make these proper refs + bilinear = model.layers[0] + dropout = model.layers[1] + + # XXX Note on dimensions: This won't work as a ragged because the floats2ds + # are not all the same dimentions. Each floats2d is a square in the size of + # the number of antecedents in the document. Actually, that will have the + # same size if antecedents are padded... Needs checking. + + mscores, sembeds = inputs + vecs = sembeds.vectors # ragged + + offset = 0 + backprops = [] + out = [] + for ll in vecs.lengths: + hi = offset + ll + # each iteration is one doc + + # first calculate the pairwise product scores + cvecs = vecs.data[offset:hi] + source, source_b = bilinear(cvecs, is_train) + target, target_b = dropout(cvecs, is_train) + pw_prod = xp.matmul(source, target.T) + + # now calculate the pairwise mention scores + ms = mscores[offset:hi].squeeze() + pw_sum = xp.expand_dims(ms, 1) + xp.expand_dims(ms, 0) + + # make a mask so antecedents precede referrents + ant_range = xp.arange(0, cvecs.shape[0]) + # with xp.errstate(divide="ignore"): + # mask = xp.log( + # (xp.expand_dims(ant_range, 1) - xp.expand_dims(ant_range, 0)) >= 1 + # ).astype(float) + mask = xp.log( + (xp.expand_dims(ant_range, 1) - xp.expand_dims(ant_range, 0)) >= 1 + ).astype(float) + + scores = pw_prod + pw_sum + mask + + top_scores, top_scores_idx = topk(xp, scores, ant_limit) + out.append((top_scores, top_scores_idx)) + + # In the full model these scores can be further refined. In the current + # state of this model we're done here, so this pruning is less important, + # but it's still helpful for reducing memory usage (since scores can be + # garbage collected when the loop exits). + + offset += ll + backprops.append((source_b, target_b, source, target)) + + def backprop( + dYs: Tuple[List[Tuple[Floats2d, Ints2d]], Ints2d] + ) -> Tuple[Floats2d, SpanEmbeddings]: + dYscores, dYembeds = dYs + dXembeds = Ragged(ops.alloc2f(*vecs.data.shape), vecs.lengths) + dXscores = ops.alloc1f(*mscores.shape) + + offset = 0 + for dy, (source_b, target_b, source, target), ll in zip( + dYscores, backprops, vecs.lengths + ): + # I'm not undoing the operations in the right order here. + dyscore, dyidx = dy + # the full score grid is square + + fullscore = ops.alloc2f(ll, ll) + # cupy has no put_along_axis + # xp.put_along_axis(fullscore, dyidx, dyscore, 1) + for ii, (ridx, rscores) in enumerate(zip(dyidx, dyscore)): + fullscore[ii][ridx] = rscores + + dS = source_b(fullscore @ target) + dT = target_b(fullscore @ source) + dXembeds.data[offset : offset + ll] = dS + dT + + # The gradient can be distributed over all the rows and columns here, + # so aggregate it + section = dXscores[offset : offset + ll] + for ii in range(ll): + section[ii] = fullscore[:, ii].sum() + fullscore[ii, :].sum() + offset += ll + # make it fit back into the linear + dXscores = xp.expand_dims(dXscores, 1) + return (dXscores, SpanEmbeddings(sembeds.indices, dXembeds)) + + return (out, sembeds.indices), backprop diff --git a/spacy/ml/models/coref_util.py b/spacy/ml/models/coref_util.py new file mode 100644 index 000000000..7c44692c3 --- /dev/null +++ b/spacy/ml/models/coref_util.py @@ -0,0 +1,252 @@ +from thinc.types import Ints2d +from spacy.tokens import Doc +from typing import List, Tuple + +# type alias to make writing this less tedious +MentionClusters = List[List[Tuple[int, int]]] + +DEFAULT_CLUSTER_PREFIX = "coref_clusters" + + +def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters: + """Given a doc, give the mention clusters. + + This is useful for scoring. + """ + out = [] + for name, val in doc.spans.items(): + if not name.startswith(prefix): + continue + + cluster = [] + for mention in val: + cluster.append((mention.start, mention.end)) + out.append(cluster) + return out + + +def topk(xp, arr, k, axis=None): + """Given and array and a k value, give the top values and idxs for each row.""" + + part = xp.argpartition(arr, -k, axis=1) + idxs = xp.flip(part)[:, :k] + + vals = xp.take_along_axis(arr, idxs, axis=1) + + sidxs = xp.argsort(vals, axis=1) + # map these idxs back to the original + oidxs = xp.take_along_axis(idxs, sidxs, axis=1) + svals = xp.take_along_axis(vals, sidxs, axis=1) + return svals, oidxs + + +def logsumexp(xp, arr, axis=None): + """Emulate torch.logsumexp by returning the log of summed exponentials + along each row in the given dimension. + + Reduces a 2d array to 1d.""" + # from slide 5 here: + # https://www.slideshare.net/ryokuta/cupy + hi = arr.max(axis=axis) + hi = xp.expand_dims(hi, 1) + return hi.squeeze() + xp.log(xp.exp(arr - hi).sum(axis=axis)) + + +# from model.py, refactored to be non-member +def get_predicted_antecedents(xp, antecedent_idx, antecedent_scores): + """Get the ID of the antecedent for each span. -1 if no antecedent.""" + predicted_antecedents = [] + for i, idx in enumerate(xp.argmax(antecedent_scores, axis=1) - 1): + if idx < 0: + predicted_antecedents.append(-1) + else: + predicted_antecedents.append(antecedent_idx[i][idx]) + return predicted_antecedents + + +# from model.py, refactored to be non-member +def get_predicted_clusters( + xp, span_starts, span_ends, antecedent_idx, antecedent_scores +): + """Convert predictions to usable cluster data. + + return values: + + clusters: a list of spans (i, j) that are a cluster + + Note that not all spans will be in the final output; spans with no + antecedent or referrent are omitted from clusters and mention2cluster. + """ + # Get predicted antecedents + predicted_antecedents = get_predicted_antecedents( + xp, antecedent_idx, antecedent_scores + ) + + # Get predicted clusters + mention_to_cluster_id = {} + predicted_clusters = [] + for i, predicted_idx in enumerate(predicted_antecedents): + if predicted_idx < 0: + continue + assert i > predicted_idx, f"span idx: {i}; antecedent idx: {predicted_idx}" + # Check antecedent's cluster + antecedent = (int(span_starts[predicted_idx]), int(span_ends[predicted_idx])) + antecedent_cluster_id = mention_to_cluster_id.get(antecedent, -1) + if antecedent_cluster_id == -1: + antecedent_cluster_id = len(predicted_clusters) + predicted_clusters.append([antecedent]) + mention_to_cluster_id[antecedent] = antecedent_cluster_id + # Add mention to cluster + mention = (int(span_starts[i]), int(span_ends[i])) + predicted_clusters[antecedent_cluster_id].append(mention) + mention_to_cluster_id[mention] = antecedent_cluster_id + + predicted_clusters = [tuple(c) for c in predicted_clusters] + return predicted_clusters + + +def get_sentence_map(doc: Doc): + """For the given span, return a list of sentence indexes.""" + + si = 0 + out = [] + for sent in doc.sents: + for tok in sent: + out.append(si) + si += 1 + return out + + +def get_candidate_mentions( + doc: Doc, max_span_width: int = 20 +) -> Tuple[List[int], List[int]]: + """Given a Doc, return candidate mentions. + + This isn't a trainable layer, it just returns raw candidates. + """ + # XXX Note that in coref-hoi the indexes are designed so you actually want [i:j+1], but here + # we're using [i:j], which is more natural. + + sentence_map = get_sentence_map(doc) + + begins = [] + ends = [] + for tok in doc: + si = sentence_map[tok.i] # sentence index + for ii in range(1, max_span_width): + ei = tok.i + ii # end index + if ei < len(doc) and sentence_map[ei] == si: + begins.append(tok.i) + ends.append(ei) + + return (begins, ends) + + +def select_non_crossing_spans( + idxs: List[int], starts: List[int], ends: List[int], limit: int +) -> List[int]: + """Given a list of spans sorted in descending order, return the indexes of + spans to keep, discarding spans that cross. + + Nested spans are allowed. + """ + # ported from Model._extract_top_spans + selected = [] + start_to_max_end = {} + end_to_min_start = {} + + for idx in idxs: + if len(selected) >= limit or idx > len(starts): + break + + start, end = starts[idx], ends[idx] + cross = False + + for ti in range(start, end + 1): + max_end = start_to_max_end.get(ti, -1) + if ti > start and max_end > end: + cross = True + break + + min_start = end_to_min_start.get(ti, -1) + if ti < end and 0 <= min_start < start: + cross = True + break + + if not cross: + # this index will be kept + # record it so we can exclude anything that crosses it + selected.append(idx) + max_end = start_to_max_end.get(start, -1) + if end > max_end: + start_to_max_end[start] = end + min_start = end_to_min_start.get(end, -1) + if start == -1 or start < min_start: + end_to_min_start[end] = start + + # sort idxs by order in doc + selected = sorted(selected, key=lambda idx: (starts[idx], ends[idx])) + while len(selected) < limit: + selected.append(selected[0]) # this seems a bit weird? + return selected + + +def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]: + """Given a Doc, convert the cluster spans to simple int tuple lists.""" + out = [] + for key, val in doc.spans.items(): + cluster = [] + for span in val: + # TODO check that there isn't an off-by-one error here + cluster.append((span.start, span.end)) + out.append(cluster) + return out + + +def make_clean_doc(nlp, doc): + """Return a doc with raw data but not span annotations.""" + # Surely there is a better way to do this? + + sents = [tok.is_sent_start for tok in doc] + words = [tok.text for tok in doc] + out = Doc(nlp.vocab, words=words, sent_starts=sents) + return out + + +def create_gold_scores( + ments: Ints2d, clusters: List[List[Tuple[int, int]]] +) -> List[List[bool]]: + """Given mentions considered for antecedents and gold clusters, + construct a gold score matrix. This does not include the placeholder.""" + # make a mapping of mentions to cluster id + # id is not important but equality will be + ment2cid = {} + for cid, cluster in enumerate(clusters): + for ment in cluster: + ment2cid[ment] = cid + + ll = len(ments) + out = [] + # The .tolist() call is necessary with cupy but not numpy + mentuples = [tuple(mm.tolist()) for mm in ments] + for ii, ment in enumerate(mentuples): + if ment not in ment2cid: + # this is not in a cluster so it has no antecedent + out.append([False] * ll) + continue + + # this might change if no real antecedent is a candidate + row = [] + cid = ment2cid[ment] + for jj, ante in enumerate(mentuples): + # antecedents must come first + if jj >= ii: + row.append(False) + continue + + row.append(cid == ment2cid.get(ante, -1)) + + out.append(row) + + # caller needs to convert to array, and add placeholder + return out diff --git a/spacy/pipeline/coref.py b/spacy/pipeline/coref.py index 9ccc2c89f..d0fecf519 100644 --- a/spacy/pipeline/coref.py +++ b/spacy/pipeline/coref.py @@ -1,6 +1,7 @@ -from typing import Iterable, Tuple, Optional, Dict, Callable, Any +from typing import Iterable, Tuple, Optional, Dict, Callable, Any, List -from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config +from thinc.types import Floats2d, Ints2d +from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy from itertools import islice from .trainable_pipe import TrainablePipe @@ -12,10 +13,25 @@ from ..scorer import Scorer from ..tokens import Doc from ..vocab import Vocab +from ..ml.models.coref_util import ( + create_gold_scores, + MentionClusters, + get_clusters_from_doc, + logsumexp, + get_predicted_clusters, + DEFAULT_CLUSTER_PREFIX, + doc2clusters, +) + default_config = """ [model] @architectures = "spacy.Coref.v0" +max_span_width = 20 +mention_limit = 3900 +dropout = 0.3 +hidden = 1000 +@get_mentions = "spacy.CorefCandidateGenerator.v0" [model.tok2vec] @architectures = "spacy.Tok2Vec.v2" @@ -41,12 +57,11 @@ DEFAULT_CLUSTERS_PREFIX = "coref_clusters" @Language.factory( "coref", - assigns=[f"doc.spans"], + assigns=["doc.spans"], requires=["doc.spans"], default_config={ "model": DEFAULT_MODEL, - "span_mentions": DEFAULT_MENTIONS, - "span_cluster_prefix": DEFAULT_CLUSTERS_PREFIX, + "span_cluster_prefix": DEFAULT_CLUSTER_PREFIX, }, default_score_weights={"coref_f": 1.0, "coref_p": None, "coref_r": None}, ) @@ -54,21 +69,11 @@ def make_coref( nlp: Language, name: str, model, - span_mentions: str, - span_cluster_prefix: str, + span_cluster_prefix: str = "coref", ) -> "CoreferenceResolver": - """Create a CoreferenceResolver component. TODO + """Create a CoreferenceResolver component.""" - model (Model[List[Doc], List[Floats2d]]): A model instance that predicts ... - threshold (float): Cutoff to consider a prediction "positive". - """ - return CoreferenceResolver( - nlp.vocab, - model, - name, - span_mentions=span_mentions, - span_cluster_prefix=span_cluster_prefix, - ) + return CoreferenceResolver(nlp.vocab, model, name, span_cluster_prefix) class CoreferenceResolver(TrainablePipe): @@ -105,9 +110,11 @@ class CoreferenceResolver(TrainablePipe): self.span_mentions = span_mentions self.span_cluster_prefix = span_cluster_prefix self._rehearsal_model = None + self.loss = CategoricalCrossentropy() + self.cfg = {} - def predict(self, docs: Iterable[Doc]): + def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]: """Apply the pipeline's model to a batch of docs, without modifying them. TODO: write actual algorithm @@ -116,12 +123,27 @@ class CoreferenceResolver(TrainablePipe): DOCS: https://spacy.io/api/coref#predict (TODO) """ + scores, idxs = self.model.predict(docs) + # idxs is a list of mentions (start / end idxs) + # each item in scores includes scores and a mapping from scores to mentions + + xp = self.model.ops.xp + clusters_by_doc = [] - for i, doc in enumerate(docs): - clusters = [] - for span in doc.spans[self.span_mentions]: - clusters.append([span]) - clusters_by_doc.append(clusters) + offset = 0 + for cscores, ant_idxs in scores: + ll = cscores.shape[0] + hi = offset + ll + + starts = idxs[offset:hi, 0] + ends = idxs[offset:hi, 1] + + # need to add the placeholder + placeholder = self.model.ops.alloc2f(cscores.shape[0], 1) + cscores = xp.concatenate((placeholder, cscores), 1) + + predicted = get_predicted_clusters(xp, starts, ends, ant_idxs, cscores) + clusters_by_doc.append(predicted) return clusters_by_doc def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None: @@ -133,18 +155,24 @@ class CoreferenceResolver(TrainablePipe): DOCS: https://spacy.io/api/coref#set_annotations (TODO) """ if len(docs) != len(clusters_by_doc): - raise ValueError("Found coref clusters incompatible with the " - "documents provided to the 'coref' component. " - "This is likely a bug in spaCy.") + raise ValueError( + "Found coref clusters incompatible with the " + "documents provided to the 'coref' component. " + "This is likely a bug in spaCy." + ) for doc, clusters in zip(docs, clusters_by_doc): - index = 0 - for cluster in clusters: - key = self.span_cluster_prefix + str(index) + for ii, cluster in enumerate(clusters): + key = self.span_cluster_prefix + "_" + str(ii) if key in doc.spans: - raise ValueError(f"Couldn't store the results of {self.name}, as the key " - f"{key} already exists in 'doc.spans'.") - doc.spans[key] = cluster - index += 1 + raise ValueError( + "Found coref clusters incompatible with the " + "documents provided to the 'coref' component. " + "This is likely a bug in spaCy." + ) + + doc.spans[key] = [] + for mention in cluster: + doc.spans[key].append(doc[mention[0] : mention[1]]) def update( self, @@ -174,13 +202,16 @@ class CoreferenceResolver(TrainablePipe): # Handle cases where there are no tokens in any docs. return losses set_dropout_rate(self.model, drop) - scores, bp_scores = self.model.begin_update([eg.predicted for eg in examples]) - # TODO below - # loss, d_scores = self.get_loss(examples, scores) - # bp_scores(d_scores) + + inputs = (example.predicted for example in examples) + preds, backprop = self.model.begin_update(inputs) + score_matrix, mention_idx = preds + loss, d_scores = self.get_loss(examples, score_matrix, mention_idx) + backprop(d_scores) + if sgd is not None: self.finish_update(sgd) - # losses[self.name] += loss + losses[self.name] += loss return losses def rehearse( @@ -236,7 +267,12 @@ class CoreferenceResolver(TrainablePipe): ) ) - def get_loss(self, examples: Iterable[Example], scores) -> Tuple[float, float]: + def get_loss( + self, + examples: Iterable[Example], + score_matrix: List[Tuple[Floats2d, Ints2d]], + mention_idx: Ints2d, + ): """Find the loss and gradient of loss for the batch of documents and their predicted scores. @@ -246,9 +282,46 @@ class CoreferenceResolver(TrainablePipe): DOCS: https://spacy.io/api/coref#get_loss (TODO) """ - validate_examples(examples, "CoreferenceResolver.get_loss") - # TODO - return None + ops = self.model.ops + xp = ops.xp + + offset = 0 + gradients = [] + loss = 0 + for example, (cscores, cidx) in zip(examples, score_matrix): + # assume cids has absolute mention ids + + ll = cscores.shape[0] + hi = offset + ll + + clusters = get_clusters_from_doc(example.reference) + gscores = create_gold_scores(mention_idx[offset:hi], clusters) + gscores = xp.asarray(gscores) + top_gscores = xp.take_along_axis(gscores, cidx, axis=1) + # now add the placeholder + gold_placeholder = ~top_gscores.any(axis=1).T + gold_placeholder = xp.expand_dims(gold_placeholder, 1) + top_gscores = xp.concatenate((gold_placeholder, top_gscores), 1) + + # boolean to float + top_gscores = ops.asarray2f(top_gscores) + + # add the placeholder to cscores + placeholder = self.model.ops.alloc2f(ll, 1) + cscores = xp.concatenate((placeholder, cscores), 1) + + # do softmax to cscores + cscores = ops.softmax(cscores, axis=1) + + diff = self.loss.get_grad(cscores, top_gscores) + diff = diff[:, 1:] + gradients.append((diff, cidx)) + + # scalar loss + # loss += xp.sum(log_norm - log_marg) + loss += self.loss.get_loss(cscores, top_gscores) + offset += ll + return loss, gradients def initialize( self, @@ -279,10 +352,39 @@ class CoreferenceResolver(TrainablePipe): DOCS: https://spacy.io/api/coref#score (TODO) """ + def clusters_getter(doc, span_key): - return [spans for name, spans in doc.spans.items() if name.startswith(span_key)] + return [ + spans for name, spans in doc.spans.items() if name.startswith(span_key) + ] + validate_examples(examples, "CoreferenceResolver.score") kwargs.setdefault("getter", clusters_getter) kwargs.setdefault("attr", self.span_cluster_prefix) kwargs.setdefault("include_label", False) return Scorer.score_clusters(examples, **kwargs) + + +# from ..coref_scorer import Evaluator, get_cluster_info, b_cubed +# TODO consider whether to use this +# def score(self, examples, **kwargs): +# """Score a batch of examples.""" +# +# #TODO traditionally coref uses the average of b_cubed, muc, and ceaf. +# # we need to handle the average ourselves. +# evaluator = Evaluator(b_cubed) +# +# for ex in examples: +# p_clusters = doc2clusters(ex.predicted, self.span_cluster_prefix) +# g_clusters = doc2clusters(ex.reference, self.span_cluster_prefix) +# +# cluster_info = get_cluster_info(p_clusters, g_clusters) +# +# evaluator.update(cluster_info) +# +# scores ={ +# "coref_f": evaluator.get_f1(), +# "coref_p": evaluator.get_precision(), +# "coref_r": evaluator.get_recall(), +# } +# return scores