Update draft of parser neural network model

Model is good, but code is messy. Currently requires Chainer, which may cause the build to fail on machines without a GPU.

Outline of the model:

We first predict context-sensitive vectors for each word in the input:

(embed_lower | embed_prefix | embed_suffix | embed_shape)
>> Maxout(token_width)
>> convolution ** 4

This convolutional layer is shared between the tagger and the parser. This prevents the parser from needing tag features.
To boost the representation, we make a "super tag" with POS, morphology and dependency label. The tagger predicts this
by adding a softmax layer onto the convolutional layer --- so, we're teaching the convolutional layer to give us a
representation that's one affine transform from this informative lexical information. This is obviously good for the
parser (which backprops to the convolutions too).

The parser model makes a state vector by concatenating the vector representations for its context tokens. Current
results suggest few context tokens works well. Maybe this is a bug.

The current context tokens:

* S0, S1, S2: Top three words on the stack
* B0, B1: First two words of the buffer
* S0L1, S0L2: Leftmost and second leftmost children of S0
* S0R1, S0R2: Rightmost and second rightmost children of S0
* S1L1, S1L2, S1R2, S1R, B0L1, B0L2: Likewise for S1 and B0

This makes the state vector quite long: 13*T, where T is the token vector width (128 is working well). Fortunately,
there's a way to structure the computation to save some expense (and make it more GPU friendly).

The parser typically visits 2*N states for a sentence of length N (although it may visit more, if it back-tracks
with a non-monotonic transition). A naive implementation would require 2*N (B, 13*T) @ (13*T, H) matrix multiplications
for a batch of size B. We can instead perform one (B*N, T) @ (T, 13*H) multiplication, to pre-compute the hidden
weights for each positional feature wrt the words in the batch. (Note that our token vectors come from the CNN
-- so we can't play this trick over the vocabulary. That's how Stanford's NN parser works --- and why its model
is so big.)

This pre-computation strategy allows a nice compromise between GPU-friendliness and implementation simplicity.
The CNN and the wide lower layer are computed on the GPU, and then the precomputed hidden weights are moved
to the CPU, before we start the transition-based parsing process. This makes a lot of things much easier.
We don't have to worry about variable-length batch sizes, and we don't have to implement the dynamic oracle
in CUDA to train.

Currently the parser's loss function is multilabel log loss, as the dynamic oracle allows multiple states to
be 0 cost. This is defined as:

(exp(score) / Z) - (exp(score) / gZ)

Where gZ is the sum of the scores assigned to gold classes. I'm very interested in regressing on the cost directly,
but so far this isn't working well.

Machinery is in place for beam-search, which has been working well for the linear model. Beam search should benefit
greatly from the pre-computation trick.
This commit is contained in:
Matthew Honnibal 2017-05-12 16:09:15 -05:00
parent b44f7e259c
commit 827b5af697
4 changed files with 394 additions and 276 deletions

View File

