Add fake batching

The way fake batching works is that the pipeline component calls the
model repeatedly in a loop internally. It feels like this should break
something, but it worked in testing.

Another issue is that this changes the signature of some of the pipeline
functions, though I don't think that's an issue.

Tested with batch size of 2, so more testing is needed, but this is a
start.
This commit is contained in:
Paul O'Leary McCann 2022-03-18 19:46:58 +09:00
parent 1a79d18796
commit a098849112

View File

@ -129,7 +129,10 @@ class CoreferenceResolver(TrainablePipe):
DOCS: https://spacy.io/api/coref#predict (TODO) DOCS: https://spacy.io/api/coref#predict (TODO)
""" """
scores, idxs = self.model.predict(docs) #print("DOCS", docs)
out = []
for doc in docs:
scores, idxs = self.model.predict([doc])
# idxs is a list of mentions (start / end idxs) # idxs is a list of mentions (start / end idxs)
# each item in scores includes scores and a mapping from scores to mentions # each item in scores includes scores and a mapping from scores to mentions
ant_idxs = idxs ant_idxs = idxs
@ -137,14 +140,13 @@ class CoreferenceResolver(TrainablePipe):
# TODO batching # TODO batching
xp = self.model.ops.xp xp = self.model.ops.xp
starts = xp.arange(0, len(docs[0])) starts = xp.arange(0, len(doc))
ends = xp.arange(0, len(docs[0])) + 1 ends = xp.arange(0, len(doc)) + 1
predicted = get_predicted_clusters(xp, starts, ends, ant_idxs, scores) predicted = get_predicted_clusters(xp, starts, ends, ant_idxs, scores)
out.append(predicted)
clusters_by_doc = [predicted] return out
return clusters_by_doc
def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None: def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None:
"""Modify a batch of Doc objects, using pre-computed scores. """Modify a batch of Doc objects, using pre-computed scores.
@ -203,17 +205,21 @@ class CoreferenceResolver(TrainablePipe):
return losses return losses
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
inputs = [example.predicted for example in examples] total_loss = 0
preds, backprop = self.model.begin_update(inputs)
for eg in examples:
# TODO does this even work?
preds, backprop = self.model.begin_update([eg.predicted])
score_matrix, mention_idx = preds score_matrix, mention_idx = preds
loss, d_scores = self.get_loss(examples, score_matrix, mention_idx) loss, d_scores = self.get_loss([eg], score_matrix, mention_idx)
total_loss += loss
# TODO check shape here # TODO check shape here
backprop((d_scores, mention_idx)) backprop((d_scores, mention_idx))
if sgd is not None: if sgd is not None:
self.finish_update(sgd) self.finish_update(sgd)
losses[self.name] += loss losses[self.name] += total_loss
return losses return losses
def rehearse( def rehearse(
@ -288,20 +294,11 @@ class CoreferenceResolver(TrainablePipe):
ops = self.model.ops ops = self.model.ops
xp = ops.xp xp = ops.xp
offset = 0 # TODO if there is more than one example, give an error
gradients = [] # (or actually rework this to take multiple things)
total_loss = 0 example = examples[0]
# TODO change this cscores = score_matrix
# 1. do not handle batching (add it back later) cidx = mention_idx
# 2. don't do index conversion (no mentions, just word indices)
# 3. convert words to spans (if necessary) in gold and predictions
# massage score matrix to be shaped correctly
score_matrix = [(score_matrix, None)]
for example, (cscores, cidx) in zip(examples, score_matrix):
ll = cscores.shape[0]
hi = offset + ll
clusters = get_clusters_from_doc(example.reference) clusters = get_clusters_from_doc(example.reference)
span_idxs = create_head_span_idxs(ops, len(example.predicted)) span_idxs = create_head_span_idxs(ops, len(example.predicted))
@ -322,14 +319,10 @@ class CoreferenceResolver(TrainablePipe):
log_marg = ops.softmax(cscores + ops.xp.log(top_gscores), axis=1) log_marg = ops.softmax(cscores + ops.xp.log(top_gscores), axis=1)
log_norm = ops.softmax(cscores, axis=1) log_norm = ops.softmax(cscores, axis=1)
grad = log_norm - log_marg grad = log_norm - log_marg
gradients.append((grad, cidx)) #gradients.append((grad, cidx))
total_loss += float((grad**2).sum()) loss = float((grad**2).sum())
offset = hi return loss, grad
# Undo the wrapping
gradients = gradients[0][0]
return total_loss, gradients
def initialize( def initialize(
self, self,