diff --git a/spacy/pipeline/coref.py b/spacy/pipeline/coref.py index f65409a80..4b1483e3c 100644 --- a/spacy/pipeline/coref.py +++ b/spacy/pipeline/coref.py @@ -129,22 +129,24 @@ class CoreferenceResolver(TrainablePipe): DOCS: https://spacy.io/api/coref#predict (TODO) """ - scores, idxs = self.model.predict(docs) - # idxs is a list of mentions (start / end idxs) - # each item in scores includes scores and a mapping from scores to mentions - ant_idxs = idxs + #print("DOCS", docs) + out = [] + for doc in docs: + scores, idxs = self.model.predict([doc]) + # idxs is a list of mentions (start / end idxs) + # each item in scores includes scores and a mapping from scores to mentions + ant_idxs = idxs - # TODO batching - xp = self.model.ops.xp + # TODO batching + xp = self.model.ops.xp - starts = xp.arange(0, len(docs[0])) - ends = xp.arange(0, len(docs[0])) + 1 + starts = xp.arange(0, len(doc)) + ends = xp.arange(0, len(doc)) + 1 - predicted = get_predicted_clusters(xp, starts, ends, ant_idxs, scores) + predicted = get_predicted_clusters(xp, starts, ends, ant_idxs, scores) + out.append(predicted) - clusters_by_doc = [predicted] - - return clusters_by_doc + return out def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None: """Modify a batch of Doc objects, using pre-computed scores. @@ -203,17 +205,21 @@ class CoreferenceResolver(TrainablePipe): return losses set_dropout_rate(self.model, drop) - inputs = [example.predicted for example in examples] - preds, backprop = self.model.begin_update(inputs) - score_matrix, mention_idx = preds + total_loss = 0 - loss, d_scores = self.get_loss(examples, score_matrix, mention_idx) - # TODO check shape here - backprop((d_scores, mention_idx)) + for eg in examples: + # TODO does this even work? + preds, backprop = self.model.begin_update([eg.predicted]) + score_matrix, mention_idx = preds + + loss, d_scores = self.get_loss([eg], score_matrix, mention_idx) + total_loss += loss + # TODO check shape here + backprop((d_scores, mention_idx)) if sgd is not None: self.finish_update(sgd) - losses[self.name] += loss + losses[self.name] += total_loss return losses def rehearse( @@ -288,48 +294,35 @@ class CoreferenceResolver(TrainablePipe): ops = self.model.ops xp = ops.xp - offset = 0 - gradients = [] - total_loss = 0 - # TODO change this - # 1. do not handle batching (add it back later) - # 2. don't do index conversion (no mentions, just word indices) - # 3. convert words to spans (if necessary) in gold and predictions + # TODO if there is more than one example, give an error + # (or actually rework this to take multiple things) + example = examples[0] + cscores = score_matrix + cidx = mention_idx - # massage score matrix to be shaped correctly - score_matrix = [(score_matrix, None)] - for example, (cscores, cidx) in zip(examples, score_matrix): + clusters = get_clusters_from_doc(example.reference) + span_idxs = create_head_span_idxs(ops, len(example.predicted)) + gscores = create_gold_scores(span_idxs, clusters) + gscores = ops.asarray2f(gscores) + # top_gscores = xp.take_along_axis(gscores, cidx, axis=1) + top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1) + # now add the placeholder + gold_placeholder = ~top_gscores.any(axis=1).T + gold_placeholder = xp.expand_dims(gold_placeholder, 1) + top_gscores = xp.concatenate((gold_placeholder, top_gscores), 1) - ll = cscores.shape[0] - hi = offset + ll + # boolean to float + top_gscores = ops.asarray2f(top_gscores) - clusters = get_clusters_from_doc(example.reference) - span_idxs = create_head_span_idxs(ops, len(example.predicted)) - gscores = create_gold_scores(span_idxs, clusters) - gscores = ops.asarray2f(gscores) - # top_gscores = xp.take_along_axis(gscores, cidx, axis=1) - top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1) - # now add the placeholder - gold_placeholder = ~top_gscores.any(axis=1).T - gold_placeholder = xp.expand_dims(gold_placeholder, 1) - top_gscores = xp.concatenate((gold_placeholder, top_gscores), 1) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=RuntimeWarning) + log_marg = ops.softmax(cscores + ops.xp.log(top_gscores), axis=1) + log_norm = ops.softmax(cscores, axis=1) + grad = log_norm - log_marg + #gradients.append((grad, cidx)) + loss = float((grad**2).sum()) - # boolean to float - top_gscores = ops.asarray2f(top_gscores) - - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=RuntimeWarning) - log_marg = ops.softmax(cscores + ops.xp.log(top_gscores), axis=1) - log_norm = ops.softmax(cscores, axis=1) - grad = log_norm - log_marg - gradients.append((grad, cidx)) - total_loss += float((grad**2).sum()) - - offset = hi - - # Undo the wrapping - gradients = gradients[0][0] - return total_loss, gradients + return loss, grad def initialize( self,