From bf7c7c9ce27368f4da50103a523c3e6a2cf37137 Mon Sep 17 00:00:00 2001
From: Matthew Honnibal <honnibal+gh@gmail.com>
Date: Sat, 13 Jul 2019 12:12:46 +0200
Subject: [PATCH] Refactor spancat

---
 spacy/pipeline/spancat.py | 197 +++++++++++++++++++-------------------
 1 file changed, 101 insertions(+), 96 deletions(-)

diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py
index 3a981308b..4223f32ed 100644
--- a/spacy/pipeline/spancat.py
+++ b/spacy/pipeline/spancat.py
@@ -1,16 +1,16 @@
-from .pipes import Pipe
-from thinc.v2v import Maxout, Affine
-from thinc.t2t import SoftAttention
+from thinc.v2v import Maxout, Affine, Model
 from thinc.t2v import Pooling, sum_pool
-from thinc.api import zero_init
-from .._ml import logistic
+from thinc.api import layerize, chain
+from thinc.misc import LayerNorm, Residual
+from .pipes import Pipe
+from .._ml import logistic, zero_init, Tok2Vec
 
 
 class SpanCategorizer(Pipe):
     """Predict labels for spans of text."""
 
     @classmethod
-    def Model(cls, **cfg):
+    def Model(cls, nO, **cfg):
         # TODO: Settings here
         tok2vec = Tok2Vec(**cfg)
         with Model.define_operators({">>": chain}):
@@ -19,17 +19,22 @@ class SpanCategorizer(Pipe):
                 #>> SoftAttention
                 >> Pooling(sum_pool)
                 >> LayerNorm(Residual(Maxout(tok2vec.nO)))
-                >> zero_init(Affine(tok2vec.nO))
+                >> zero_init(Affine(nO, tok2vec.nO))
                 >> logistic
             )
-        return create_span_model(self.get_spans, tok2vec, span2scores)
+        return SpanCategorizerModel(tok2vec, span2scores)
 
-    def __init__(self, user_data_key="phrases", get_spans=None, model=True):
-        Pipe.__init__(self)
+    def __init__(self, vocab, user_data_key="phrases", max_length=10,
+            get_spans=None, model=True):
+        self.cfg = {
+            "user_data_key": user_data_key,
+            "labels": [],
+            "max_length": max_length
+        }
+        self.vocab = vocab
         self.user_data_key = user_data_key
         self.span_getter = get_spans
         self.model = model
-        self.max_length = 10
 
     @property
     def tok2vec(self):
@@ -116,13 +121,13 @@ class SpanCategorizer(Pipe):
         get_grads.b1 = sgd.b1
         get_grads.b2 = sgd.b2
 
-        token_label_matrix = _get_token_label_matrix(
-            [g.phrases for g in golds], [len(doc) for doc in docs], self.labels
-        )
+        golds = [getattr(g, "spans", g) for g in golds]
+
+        token_label_matrix = _get_token_label_matrix(golds, [len(d) for d in docs], self.labels)
 
         for indices, starts, length in _batch_spans_by_length(spans):
             X = _get_span_batch(tokvecs, starts, length)
-            Y, backprop = self.spans2scores.begin_update(X, drop=drop)
+            Y, backprop = self.model.spans2scores.begin_update(X, drop=drop)
             dY = self.get_loss((indices, Y), token_label_matrix)
             dX = backprop(dY, sgd=get_grads)
             for i, start in enumerate(starts):
@@ -135,6 +140,76 @@ class SpanCategorizer(Pipe):
         return losses
 
 
