mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
Training runs now
Evaluation needs fixing, and code still needs cleanup.
This commit is contained in:
parent
d22a002641
commit
8eadf3781b
|
@ -700,6 +700,7 @@ class CorefScorer(torch.nn.Module):
|
|||
"""
|
||||
# words [n_words, span_emb]
|
||||
# cluster_ids [n_words]
|
||||
self.lstm.flatten_parameters() # XXX without this there's a warning
|
||||
word_features = torch.unsqueeze(word_features, dim=0)
|
||||
words, _ = self.lstm(word_features)
|
||||
words = words.squeeze()
|
||||
|
|
|
@ -155,7 +155,6 @@ class CoreferenceResolver(TrainablePipe):
|
|||
|
||||
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
|
||||
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||
TODO: write actual algorithm
|
||||
|
||||
docs (Iterable[Doc]): The documents to predict.
|
||||
RETURNS: The models prediction for each document.
|
||||
|
@ -165,20 +164,18 @@ class CoreferenceResolver(TrainablePipe):
|
|||
scores, idxs = self.model.predict(docs)
|
||||
# idxs is a list of mentions (start / end idxs)
|
||||
# each item in scores includes scores and a mapping from scores to mentions
|
||||
ant_idxs = idxs
|
||||
|
||||
#TODO batching
|
||||
xp = self.model.ops.xp
|
||||
|
||||
clusters_by_doc = []
|
||||
offset = 0
|
||||
for cscores, ant_idxs in scores:
|
||||
ll = cscores.shape[0]
|
||||
hi = offset + ll
|
||||
starts = xp.arange(0, len(docs[0]))
|
||||
ends = xp.arange(0, len(docs[0])) + 1
|
||||
|
||||
starts = idxs[offset:hi, 0]
|
||||
ends = idxs[offset:hi, 1]
|
||||
predicted = get_predicted_clusters(xp, starts, ends, ant_idxs, scores)
|
||||
|
||||
clusters_by_doc = [predicted]
|
||||
|
||||
predicted = get_predicted_clusters(xp, starts, ends, ant_idxs, cscores)
|
||||
clusters_by_doc.append(predicted)
|
||||
return clusters_by_doc
|
||||
|
||||
def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None:
|
||||
|
|
Loading…
Reference in New Issue
Block a user