mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
Training runs now
Evaluation needs fixing, and code still needs cleanup.
This commit is contained in:
parent
d22a002641
commit
8eadf3781b
|
@ -700,6 +700,7 @@ class CorefScorer(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
# words [n_words, span_emb]
|
# words [n_words, span_emb]
|
||||||
# cluster_ids [n_words]
|
# cluster_ids [n_words]
|
||||||
|
self.lstm.flatten_parameters() # XXX without this there's a warning
|
||||||
word_features = torch.unsqueeze(word_features, dim=0)
|
word_features = torch.unsqueeze(word_features, dim=0)
|
||||||
words, _ = self.lstm(word_features)
|
words, _ = self.lstm(word_features)
|
||||||
words = words.squeeze()
|
words = words.squeeze()
|
||||||
|
|
|
@ -155,7 +155,6 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
|
|
||||||
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
|
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
|
||||||
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||||
TODO: write actual algorithm
|
|
||||||
|
|
||||||
docs (Iterable[Doc]): The documents to predict.
|
docs (Iterable[Doc]): The documents to predict.
|
||||||
RETURNS: The models prediction for each document.
|
RETURNS: The models prediction for each document.
|
||||||
|
@ -165,20 +164,18 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
scores, idxs = self.model.predict(docs)
|
scores, idxs = self.model.predict(docs)
|
||||||
# idxs is a list of mentions (start / end idxs)
|
# idxs is a list of mentions (start / end idxs)
|
||||||
# each item in scores includes scores and a mapping from scores to mentions
|
# each item in scores includes scores and a mapping from scores to mentions
|
||||||
|
ant_idxs = idxs
|
||||||
|
|
||||||
|
#TODO batching
|
||||||
xp = self.model.ops.xp
|
xp = self.model.ops.xp
|
||||||
|
|
||||||
clusters_by_doc = []
|
starts = xp.arange(0, len(docs[0]))
|
||||||
offset = 0
|
ends = xp.arange(0, len(docs[0])) + 1
|
||||||
for cscores, ant_idxs in scores:
|
|
||||||
ll = cscores.shape[0]
|
|
||||||
hi = offset + ll
|
|
||||||
|
|
||||||
starts = idxs[offset:hi, 0]
|
predicted = get_predicted_clusters(xp, starts, ends, ant_idxs, scores)
|
||||||
ends = idxs[offset:hi, 1]
|
|
||||||
|
clusters_by_doc = [predicted]
|
||||||
|
|
||||||
predicted = get_predicted_clusters(xp, starts, ends, ant_idxs, cscores)
|
|
||||||
clusters_by_doc.append(predicted)
|
|
||||||
return clusters_by_doc
|
return clusters_by_doc
|
||||||
|
|
||||||
def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None:
|
def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user