@ -1,12 +1,12 @@
from thinc.api import add, layerize, chain, clone, concatenate, with_flatten from thinc.api import add, layerize, chain, clone, concatenate, with_flatten
from thinc.neural import Model, ReLu, Maxout, Softmax, Affine from thinc.neural import Model, Maxout, Softmax, Affine
from thinc.neural._classes.hash_embed import HashEmbed from thinc.neural._classes.hash_embed import HashEmbed
from thinc.neural.ops import NumpyOps, CupyOps
from thinc.neural._classes.convolution import ExtractWindow from thinc.neural._classes.convolution import ExtractWindow
from thinc.neural._classes.static_vectors import StaticVectors from thinc.neural._classes.static_vectors import StaticVectors
from thinc.neural._classes.batchnorm import BatchNorm from thinc.neural._classes.batchnorm import BatchNorm
from thinc.neural._classes.resnet import Residual from thinc.neural._classes.resnet import Residual
from thinc import describe from thinc import describe
from thinc.describe import Dimension, Synapses, Biases, Gradient from thinc.describe import Dimension, Synapses, Biases, Gradient
from thinc.neural._classes.affine import _set_dimensions_if_needed from thinc.neural._classes.affine import _set_dimensions_if_needed
@ -78,7 +78,7 @@ class PrecomputableAffine(Model):
d_b=Gradient("b") d_b=Gradient("b")
) )
class PrecomputableMaxouts(Model): class PrecomputableMaxouts(Model):
def __init__(self, nO=None, nI=None, nF=None, pieces=2, **kwargs): def __init__(self, nO=None, nI=None, nF=None, pieces=3, **kwargs):
Model.__init__(self, **kwargs) Model.__init__(self, **kwargs)
self.nO = nO self.nO = nO
self.nP = pieces self.nP = pieces
@ -87,88 +87,71 @@ class PrecomputableMaxouts(Model):
def begin_update(self, X, drop=0.): def begin_update(self, X, drop=0.):
# X: (b, i) # X: (b, i)
# Yfp: (f, b, o, p) # Yfp: (b, f, o, p)
# Yf: (f, b, o) # Xf: (f, b, i)
# Xf: (b, f, i)
# dY: (b, o)
# dYp: (b, o, p) # dYp: (b, o, p)
# W: (f, o, p, i) # W: (f, o, p, i)
# b: (o, p) # b: (o, p)
# Yfp = numpy.einsum('bi,fopi->fbop', X, self.W) # bi,opfi->bfop
Yfp = self.ops.xp.tensordot(X, self.W, # bop,fopi->bfi
axes=[[1], [3]]).transpose((1, 0, 2, 3)) # bop,fbi->opfi : fopi
Yfp = self.ops.xp.ascontiguousarray(Yfp)
tensordot = self.ops.xp.tensordot
ascontiguous = self.ops.xp.ascontiguousarray
Yfp = tensordot(X, self.W, axes=[[1], [3]])
Yfp += self.b Yfp += self.b
Yf = self.ops.allocate((self.nF, X.shape[0], self.nO))
which = self.ops.allocate((self.nF, X.shape[0], self.nO), dtype='i') def backward(dYp_ids, sgd=None):
for i in range(self.nF): dYp, ids = dYp_ids
Yf[i], which[i] = self.ops.maxout(Yfp[i])
def backward(dY_ids, sgd=None):
dY, ids = dY_ids
Xf = X[ids] Xf = X[ids]
dYp = self.ops.allocate((dY.shape[0], self.nO, self.nP))
for i in range(self.nF):
dYp += self.ops.backprop_maxout(dY, which[i], self.nP)
#dXf = numpy.einsum('bop,fopi->bfi', dYp, self.W) dXf = tensordot(dYp, self.W, axes=[[1, 2], [1,2]])
dXf = self.ops.xp.tensordot(dYp, self.W, axes=[[1,2], [1,2]]) dW = tensordot(dYp, Xf, axes=[[0], [0]])
#dW = numpy.einsum('bfi,bop->fopi', Xf, dYp)
dW = self.ops.xp.tensordot(Xf, dYp, axes=[[0], [0]])
dW = dW.transpose((0, 2, 3, 1))
db = dYp.sum(axis=0)
self.d_W += dW self.d_W += dW.transpose((2, 0, 1, 3))
self.d_b += db self.d_b += dYp.sum(axis=0)
if sgd is not None: if sgd is not None:
sgd(self._mem.weights, self._mem.gradient, key=self.id) sgd(self._mem.weights, self._mem.gradient, key=self.id)
return dXf return dXf
return Yf, backward return Yfp, backward
def get_col(idx): def get_col(idx):
def forward(X, drop=0.): def forward(X, drop=0.):
if isinstance(X, numpy.ndarray):
ops = NumpyOps()
else:
ops = CupyOps()
assert len(X.shape) <= 3 assert len(X.shape) <= 3
output = Model.ops.xp.ascontiguousarray(X[:, idx]) output = ops.xp.ascontiguousarray(X[:, idx])
def backward(y, sgd=None): def backward(y, sgd=None):
dX = Model.ops.allocate(X.shape) dX = ops.allocate(X.shape)
dX[:, idx] += y dX[:, idx] += y
return dX return dX
return output, backward return output, backward
return layerize(forward) return layerize(forward)
def build_tok2vec(lang, width, depth=2, embed_size=1000): def zero_init(model):
def _hook(self, X, y=None):
self.W.fill(0)
model.on_data_hooks.append(_hook)
return model
def doc2feats(cols=None):
cols = [ID, LOWER, PREFIX, SUFFIX, SHAPE] cols = [ID, LOWER, PREFIX, SUFFIX, SHAPE]
with Model.define_operators({'>>': chain, '|': concatenate, '**': clone}):
#static = get_col(cols.index(ID)) >> StaticVectors(lang, width)
lower = get_col(cols.index(LOWER)) >> HashEmbed(width, embed_size)
prefix = get_col(cols.index(PREFIX)) >> HashEmbed(width//4, embed_size)
suffix = get_col(cols.index(SUFFIX)) >> HashEmbed(width//4, embed_size)
shape = get_col(cols.index(SHAPE)) >> HashEmbed(width//4, embed_size)
tok2vec = (
doc2feats(cols)
>> with_flatten(
#(static | prefix | suffix | shape)
(lower | prefix | suffix | shape)
>> Maxout(width)
>> (ExtractWindow(nW=1) >> Maxout(width, width*3))
>> (ExtractWindow(nW=1) >> Maxout(width, width*3))
)
)
return tok2vec
def doc2feats(cols):
def forward(docs, drop=0.): def forward(docs, drop=0.):
feats = [doc.to_array(cols) for doc in docs] feats = [doc.to_array(cols) for doc in docs]
feats = [model.ops.asarray(f, dtype='uint64') for f in feats] feats = [model.ops.asarray(f, dtype='uint64') for f in feats]
return feats, None return feats, None
model = layerize(forward) model = layerize(forward)
model.cols = cols
return model return model
def print_shape(prefix): def print_shape(prefix):
def forward(X, drop=0.): def forward(X, drop=0.):
return X, lambda dX, **kwargs: dX return X, lambda dX, **kwargs: dX
@ -186,52 +169,12 @@ def get_token_vectors(tokens_attrs_vectors, drop=0.):
@layerize @layerize
def flatten(seqs, drop=0.): def flatten(seqs, drop=0.):
ops = Model.ops if isinstance(seqs[0], numpy.ndarray):
ops = NumpyOps()
else:
ops = CupyOps()
lengths = [len(seq) for seq in seqs] lengths = [len(seq) for seq in seqs]
def finish_update(d_X, sgd=None): def finish_update(d_X, sgd=None):
return ops.unflatten(d_X, lengths) return ops.unflatten(d_X, lengths)
X = ops.xp.vstack(seqs) X = ops.xp.vstack(seqs)
return X, finish_update return X, finish_update
#def build_feature_precomputer(model, feat_maps):
# '''Allow a model to be "primed" by pre-computing input features in bulk.
#
# This is used for the parser, where we want to take a batch of documents,
# and compute vectors for each (token, position) pair. These vectors can then
# be reused, especially for beam-search.
#
# Let's say we're using 12 features for each state, e.g. word at start of
# buffer, three words on stack, their children, etc. In the normal arc-eager
# system, a document of length N is processed in 2*N states. This means we'll
# create 2*N*12 feature vectors --- but if we pre-compute, we only need
# N*12 vector computations. The saving for beam-search is much better:
# if we have a beam of k, we'll normally make 2*N*12*K computations --
# so we can save the factor k. This also gives a nice CPU/GPU division:
# we can do all our hard maths up front, packed into large multiplications,
# and do the hard-to-program parsing on the CPU.
# '''
# def precompute(input_vectors):
# cached, backprops = zip(*[lyr.begin_update(input_vectors)
# for lyr in feat_maps)
# def forward(batch_token_ids, drop=0.):
# output = ops.allocate((batch_size, output_width))
# # i: batch index
# # j: position index (i.e. N0, S0, etc
# # tok_i: Index of the token within its document
# for i, token_ids in enumerate(batch_token_ids):
# for j, tok_i in enumerate(token_ids):
# output[i] += cached[j][tok_i]
# def backward(d_vector, sgd=None):
# d_inputs = ops.allocate((batch_size, n_feat, vec_width))
# for i, token_ids in enumerate(batch_token_ids):
# for j in range(len(token_ids)):
# d_inputs[i][j] = backprops[j](d_vector, sgd)
# # Return the IDs, so caller can associate to correct token
# return (batch_token_ids, d_inputs)
# return vector, backward
# return chain(layerize(forward), model)
# return precompute
#
#

View File

@ -1,9 +1,12 @@
# cython: infer_types=True
# cython: profile=True
# coding: utf8 # coding: utf8
from __future__ import unicode_literals from __future__ import unicode_literals
from thinc.api import chain, layerize, with_getitem from thinc.api import chain, layerize, with_getitem
from thinc.neural import Model, Softmax from thinc.neural import Model, Softmax
import numpy import numpy
cimport numpy as np
from .tokens.doc cimport Doc from .tokens.doc cimport Doc
from .syntax.parser cimport Parser from .syntax.parser cimport Parser
@ -11,63 +14,96 @@ from .syntax.parser cimport Parser
from .syntax.ner cimport BiluoPushDown from .syntax.ner cimport BiluoPushDown
from .syntax.arc_eager cimport ArcEager from .syntax.arc_eager cimport ArcEager
from .tagger import Tagger from .tagger import Tagger
from ._ml import build_tok2vec, flatten from .gold cimport GoldParse
from thinc.api import add, layerize, chain, clone, concatenate
from thinc.neural import Model, Maxout, Softmax, Affine
from thinc.neural._classes.hash_embed import HashEmbed
from thinc.neural.util import to_categorical
from thinc.neural._classes.convolution import ExtractWindow
from thinc.neural._classes.resnet import Residual
from thinc.neural._classes.batchnorm import BatchNorm as BN
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP
from ._ml import flatten, get_col, doc2feats
# TODO: The disorganization here is pretty embarrassing. At least it's only
# internals.
from .syntax.parser import get_templates as get_feature_templates
from .attrs import DEP, ENT_TYPE
class TokenVectorEncoder(object): class TokenVectorEncoder(object):
'''Assign position-sensitive vectors to tokens, using a CNN or RNN.''' '''Assign position-sensitive vectors to tokens, using a CNN or RNN.'''
def __init__(self, vocab, **cfg): def __init__(self, vocab, token_vector_width, **cfg):
self.vocab = vocab self.vocab = vocab
self.model = build_tok2vec(vocab.lang, **cfg) self.doc2feats = doc2feats()
self.model = self.build_model(vocab.lang, token_vector_width, **cfg)
self.tagger = chain( self.tagger = chain(
self.model, self.model,
flatten, Softmax(self.vocab.morphology.n_tags,
Softmax(self.vocab.morphology.n_tags, 64)) token_vector_width))
def build_model(self, lang, width, embed_size=1000, **cfg):
cols = self.doc2feats.cols
with Model.define_operators({'>>': chain, '|': concatenate, '**': clone, '+': add}):
lower = get_col(cols.index(LOWER)) >> (HashEmbed(width, embed_size*3)
+HashEmbed(width, embed_size*3))
prefix = get_col(cols.index(PREFIX)) >> HashEmbed(width, embed_size)
suffix = get_col(cols.index(SUFFIX)) >> HashEmbed(width, embed_size)
shape = get_col(cols.index(SHAPE)) >> HashEmbed(width, embed_size)
tok2vec = (
flatten
>> (lower | prefix | suffix | shape )
>> BN(Maxout(width, pieces=3))
>> Residual(ExtractWindow(nW=1) >> BN(Maxout(width, width*3)))
>> Residual(ExtractWindow(nW=1) >> BN(Maxout(width, width*3)))
>> Residual(ExtractWindow(nW=1) >> BN(Maxout(width, width*3)))
>> Residual(ExtractWindow(nW=1) >> BN(Maxout(width, width*3)))
)
return tok2vec
def pipe(self, docs):
docs = list(docs)
self.predict_tags(docs)
for doc in docs:
yield doc
def __call__(self, doc): def __call__(self, doc):
doc.tensor = self.model([doc])[0]
self.predict_tags([doc]) self.predict_tags([doc])
def begin_update(self, docs, drop=0.): def begin_update(self, feats, drop=0.):
tensors, bp_tensors = self.model.begin_update(docs, drop=drop) tokvecs, bp_tokvecs = self.model.begin_update(feats, drop=drop)
for i, doc in enumerate(docs): return tokvecs, bp_tokvecs
doc.tensor = tensors[i]
self.predict_tags(docs)
return tensors, bp_tensors
def predict_tags(self, docs, drop=0.): def predict_tags(self, docs, drop=0.):
cdef Doc doc cdef Doc doc
scores, _ = self.tagger.begin_update(docs, drop=drop) feats = self.doc2feats(docs)
scores, finish_update = self.tagger.begin_update(feats, drop=drop)
scores, _ = self.tagger.begin_update(feats, drop=drop)
idx = 0 idx = 0
guesses = scores.argmax(axis=1).get()
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
tag_ids = scores[idx:idx+len(doc)].argmax(axis=1) tag_ids = guesses[idx:idx+len(doc)]
for j, tag_id in enumerate(tag_ids): for j, tag_id in enumerate(tag_ids):
doc.vocab.morphology.assign_tag_id(&doc.c[j], tag_id) doc.vocab.morphology.assign_tag_id(&doc.c[j], tag_id)
idx += 1 idx += 1
def update(self, docs, golds, drop=0., sgd=None): def update(self, docs_feats, golds, drop=0., sgd=None):
scores, finish_update = self.tagger.begin_update(docs, drop=drop) cdef int i, j, idx
losses = scores.copy() cdef GoldParse gold
docs, feats = docs_feats
scores, finish_update = self.tagger.begin_update(feats, drop=drop)
tag_index = {tag: i for i, tag in enumerate(docs[0].vocab.morphology.tag_names)}
idx = 0 idx = 0
for i, gold in enumerate(golds): correct = numpy.zeros((scores.shape[0],), dtype='i')
if hasattr(self.tagger.ops.xp, 'scatter_add'): for gold in golds:
ids = numpy.zeros((len(gold),), dtype='i') for tag in gold.tags:
start = idx correct[idx] = tag_index[tag]
for j, tag in enumerate(gold.tags): idx += 1
ids[j] = docs[0].vocab.morphology.tag_names.index(tag) correct = self.model.ops.xp.array(correct)
idx += 1 d_scores = scores - to_categorical(correct, nb_classes=scores.shape[1])
self.tagger.ops.xp.scatter_add(losses[start:idx], ids, -1.0) finish_update(d_scores, sgd)
else:
for j, tag in enumerate(gold.tags):
tag_id = docs[0].vocab.morphology.tag_names.index(tag)
losses[idx, tag_id] -= 1.
idx += 1
finish_update(losses, sgd)
cdef class EntityRecognizer(Parser): cdef class EntityRecognizer(Parser):
@ -76,17 +112,10 @@ cdef class EntityRecognizer(Parser):
""" """
TransitionSystem = BiluoPushDown TransitionSystem = BiluoPushDown
feature_templates = get_feature_templates('ner')
def add_label(self, label): def add_label(self, label):
Parser.add_label(self, label) Parser.add_label(self, label)
if isinstance(label, basestring): if isinstance(label, basestring):
label = self.vocab.strings[label] label = self.vocab.strings[label]
# Set label into serializer. Super hacky :(
for attr, freqs in self.vocab.serializer_freqs:
if attr == ENT_TYPE and label not in freqs:
freqs.append([label, 1])
self.vocab._serializer = None
# #
#cdef class BeamEntityRecognizer(BeamParser): #cdef class BeamEntityRecognizer(BeamParser):
@ -111,17 +140,10 @@ cdef class EntityRecognizer(Parser):
cdef class DependencyParser(Parser): cdef class DependencyParser(Parser):
TransitionSystem = ArcEager TransitionSystem = ArcEager
feature_templates = get_feature_templates('basic')
def add_label(self, label): def add_label(self, label):
Parser.add_label(self, label) Parser.add_label(self, label)
if isinstance(label, basestring): if isinstance(label, basestring):
label = self.vocab.strings[label] label = self.vocab.strings[label]
for attr, freqs in self.vocab.serializer_freqs:
if attr == DEP and label not in freqs:
freqs.append([label, 1])
# Super hacky :(
self.vocab._serializer = None
# #
#cdef class BeamDependencyParser(BeamParser): #cdef class BeamDependencyParser(BeamParser):

View File

@ -1,17 +1,23 @@
"""
MALT-style dependency parser
"""
# coding: utf-8
# cython: infer_types=True # cython: infer_types=True
# cython: profile=True
# coding: utf-8
from __future__ import unicode_literals from __future__ import unicode_literals
from collections import Counter from collections import Counter
import ujson import ujson
from chainer.functions.activation.softmax import Softmax as ChainerSoftmax
from cupy.cuda.stream import Stream
import cupy
from libc.math cimport exp
cimport cython cimport cython
cimport cython.parallel cimport cython.parallel
import cytoolz
import numpy.random import numpy.random
cimport numpy as np
from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF
from cpython.exc cimport PyErr_CheckSignals from cpython.exc cimport PyErr_CheckSignals
@ -28,11 +34,11 @@ from murmurhash.mrmr cimport hash64
from preshed.maps cimport MapStruct from preshed.maps cimport MapStruct
from preshed.maps cimport map_get from preshed.maps cimport map_get
from thinc.api import layerize, chain from thinc.api import layerize, chain
from thinc.neural import Model, Maxout from thinc.neural import Affine, Model, Maxout
from thinc.neural.ops import NumpyOps
from .._ml import PrecomputableAffine, PrecomputableMaxouts from .._ml import zero_init, PrecomputableAffine, PrecomputableMaxouts
from . import _parse_features from . import _parse_features
from ._parse_features cimport CONTEXT_SIZE from ._parse_features cimport CONTEXT_SIZE
from ._parse_features cimport fill_context from ._parse_features cimport fill_context
@ -58,94 +64,133 @@ def set_debug(val):
DEBUG = val DEBUG = val
def get_greedy_model_for_batch(tokvecs, TransitionSystem moves, upper_model, lower_model): def get_greedy_model_for_batch(batch_size, tokvecs, lower_model, cuda_stream=None):
cdef int[:, :] is_valid_ '''Allow a model to be "primed" by pre-computing input features in bulk.
cdef float[:, :] costs_
lengths = [len(t) for t in tokvecs]
tokvecs = upper_model.ops.flatten(tokvecs)
is_valid = upper_model.ops.allocate((len(tokvecs), moves.n_moves), dtype='i')
costs = upper_model.ops.allocate((len(tokvecs), moves.n_moves), dtype='f')
token_ids = upper_model.ops.allocate((len(tokvecs), lower_model.nF), dtype='i')
cached, bp_features = lower_model.begin_update(tokvecs, drop=0.) This is used for the parser, where we want to take a batch of documents,
and compute vectors for each (token, position) pair. These vectors can then
be reused, especially for beam-search.
is_valid_ = is_valid Let's say we're using 12 features for each state, e.g. word at start of
costs_ = costs buffer, three words on stack, their children, etc. In the normal arc-eager
system, a document of length N is processed in 2*N states. This means we'll
create 2*N*12 feature vectors --- but if we pre-compute, we only need
N*12 vector computations. The saving for beam-search is much better:
if we have a beam of k, we'll normally make 2*N*12*K computations --
so we can save the factor k. This also gives a nice CPU/GPU division:
we can do all our hard maths up front, packed into large multiplications,
and do the hard-to-program parsing on the CPU.
'''
gpu_cached, bp_features = lower_model.begin_update(tokvecs, drop=0.)
cdef np.ndarray cached
if not isinstance(gpu_cached, numpy.ndarray):
cached = gpu_cached.get(stream=cuda_stream)
else:
cached = gpu_cached
nF = gpu_cached.shape[1]
nP = gpu_cached.shape[3]
ops = lower_model.ops
features = numpy.zeros((batch_size, cached.shape[2], nP), dtype='f')
synchronized = False
def forward(states_offsets, drop=0.): def forward(token_ids, drop=0.):
nonlocal is_valid, costs, token_ids, moves nonlocal synchronized
states, offsets = states_offsets if not synchronized and cuda_stream is not None:
assert len(states) != 0 cuda_stream.synchronize()
is_valid = is_valid[:len(states)] synchronized = True
costs = costs[:len(states)] # This is tricky, but:
token_ids = token_ids[:len(states)] # - Input to forward on CPU
is_valid = is_valid[:len(states)] # - Output from forward on CPU
cdef StateClass state # - Input to backward on GPU!
cdef int i # - Output from backward on GPU
for i, (offset, state) in enumerate(zip(offsets, states)): nonlocal features
state.set_context_tokens(token_ids[i]) features = features[:len(token_ids)]
moves.set_valid(&is_valid_[i, 0], state.c) features.fill(0)
adjusted_ids = token_ids.copy() cdef float[:, :, ::1] feats = features
for i, offset in enumerate(offsets): cdef int[:, ::1] ids = token_ids
adjusted_ids[i] *= token_ids[i] >= 0 _sum_features(<float*>&feats[0,0,0],
adjusted_ids[i] += offset <float*>cached.data, &ids[0,0],
features = upper_model.ops.allocate((len(states), lower_model.nO), dtype='f') token_ids.shape[0], nF, cached.shape[2]*nP)
for i in range(len(states)):
for j, tok_i in enumerate(adjusted_ids[i]):
if tok_i >= 0:
features[i] += cached[j, tok_i]
scores, bp_scores = upper_model.begin_update(features, drop=drop) if nP >= 2:
scores = upper_model.ops.relu(scores) best, which = ops.maxout(features)
softmaxed = upper_model.ops.softmax(scores) else:
# Renormalize for invalid actions best = features.reshape((features.shape[0], features.shape[1]))
softmaxed *= is_valid which = None
totals = softmaxed.sum(axis=1)
for total in totals:
assert total > 0, (totals, scores, softmaxed)
assert total <= 1.1, totals
softmaxed /= softmaxed.sum(axis=1).reshape((softmaxed.shape[0], 1))
def backward(golds, sgd=None): def backward(d_best, sgd=None):
nonlocal costs_, is_valid_, moves # This will usually be on GPU
cdef int i if isinstance(d_best, numpy.ndarray):
for i, (state, gold) in enumerate(zip(states, golds)): d_best = ops.xp.array(d_best)
moves.set_costs(&is_valid_[i, 0], &costs_[i, 0], if nP >= 2:
state, gold) d_features = ops.backprop_maxout(d_best, which, nP)
d_scores = scores.copy() else:
d_scores.fill(0) d_features = d_best.reshape((d_best.shape[0], d_best.shape[1], 1))
set_log_loss(upper_model.ops, d_scores, d_tokens = bp_features((d_features, token_ids), sgd)
scores, is_valid, costs) return d_tokens
upper_model.ops.backprop_relu(d_scores, scores, inplace=True)
d_features = bp_scores(d_scores, sgd)
d_tokens = bp_features((d_features, adjusted_ids), sgd)
return (token_ids, d_tokens)
return softmaxed, backward return best, backward
return layerize(forward) return forward
def set_log_loss(ops, gradients, scores, is_valid, costs): cdef void _sum_features(float* output,
"""Do multi-label log loss""" const float* cached, const int* token_ids, int B, int F, int O) nogil:
n = gradients.shape[0] cdef int idx, b, f, i
scores = scores * is_valid cdef const float* feature
g_scores = scores * is_valid * (costs <= 0.) for b in range(B):
exps = ops.xp.exp(scores - scores.max(axis=1).reshape((n, 1))) for f in range(F):
exps *= is_valid if token_ids[f] < 0:
g_exps = ops.xp.exp(g_scores - g_scores.max(axis=1).reshape((n, 1))) continue
g_exps *= costs <= 0. idx = token_ids[f] * F * O + f*O
g_exps *= is_valid feature = &cached[idx]
gradients[:] = exps / exps.sum(axis=1).reshape((n, 1)) for i in range(O):
gradients -= g_exps / g_exps.sum(axis=1).reshape((n, 1)) output[i] += feature[i]
output += O
token_ids += F
def transition_batch(TransitionSystem moves, states, scores): def get_batch_loss(TransitionSystem moves, states, golds, float[:, ::1] scores):
cdef StateClass state cdef StateClass state
cdef int guess cdef GoldParse gold
for state, guess in zip(states, scores.argmax(axis=1)): cdef Pool mem = Pool()
action = moves.c[guess] cdef int i
action.do(state.c, action.label) is_valid = <int*>mem.alloc(moves.n_moves, sizeof(int))
costs = <float*>mem.alloc(moves.n_moves, sizeof(float))
cdef np.ndarray d_scores = numpy.zeros((len(states), moves.n_moves), dtype='f')
c_d_scores = <float*>d_scores.data
for i, (state, gold) in enumerate(zip(states, golds)):
memset(is_valid, 0, moves.n_moves * sizeof(int))
memset(costs, 0, moves.n_moves * sizeof(float))
moves.set_costs(is_valid, costs, state, gold)
cpu_log_loss(c_d_scores, costs, is_valid, &scores[i, 0], d_scores.shape[1])
c_d_scores += d_scores.shape[1]
return d_scores
cdef void cpu_log_loss(float* d_scores,
const float* costs, const int* is_valid, const float* scores,
int O) nogil:
"""Do multi-label log loss"""
cdef double max_, gmax, Z, gZ
best = arg_max_if_gold(scores, costs, is_valid, O)
guess = arg_max_if_valid(scores, is_valid, O)
Z = 1e-10
gZ = 1e-10
max_ = scores[guess]
gmax = scores[best]
for i in range(O):
if is_valid[i]:
Z += exp(scores[i] - max_)
if costs[i] <= 0:
gZ += exp(scores[i] - gmax)
for i in range(O):
if not is_valid[i]:
d_scores[i] = 0.
elif costs[i] <= 0:
d_scores[i] = (exp(scores[i]-max_) / Z) - (exp(scores[i]-gmax)/gZ)
else:
d_scores[i] = exp(scores[i]-max_) / Z
def init_states(TransitionSystem moves, docs): def init_states(TransitionSystem moves, docs):
@ -163,6 +208,18 @@ def init_states(TransitionSystem moves, docs):
return states, offsets return states, offsets
def extract_token_ids(states, offsets=None, nF=1, nB=0, nS=2, nL=0, nR=0):
cdef StateClass state
cdef int n_tokens = states[0].nr_context_tokens(nF, nB, nS, nL, nR)
ids = numpy.zeros((len(states), n_tokens), dtype='i')
if offsets is None:
offsets = [0] * len(states)
for i, (state, offset) in enumerate(zip(states, offsets)):
state.set_context_tokens(ids[i], nF, nB, nS, nL, nR)
ids[i] += (ids[i] >= 0) * offset
return ids
cdef class Parser: cdef class Parser:
""" """
Base class of the DependencyParser and EntityRecognizer. Base class of the DependencyParser and EntityRecognizer.
@ -218,11 +275,19 @@ cdef class Parser:
def __reduce__(self): def __reduce__(self):
return (Parser, (self.vocab, self.moves, self.model), None, None) return (Parser, (self.vocab, self.moves, self.model), None, None)
def build_model(self, width=64, nr_vector=1000, nF=1, nB=1, nS=1, nL=1, nR=1, **_): def build_model(self,
hidden_width=128, token_vector_width=96, nr_vector=1000,
nF=1, nB=1, nS=1, nL=1, nR=1, **cfg):
nr_context_tokens = StateClass.nr_context_tokens(nF, nB, nS, nL, nR) nr_context_tokens = StateClass.nr_context_tokens(nF, nB, nS, nL, nR)
with Model.use_device('cpu'):
upper = chain(Maxout(width, width), Maxout(self.moves.n_moves, width)) upper = chain(
lower = PrecomputableMaxouts(width, nF=nr_context_tokens, nI=width) Maxout(token_vector_width),
zero_init(Affine(self.moves.n_moves, token_vector_width)))
assert isinstance(upper.ops, NumpyOps)
lower = PrecomputableMaxouts(token_vector_width, nF=nr_context_tokens, nI=token_vector_width,
pieces=cfg.get('maxout_pieces', 1))
upper.begin_training(upper.ops.allocate((500, token_vector_width)))
lower.begin_training(lower.ops.allocate((500, token_vector_width)))
return upper, lower return upper, lower
def __call__(self, Doc tokens): def __call__(self, Doc tokens):
@ -248,12 +313,6 @@ cdef class Parser:
The number of threads with which to work on the buffer in parallel. The number of threads with which to work on the buffer in parallel.
Yields (Doc): Documents, in order. Yields (Doc): Documents, in order.
""" """
cdef Pool mem = Pool()
cdef int* lengths = <int*>mem.alloc(batch_size, sizeof(int))
cdef Doc doc
cdef int i
cdef int nr_feat = self.model.nr_feat
cdef int status
queue = [] queue = []
for doc in stream: for doc in stream:
queue.append(doc) queue.append(doc)
@ -269,54 +328,110 @@ cdef class Parser:
self.moves.finalize_doc(doc) self.moves.finalize_doc(doc)
yield doc yield doc
def parse_batch(self, docs): def parse_batch(self, docs_tokvecs):
cdef Doc doc cdef:
cdef StateClass state int nC
model = get_greedy_model_for_batch([d.tensor for d in docs], Doc doc
self.moves, self.model, self.feature_maps) StateClass state
np.ndarray py_scores
int[500] is_valid # Hacks for now
cuda_stream = Stream()
docs, tokvecs = docs_tokvecs
lower_model = get_greedy_model_for_batch(len(docs), tokvecs, self.feature_maps,
cuda_stream)
upper_model = self.model
states, offsets = init_states(self.moves, docs) states, offsets = init_states(self.moves, docs)
all_states = list(states) all_states = list(states)
todo = list(zip(states, offsets)) todo = [st for st in zip(states, offsets) if not st[0].py_is_final()]
while todo: while todo:
todo = filter(lambda sp: not sp[0].py_is_final(), todo)
if not todo:
break
states, offsets = zip(*todo) states, offsets = zip(*todo)
scores = model((states, offsets)) token_ids = extract_token_ids(states, offsets=offsets)
transition_batch(self.moves, states, scores)
py_scores = upper_model(lower_model(token_ids)[0])
scores = <float*>py_scores.data
nC = py_scores.shape[1]
for state, offset in zip(states, offsets):
self.moves.set_valid(is_valid, state.c)
guess = arg_max_if_valid(scores, is_valid, nC)
action = self.moves.c[guess]
action.do(state.c, action.label)
scores += nC
todo = [st for st in todo if not st[0].py_is_final()] todo = [st for st in todo if not st[0].py_is_final()]
for state, doc in zip(all_states, docs): for state, doc in zip(all_states, docs):
self.moves.finalize_state(state.c) self.moves.finalize_state(state.c)
for i in range(doc.length): for i in range(doc.length):
doc.c[i] = state.c._sent[i] doc.c[i] = state.c._sent[i]
for doc in docs:
self.moves.finalize_doc(doc) self.moves.finalize_doc(doc)
def update(self, docs, golds, drop=0., sgd=None): def update(self, docs_tokvecs, golds, drop=0., sgd=None):
cdef:
int nC
int[500] is_valid # Hack for now
Doc doc
StateClass state
np.ndarray scores
docs, tokvecs = docs_tokvecs
cuda_stream = Stream()
if isinstance(docs, Doc) and isinstance(golds, GoldParse): if isinstance(docs, Doc) and isinstance(golds, GoldParse):
return self.update([docs], [golds], drop=drop) return self.update(([docs], tokvecs), [golds], drop=drop)
for gold in golds: for gold in golds:
self.moves.preprocess_gold(gold) self.moves.preprocess_gold(gold)
model = get_greedy_model_for_batch([d.tensor for d in docs],
self.moves, self.model, self.feature_maps)
states, offsets = init_states(self.moves, docs) states, offsets = init_states(self.moves, docs)
d_tokens = [self.model.ops.allocate(d.tensor.shape) for d in docs] todo = zip(states, offsets, golds)
output = list(d_tokens) todo = filter(lambda sp: not sp[0].py_is_final(), todo)
todo = zip(states, offsets, golds, d_tokens)
lower_model = get_greedy_model_for_batch(len(todo),
tokvecs, self.feature_maps, cuda_stream=cuda_stream)
upper_model = self.model
d_tokens = self.feature_maps.ops.allocate(tokvecs.shape)
backprops = []
n_tokens = tokvecs.shape[0]
nF = self.feature_maps.nF
while todo: while todo:
# Get unfinished states (and their matching gold and token gradients) states, offsets, golds = zip(*todo)
token_ids = extract_token_ids(states, offsets=offsets)
lower, bp_lower = lower_model(token_ids)
scores, bp_scores = upper_model.begin_update(lower)
d_scores = get_batch_loss(self.moves, states, golds, scores)
d_lower = bp_scores(d_scores, sgd=sgd)
gpu_tok_ids = cupy.ndarray(token_ids.shape, dtype='i')
gpu_d_lower = cupy.ndarray(d_lower.shape, dtype='f')
gpu_tok_ids.set(token_ids, stream=cuda_stream)
gpu_d_lower.set(d_lower, stream=cuda_stream)
backprops.append((gpu_tok_ids, gpu_d_lower, bp_lower))
c_scores = <float*>scores.data
for state in states:
self.moves.set_valid(is_valid, state.c)
guess = arg_max_if_valid(c_scores, is_valid, scores.shape[1])
action = self.moves.c[guess]
action.do(state.c, action.label)
c_scores += scores.shape[1]
todo = filter(lambda sp: not sp[0].py_is_final(), todo) todo = filter(lambda sp: not sp[0].py_is_final(), todo)
if not todo: # This tells CUDA to block --- so we know our copies are complete.
break cuda_stream.synchronize()
states, offsets, golds, d_tokens = zip(*todo) for token_ids, d_lower, bp_lower in backprops:
scores, finish_update = model.begin_update((states, offsets)) d_state_features = bp_lower(d_lower, sgd=sgd)
(token_ids, d_state_features) = finish_update(golds, sgd=sgd) active_feats = token_ids * (token_ids >= 0)
for i, token_ids in enumerate(token_ids): active_feats = active_feats.reshape((token_ids.shape[0], token_ids.shape[1], 1))
d_tokens[i][token_ids] += d_state_features[i] if hasattr(self.feature_maps.ops.xp, 'scatter_add'):
transition_batch(self.moves, states, scores) self.feature_maps.ops.xp.scatter_add(d_tokens,
return output token_ids, d_state_features * active_feats)
else:
self.model.ops.xp.add.at(d_tokens,
token_ids, d_state_features * active_feats)
return d_tokens
def step_through(self, Doc doc, GoldParse gold=None): def step_through(self, Doc doc, GoldParse gold=None):
""" """
@ -464,3 +579,39 @@ class ParserStateError(ValueError):
"https://github.com/spacy-io/spaCy/issues/429\n" "https://github.com/spacy-io/spaCy/issues/429\n"
"Please include the text that the parser failed on, which is:\n" "Please include the text that the parser failed on, which is:\n"
"%s" % repr(doc.text)) "%s" % repr(doc.text))
cdef int arg_max_if_gold(const weight_t* scores, const weight_t* costs, const int* is_valid, int n) nogil:
# Find minimum cost
cdef float cost = 1
for i in range(n):
if is_valid[i] and costs[i] < cost:
cost = costs[i]
# Now find best-scoring with that cost
cdef int best = -1
for i in range(n):
if costs[i] <= cost and is_valid[i]:
if best == -1 or scores[i] > scores[best]:
best = i
return best
cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) nogil:
cdef int best = -1
for i in range(n):
if is_valid[i] >= 1:
if best == -1 or scores[i] > scores[best]:
best = i
return best
cdef int _arg_max_clas(const weight_t* scores, int move, const Transition* actions,
int nr_class) except -1:
cdef weight_t score = 0
cdef int mode = -1
cdef int i
for i in range(nr_class):
if actions[i].move == move and (mode == -1 or scores[i] >= score):
mode = i
score = scores[i]
return mode