+class SpanCategorizerModel(Model):
+    """Predict labels for spans, using two component models: a tok2vec model,
+    and a span2scores model. The input to the SpanCategorizerModel should be
+    a tuple (docs, spans), where the spans are an array with two columns:
+    (start, end).
+
+    The output will be a tuple (outputs, spans), where the outputs array
+    will have one row per span. In the backward pass, we take the gradients w.r.t.
+    the spans, and backprop through the input vectors.
+
+    A naive implementation of this process would make a single array, padded
+    for all spans. However, the longest span may be very long, and we might
+    have lots of spans, so this array could consume an enormous amount of
+    memory. Instead, we sort the spans by length and work in batches. This
+    reduces the total amount of padding, and means we do not need to hold
+    expanded arrays for the whole data. As a bonus, the input model also
+    doesn't need to worry about masking: we know that the
+    data it works over has no empty items.
+    """
+    name = "spancat_model"
+    def __init__(self, tok2vec, span2scores):
+        Model.__init__(self)
+        self.tok2vec = tok2vec
+        self.span2scores = span2scores
+        self._layers.append(tok2vec)
+        self._layers.append(span2scores)
+        
+    @property
+    def nO(self):
+        return span2scores.nO
+
+    def predict(self, docs_spans):
+        """Predict scores for the spans, batching by length."""
+        scores = self.ops.xp.zeros((len(spans), model.nO), dtype="f")
+        for indices, starts, length in _batch_spans_by_length(spans):
+            X = _get_span_batch(tokvecs, starts, length)
+            batchY = model(X)
+            for i, output_idx in enumerate(indices):
+                scores[output_idx] = Y[i]
+        return scores
+
+    def begin_update(self, docs_spans, drop=0.):
+        """Do the forward pass of the span classification, batching the input.
+        
+        Returns a tuple (scores, callback), where the callback takes an array
+        d_scores and performs the backward pass."""
+        docs, spans = docs_spans
+        tokvecs, bp_tokvecs = self.tok2vec.begin_update(docs, drop=drop)
+        scores = self.ops.xp.zeros((len(spans), model.nO), dtype="f")
+        backprops = []
+        for indices, starts, length in _batch_spans_by_length(spans):
+            X = _get_span_batch(tokvecs, starts, length)
+            batchY, backprop = self.span2scores.begin_update(X, drop=drop)
+            for i, output_idx in enumerate(indices):
+                scores[output_idx] = Y[i]
+            backprops.append((indices, starts, length, backprop))
+        shape = tokvecs.shape
+
+        def backprop_spancat_model(d_scores, sgd=None):
+            d_tokvecs = self.ops.xp.zeros(shape, dtype=d_output.dtype)
+            for indices, starts, ends, backprop in backprops:
+                dY = d_output[indices]
+                dX = backprop(dY)
+                for i, (start, end) in enumerate(zip(starts, ends)):
+                    d_tokvecs[start:end] += dX[i, : end - start]
+            return bp_tokvecs(d_tokvecs, sgd=sgd)
+
+        return scores, backprop_spancat_model
+
+
 @layerize
 def reshape_add_lengths(X, drop=0.):
     xp = get_array_module(X)
@@ -149,52 +224,9 @@ def reshape_add_lengths(X, drop=0.):
     return Y, backprop_reshape
 
 
-def predict_spans(doc2spans, tok2vecs, span2scores):
-    """Apply a model over inputs that are a tuple of (vectors, spans), where the
-    spans are an array of (start, end) offsets. The vectors should be a single
-    array concatenated for the whole batch.
-
-    The output will be a tuple (outputs, spans), where the outputs array
-    will have one row per span. In the backward pass, we take the gradients w.r.t.
-    the spans, and return the gradients w.r.t. the input vectors.
-
-    A naive implementation of this process would make a single array, padded
-    for all spans. However, the longest span may be very long, so this array
-    would consume an enormous amount of memory. Instead, we sort the spans by
-    length and work in batches. This reduces the total amount of padding, and
-    means we do not need to hold expanded arrays for the whole data. As a bonus,
-    the input model also doesn't need to worry about masking: we know that the
-    data it works over has no empty items.
-    """
-
-    def apply_to_spans_forward(inputs, drop=0.0):
-        docs = inputs.get("docs")
-        tokvecs = inputs.get("tokvecs")
-        spans = inputs.get("spans")
-        if spans is None:
-            spans = doc2spans(docs)
-        if tokvecs is None:
-            tokvecs, bp_tokvecs = tok2vecs.begin_update(docs, drop=drop)
-        else:
-            bp_tokvecs = None
-        scores, backprop_scores = _begin_update_batched(
-            span2scores, tokvecs, spans, drop=drop
-        )
-        shape = tokvecs.shape
-
-        def apply_to_spans_backward(d_scores, sgd=None):
-            d_tokvecs = _backprop_batched(shape, d_scores, backprops, sgd)
-            return d_tokvecs
-
-        return (scores, spans), apply_to_spans_backward
-
-    model = wrap(apply_to_spans_forward, tok2vecs, span2scores)
-    model.tok2vec = tok2vec
-    model.span2scores = span2scores
-    return model
-
-
 def _get_token_label_matrix(gold_phrases, lengths, labels):
+    """Figure out how each token should be labelled w.r.t. some gold-standard
+    spans, where the labels indicate whether that token is part of the span."""
     output = numpy.zeros((sum(lengths), len(labels)), dtype="i")
     label2class = {label: i for i, label in enumerate(labels)}
     offset = 0
@@ -208,6 +240,7 @@ def _get_token_label_matrix(gold_phrases, lengths, labels):
 
 
 def _scores2spans(docs, scores, starts, ends, labels, threshold=0.5):
+    """Produce labelled Span objects implied by the model's predictions."""
     token_to_doc = _get_token_to_doc(docs)
     output = []
     # When we predict, assume only one label per identical span.
@@ -216,12 +249,15 @@ def _scores2spans(docs, scores, starts, ends, labels, threshold=0.5):
     for i, start in enumerate(starts):
         doc_i, offset = token_to_doc[start]
         if bests[i] >= threshold:
-            span = Span(docs[doc_i], start, ends[i], label=labels[guesses[i])
+            span = Span(docs[doc_i], start, ends[i], label=labels[guesses[i]])
             output.append(span)
     return output
 
 
 def _get_token_to_doc(docs):
+    """Map token positions within a batch to a tuple (doc_index, doc_offset).
+    When we flatten an array for the batch, this lets us easily find the token
+    each row in the flat array corresponds to."""
     offset = 0
     token_to_doc = {}
     for i, doc in enumerate(docs):
@@ -232,6 +268,9 @@ def _get_token_to_doc(docs):
 
 
 def _get_all_spans(length, max_len, offset=0):
+    """List (start, end) indices of all subsequences up to `max_len`,
+    for a sequence of length `length`. Indices may be offset by `offset`.
+    """
     spans = []
     for start in range(length):
         for end in range(i + 1, min(i + 1 + max_len, length)):
@@ -258,43 +297,9 @@ def _batch_spans_by_length(spans):
 
 
 def _get_span_batch(vectors, starts, length):
+    """Make a contiguous array for spans of a certain length."""
     xp = get_array_module(vectors)
     output = xp.zeros((len(starts), length, vectors.shape[1]))
     for i, start in enumerate(starts):
         output[i] = vectors[start : start + length]
     return output
-
-
-def _predict_batched(model, tokvecs, spans):
-    xp = get_array_module(vectors)
-    output = xp.zeros((len(spans), model.nO), dtype="f")
-    for indices, starts, length in _batch_spans_by_length(spans):
-        X = _get_span_batch(tokvecs, starts, length)
-        batchY = model(X)
-        for i, output_idx in enumerate(indices):
-            output[output_idx] = Y[i]
-    return output
-
-
-def _begin_update_batched(model, tokvecs, spans, drop):
-    xp = get_array_module(vectors)
-    output = xp.zeros((len(spans), model.nO), dtype="f")
-    backprops = []
-    for indices, starts, length in _batch_spans_by_length(spans):
-        X = _get_span_batch(tokvecs, starts, length)
-        batchY, backprop = model.begin_update(X, drop=drop)
-        for i, output_idx in enumerate(indices):
-            output[output_idx] = Y[i]
-        backprops.append((indices, starts, length, backprop))
-    return output, backprops
-
-
-def _backprop_batched(shape, d_output, backprops):
-    xp = get_array_module(d_output)
-    d_tokvecs = xp.zeros(shape, dtype=d_output.dtype)
-    for indices, starts, ends, backprop in backprops:
-        dY = d_output[indices]
-        dX = backprop(dY)
-        for i, (start, end) in enumerate(zip(starts, ends)):
-            d_tokvecs[start:end] += dX[i, : end - start]
-    return d_tokvecs