Refactor spancat

This commit is contained in:
Matthew Honnibal 2019-07-13 12:12:46 +02:00
parent f04bfc3cfd
commit bf7c7c9ce2

View File

@ -1,16 +1,16 @@
from .pipes import Pipe
from thinc.v2v import Maxout, Affine
from thinc.t2t import SoftAttention
from thinc.v2v import Maxout, Affine, Model
from thinc.t2v import Pooling, sum_pool
from thinc.api import zero_init
from .._ml import logistic
from thinc.api import layerize, chain
from thinc.misc import LayerNorm, Residual
from .pipes import Pipe
from .._ml import logistic, zero_init, Tok2Vec
class SpanCategorizer(Pipe):
"""Predict labels for spans of text."""
@classmethod
def Model(cls, **cfg):
def Model(cls, nO, **cfg):
# TODO: Settings here
tok2vec = Tok2Vec(**cfg)
with Model.define_operators({">>": chain}):
@ -19,17 +19,22 @@ class SpanCategorizer(Pipe):
#>> SoftAttention
>> Pooling(sum_pool)
>> LayerNorm(Residual(Maxout(tok2vec.nO)))
>> zero_init(Affine(tok2vec.nO))
>> zero_init(Affine(nO, tok2vec.nO))
>> logistic
)
return create_span_model(self.get_spans, tok2vec, span2scores)
return SpanCategorizerModel(tok2vec, span2scores)
def __init__(self, user_data_key="phrases", get_spans=None, model=True):
Pipe.__init__(self)
def __init__(self, vocab, user_data_key="phrases", max_length=10,
get_spans=None, model=True):
self.cfg = {
"user_data_key": user_data_key,
"labels": [],
"max_length": max_length
}
self.vocab = vocab
self.user_data_key = user_data_key
self.span_getter = get_spans
self.model = model
self.max_length = 10
@property
def tok2vec(self):
@ -116,13 +121,13 @@ class SpanCategorizer(Pipe):
get_grads.b1 = sgd.b1
get_grads.b2 = sgd.b2
token_label_matrix = _get_token_label_matrix(
[g.phrases for g in golds], [len(doc) for doc in docs], self.labels
)
golds = [getattr(g, "spans", g) for g in golds]
token_label_matrix = _get_token_label_matrix(golds, [len(d) for d in docs], self.labels)
for indices, starts, length in _batch_spans_by_length(spans):
X = _get_span_batch(tokvecs, starts, length)
Y, backprop = self.spans2scores.begin_update(X, drop=drop)
Y, backprop = self.model.spans2scores.begin_update(X, drop=drop)
dY = self.get_loss((indices, Y), token_label_matrix)
dX = backprop(dY, sgd=get_grads)
for i, start in enumerate(starts):
@ -135,6 +140,76 @@ class SpanCategorizer(Pipe):
return losses
class SpanCategorizerModel(Model):
"""Predict labels for spans, using two component models: a tok2vec model,
and a span2scores model. The input to the SpanCategorizerModel should be
a tuple (docs, spans), where the spans are an array with two columns:
(start, end).
The output will be a tuple (outputs, spans), where the outputs array
will have one row per span. In the backward pass, we take the gradients w.r.t.
the spans, and backprop through the input vectors.
A naive implementation of this process would make a single array, padded
for all spans. However, the longest span may be very long, and we might
have lots of spans, so this array could consume an enormous amount of
memory. Instead, we sort the spans by length and work in batches. This
reduces the total amount of padding, and means we do not need to hold
expanded arrays for the whole data. As a bonus, the input model also
doesn't need to worry about masking: we know that the
data it works over has no empty items.
"""
name = "spancat_model"
def __init__(self, tok2vec, span2scores):
Model.__init__(self)
self.tok2vec = tok2vec
self.span2scores = span2scores
self._layers.append(tok2vec)
self._layers.append(span2scores)
@property
def nO(self):
return span2scores.nO
def predict(self, docs_spans):
"""Predict scores for the spans, batching by length."""
scores = self.ops.xp.zeros((len(spans), model.nO), dtype="f")
for indices, starts, length in _batch_spans_by_length(spans):
X = _get_span_batch(tokvecs, starts, length)
batchY = model(X)
for i, output_idx in enumerate(indices):
scores[output_idx] = Y[i]
return scores
def begin_update(self, docs_spans, drop=0.):
"""Do the forward pass of the span classification, batching the input.
Returns a tuple (scores, callback), where the callback takes an array
d_scores and performs the backward pass."""
docs, spans = docs_spans
tokvecs, bp_tokvecs = self.tok2vec.begin_update(docs, drop=drop)
scores = self.ops.xp.zeros((len(spans), model.nO), dtype="f")
backprops = []
for indices, starts, length in _batch_spans_by_length(spans):
X = _get_span_batch(tokvecs, starts, length)
batchY, backprop = self.span2scores.begin_update(X, drop=drop)
for i, output_idx in enumerate(indices):
scores[output_idx] = Y[i]
backprops.append((indices, starts, length, backprop))
shape = tokvecs.shape
def backprop_spancat_model(d_scores, sgd=None):
d_tokvecs = self.ops.xp.zeros(shape, dtype=d_output.dtype)
for indices, starts, ends, backprop in backprops:
dY = d_output[indices]
dX = backprop(dY)
for i, (start, end) in enumerate(zip(starts, ends)):
d_tokvecs[start:end] += dX[i, : end - start]
return bp_tokvecs(d_tokvecs, sgd=sgd)
return scores, backprop_spancat_model
@layerize
def reshape_add_lengths(X, drop=0.):
xp = get_array_module(X)
@ -149,52 +224,9 @@ def reshape_add_lengths(X, drop=0.):
return Y, backprop_reshape
def predict_spans(doc2spans, tok2vecs, span2scores):
"""Apply a model over inputs that are a tuple of (vectors, spans), where the
spans are an array of (start, end) offsets. The vectors should be a single
array concatenated for the whole batch.
The output will be a tuple (outputs, spans), where the outputs array
will have one row per span. In the backward pass, we take the gradients w.r.t.
the spans, and return the gradients w.r.t. the input vectors.
A naive implementation of this process would make a single array, padded
for all spans. However, the longest span may be very long, so this array
would consume an enormous amount of memory. Instead, we sort the spans by
length and work in batches. This reduces the total amount of padding, and
means we do not need to hold expanded arrays for the whole data. As a bonus,
the input model also doesn't need to worry about masking: we know that the
data it works over has no empty items.
"""
def apply_to_spans_forward(inputs, drop=0.0):
docs = inputs.get("docs")
tokvecs = inputs.get("tokvecs")
spans = inputs.get("spans")
if spans is None:
spans = doc2spans(docs)
if tokvecs is None:
tokvecs, bp_tokvecs = tok2vecs.begin_update(docs, drop=drop)
else:
bp_tokvecs = None
scores, backprop_scores = _begin_update_batched(
span2scores, tokvecs, spans, drop=drop
)
shape = tokvecs.shape
def apply_to_spans_backward(d_scores, sgd=None):
d_tokvecs = _backprop_batched(shape, d_scores, backprops, sgd)
return d_tokvecs
return (scores, spans), apply_to_spans_backward
model = wrap(apply_to_spans_forward, tok2vecs, span2scores)
model.tok2vec = tok2vec
model.span2scores = span2scores
return model
def _get_token_label_matrix(gold_phrases, lengths, labels):
"""Figure out how each token should be labelled w.r.t. some gold-standard
spans, where the labels indicate whether that token is part of the span."""
output = numpy.zeros((sum(lengths), len(labels)), dtype="i")
label2class = {label: i for i, label in enumerate(labels)}
offset = 0
@ -208,6 +240,7 @@ def _get_token_label_matrix(gold_phrases, lengths, labels):
def _scores2spans(docs, scores, starts, ends, labels, threshold=0.5):
"""Produce labelled Span objects implied by the model's predictions."""
token_to_doc = _get_token_to_doc(docs)
output = []
# When we predict, assume only one label per identical span.
@ -216,12 +249,15 @@ def _scores2spans(docs, scores, starts, ends, labels, threshold=0.5):
for i, start in enumerate(starts):
doc_i, offset = token_to_doc[start]
if bests[i] >= threshold:
span = Span(docs[doc_i], start, ends[i], label=labels[guesses[i])
span = Span(docs[doc_i], start, ends[i], label=labels[guesses[i]])
output.append(span)
return output
def _get_token_to_doc(docs):
"""Map token positions within a batch to a tuple (doc_index, doc_offset).
When we flatten an array for the batch, this lets us easily find the token
each row in the flat array corresponds to."""
offset = 0
token_to_doc = {}
for i, doc in enumerate(docs):
@ -232,6 +268,9 @@ def _get_token_to_doc(docs):
def _get_all_spans(length, max_len, offset=0):
"""List (start, end) indices of all subsequences up to `max_len`,
for a sequence of length `length`. Indices may be offset by `offset`.
"""
spans = []
for start in range(length):
for end in range(i + 1, min(i + 1 + max_len, length)):
@ -258,43 +297,9 @@ def _batch_spans_by_length(spans):
def _get_span_batch(vectors, starts, length):
"""Make a contiguous array for spans of a certain length."""
xp = get_array_module(vectors)
output = xp.zeros((len(starts), length, vectors.shape[1]))
for i, start in enumerate(starts):
output[i] = vectors[start : start + length]
return output
def _predict_batched(model, tokvecs, spans):
xp = get_array_module(vectors)
output = xp.zeros((len(spans), model.nO), dtype="f")
for indices, starts, length in _batch_spans_by_length(spans):
X = _get_span_batch(tokvecs, starts, length)
batchY = model(X)
for i, output_idx in enumerate(indices):
output[output_idx] = Y[i]
return output
def _begin_update_batched(model, tokvecs, spans, drop):
xp = get_array_module(vectors)
output = xp.zeros((len(spans), model.nO), dtype="f")
backprops = []
for indices, starts, length in _batch_spans_by_length(spans):
X = _get_span_batch(tokvecs, starts, length)
batchY, backprop = model.begin_update(X, drop=drop)
for i, output_idx in enumerate(indices):
output[output_idx] = Y[i]
backprops.append((indices, starts, length, backprop))
return output, backprops
def _backprop_batched(shape, d_output, backprops):
xp = get_array_module(d_output)
d_tokvecs = xp.zeros(shape, dtype=d_output.dtype)
for indices, starts, ends, backprop in backprops:
dY = d_output[indices]
dX = backprop(dY)
for i, (start, end) in enumerate(zip(starts, ends)):
d_tokvecs[start:end] += dX[i, : end - start]
return d_tokvecs