Training runs now

Evaluation needs fixing, and code still needs cleanup.
This commit is contained in:
Paul O'Leary McCann 2022-03-14 19:02:17 +09:00
parent d22a002641
commit 8eadf3781b
2 changed files with 8 additions and 10 deletions

View File

@ -700,6 +700,7 @@ class CorefScorer(torch.nn.Module):
""" """
# words [n_words, span_emb] # words [n_words, span_emb]
# cluster_ids [n_words] # cluster_ids [n_words]
self.lstm.flatten_parameters() # XXX without this there's a warning
word_features = torch.unsqueeze(word_features, dim=0) word_features = torch.unsqueeze(word_features, dim=0)
words, _ = self.lstm(word_features) words, _ = self.lstm(word_features)
words = words.squeeze() words = words.squeeze()

View File

@ -155,7 +155,6 @@ class CoreferenceResolver(TrainablePipe):
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]: def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
"""Apply the pipeline's model to a batch of docs, without modifying them. """Apply the pipeline's model to a batch of docs, without modifying them.
TODO: write actual algorithm
docs (Iterable[Doc]): The documents to predict. docs (Iterable[Doc]): The documents to predict.
RETURNS: The models prediction for each document. RETURNS: The models prediction for each document.
@ -165,20 +164,18 @@ class CoreferenceResolver(TrainablePipe):
scores, idxs = self.model.predict(docs) scores, idxs = self.model.predict(docs)
# idxs is a list of mentions (start / end idxs) # idxs is a list of mentions (start / end idxs)
# each item in scores includes scores and a mapping from scores to mentions # each item in scores includes scores and a mapping from scores to mentions
ant_idxs = idxs
#TODO batching
xp = self.model.ops.xp xp = self.model.ops.xp
clusters_by_doc = [] starts = xp.arange(0, len(docs[0]))
offset = 0 ends = xp.arange(0, len(docs[0])) + 1
for cscores, ant_idxs in scores:
ll = cscores.shape[0]
hi = offset + ll
starts = idxs[offset:hi, 0] predicted = get_predicted_clusters(xp, starts, ends, ant_idxs, scores)
ends = idxs[offset:hi, 1]
clusters_by_doc = [predicted]
predicted = get_predicted_clusters(xp, starts, ends, ant_idxs, cscores)
clusters_by_doc.append(predicted)
return clusters_by_doc return clusters_by_doc
def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None: def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None: