mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
Update draft of parser neural network model
Model is good, but code is messy. Currently requires Chainer, which may cause the build to fail on machines without a GPU. Outline of the model: We first predict context-sensitive vectors for each word in the input: (embed_lower | embed_prefix | embed_suffix | embed_shape) >> Maxout(token_width) >> convolution ** 4 This convolutional layer is shared between the tagger and the parser. This prevents the parser from needing tag features. To boost the representation, we make a "super tag" with POS, morphology and dependency label. The tagger predicts this by adding a softmax layer onto the convolutional layer --- so, we're teaching the convolutional layer to give us a representation that's one affine transform from this informative lexical information. This is obviously good for the parser (which backprops to the convolutions too). The parser model makes a state vector by concatenating the vector representations for its context tokens. Current results suggest few context tokens works well. Maybe this is a bug. The current context tokens: * S0, S1, S2: Top three words on the stack * B0, B1: First two words of the buffer * S0L1, S0L2: Leftmost and second leftmost children of S0 * S0R1, S0R2: Rightmost and second rightmost children of S0 * S1L1, S1L2, S1R2, S1R, B0L1, B0L2: Likewise for S1 and B0 This makes the state vector quite long: 13*T, where T is the token vector width (128 is working well). Fortunately, there's a way to structure the computation to save some expense (and make it more GPU friendly). The parser typically visits 2*N states for a sentence of length N (although it may visit more, if it back-tracks with a non-monotonic transition). A naive implementation would require 2*N (B, 13*T) @ (13*T, H) matrix multiplications for a batch of size B. We can instead perform one (B*N, T) @ (T, 13*H) multiplication, to pre-compute the hidden weights for each positional feature wrt the words in the batch. (Note that our token vectors come from the CNN -- so we can't play this trick over the vocabulary. That's how Stanford's NN parser works --- and why its model is so big.) This pre-computation strategy allows a nice compromise between GPU-friendliness and implementation simplicity. The CNN and the wide lower layer are computed on the GPU, and then the precomputed hidden weights are moved to the CPU, before we start the transition-based parsing process. This makes a lot of things much easier. We don't have to worry about variable-length batch sizes, and we don't have to implement the dynamic oracle in CUDA to train. Currently the parser's loss function is multilabel log loss, as the dynamic oracle allows multiple states to be 0 cost. This is defined as: (exp(score) / Z) - (exp(score) / gZ) Where gZ is the sum of the scores assigned to gold classes. I'm very interested in regressing on the cost directly, but so far this isn't working well. Machinery is in place for beam-search, which has been working well for the linear model. Beam search should benefit greatly from the pre-computation trick.
This commit is contained in:
parent
b44f7e259c
commit
827b5af697
137
spacy/_ml.py
137
spacy/_ml.py
|
@ -1,12 +1,12 @@
|
|||
from thinc.api import add, layerize, chain, clone, concatenate, with_flatten
|
||||
from thinc.neural import Model, ReLu, Maxout, Softmax, Affine
|
||||
from thinc.neural import Model, Maxout, Softmax, Affine
|
||||
from thinc.neural._classes.hash_embed import HashEmbed
|
||||
from thinc.neural.ops import NumpyOps, CupyOps
|
||||
|
||||
from thinc.neural._classes.convolution import ExtractWindow
|
||||
from thinc.neural._classes.static_vectors import StaticVectors
|
||||
from thinc.neural._classes.batchnorm import BatchNorm
|
||||
from thinc.neural._classes.resnet import Residual
|
||||
|
||||
from thinc import describe
|
||||
from thinc.describe import Dimension, Synapses, Biases, Gradient
|
||||
from thinc.neural._classes.affine import _set_dimensions_if_needed
|
||||
|
@ -78,7 +78,7 @@ class PrecomputableAffine(Model):
|
|||
d_b=Gradient("b")
|
||||
)
|
||||
class PrecomputableMaxouts(Model):
|
||||
def __init__(self, nO=None, nI=None, nF=None, pieces=2, **kwargs):
|
||||
def __init__(self, nO=None, nI=None, nF=None, pieces=3, **kwargs):
|
||||
Model.__init__(self, **kwargs)
|
||||
self.nO = nO
|
||||
self.nP = pieces
|
||||
|
@ -87,88 +87,71 @@ class PrecomputableMaxouts(Model):
|
|||
|
||||
def begin_update(self, X, drop=0.):
|
||||
# X: (b, i)
|
||||
# Yfp: (f, b, o, p)
|
||||
# Yf: (f, b, o)
|
||||
# Xf: (b, f, i)
|
||||
# dY: (b, o)
|
||||
# Yfp: (b, f, o, p)
|
||||
# Xf: (f, b, i)
|
||||
# dYp: (b, o, p)
|
||||
# W: (f, o, p, i)
|
||||
# b: (o, p)
|
||||
|
||||
# Yfp = numpy.einsum('bi,fopi->fbop', X, self.W)
|
||||
Yfp = self.ops.xp.tensordot(X, self.W,
|
||||
axes=[[1], [3]]).transpose((1, 0, 2, 3))
|
||||
Yfp = self.ops.xp.ascontiguousarray(Yfp)
|
||||
# bi,opfi->bfop
|
||||
# bop,fopi->bfi
|
||||
# bop,fbi->opfi : fopi
|
||||
|
||||
tensordot = self.ops.xp.tensordot
|
||||
ascontiguous = self.ops.xp.ascontiguousarray
|
||||
|
||||
Yfp = tensordot(X, self.W, axes=[[1], [3]])
|
||||
Yfp += self.b
|
||||
Yf = self.ops.allocate((self.nF, X.shape[0], self.nO))
|
||||
which = self.ops.allocate((self.nF, X.shape[0], self.nO), dtype='i')
|
||||
for i in range(self.nF):
|
||||
Yf[i], which[i] = self.ops.maxout(Yfp[i])
|
||||
def backward(dY_ids, sgd=None):
|
||||
dY, ids = dY_ids
|
||||
|
||||
def backward(dYp_ids, sgd=None):
|
||||
dYp, ids = dYp_ids
|
||||
Xf = X[ids]
|
||||
dYp = self.ops.allocate((dY.shape[0], self.nO, self.nP))
|
||||
for i in range(self.nF):
|
||||
dYp += self.ops.backprop_maxout(dY, which[i], self.nP)
|
||||
|
||||
#dXf = numpy.einsum('bop,fopi->bfi', dYp, self.W)
|
||||
dXf = self.ops.xp.tensordot(dYp, self.W, axes=[[1,2], [1,2]])
|
||||
#dW = numpy.einsum('bfi,bop->fopi', Xf, dYp)
|
||||
dW = self.ops.xp.tensordot(Xf, dYp, axes=[[0], [0]])
|
||||
dW = dW.transpose((0, 2, 3, 1))
|
||||
db = dYp.sum(axis=0)
|
||||
dXf = tensordot(dYp, self.W, axes=[[1, 2], [1,2]])
|
||||
dW = tensordot(dYp, Xf, axes=[[0], [0]])
|
||||
|
||||
self.d_W += dW
|
||||
self.d_b += db
|
||||
self.d_W += dW.transpose((2, 0, 1, 3))
|
||||
self.d_b += dYp.sum(axis=0)
|
||||
|
||||
if sgd is not None:
|
||||
sgd(self._mem.weights, self._mem.gradient, key=self.id)
|
||||
return dXf
|
||||
return Yf, backward
|
||||
return Yfp, backward
|
||||
|
||||
|
||||
def get_col(idx):
|
||||
def forward(X, drop=0.):
|
||||
if isinstance(X, numpy.ndarray):
|
||||
ops = NumpyOps()
|
||||
else:
|
||||
ops = CupyOps()
|
||||
assert len(X.shape) <= 3
|
||||
output = Model.ops.xp.ascontiguousarray(X[:, idx])
|
||||
output = ops.xp.ascontiguousarray(X[:, idx])
|
||||
def backward(y, sgd=None):
|
||||
dX = Model.ops.allocate(X.shape)
|
||||
dX = ops.allocate(X.shape)
|
||||
dX[:, idx] += y
|
||||
return dX
|
||||
return output, backward
|
||||
return layerize(forward)
|
||||
|
||||
|
||||
def build_tok2vec(lang, width, depth=2, embed_size=1000):
|
||||
def zero_init(model):
|
||||
def _hook(self, X, y=None):
|
||||
self.W.fill(0)
|
||||
model.on_data_hooks.append(_hook)
|
||||
return model
|
||||
|
||||
|
||||
def doc2feats(cols=None):
|
||||
cols = [ID, LOWER, PREFIX, SUFFIX, SHAPE]
|
||||
with Model.define_operators({'>>': chain, '|': concatenate, '**': clone}):
|
||||
#static = get_col(cols.index(ID)) >> StaticVectors(lang, width)
|
||||
lower = get_col(cols.index(LOWER)) >> HashEmbed(width, embed_size)
|
||||
prefix = get_col(cols.index(PREFIX)) >> HashEmbed(width//4, embed_size)
|
||||
suffix = get_col(cols.index(SUFFIX)) >> HashEmbed(width//4, embed_size)
|
||||
shape = get_col(cols.index(SHAPE)) >> HashEmbed(width//4, embed_size)
|
||||
tok2vec = (
|
||||
doc2feats(cols)
|
||||
>> with_flatten(
|
||||
#(static | prefix | suffix | shape)
|
||||
(lower | prefix | suffix | shape)
|
||||
>> Maxout(width)
|
||||
>> (ExtractWindow(nW=1) >> Maxout(width, width*3))
|
||||
>> (ExtractWindow(nW=1) >> Maxout(width, width*3))
|
||||
)
|
||||
)
|
||||
return tok2vec
|
||||
|
||||
|
||||
def doc2feats(cols):
|
||||
def forward(docs, drop=0.):
|
||||
feats = [doc.to_array(cols) for doc in docs]
|
||||
feats = [model.ops.asarray(f, dtype='uint64') for f in feats]
|
||||
return feats, None
|
||||
model = layerize(forward)
|
||||
model.cols = cols
|
||||
return model
|
||||
|
||||
|
||||
def print_shape(prefix):
|
||||
def forward(X, drop=0.):
|
||||
return X, lambda dX, **kwargs: dX
|
||||
|
@ -186,52 +169,12 @@ def get_token_vectors(tokens_attrs_vectors, drop=0.):
|
|||
|
||||
@layerize
|
||||
def flatten(seqs, drop=0.):
|
||||
ops = Model.ops
|
||||
if isinstance(seqs[0], numpy.ndarray):
|
||||
ops = NumpyOps()
|
||||
else:
|
||||
ops = CupyOps()
|
||||
lengths = [len(seq) for seq in seqs]
|
||||
def finish_update(d_X, sgd=None):
|
||||
return ops.unflatten(d_X, lengths)
|
||||
X = ops.xp.vstack(seqs)
|
||||
return X, finish_update
|
||||
|
||||
|
||||
#def build_feature_precomputer(model, feat_maps):
|
||||
# '''Allow a model to be "primed" by pre-computing input features in bulk.
|
||||
#
|
||||
# This is used for the parser, where we want to take a batch of documents,
|
||||
# and compute vectors for each (token, position) pair. These vectors can then
|
||||
# be reused, especially for beam-search.
|
||||
#
|
||||
# Let's say we're using 12 features for each state, e.g. word at start of
|
||||
# buffer, three words on stack, their children, etc. In the normal arc-eager
|
||||
# system, a document of length N is processed in 2*N states. This means we'll
|
||||
# create 2*N*12 feature vectors --- but if we pre-compute, we only need
|
||||
# N*12 vector computations. The saving for beam-search is much better:
|
||||
# if we have a beam of k, we'll normally make 2*N*12*K computations --
|
||||
# so we can save the factor k. This also gives a nice CPU/GPU division:
|
||||
# we can do all our hard maths up front, packed into large multiplications,
|
||||
# and do the hard-to-program parsing on the CPU.
|
||||
# '''
|
||||
# def precompute(input_vectors):
|
||||
# cached, backprops = zip(*[lyr.begin_update(input_vectors)
|
||||
# for lyr in feat_maps)
|
||||
# def forward(batch_token_ids, drop=0.):
|
||||
# output = ops.allocate((batch_size, output_width))
|
||||
# # i: batch index
|
||||
# # j: position index (i.e. N0, S0, etc
|
||||
# # tok_i: Index of the token within its document
|
||||
# for i, token_ids in enumerate(batch_token_ids):
|
||||
# for j, tok_i in enumerate(token_ids):
|
||||
# output[i] += cached[j][tok_i]
|
||||
# def backward(d_vector, sgd=None):
|
||||
# d_inputs = ops.allocate((batch_size, n_feat, vec_width))
|
||||
# for i, token_ids in enumerate(batch_token_ids):
|
||||
# for j in range(len(token_ids)):
|
||||
# d_inputs[i][j] = backprops[j](d_vector, sgd)
|
||||
# # Return the IDs, so caller can associate to correct token
|
||||
# return (batch_token_ids, d_inputs)
|
||||
# return vector, backward
|
||||
# return chain(layerize(forward), model)
|
||||
# return precompute
|
||||
#
|
||||
#
|
||||
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
# cython: infer_types=True
|
||||
# cython: profile=True
|
||||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from thinc.api import chain, layerize, with_getitem
|
||||
from thinc.neural import Model, Softmax
|
||||
import numpy
|
||||
cimport numpy as np
|
||||
|
||||
from .tokens.doc cimport Doc
|
||||
from .syntax.parser cimport Parser
|
||||
|
@ -11,63 +14,96 @@ from .syntax.parser cimport Parser
|
|||
from .syntax.ner cimport BiluoPushDown
|
||||
from .syntax.arc_eager cimport ArcEager
|
||||
from .tagger import Tagger
|
||||
from ._ml import build_tok2vec, flatten
|
||||
from .gold cimport GoldParse
|
||||
|
||||
from thinc.api import add, layerize, chain, clone, concatenate
|
||||
from thinc.neural import Model, Maxout, Softmax, Affine
|
||||
from thinc.neural._classes.hash_embed import HashEmbed
|
||||
from thinc.neural.util import to_categorical
|
||||
|
||||
from thinc.neural._classes.convolution import ExtractWindow
|
||||
from thinc.neural._classes.resnet import Residual
|
||||
from thinc.neural._classes.batchnorm import BatchNorm as BN
|
||||
|
||||
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP
|
||||
from ._ml import flatten, get_col, doc2feats
|
||||
|
||||
# TODO: The disorganization here is pretty embarrassing. At least it's only
|
||||
# internals.
|
||||
from .syntax.parser import get_templates as get_feature_templates
|
||||
from .attrs import DEP, ENT_TYPE
|
||||
|
||||
|
||||
class TokenVectorEncoder(object):
|
||||
'''Assign position-sensitive vectors to tokens, using a CNN or RNN.'''
|
||||
def __init__(self, vocab, **cfg):
|
||||
def __init__(self, vocab, token_vector_width, **cfg):
|
||||
self.vocab = vocab
|
||||
self.model = build_tok2vec(vocab.lang, **cfg)
|
||||
self.doc2feats = doc2feats()
|
||||
self.model = self.build_model(vocab.lang, token_vector_width, **cfg)
|
||||
self.tagger = chain(
|
||||
self.model,
|
||||
flatten,
|
||||
Softmax(self.vocab.morphology.n_tags, 64))
|
||||
Softmax(self.vocab.morphology.n_tags,
|
||||
token_vector_width))
|
||||
|
||||
def build_model(self, lang, width, embed_size=1000, **cfg):
|
||||
cols = self.doc2feats.cols
|
||||
with Model.define_operators({'>>': chain, '|': concatenate, '**': clone, '+': add}):
|
||||
lower = get_col(cols.index(LOWER)) >> (HashEmbed(width, embed_size*3)
|
||||
+HashEmbed(width, embed_size*3))
|
||||
prefix = get_col(cols.index(PREFIX)) >> HashEmbed(width, embed_size)
|
||||
suffix = get_col(cols.index(SUFFIX)) >> HashEmbed(width, embed_size)
|
||||
shape = get_col(cols.index(SHAPE)) >> HashEmbed(width, embed_size)
|
||||
|
||||
tok2vec = (
|
||||
flatten
|
||||
>> (lower | prefix | suffix | shape )
|
||||
>> BN(Maxout(width, pieces=3))
|
||||
>> Residual(ExtractWindow(nW=1) >> BN(Maxout(width, width*3)))
|
||||
>> Residual(ExtractWindow(nW=1) >> BN(Maxout(width, width*3)))
|
||||
>> Residual(ExtractWindow(nW=1) >> BN(Maxout(width, width*3)))
|
||||
>> Residual(ExtractWindow(nW=1) >> BN(Maxout(width, width*3)))
|
||||
)
|
||||
return tok2vec
|
||||
|
||||
def pipe(self, docs):
|
||||
docs = list(docs)
|
||||
self.predict_tags(docs)
|
||||
for doc in docs:
|
||||
yield doc
|
||||
|
||||
def __call__(self, doc):
|
||||
doc.tensor = self.model([doc])[0]
|
||||
self.predict_tags([doc])
|
||||
|
||||
def begin_update(self, docs, drop=0.):
|
||||
tensors, bp_tensors = self.model.begin_update(docs, drop=drop)
|
||||
for i, doc in enumerate(docs):
|
||||
doc.tensor = tensors[i]
|
||||
self.predict_tags(docs)
|
||||
return tensors, bp_tensors
|
||||
def begin_update(self, feats, drop=0.):
|
||||
tokvecs, bp_tokvecs = self.model.begin_update(feats, drop=drop)
|
||||
return tokvecs, bp_tokvecs
|
||||
|
||||
def predict_tags(self, docs, drop=0.):
|
||||
cdef Doc doc
|
||||
scores, _ = self.tagger.begin_update(docs, drop=drop)
|
||||
feats = self.doc2feats(docs)
|
||||
scores, finish_update = self.tagger.begin_update(feats, drop=drop)
|
||||
scores, _ = self.tagger.begin_update(feats, drop=drop)
|
||||
idx = 0
|
||||
guesses = scores.argmax(axis=1).get()
|
||||
for i, doc in enumerate(docs):
|
||||
tag_ids = scores[idx:idx+len(doc)].argmax(axis=1)
|
||||
tag_ids = guesses[idx:idx+len(doc)]
|
||||
for j, tag_id in enumerate(tag_ids):
|
||||
doc.vocab.morphology.assign_tag_id(&doc.c[j], tag_id)
|
||||
idx += 1
|
||||
|
||||
def update(self, docs, golds, drop=0., sgd=None):
|
||||
scores, finish_update = self.tagger.begin_update(docs, drop=drop)
|
||||
losses = scores.copy()
|
||||
def update(self, docs_feats, golds, drop=0., sgd=None):
|
||||
cdef int i, j, idx
|
||||
cdef GoldParse gold
|
||||
docs, feats = docs_feats
|
||||
scores, finish_update = self.tagger.begin_update(feats, drop=drop)
|
||||
|
||||
tag_index = {tag: i for i, tag in enumerate(docs[0].vocab.morphology.tag_names)}
|
||||
|
||||
idx = 0
|
||||
for i, gold in enumerate(golds):
|
||||
if hasattr(self.tagger.ops.xp, 'scatter_add'):
|
||||
ids = numpy.zeros((len(gold),), dtype='i')
|
||||
start = idx
|
||||
for j, tag in enumerate(gold.tags):
|
||||
ids[j] = docs[0].vocab.morphology.tag_names.index(tag)
|
||||
idx += 1
|
||||
self.tagger.ops.xp.scatter_add(losses[start:idx], ids, -1.0)
|
||||
else:
|
||||
for j, tag in enumerate(gold.tags):
|
||||
tag_id = docs[0].vocab.morphology.tag_names.index(tag)
|
||||
losses[idx, tag_id] -= 1.
|
||||
idx += 1
|
||||
finish_update(losses, sgd)
|
||||
correct = numpy.zeros((scores.shape[0],), dtype='i')
|
||||
for gold in golds:
|
||||
for tag in gold.tags:
|
||||
correct[idx] = tag_index[tag]
|
||||
idx += 1
|
||||
correct = self.model.ops.xp.array(correct)
|
||||
d_scores = scores - to_categorical(correct, nb_classes=scores.shape[1])
|
||||
finish_update(d_scores, sgd)
|
||||
|
||||
|
||||
cdef class EntityRecognizer(Parser):
|
||||
|
@ -76,17 +112,10 @@ cdef class EntityRecognizer(Parser):
|
|||
"""
|
||||
TransitionSystem = BiluoPushDown
|
||||
|
||||
feature_templates = get_feature_templates('ner')
|
||||
|
||||
def add_label(self, label):
|
||||
Parser.add_label(self, label)
|
||||
if isinstance(label, basestring):
|
||||
label = self.vocab.strings[label]
|
||||
# Set label into serializer. Super hacky :(
|
||||
for attr, freqs in self.vocab.serializer_freqs:
|
||||
if attr == ENT_TYPE and label not in freqs:
|
||||
freqs.append([label, 1])
|
||||
self.vocab._serializer = None
|
||||
|
||||
#
|
||||
#cdef class BeamEntityRecognizer(BeamParser):
|
||||
|
@ -111,17 +140,10 @@ cdef class EntityRecognizer(Parser):
|
|||
cdef class DependencyParser(Parser):
|
||||
TransitionSystem = ArcEager
|
||||
|
||||
feature_templates = get_feature_templates('basic')
|
||||
|
||||
def add_label(self, label):
|
||||
Parser.add_label(self, label)
|
||||
if isinstance(label, basestring):
|
||||
label = self.vocab.strings[label]
|
||||
for attr, freqs in self.vocab.serializer_freqs:
|
||||
if attr == DEP and label not in freqs:
|
||||
freqs.append([label, 1])
|
||||
# Super hacky :(
|
||||
self.vocab._serializer = None
|
||||
|
||||
#
|
||||
#cdef class BeamDependencyParser(BeamParser):
|
||||
|
|
|
@ -1,17 +1,23 @@
|
|||
"""
|
||||
MALT-style dependency parser
|
||||
"""
|
||||
# coding: utf-8
|
||||
# cython: infer_types=True
|
||||
# cython: profile=True
|
||||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from collections import Counter
|
||||
import ujson
|
||||
|
||||
from chainer.functions.activation.softmax import Softmax as ChainerSoftmax
|
||||
|
||||
from cupy.cuda.stream import Stream
|
||||
import cupy
|
||||
|
||||
from libc.math cimport exp
|
||||
cimport cython
|
||||
cimport cython.parallel
|
||||
import cytoolz
|
||||
|
||||
import numpy.random
|
||||
cimport numpy as np
|
||||
|
||||
from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF
|
||||
from cpython.exc cimport PyErr_CheckSignals
|
||||
|
@ -28,11 +34,11 @@ from murmurhash.mrmr cimport hash64
|
|||
from preshed.maps cimport MapStruct
|
||||
from preshed.maps cimport map_get
|
||||
|
||||
|
||||
from thinc.api import layerize, chain
|
||||
from thinc.neural import Model, Maxout
|
||||
from thinc.neural import Affine, Model, Maxout
|
||||
from thinc.neural.ops import NumpyOps
|
||||
|
||||
from .._ml import PrecomputableAffine, PrecomputableMaxouts
|
||||
from .._ml import zero_init, PrecomputableAffine, PrecomputableMaxouts
|
||||
from . import _parse_features
|
||||
from ._parse_features cimport CONTEXT_SIZE
|
||||
from ._parse_features cimport fill_context
|
||||
|
@ -58,94 +64,133 @@ def set_debug(val):
|
|||
DEBUG = val
|
||||
|
||||
|
||||
def get_greedy_model_for_batch(tokvecs, TransitionSystem moves, upper_model, lower_model):
|
||||
cdef int[:, :] is_valid_
|
||||
cdef float[:, :] costs_
|
||||
lengths = [len(t) for t in tokvecs]
|
||||
tokvecs = upper_model.ops.flatten(tokvecs)
|
||||
is_valid = upper_model.ops.allocate((len(tokvecs), moves.n_moves), dtype='i')
|
||||
costs = upper_model.ops.allocate((len(tokvecs), moves.n_moves), dtype='f')
|
||||
token_ids = upper_model.ops.allocate((len(tokvecs), lower_model.nF), dtype='i')
|
||||
def get_greedy_model_for_batch(batch_size, tokvecs, lower_model, cuda_stream=None):
|
||||
'''Allow a model to be "primed" by pre-computing input features in bulk.
|
||||
|
||||
cached, bp_features = lower_model.begin_update(tokvecs, drop=0.)
|
||||
This is used for the parser, where we want to take a batch of documents,
|
||||
and compute vectors for each (token, position) pair. These vectors can then
|
||||
be reused, especially for beam-search.
|
||||
|
||||
is_valid_ = is_valid
|
||||
costs_ = costs
|
||||
Let's say we're using 12 features for each state, e.g. word at start of
|
||||
buffer, three words on stack, their children, etc. In the normal arc-eager
|
||||
system, a document of length N is processed in 2*N states. This means we'll
|
||||
create 2*N*12 feature vectors --- but if we pre-compute, we only need
|
||||
N*12 vector computations. The saving for beam-search is much better:
|
||||
if we have a beam of k, we'll normally make 2*N*12*K computations --
|
||||
so we can save the factor k. This also gives a nice CPU/GPU division:
|
||||
we can do all our hard maths up front, packed into large multiplications,
|
||||
and do the hard-to-program parsing on the CPU.
|
||||
'''
|
||||
gpu_cached, bp_features = lower_model.begin_update(tokvecs, drop=0.)
|
||||
cdef np.ndarray cached
|
||||
if not isinstance(gpu_cached, numpy.ndarray):
|
||||
cached = gpu_cached.get(stream=cuda_stream)
|
||||
else:
|
||||
cached = gpu_cached
|
||||
nF = gpu_cached.shape[1]
|
||||
nP = gpu_cached.shape[3]
|
||||
ops = lower_model.ops
|
||||
features = numpy.zeros((batch_size, cached.shape[2], nP), dtype='f')
|
||||
synchronized = False
|
||||
|
||||
def forward(states_offsets, drop=0.):
|
||||
nonlocal is_valid, costs, token_ids, moves
|
||||
states, offsets = states_offsets
|
||||
assert len(states) != 0
|
||||
is_valid = is_valid[:len(states)]
|
||||
costs = costs[:len(states)]
|
||||
token_ids = token_ids[:len(states)]
|
||||
is_valid = is_valid[:len(states)]
|
||||
cdef StateClass state
|
||||
cdef int i
|
||||
for i, (offset, state) in enumerate(zip(offsets, states)):
|
||||
state.set_context_tokens(token_ids[i])
|
||||
moves.set_valid(&is_valid_[i, 0], state.c)
|
||||
adjusted_ids = token_ids.copy()
|
||||
for i, offset in enumerate(offsets):
|
||||
adjusted_ids[i] *= token_ids[i] >= 0
|
||||
adjusted_ids[i] += offset
|
||||
features = upper_model.ops.allocate((len(states), lower_model.nO), dtype='f')
|
||||
for i in range(len(states)):
|
||||
for j, tok_i in enumerate(adjusted_ids[i]):
|
||||
if tok_i >= 0:
|
||||
features[i] += cached[j, tok_i]
|
||||
def forward(token_ids, drop=0.):
|
||||
nonlocal synchronized
|
||||
if not synchronized and cuda_stream is not None:
|
||||
cuda_stream.synchronize()
|
||||
synchronized = True
|
||||
# This is tricky, but:
|
||||
# - Input to forward on CPU
|
||||
# - Output from forward on CPU
|
||||
# - Input to backward on GPU!
|
||||
# - Output from backward on GPU
|
||||
nonlocal features
|
||||
features = features[:len(token_ids)]
|
||||
features.fill(0)
|
||||
cdef float[:, :, ::1] feats = features
|
||||
cdef int[:, ::1] ids = token_ids
|
||||
_sum_features(<float*>&feats[0,0,0],
|
||||
<float*>cached.data, &ids[0,0],
|
||||
token_ids.shape[0], nF, cached.shape[2]*nP)
|
||||
|
||||
scores, bp_scores = upper_model.begin_update(features, drop=drop)
|
||||
scores = upper_model.ops.relu(scores)
|
||||
softmaxed = upper_model.ops.softmax(scores)
|
||||
# Renormalize for invalid actions
|
||||
softmaxed *= is_valid
|
||||
totals = softmaxed.sum(axis=1)
|
||||
for total in totals:
|
||||
assert total > 0, (totals, scores, softmaxed)
|
||||
assert total <= 1.1, totals
|
||||
softmaxed /= softmaxed.sum(axis=1).reshape((softmaxed.shape[0], 1))
|
||||
if nP >= 2:
|
||||
best, which = ops.maxout(features)
|
||||
else:
|
||||
best = features.reshape((features.shape[0], features.shape[1]))
|
||||
which = None
|
||||
|
||||
def backward(golds, sgd=None):
|
||||
nonlocal costs_, is_valid_, moves
|
||||
cdef int i
|
||||
for i, (state, gold) in enumerate(zip(states, golds)):
|
||||
moves.set_costs(&is_valid_[i, 0], &costs_[i, 0],
|
||||
state, gold)
|
||||
d_scores = scores.copy()
|
||||
d_scores.fill(0)
|
||||
set_log_loss(upper_model.ops, d_scores,
|
||||
scores, is_valid, costs)
|
||||
upper_model.ops.backprop_relu(d_scores, scores, inplace=True)
|
||||
d_features = bp_scores(d_scores, sgd)
|
||||
d_tokens = bp_features((d_features, adjusted_ids), sgd)
|
||||
return (token_ids, d_tokens)
|
||||
def backward(d_best, sgd=None):
|
||||
# This will usually be on GPU
|
||||
if isinstance(d_best, numpy.ndarray):
|
||||
d_best = ops.xp.array(d_best)
|
||||
if nP >= 2:
|
||||
d_features = ops.backprop_maxout(d_best, which, nP)
|
||||
else:
|
||||
d_features = d_best.reshape((d_best.shape[0], d_best.shape[1], 1))
|
||||
d_tokens = bp_features((d_features, token_ids), sgd)
|
||||
return d_tokens
|
||||
|
||||
return softmaxed, backward
|
||||
return best, backward
|
||||
|
||||
return layerize(forward)
|
||||
return forward
|
||||
|
||||
|
||||
def set_log_loss(ops, gradients, scores, is_valid, costs):
|
||||
"""Do multi-label log loss"""
|
||||
n = gradients.shape[0]
|
||||
scores = scores * is_valid
|
||||
g_scores = scores * is_valid * (costs <= 0.)
|
||||
exps = ops.xp.exp(scores - scores.max(axis=1).reshape((n, 1)))
|
||||
exps *= is_valid
|
||||
g_exps = ops.xp.exp(g_scores - g_scores.max(axis=1).reshape((n, 1)))
|
||||
g_exps *= costs <= 0.
|
||||
g_exps *= is_valid
|
||||
gradients[:] = exps / exps.sum(axis=1).reshape((n, 1))
|
||||
gradients -= g_exps / g_exps.sum(axis=1).reshape((n, 1))
|
||||
cdef void _sum_features(float* output,
|
||||
const float* cached, const int* token_ids, int B, int F, int O) nogil:
|
||||
cdef int idx, b, f, i
|
||||
cdef const float* feature
|
||||
for b in range(B):
|
||||
for f in range(F):
|
||||
if token_ids[f] < 0:
|
||||
continue
|
||||
idx = token_ids[f] * F * O + f*O
|
||||
feature = &cached[idx]
|
||||
for i in range(O):
|
||||
output[i] += feature[i]
|
||||
output += O
|
||||
token_ids += F
|
||||
|
||||
|
||||
def transition_batch(TransitionSystem moves, states, scores):
|
||||
def get_batch_loss(TransitionSystem moves, states, golds, float[:, ::1] scores):
|
||||
cdef StateClass state
|
||||
cdef int guess
|
||||
for state, guess in zip(states, scores.argmax(axis=1)):
|
||||
action = moves.c[guess]
|
||||
action.do(state.c, action.label)
|
||||
cdef GoldParse gold
|
||||
cdef Pool mem = Pool()
|
||||
cdef int i
|
||||
is_valid = <int*>mem.alloc(moves.n_moves, sizeof(int))
|
||||
costs = <float*>mem.alloc(moves.n_moves, sizeof(float))
|
||||
cdef np.ndarray d_scores = numpy.zeros((len(states), moves.n_moves), dtype='f')
|
||||
c_d_scores = <float*>d_scores.data
|
||||
for i, (state, gold) in enumerate(zip(states, golds)):
|
||||
memset(is_valid, 0, moves.n_moves * sizeof(int))
|
||||
memset(costs, 0, moves.n_moves * sizeof(float))
|
||||
moves.set_costs(is_valid, costs, state, gold)
|
||||
cpu_log_loss(c_d_scores, costs, is_valid, &scores[i, 0], d_scores.shape[1])
|
||||
c_d_scores += d_scores.shape[1]
|
||||
return d_scores
|
||||
|
||||
|
||||
cdef void cpu_log_loss(float* d_scores,
|
||||
const float* costs, const int* is_valid, const float* scores,
|
||||
int O) nogil:
|
||||
"""Do multi-label log loss"""
|
||||
cdef double max_, gmax, Z, gZ
|
||||
best = arg_max_if_gold(scores, costs, is_valid, O)
|
||||
guess = arg_max_if_valid(scores, is_valid, O)
|
||||
Z = 1e-10
|
||||
gZ = 1e-10
|
||||
max_ = scores[guess]
|
||||
gmax = scores[best]
|
||||
for i in range(O):
|
||||
if is_valid[i]:
|
||||
Z += exp(scores[i] - max_)
|
||||
if costs[i] <= 0:
|
||||
gZ += exp(scores[i] - gmax)
|
||||
for i in range(O):
|
||||
if not is_valid[i]:
|
||||
d_scores[i] = 0.
|
||||
elif costs[i] <= 0:
|
||||
d_scores[i] = (exp(scores[i]-max_) / Z) - (exp(scores[i]-gmax)/gZ)
|
||||
else:
|
||||
d_scores[i] = exp(scores[i]-max_) / Z
|
||||
|
||||
|
||||
def init_states(TransitionSystem moves, docs):
|
||||
|
@ -163,6 +208,18 @@ def init_states(TransitionSystem moves, docs):
|
|||
return states, offsets
|
||||
|
||||
|
||||
def extract_token_ids(states, offsets=None, nF=1, nB=0, nS=2, nL=0, nR=0):
|
||||
cdef StateClass state
|
||||
cdef int n_tokens = states[0].nr_context_tokens(nF, nB, nS, nL, nR)
|
||||
ids = numpy.zeros((len(states), n_tokens), dtype='i')
|
||||
if offsets is None:
|
||||
offsets = [0] * len(states)
|
||||
for i, (state, offset) in enumerate(zip(states, offsets)):
|
||||
state.set_context_tokens(ids[i], nF, nB, nS, nL, nR)
|
||||
ids[i] += (ids[i] >= 0) * offset
|
||||
return ids
|
||||
|
||||
|
||||
cdef class Parser:
|
||||
"""
|
||||
Base class of the DependencyParser and EntityRecognizer.
|
||||
|
@ -218,11 +275,19 @@ cdef class Parser:
|
|||
def __reduce__(self):
|
||||
return (Parser, (self.vocab, self.moves, self.model), None, None)
|
||||
|
||||
def build_model(self, width=64, nr_vector=1000, nF=1, nB=1, nS=1, nL=1, nR=1, **_):
|
||||
def build_model(self,
|
||||
hidden_width=128, token_vector_width=96, nr_vector=1000,
|
||||
nF=1, nB=1, nS=1, nL=1, nR=1, **cfg):
|
||||
nr_context_tokens = StateClass.nr_context_tokens(nF, nB, nS, nL, nR)
|
||||
|
||||
upper = chain(Maxout(width, width), Maxout(self.moves.n_moves, width))
|
||||
lower = PrecomputableMaxouts(width, nF=nr_context_tokens, nI=width)
|
||||
with Model.use_device('cpu'):
|
||||
upper = chain(
|
||||
Maxout(token_vector_width),
|
||||
zero_init(Affine(self.moves.n_moves, token_vector_width)))
|
||||
assert isinstance(upper.ops, NumpyOps)
|
||||
lower = PrecomputableMaxouts(token_vector_width, nF=nr_context_tokens, nI=token_vector_width,
|
||||
pieces=cfg.get('maxout_pieces', 1))
|
||||
upper.begin_training(upper.ops.allocate((500, token_vector_width)))
|
||||
lower.begin_training(lower.ops.allocate((500, token_vector_width)))
|
||||
return upper, lower
|
||||
|
||||
def __call__(self, Doc tokens):
|
||||
|
@ -248,12 +313,6 @@ cdef class Parser:
|
|||
The number of threads with which to work on the buffer in parallel.
|
||||
Yields (Doc): Documents, in order.
|
||||
"""
|
||||
cdef Pool mem = Pool()
|
||||
cdef int* lengths = <int*>mem.alloc(batch_size, sizeof(int))
|
||||
cdef Doc doc
|
||||
cdef int i
|
||||
cdef int nr_feat = self.model.nr_feat
|
||||
cdef int status
|
||||
queue = []
|
||||
for doc in stream:
|
||||
queue.append(doc)
|
||||
|
@ -269,54 +328,110 @@ cdef class Parser:
|
|||
self.moves.finalize_doc(doc)
|
||||
yield doc
|
||||
|
||||
def parse_batch(self, docs):
|
||||
cdef Doc doc
|
||||
cdef StateClass state
|
||||
model = get_greedy_model_for_batch([d.tensor for d in docs],
|
||||
self.moves, self.model, self.feature_maps)
|
||||
def parse_batch(self, docs_tokvecs):
|
||||
cdef:
|
||||
int nC
|
||||
Doc doc
|
||||
StateClass state
|
||||
np.ndarray py_scores
|
||||
int[500] is_valid # Hacks for now
|
||||
|
||||
cuda_stream = Stream()
|
||||
docs, tokvecs = docs_tokvecs
|
||||
lower_model = get_greedy_model_for_batch(len(docs), tokvecs, self.feature_maps,
|
||||
cuda_stream)
|
||||
upper_model = self.model
|
||||
|
||||
states, offsets = init_states(self.moves, docs)
|
||||
all_states = list(states)
|
||||
todo = list(zip(states, offsets))
|
||||
todo = [st for st in zip(states, offsets) if not st[0].py_is_final()]
|
||||
|
||||
while todo:
|
||||
todo = filter(lambda sp: not sp[0].py_is_final(), todo)
|
||||
if not todo:
|
||||
break
|
||||
states, offsets = zip(*todo)
|
||||
scores = model((states, offsets))
|
||||
transition_batch(self.moves, states, scores)
|
||||
token_ids = extract_token_ids(states, offsets=offsets)
|
||||
|
||||
py_scores = upper_model(lower_model(token_ids)[0])
|
||||
scores = <float*>py_scores.data
|
||||
nC = py_scores.shape[1]
|
||||
for state, offset in zip(states, offsets):
|
||||
self.moves.set_valid(is_valid, state.c)
|
||||
guess = arg_max_if_valid(scores, is_valid, nC)
|
||||
action = self.moves.c[guess]
|
||||
action.do(state.c, action.label)
|
||||
scores += nC
|
||||
todo = [st for st in todo if not st[0].py_is_final()]
|
||||
|
||||
for state, doc in zip(all_states, docs):
|
||||
self.moves.finalize_state(state.c)
|
||||
for i in range(doc.length):
|
||||
doc.c[i] = state.c._sent[i]
|
||||
for doc in docs:
|
||||
self.moves.finalize_doc(doc)
|
||||
|
||||
def update(self, docs, golds, drop=0., sgd=None):
|
||||
def update(self, docs_tokvecs, golds, drop=0., sgd=None):
|
||||
cdef:
|
||||
int nC
|
||||
int[500] is_valid # Hack for now
|
||||
Doc doc
|
||||
StateClass state
|
||||
np.ndarray scores
|
||||
|
||||
docs, tokvecs = docs_tokvecs
|
||||
cuda_stream = Stream()
|
||||
if isinstance(docs, Doc) and isinstance(golds, GoldParse):
|
||||
return self.update([docs], [golds], drop=drop)
|
||||
return self.update(([docs], tokvecs), [golds], drop=drop)
|
||||
for gold in golds:
|
||||
self.moves.preprocess_gold(gold)
|
||||
|
||||
model = get_greedy_model_for_batch([d.tensor for d in docs],
|
||||
self.moves, self.model, self.feature_maps)
|
||||
states, offsets = init_states(self.moves, docs)
|
||||
|
||||
d_tokens = [self.model.ops.allocate(d.tensor.shape) for d in docs]
|
||||
output = list(d_tokens)
|
||||
todo = zip(states, offsets, golds, d_tokens)
|
||||
todo = zip(states, offsets, golds)
|
||||
todo = filter(lambda sp: not sp[0].py_is_final(), todo)
|
||||
|
||||
lower_model = get_greedy_model_for_batch(len(todo),
|
||||
tokvecs, self.feature_maps, cuda_stream=cuda_stream)
|
||||
upper_model = self.model
|
||||
d_tokens = self.feature_maps.ops.allocate(tokvecs.shape)
|
||||
backprops = []
|
||||
n_tokens = tokvecs.shape[0]
|
||||
nF = self.feature_maps.nF
|
||||
while todo:
|
||||
# Get unfinished states (and their matching gold and token gradients)
|
||||
states, offsets, golds = zip(*todo)
|
||||
|
||||
token_ids = extract_token_ids(states, offsets=offsets)
|
||||
lower, bp_lower = lower_model(token_ids)
|
||||
scores, bp_scores = upper_model.begin_update(lower)
|
||||
|
||||
d_scores = get_batch_loss(self.moves, states, golds, scores)
|
||||
d_lower = bp_scores(d_scores, sgd=sgd)
|
||||
|
||||
gpu_tok_ids = cupy.ndarray(token_ids.shape, dtype='i')
|
||||
gpu_d_lower = cupy.ndarray(d_lower.shape, dtype='f')
|
||||
gpu_tok_ids.set(token_ids, stream=cuda_stream)
|
||||
gpu_d_lower.set(d_lower, stream=cuda_stream)
|
||||
backprops.append((gpu_tok_ids, gpu_d_lower, bp_lower))
|
||||
|
||||
c_scores = <float*>scores.data
|
||||
for state in states:
|
||||
self.moves.set_valid(is_valid, state.c)
|
||||
guess = arg_max_if_valid(c_scores, is_valid, scores.shape[1])
|
||||
action = self.moves.c[guess]
|
||||
action.do(state.c, action.label)
|
||||
c_scores += scores.shape[1]
|
||||
|
||||
todo = filter(lambda sp: not sp[0].py_is_final(), todo)
|
||||
if not todo:
|
||||
break
|
||||
states, offsets, golds, d_tokens = zip(*todo)
|
||||
scores, finish_update = model.begin_update((states, offsets))
|
||||
(token_ids, d_state_features) = finish_update(golds, sgd=sgd)
|
||||
for i, token_ids in enumerate(token_ids):
|
||||
d_tokens[i][token_ids] += d_state_features[i]
|
||||
transition_batch(self.moves, states, scores)
|
||||
return output
|
||||
# This tells CUDA to block --- so we know our copies are complete.
|
||||
cuda_stream.synchronize()
|
||||
for token_ids, d_lower, bp_lower in backprops:
|
||||
d_state_features = bp_lower(d_lower, sgd=sgd)
|
||||
active_feats = token_ids * (token_ids >= 0)
|
||||
active_feats = active_feats.reshape((token_ids.shape[0], token_ids.shape[1], 1))
|
||||
if hasattr(self.feature_maps.ops.xp, 'scatter_add'):
|
||||
self.feature_maps.ops.xp.scatter_add(d_tokens,
|
||||
token_ids, d_state_features * active_feats)
|
||||
else:
|
||||
self.model.ops.xp.add.at(d_tokens,
|
||||
token_ids, d_state_features * active_feats)
|
||||
return d_tokens
|
||||
|
||||
def step_through(self, Doc doc, GoldParse gold=None):
|
||||
"""
|
||||
|
@ -464,3 +579,39 @@ class ParserStateError(ValueError):
|
|||
"https://github.com/spacy-io/spaCy/issues/429\n"
|
||||
"Please include the text that the parser failed on, which is:\n"
|
||||
"%s" % repr(doc.text))
|
||||
|
||||
|
||||
cdef int arg_max_if_gold(const weight_t* scores, const weight_t* costs, const int* is_valid, int n) nogil:
|
||||
# Find minimum cost
|
||||
cdef float cost = 1
|
||||
for i in range(n):
|
||||
if is_valid[i] and costs[i] < cost:
|
||||
cost = costs[i]
|
||||
# Now find best-scoring with that cost
|
||||
cdef int best = -1
|
||||
for i in range(n):
|
||||
if costs[i] <= cost and is_valid[i]:
|
||||
if best == -1 or scores[i] > scores[best]:
|
||||
best = i
|
||||
return best
|
||||
|
||||
|
||||
cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) nogil:
|
||||
cdef int best = -1
|
||||
for i in range(n):
|
||||
if is_valid[i] >= 1:
|
||||
if best == -1 or scores[i] > scores[best]:
|
||||
best = i
|
||||
return best
|
||||
|
||||
|
||||
cdef int _arg_max_clas(const weight_t* scores, int move, const Transition* actions,
|
||||
int nr_class) except -1:
|
||||
cdef weight_t score = 0
|
||||
cdef int mode = -1
|
||||
cdef int i
|
||||
for i in range(nr_class):
|
||||
if actions[i].move == move and (mode == -1 or scores[i] >= score):
|
||||
mode = i
|
||||
score = scores[i]
|
||||
return mode
|
||||
|
|
|
@ -48,7 +48,7 @@ cdef class StateClass:
|
|||
|
||||
@classmethod
|
||||
def nr_context_tokens(cls, int nF, int nB, int nS, int nL, int nR):
|
||||
return 11
|
||||
return 13
|
||||
|
||||
def set_context_tokens(self, int[:] output, nF=1, nB=0, nS=2,
|
||||
nL=2, nR=2):
|
||||
|
@ -56,14 +56,16 @@ cdef class StateClass:
|
|||
output[1] = self.B(1)
|
||||
output[2] = self.S(0)
|
||||
output[3] = self.S(1)
|
||||
output[4] = self.L(self.S(0), 1)
|
||||
output[5] = self.L(self.S(0), 2)
|
||||
output[4] = self.S(2)
|
||||
output[5] = self.L(self.S(0), 1)
|
||||
output[6] = self.L(self.S(0), 2)
|
||||
output[6] = self.R(self.S(0), 1)
|
||||
output[7] = self.R(self.S(0), 2)
|
||||
output[7] = self.L(self.S(1), 1)
|
||||
output[8] = self.L(self.S(1), 2)
|
||||
output[9] = self.R(self.S(1), 1)
|
||||
output[10] = self.R(self.S(1), 2)
|
||||
output[7] = self.L(self.B(0), 1)
|
||||
output[8] = self.R(self.S(0), 2)
|
||||
output[9] = self.L(self.S(1), 1)
|
||||
output[10] = self.L(self.S(1), 2)
|
||||
output[11] = self.R(self.S(1), 1)
|
||||
output[12] = self.R(self.S(1), 2)
|
||||
|
||||
def set_attributes(self, uint64_t[:, :] vals, int[:] tokens, int[:] names):
|
||||
cdef int i, j, tok_i
|
||||
|
|
Loading…
Reference in New Issue
Block a user