diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index bb3c4c43c..b3664408e 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -700,6 +700,7 @@ class CorefScorer(torch.nn.Module): """ # words [n_words, span_emb] # cluster_ids [n_words] + self.lstm.flatten_parameters() # XXX without this there's a warning word_features = torch.unsqueeze(word_features, dim=0) words, _ = self.lstm(word_features) words = words.squeeze() diff --git a/spacy/pipeline/coref.py b/spacy/pipeline/coref.py index d8b534962..6833a95b4 100644 --- a/spacy/pipeline/coref.py +++ b/spacy/pipeline/coref.py @@ -155,7 +155,6 @@ class CoreferenceResolver(TrainablePipe): def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]: """Apply the pipeline's model to a batch of docs, without modifying them. - TODO: write actual algorithm docs (Iterable[Doc]): The documents to predict. RETURNS: The models prediction for each document. @@ -165,20 +164,18 @@ class CoreferenceResolver(TrainablePipe): scores, idxs = self.model.predict(docs) # idxs is a list of mentions (start / end idxs) # each item in scores includes scores and a mapping from scores to mentions + ant_idxs = idxs + #TODO batching xp = self.model.ops.xp - clusters_by_doc = [] - offset = 0 - for cscores, ant_idxs in scores: - ll = cscores.shape[0] - hi = offset + ll + starts = xp.arange(0, len(docs[0])) + ends = xp.arange(0, len(docs[0])) + 1 - starts = idxs[offset:hi, 0] - ends = idxs[offset:hi, 1] + predicted = get_predicted_clusters(xp, starts, ends, ant_idxs, scores) + + clusters_by_doc = [predicted] - predicted = get_predicted_clusters(xp, starts, ends, ant_idxs, cscores) - clusters_by_doc.append(predicted) return clusters_by_doc def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None: