diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index 3a981308b..4223f32ed 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -1,16 +1,16 @@ -from .pipes import Pipe -from thinc.v2v import Maxout, Affine -from thinc.t2t import SoftAttention +from thinc.v2v import Maxout, Affine, Model from thinc.t2v import Pooling, sum_pool -from thinc.api import zero_init -from .._ml import logistic +from thinc.api import layerize, chain +from thinc.misc import LayerNorm, Residual +from .pipes import Pipe +from .._ml import logistic, zero_init, Tok2Vec class SpanCategorizer(Pipe): """Predict labels for spans of text.""" @classmethod - def Model(cls, **cfg): + def Model(cls, nO, **cfg): # TODO: Settings here tok2vec = Tok2Vec(**cfg) with Model.define_operators({">>": chain}): @@ -19,17 +19,22 @@ class SpanCategorizer(Pipe): #>> SoftAttention >> Pooling(sum_pool) >> LayerNorm(Residual(Maxout(tok2vec.nO))) - >> zero_init(Affine(tok2vec.nO)) + >> zero_init(Affine(nO, tok2vec.nO)) >> logistic ) - return create_span_model(self.get_spans, tok2vec, span2scores) + return SpanCategorizerModel(tok2vec, span2scores) - def __init__(self, user_data_key="phrases", get_spans=None, model=True): - Pipe.__init__(self) + def __init__(self, vocab, user_data_key="phrases", max_length=10, + get_spans=None, model=True): + self.cfg = { + "user_data_key": user_data_key, + "labels": [], + "max_length": max_length + } + self.vocab = vocab self.user_data_key = user_data_key self.span_getter = get_spans self.model = model - self.max_length = 10 @property def tok2vec(self): @@ -116,13 +121,13 @@ class SpanCategorizer(Pipe): get_grads.b1 = sgd.b1 get_grads.b2 = sgd.b2 - token_label_matrix = _get_token_label_matrix( - [g.phrases for g in golds], [len(doc) for doc in docs], self.labels - ) + golds = [getattr(g, "spans", g) for g in golds] + + token_label_matrix = _get_token_label_matrix(golds, [len(d) for d in docs], self.labels) for indices, starts, length in _batch_spans_by_length(spans): X = _get_span_batch(tokvecs, starts, length) - Y, backprop = self.spans2scores.begin_update(X, drop=drop) + Y, backprop = self.model.spans2scores.begin_update(X, drop=drop) dY = self.get_loss((indices, Y), token_label_matrix) dX = backprop(dY, sgd=get_grads) for i, start in enumerate(starts): @@ -135,6 +140,76 @@ class SpanCategorizer(Pipe): return losses +class SpanCategorizerModel(Model): + """Predict labels for spans, using two component models: a tok2vec model, + and a span2scores model. The input to the SpanCategorizerModel should be + a tuple (docs, spans), where the spans are an array with two columns: + (start, end). + + The output will be a tuple (outputs, spans), where the outputs array + will have one row per span. In the backward pass, we take the gradients w.r.t. + the spans, and backprop through the input vectors. + + A naive implementation of this process would make a single array, padded + for all spans. However, the longest span may be very long, and we might + have lots of spans, so this array could consume an enormous amount of + memory. Instead, we sort the spans by length and work in batches. This + reduces the total amount of padding, and means we do not need to hold + expanded arrays for the whole data. As a bonus, the input model also + doesn't need to worry about masking: we know that the + data it works over has no empty items. + """ + name = "spancat_model" + def __init__(self, tok2vec, span2scores): + Model.__init__(self) + self.tok2vec = tok2vec + self.span2scores = span2scores + self._layers.append(tok2vec) + self._layers.append(span2scores) + + @property + def nO(self): + return span2scores.nO + + def predict(self, docs_spans): + """Predict scores for the spans, batching by length.""" + scores = self.ops.xp.zeros((len(spans), model.nO), dtype="f") + for indices, starts, length in _batch_spans_by_length(spans): + X = _get_span_batch(tokvecs, starts, length) + batchY = model(X) + for i, output_idx in enumerate(indices): + scores[output_idx] = Y[i] + return scores + + def begin_update(self, docs_spans, drop=0.): + """Do the forward pass of the span classification, batching the input. + + Returns a tuple (scores, callback), where the callback takes an array + d_scores and performs the backward pass.""" + docs, spans = docs_spans + tokvecs, bp_tokvecs = self.tok2vec.begin_update(docs, drop=drop) + scores = self.ops.xp.zeros((len(spans), model.nO), dtype="f") + backprops = [] + for indices, starts, length in _batch_spans_by_length(spans): + X = _get_span_batch(tokvecs, starts, length) + batchY, backprop = self.span2scores.begin_update(X, drop=drop) + for i, output_idx in enumerate(indices): + scores[output_idx] = Y[i] + backprops.append((indices, starts, length, backprop)) + shape = tokvecs.shape + + def backprop_spancat_model(d_scores, sgd=None): + d_tokvecs = self.ops.xp.zeros(shape, dtype=d_output.dtype) + for indices, starts, ends, backprop in backprops: + dY = d_output[indices] + dX = backprop(dY) + for i, (start, end) in enumerate(zip(starts, ends)): + d_tokvecs[start:end] += dX[i, : end - start] + return bp_tokvecs(d_tokvecs, sgd=sgd) + + return scores, backprop_spancat_model + + @layerize def reshape_add_lengths(X, drop=0.): xp = get_array_module(X) @@ -149,52 +224,9 @@ def reshape_add_lengths(X, drop=0.): return Y, backprop_reshape -def predict_spans(doc2spans, tok2vecs, span2scores): - """Apply a model over inputs that are a tuple of (vectors, spans), where the - spans are an array of (start, end) offsets. The vectors should be a single - array concatenated for the whole batch. - - The output will be a tuple (outputs, spans), where the outputs array - will have one row per span. In the backward pass, we take the gradients w.r.t. - the spans, and return the gradients w.r.t. the input vectors. - - A naive implementation of this process would make a single array, padded - for all spans. However, the longest span may be very long, so this array - would consume an enormous amount of memory. Instead, we sort the spans by - length and work in batches. This reduces the total amount of padding, and - means we do not need to hold expanded arrays for the whole data. As a bonus, - the input model also doesn't need to worry about masking: we know that the - data it works over has no empty items. - """ - - def apply_to_spans_forward(inputs, drop=0.0): - docs = inputs.get("docs") - tokvecs = inputs.get("tokvecs") - spans = inputs.get("spans") - if spans is None: - spans = doc2spans(docs) - if tokvecs is None: - tokvecs, bp_tokvecs = tok2vecs.begin_update(docs, drop=drop) - else: - bp_tokvecs = None - scores, backprop_scores = _begin_update_batched( - span2scores, tokvecs, spans, drop=drop - ) - shape = tokvecs.shape - - def apply_to_spans_backward(d_scores, sgd=None): - d_tokvecs = _backprop_batched(shape, d_scores, backprops, sgd) - return d_tokvecs - - return (scores, spans), apply_to_spans_backward - - model = wrap(apply_to_spans_forward, tok2vecs, span2scores) - model.tok2vec = tok2vec - model.span2scores = span2scores - return model - - def _get_token_label_matrix(gold_phrases, lengths, labels): + """Figure out how each token should be labelled w.r.t. some gold-standard + spans, where the labels indicate whether that token is part of the span.""" output = numpy.zeros((sum(lengths), len(labels)), dtype="i") label2class = {label: i for i, label in enumerate(labels)} offset = 0 @@ -208,6 +240,7 @@ def _get_token_label_matrix(gold_phrases, lengths, labels): def _scores2spans(docs, scores, starts, ends, labels, threshold=0.5): + """Produce labelled Span objects implied by the model's predictions.""" token_to_doc = _get_token_to_doc(docs) output = [] # When we predict, assume only one label per identical span. @@ -216,12 +249,15 @@ def _scores2spans(docs, scores, starts, ends, labels, threshold=0.5): for i, start in enumerate(starts): doc_i, offset = token_to_doc[start] if bests[i] >= threshold: - span = Span(docs[doc_i], start, ends[i], label=labels[guesses[i]) + span = Span(docs[doc_i], start, ends[i], label=labels[guesses[i]]) output.append(span) return output def _get_token_to_doc(docs): + """Map token positions within a batch to a tuple (doc_index, doc_offset). + When we flatten an array for the batch, this lets us easily find the token + each row in the flat array corresponds to.""" offset = 0 token_to_doc = {} for i, doc in enumerate(docs): @@ -232,6 +268,9 @@ def _get_token_to_doc(docs): def _get_all_spans(length, max_len, offset=0): + """List (start, end) indices of all subsequences up to `max_len`, + for a sequence of length `length`. Indices may be offset by `offset`. + """ spans = [] for start in range(length): for end in range(i + 1, min(i + 1 + max_len, length)): @@ -258,43 +297,9 @@ def _batch_spans_by_length(spans): def _get_span_batch(vectors, starts, length): + """Make a contiguous array for spans of a certain length.""" xp = get_array_module(vectors) output = xp.zeros((len(starts), length, vectors.shape[1])) for i, start in enumerate(starts): output[i] = vectors[start : start + length] return output - - -def _predict_batched(model, tokvecs, spans): - xp = get_array_module(vectors) - output = xp.zeros((len(spans), model.nO), dtype="f") - for indices, starts, length in _batch_spans_by_length(spans): - X = _get_span_batch(tokvecs, starts, length) - batchY = model(X) - for i, output_idx in enumerate(indices): - output[output_idx] = Y[i] - return output - - -def _begin_update_batched(model, tokvecs, spans, drop): - xp = get_array_module(vectors) - output = xp.zeros((len(spans), model.nO), dtype="f") - backprops = [] - for indices, starts, length in _batch_spans_by_length(spans): - X = _get_span_batch(tokvecs, starts, length) - batchY, backprop = model.begin_update(X, drop=drop) - for i, output_idx in enumerate(indices): - output[output_idx] = Y[i] - backprops.append((indices, starts, length, backprop)) - return output, backprops - - -def _backprop_batched(shape, d_output, backprops): - xp = get_array_module(d_output) - d_tokvecs = xp.zeros(shape, dtype=d_output.dtype) - for indices, starts, ends, backprop in backprops: - dY = d_output[indices] - dX = backprop(dY) - for i, (start, end) in enumerate(zip(starts, ends)): - d_tokvecs[start:end] += dX[i, : end - start] - return d_tokvecs