View File

@ -48,7 +48,7 @@ cdef class StateClass:
@classmethod @classmethod
def nr_context_tokens(cls, int nF, int nB, int nS, int nL, int nR): def nr_context_tokens(cls, int nF, int nB, int nS, int nL, int nR):
return 11 return 13
def set_context_tokens(self, int[:] output, nF=1, nB=0, nS=2, def set_context_tokens(self, int[:] output, nF=1, nB=0, nS=2,
nL=2, nR=2): nL=2, nR=2):
@ -56,14 +56,16 @@ cdef class StateClass:
output[1] = self.B(1) output[1] = self.B(1)
output[2] = self.S(0) output[2] = self.S(0)
output[3] = self.S(1) output[3] = self.S(1)
output[4] = self.L(self.S(0), 1) output[4] = self.S(2)
output[5] = self.L(self.S(0), 2) output[5] = self.L(self.S(0), 1)
output[6] = self.L(self.S(0), 2)
output[6] = self.R(self.S(0), 1) output[6] = self.R(self.S(0), 1)
output[7] = self.R(self.S(0), 2) output[7] = self.L(self.B(0), 1)
output[7] = self.L(self.S(1), 1) output[8] = self.R(self.S(0), 2)
output[8] = self.L(self.S(1), 2) output[9] = self.L(self.S(1), 1)
output[9] = self.R(self.S(1), 1) output[10] = self.L(self.S(1), 2)
output[10] = self.R(self.S(1), 2) output[11] = self.R(self.S(1), 1)
output[12] = self.R(self.S(1), 2)
def set_attributes(self, uint64_t[:, :] vals, int[:] tokens, int[:] names): def set_attributes(self, uint64_t[:, :] vals, int[:] tokens, int[:] names):
cdef int i, j, tok_i cdef int i, j, tok_i