diff --git a/bin/parser/nn_train.py b/bin/parser/nn_train.py new file mode 100755 index 000000000..72c9e04f1 --- /dev/null +++ b/bin/parser/nn_train.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python +from __future__ import division +from __future__ import unicode_literals + +import os +from os import path +import shutil +import codecs +import random + +import plac +import cProfile +import pstats +import re + +import spacy.util +from spacy.en import English +from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir + +from spacy.syntax.util import Config +from spacy.gold import read_json_file +from spacy.gold import GoldParse + +from spacy.scorer import Scorer + +from spacy.syntax.parser import Parser, get_templates +from spacy._theano import TheanoModel + +import theano +import theano.tensor as T + +from theano.printing import Print + +import numpy +from collections import OrderedDict, defaultdict + + +theano.config.profile = False +theano.config.floatX = 'float32' +floatX = theano.config.floatX + + +def L1(L1_reg, *weights): + return L1_reg * sum(abs(w).sum() for w in weights) + + +def L2(L2_reg, *weights): + return L2_reg * sum((w ** 2).sum() for w in weights) + + +def rms_prop(loss, params, eta=1.0, rho=0.9, eps=1e-6): + updates = OrderedDict() + for param in params: + value = param.get_value(borrow=True) + accu = theano.shared(np.zeros(value.shape, dtype=value.dtype), + broadcastable=param.broadcastable) + + grad = T.grad(loss, param) + accu_new = rho * accu + (1 - rho) * grad ** 2 + updates[accu] = accu_new + updates[param] = param - (eta * grad / T.sqrt(accu_new + eps)) + return updates + + +def relu(x): + return x * (x > 0) + + +def feed_layer(activation, weights, bias, input_): + return activation(T.dot(input_, weights) + bias) + + +def init_weights(n_in, n_out): + rng = numpy.random.RandomState(1235) + + weights = numpy.asarray( + rng.standard_normal(size=(n_in, n_out)) * numpy.sqrt(2.0 / n_in), + dtype=theano.config.floatX + ) + bias = numpy.zeros((n_out,), dtype=theano.config.floatX) + return [wrapper(weights, name='W'), wrapper(bias, name='b')] + + +def compile_model(n_classes, n_hidden, n_in, optimizer): + x = T.vector('x') + costs = T.ivector('costs') + loss = T.scalar('loss') + + maxent_W, maxent_b = init_weights(n_hidden, n_classes) + hidden_W, hidden_b = init_weights(n_in, n_hidden) + + # Feed the inputs forward through the network + p_y_given_x = feed_layer( + T.nnet.softmax, + maxent_W, + maxent_b, + feed_layer( + relu, + hidden_W, + hidden_b, + x)) + + loss = -T.log(T.sum(p_y_given_x[0] * T.eq(costs, 0)) + 1e-8) + + train_model = theano.function( + name='train_model', + inputs=[x, costs], + outputs=[p_y_given_x[0], T.grad(loss, x), loss], + updates=optimizer(loss, [maxent_W, maxent_b, hidden_W, hidden_b]), + on_unused_input='warn' + ) + + evaluate_model = theano.function( + name='evaluate_model', + inputs=[x], + outputs=[ + feed_layer( + T.nnet.softmax, + maxent_W, + maxent_b, + feed_layer( + relu, + hidden_W, + hidden_b, + x + ) + )[0] + ] + ) + return train_model, evaluate_model + + +def score_model(scorer, nlp, annot_tuples, verbose=False): + tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1]) + nlp.tagger(tokens) + nlp.parser(tokens) + gold = GoldParse(tokens, annot_tuples) + scorer.score(tokens, gold, verbose=verbose) + + +def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', + eta=0.01, mu=0.9, nv_hidden=100, nv_word=10, nv_tag=10, nv_label=10, + seed=0, n_sents=0, verbose=False): + + dep_model_dir = path.join(model_dir, 'deps') + pos_model_dir = path.join(model_dir, 'pos') + if path.exists(dep_model_dir): + shutil.rmtree(dep_model_dir) + if path.exists(pos_model_dir): + shutil.rmtree(pos_model_dir) + os.mkdir(dep_model_dir) + os.mkdir(pos_model_dir) + setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES, pos_model_dir) + + Config.write(dep_model_dir, 'config', + seed=seed, + templates=tuple(), + labels=Language.ParserTransitionSystem.get_labels(gold_tuples), + vector_lengths=(nv_word, nv_tag, nv_label), + hidden_nodes=nv_hidden, + eta=eta, + mu=mu + ) + + # Bake-in hyper-parameters + optimizer = lambda loss, params: rms_prop(loss, params, eta=eta, rho=rho, eps=eps) + nlp = Language(data_dir=model_dir) + n_classes = nlp.parser.model.n_classes + train, predict = compile_model(n_classes, nv_hidden, n_in, optimizer) + nlp.parser.model = TheanoModel(n_classes, input_spec, train, + predict, model_loc) + + if n_sents > 0: + gold_tuples = gold_tuples[:n_sents] + print "Itn.\tP.Loss\tUAS\tTag %\tToken %" + log_loc = path.join(model_dir, 'job.log') + for itn in range(n_iter): + scorer = Scorer() + loss = 0 + for _, sents in gold_tuples: + for annot_tuples, ctnt in sents: + if len(annot_tuples[1]) == 1: + continue + score_model(scorer, nlp, annot_tuples) + tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1]) + nlp.tagger(tokens) + gold = GoldParse(tokens, annot_tuples, make_projective=True) + assert gold.is_projective + loss += nlp.parser.train(tokens, gold) + nlp.tagger.train(tokens, gold.tags) + random.shuffle(gold_tuples) + logline = '%d:\t%d\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas, + scorer.tags_acc, + scorer.token_acc) + print logline + with open(log_loc, 'aw') as file_: + file_.write(logline + '\n') + nlp.parser.model.end_training() + nlp.tagger.model.end_training() + nlp.vocab.strings.dump(path.join(model_dir, 'vocab', 'strings.txt')) + return nlp + + +def evaluate(nlp, gold_tuples, gold_preproc=True): + scorer = Scorer() + for raw_text, sents in gold_tuples: + for annot_tuples, brackets in sents: + tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1]) + nlp.tagger(tokens) + nlp.parser(tokens) + gold = GoldParse(tokens, annot_tuples) + scorer.score(tokens, gold) + return scorer + + +@plac.annotations( + train_loc=("Location of training file or directory"), + dev_loc=("Location of development file or directory"), + model_dir=("Location of output model directory",), + eval_only=("Skip training, and only evaluate", "flag", "e", bool), + n_sents=("Number of training sentences", "option", "n", int), + n_iter=("Number of training iterations", "option", "i", int), + verbose=("Verbose error reporting", "flag", "v", bool), + + nv_word=("Word vector length", "option", "W", int), + nv_tag=("Tag vector length", "option", "T", int), + nv_label=("Label vector length", "option", "L", int), + nv_hidden=("Hidden nodes length", "option", "H", int), + eta=("Learning rate", "option", "E", float), + mu=("Momentum", "option", "M", float), +) +def main(train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, verbose=False, + nv_word=10, nv_tag=10, nv_label=10, nv_hidden=10, + eta=0.1, mu=0.9, eval_only=False): + + + + + gold_train = list(read_json_file(train_loc, lambda doc: 'wsj' in doc['id'])) + + nlp = train(English, gold_train, model_dir, + feat_set='embed', + eta=eta, mu=mu, + nv_word=nv_word, nv_tag=nv_tag, nv_label=nv_label, nv_hidden=nv_hidden, + n_sents=n_sents, n_iter=n_iter, + verbose=verbose) + + scorer = evaluate(nlp, list(read_json_file(dev_loc))) + + print 'TOK', 100-scorer.token_acc + print 'POS', scorer.tags_acc + print 'UAS', scorer.uas + print 'LAS', scorer.las + + print 'NER P', scorer.ents_p + print 'NER R', scorer.ents_r + print 'NER F', scorer.ents_f + + +if __name__ == '__main__': + plac.call(main) diff --git a/setup.py b/setup.py index b127f68c1..5cae257b4 100644 --- a/setup.py +++ b/setup.py @@ -152,7 +152,8 @@ MOD_NAMES = ['spacy.parts_of_speech', 'spacy.strings', 'spacy.lexeme', 'spacy.vocab', 'spacy.morphology', 'spacy.syntax.stateclass', - 'spacy._ml', 'spacy.tokenizer', 'spacy.en.attrs', + 'spacy._ml', 'spacy._theano', + 'spacy.tokenizer', 'spacy.en.attrs', 'spacy.en.pos', 'spacy.syntax.parser', 'spacy.syntax.transition_system', 'spacy.syntax.arc_eager', diff --git a/spacy/_bu_nn.pyx b/spacy/_bu_nn.pyx new file mode 100644 index 000000000..ae875b235 --- /dev/null +++ b/spacy/_bu_nn.pyx @@ -0,0 +1,490 @@ +"""Feed-forward neural network, using Thenao.""" + +import os +import sys +import time + +import numpy + +import theano +import theano.tensor as T +import gzip +import cPickle + + +def load_data(dataset): + ''' Loads the dataset + + :type dataset: string + :param dataset: the path to the dataset (here MNIST) + ''' + + ############# + # LOAD DATA # + ############# + + # Download the MNIST dataset if it is not present + data_dir, data_file = os.path.split(dataset) + if data_dir == "" and not os.path.isfile(dataset): + # Check if dataset is in the data directory. + new_path = os.path.join( + os.path.split(__file__)[0], + "..", + "data", + dataset + ) + if os.path.isfile(new_path) or data_file == 'mnist.pkl.gz': + dataset = new_path + + if (not os.path.isfile(dataset)) and data_file == 'mnist.pkl.gz': + import urllib + origin = ( + 'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz' + ) + print 'Downloading data from %s' % origin + urllib.urlretrieve(origin, dataset) + + print '... loading data' + + # Load the dataset + f = gzip.open(dataset, 'rb') + train_set, valid_set, test_set = cPickle.load(f) + f.close() + #train_set, valid_set, test_set format: tuple(input, target) + #input is an numpy.ndarray of 2 dimensions (a matrix), + #each row corresponding to an example. target is a + #numpy.ndarray of 1 dimension (vector)) that have the same length as + #the number of rows in the input. It should give the target + #target to the example with the same index in the input. + + def shared_dataset(data_xy, borrow=True): + """ Function that loads the dataset into shared variables + + The reason we store our dataset in shared variables is to allow + Theano to copy it into the GPU memory (when code is run on GPU). + Since copying data into the GPU is slow, copying a minibatch everytime + is needed (the default behaviour if the data is not in a shared + variable) would lead to a large decrease in performance. + """ + data_x, data_y = data_xy + shared_x = theano.shared(numpy.asarray(data_x, dtype=theano.config.floatX), + borrow=borrow) + shared_y = theano.shared(numpy.asarray(data_y, dtype=theano.config.floatX), + borrow=borrow) + # When storing data on the GPU it has to be stored as floats + # therefore we will store the labels as ``floatX`` as well + # (``shared_y`` does exactly that). But during our computations + # we need them as ints (we use labels as index, and if they are + # floats it doesn't make sense) therefore instead of returning + # ``shared_y`` we will have to cast it to int. This little hack + # lets ous get around this issue + return shared_x, T.cast(shared_y, 'int32') + + test_set_x, test_set_y = shared_dataset(test_set) + valid_set_x, valid_set_y = shared_dataset(valid_set) + train_set_x, train_set_y = shared_dataset(train_set) + + rval = [(train_set_x, train_set_y), (valid_set_x, valid_set_y), + (test_set_x, test_set_y)] + return rval + + +class LogisticRegression(object): + """Multi-class Logistic Regression Class + + The logistic regression is fully described by a weight matrix :math:`W` + and bias vector :math:`b`. Classification is done by projecting data + points onto a set of hyperplanes, the distance to which is used to + determine a class membership probability. + """ + + def __init__(self, input, n_in, n_out): + """ Initialize the parameters of the logistic regression + + :type input: theano.tensor.TensorType + :param input: symbolic variable that describes the input of the + architecture (one minibatch) + + :type n_in: int + :param n_in: number of input units, the dimension of the space in + which the datapoints lie + + :type n_out: int + :param n_out: number of output units, the dimension of the space in + which the labels lie + + """ + # start-snippet-1 + # initialize with 0 the weights W as a matrix of shape (n_in, n_out) + self.W = theano.shared( + value=numpy.zeros((n_in, n_out), + dtype=theano.config.floatX + ), + name='W', + borrow=True + ) + # initialize the baises b as a vector of n_out 0s + self.b = theano.shared( + value=numpy.zeros( + (n_out,), + dtype=theano.config.floatX + ), + name='b', + borrow=True + ) + + # symbolic expression for computing the matrix of class-membership + # probabilities + # Where: + # W is a matrix where column-k represent the separation hyper plain for + # class-k + # x is a matrix where row-j represents input training sample-j + # b is a vector where element-k represent the free parameter of hyper + # plain-k + self.p_y_given_x = T.nnet.softmax(T.dot(input, self.W) + self.b) + + # symbolic description of how to compute prediction as class whose + # probability is maximal + self.y_pred = T.argmax(self.p_y_given_x, axis=1) + # end-snippet-1 + + # parameters of the model + self.params = [self.W, self.b] + + def neg_ll(self, y): + """Return the mean of the negative log-likelihood of the prediction + of this model under a given target distribution. + + .. math:: + + \frac{1}{|\mathcal{D}|} \mathcal{L} (\theta=\{W,b\}, \mathcal{D}) = + \frac{1}{|\mathcal{D}|} \sum_{i=0}^{|\mathcal{D}|} + \log(P(Y=y^{(i)}|x^{(i)}, W,b)) \\ + \ell (\theta=\{W,b\}, \mathcal{D}) + + :type y: theano.tensor.TensorType + :param y: corresponds to a vector that gives for each example the + correct label + + Note: we use the mean instead of the sum so that + the learning rate is less dependent on the batch size + """ + # start-snippet-2 + # y.shape[0] is (symbolically) the number of rows in y, i.e., + # number of examples (call it n) in the minibatch + # T.arange(y.shape[0]) is a symbolic vector which will contain + # [0,1,2,... n-1] T.log(self.p_y_given_x) is a matrix of + # Log-Probabilities (call it LP) with one row per example and + # one column per class LP[T.arange(y.shape[0]),y] is a vector + # v containing [LP[0,y[0]], LP[1,y[1]], LP[2,y[2]], ..., + # LP[n-1,y[n-1]]] and T.mean(LP[T.arange(y.shape[0]),y]) is + # the mean (across minibatch examples) of the elements in v, + # i.e., the mean log-likelihood across the minibatch. + return -T.mean(T.log(self.p_y_given_x)[T.arange(y.shape[0]), y]) + # end-snippet-2 + + def errors(self, y): + """Return a float representing the number of errors in the minibatch + over the total number of examples of the minibatch ; zero one + loss over the size of the minibatch + + :type y: theano.tensor.TensorType + :param y: corresponds to a vector that gives for each example the + correct label + """ + + # check if y has same dimension of y_pred + if y.ndim != self.y_pred.ndim: + raise TypeError( + 'y should have the same shape as self.y_pred', + ('y', y.type, 'y_pred', self.y_pred.type) + ) + # check if y is of the correct datatype + if y.dtype.startswith('int'): + # the T.neq operator returns a vector of 0s and 1s, where 1 + # represents a mistake in prediction + return T.mean(T.neq(self.y_pred, y)) + else: + raise NotImplementedError() + + +# start-snippet-1 +class HiddenLayer(object): + def __init__(self, rng, input, n_in, n_out, W=None, b=None, + activation=T.tanh): + """ + Typical hidden layer of a MLP: units are fully-connected and have + sigmoidal activation function. Weight matrix W is of shape (n_in,n_out) + and the bias vector b is of shape (n_out,). + + NOTE : The nonlinearity used here is tanh + + Hidden unit activation is given by: tanh(dot(input,W) + b) + + :type rng: numpy.random.RandomState + :param rng: a random number generator used to initialize weights + + :type input: theano.tensor.dmatrix + :param input: a symbolic tensor of shape (n_examples, n_in) + + :type n_in: int + :param n_in: dimensionality of input + + :type n_out: int + :param n_out: number of hidden units + + :type activation: theano.Op or function + :param activation: Non linearity to be applied in the hidden + layer + """ + self.input = input + # end-snippet-1 + + # `W` is initialized with `W_values` which is uniformely sampled + # from sqrt(-6./(n_in+n_hidden)) and sqrt(6./(n_in+n_hidden)) + # for tanh activation function + # the output of uniform if converted using asarray to dtype + # theano.config.floatX so that the code is runable on GPU + # Note : optimal initialization of weights is dependent on the + # activation function used (among other things). + # For example, results presented in [Xavier10] suggest that you + # should use 4 times larger initial weights for sigmoid + # compared to tanh + # We have no info for other function, so we use the same as + # tanh. + if W is None: + W_values = numpy.asarray( + rng.uniform( + low=-numpy.sqrt(6. / (n_in + n_out)), + high=numpy.sqrt(6. / (n_in + n_out)), + size=(n_in, n_out) + ), + dtype=theano.config.floatX + ) + if activation == theano.tensor.nnet.sigmoid: + W_values *= 4 + + W = theano.shared(value=W_values, name='W', borrow=True) + + if b is None: + b_values = numpy.zeros((n_out,), dtype=theano.config.floatX) + b = theano.shared(value=b_values, name='b', borrow=True) + + self.W = W + self.b = b + + lin_output = T.dot(input, self.W) + self.b + self.output = ( + lin_output if activation is None + else activation(lin_output) + ) + # parameters of the model + self.params = [self.W, self.b] + + +# start-snippet-2 +class MLP(object): + """Multi-Layer Perceptron Class + + A multilayer perceptron is a feedforward artificial neural network model + that has one layer or more of hidden units and nonlinear activations. + Intermediate layers usually have as activation function tanh or the + sigmoid function (defined here by a ``HiddenLayer`` class) while the + top layer is a softmax layer (defined here by a ``LogisticRegression`` + class). + """ + + def __init__(self, rng, input, n_in, n_hidden, n_out): + """Initialize the parameters for the multilayer perceptron + + :type rng: numpy.random.RandomState + :param rng: a random number generator used to initialize weights + + :type input: theano.tensor.TensorType + :param input: symbolic variable that describes the input of the + architecture (one minibatch) + + :type n_in: int + :param n_in: number of input units, the dimension of the space in + which the datapoints lie + + :type n_hidden: int + :param n_hidden: number of hidden units + + :type n_out: int + :param n_out: number of output units, the dimension of the space in + which the labels lie + + """ + + # Since we are dealing with a one hidden layer MLP, this will translate + # into a HiddenLayer with a tanh activation function connected to the + # LogisticRegression layer; the activation function can be replaced by + # sigmoid or any other nonlinear function + self.hidden = HiddenLayer( + rng=rng, + input=input, + n_in=n_in, + n_out=n_hidden, + activation=T.tanh + ) + + # The logistic regression layer gets as input the hidden units + # of the hidden layer + self.maxent = LogisticRegression( + input=self.hidden.output, + n_in=n_hidden, + n_out=n_out + ) + # L1 norm ; one regularization option is to enforce L1 norm to + # be small + self.L1 = abs(self.hidden.W).sum() + abs(self.maxent.W).sum() + + # square of L2 norm ; one regularization option is to enforce + # square of L2 norm to be small + self.L2_sqr = (self.hidden.W ** 2).sum() + (self.maxent.W ** 2).sum() + + # negative log likelihood of the MLP is given by the negative + # log likelihood of the output of the model, computed in the + # logistic regression layer + self.neg_ll = self.maxent.neg_ll + # same holds for the function computing the number of errors + self.errors = self.maxent.errors + + # the parameters of the model are the parameters of the two layer it is + # made out of + self.params = self.hidden.params + self.maxent.params + + + + +def test_mlp(learning_rate=0.01, L1_reg=0.00, L2_reg=0.0001, n_epochs=1000, + dataset='mnist.pkl.gz', batch_size=1, n_hidden=500): + """ + Demonstrate stochastic gradient descent optimization for a multilayer + perceptron + + This is demonstrated on MNIST. + + :type learning_rate: float + :param learning_rate: learning rate used (factor for the stochastic + gradient + + :type L1_reg: float + :param L1_reg: L1-norm's weight when added to the cost (see + regularization) + + :type L2_reg: float + :param L2_reg: L2-norm's weight when added to the cost (see + regularization) + + :type n_epochs: int + :param n_epochs: maximal number of epochs to run the optimizer + + :type dataset: string + :param dataset: the path of the MNIST dataset file from + http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz + """ + datasets = load_data(dataset) + + train_set_x, train_set_y = datasets[0] + valid_set_x, valid_set_y = datasets[1] + test_set_x, test_set_y = datasets[2] + + ###################### + # BUILD ACTUAL MODEL # + ###################### + print '... building the model' + + # allocate symbolic variables for the data + index = T.lscalar() # index to a [mini]batch + x = T.matrix('x') # the data is presented as rasterized images + y = T.ivector('y') # the labels are presented as 1D vector of + # [int] labels + + rng = numpy.random.RandomState(1234) + + # construct the MLP class + mlp = MLP( + rng=rng, + input=x, + n_in=28 * 28, + n_hidden=n_hidden, + n_out=10 + ) + + # the cost we minimize during training is the negative log likelihood of + # the model plus the regularization terms (L1 and L2); cost is expressed + # here symbolically + + # compiling a Theano function that computes the mistakes that are made + # by the model on a minibatch + test_model = theano.function( + inputs=[index], + outputs=mlp.maxent.errors(y), + givens={ + x: test_set_x[index:index+1], + y: test_set_y[index:index+1] + } + ) + + validate_model = theano.function( + inputs=[index], + outputs=mlp.maxent.errors(y), + givens={ + x: valid_set_x[index:index+1], + y: valid_set_y[index:index+1] + } + ) + + # compute the gradient of cost with respect to theta (sotred in params) + # the resulting gradients will be stored in a list gparams + cost = mlp.neg_ll(y) + L1_reg * mlp.L1 + L2_reg * mlp.L2_sqr + gparams = [T.grad(cost, param) for param in mlp.params] + + # specify how to update the parameters of the model as a list of + # (variable, update expression) pairs + + updates = [(mlp.params[i], mlp.params[i] - (learning_rate * gparams[i])) + for i in xrange(len(gparams))] + + # compiling a Theano function `train_model` that returns the cost, but + # in the same time updates the parameter of the model based on the rules + # defined in `updates` + train_model = theano.function( + inputs=[index], + outputs=cost, + updates=updates, + givens={ + x: train_set_x[index:index+1], + y: train_set_y[index:index+1] + } + ) + # end-snippet-5 + + ############### + # TRAIN MODEL # + ############### + print '... training' + + start_time = time.clock() + + n_examples = train_set_x.get_value(borrow=True).shape[0] + n_dev_examples = valid_set_x.get_value(borrow=True).shape[0] + n_test_examples = test_set_x.get_value(borrow=True).shape[0] + + for epoch in range(1, n_epochs+1): + for idx in xrange(n_examples): + train_model(idx) + # compute zero-one loss on validation set + error = numpy.mean(map(validate_model, xrange(n_dev_examples))) + print('epoch %i, validation error %f %%' % (epoch, error * 100)) + + end_time = time.clock() + print >> sys.stderr, ('The code for file ' + + os.path.split(__file__)[1] + + ' ran for %.2fm' % ((end_time - start_time) / 60.)) + + +if __name__ == '__main__': + test_mlp() diff --git a/spacy/_ml.pxd b/spacy/_ml.pxd index e39b3a5e3..83fc0b405 100644 --- a/spacy/_ml.pxd +++ b/spacy/_ml.pxd @@ -5,6 +5,7 @@ from cymem.cymem cimport Pool from thinc.learner cimport LinearModel from thinc.features cimport Extractor, Feature from thinc.typedefs cimport atom_t, feat_t, weight_t, class_t +from thinc.api cimport ExampleC from preshed.maps cimport PreshMapArray @@ -13,9 +14,14 @@ from .typedefs cimport hash_t, id_t cdef int arg_max(const weight_t* scores, const int n_classes) nogil +cdef int arg_max_if_true(const weight_t* scores, const int* is_valid, int n_classes) nogil + +cdef int arg_max_if_zero(const weight_t* scores, const int* costs, int n_classes) nogil + cdef class Model: - cdef int n_classes + cdef readonly int n_classes + cdef readonly int n_feats cdef const weight_t* score(self, atom_t* context) except NULL cdef int set_scores(self, weight_t* scores, atom_t* context) except -1 diff --git a/spacy/_ml.pyx b/spacy/_ml.pyx index be647c2dd..f84068778 100644 --- a/spacy/_ml.pyx +++ b/spacy/_ml.pyx @@ -10,6 +10,7 @@ import cython import numpy.random from thinc.features cimport Feature, count_feats +from thinc.api cimport Example cdef int arg_max(const weight_t* scores, const int n_classes) nogil: @@ -23,17 +24,52 @@ cdef int arg_max(const weight_t* scores, const int n_classes) nogil: return best +cdef int arg_max_if_true(const weight_t* scores, const int* is_valid, + const int n_classes) nogil: + cdef int i + cdef int best = 0 + cdef weight_t mode = -900000 + for i in range(n_classes): + if is_valid[i] and scores[i] > mode: + mode = scores[i] + best = i + return best + + +cdef int arg_max_if_zero(const weight_t* scores, const int* costs, + const int n_classes) nogil: + cdef int i + cdef int best = 0 + cdef weight_t mode = -900000 + for i in range(n_classes): + if costs[i] == 0 and scores[i] > mode: + mode = scores[i] + best = i + return best + + cdef class Model: def __init__(self, n_classes, templates, model_loc=None): if model_loc is not None and path.isdir(model_loc): model_loc = path.join(model_loc, 'model') self.n_classes = n_classes self._extractor = Extractor(templates) + self.n_feats = self._extractor.n_templ self._model = LinearModel(n_classes, self._extractor.n_templ) self.model_loc = model_loc if self.model_loc and path.exists(self.model_loc): self._model.load(self.model_loc, freq_thresh=0) + def predict(self, Example eg): + self.set_scores(eg.c.scores, eg.c.atoms) + eg.c.guess = arg_max_if_true(eg.c.scores, eg.c.is_valid, self.n_classes) + + def train(self, Example eg): + self.predict(eg) + eg.c.best = arg_max_if_zero(eg.c.scores, eg.c.costs, self.n_classes) + eg.c.cost = eg.c.costs[eg.c.guess] + self.update(eg.c.atoms, eg.c.guess, eg.c.best, eg.c.cost) + cdef const weight_t* score(self, atom_t* context) except NULL: cdef int n_feats feats = self._extractor.get_feats(context, &n_feats) diff --git a/spacy/_nn.py b/spacy/_nn.py new file mode 100644 index 000000000..48dca390c --- /dev/null +++ b/spacy/_nn.py @@ -0,0 +1,3 @@ +"""Feed-forward neural network, using Thenao.""" + + diff --git a/spacy/_nn.pyx b/spacy/_nn.pyx new file mode 100644 index 000000000..c47be1f49 --- /dev/null +++ b/spacy/_nn.pyx @@ -0,0 +1,146 @@ +"""Feed-forward neural network, using Thenao.""" + +import os +import sys +import time + +import numpy + +import theano +import theano.tensor as T +import plac + +from spacy.gold import read_json_file +from spacy.gold import GoldParse +from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir + + +def build_model(n_classes, n_vocab, n_hidden, n_word_embed, n_tag_embed): + # allocate symbolic variables for the data + words = T.vector('words') + tags = T.vector('tags') + + word_e = _init_embedding(n_words, n_word_embed) + tag_e = _init_embedding(n_tags, n_tag_embed) + label_e = _init_embedding(n_labels, n_label_embed) + maxent_W, maxent_b = _init_maxent_weights(n_hidden, n_classes) + hidden_W, hidden_b = _init_hidden_weights(28*28, n_hidden, T.tanh) + params = [hidden_W, hidden_b, maxent_W, maxent_b, word_e, tag_e, label_e] + + x = T.concatenate([ + T.flatten(word_e[word_indices], outdim=1), + T.flatten(tag_e[tag_indices], outdim=1)]) + + p_y_given_x = feed_layer( + T.nnet.softmax, + maxent_W, + maxent_b, + feed_layer( + T.tanh, + hidden_W, + hidden_b, + x))[0] + + guess = T.argmax(p_y_given_x) + + cost = ( + -T.log(p_y_given_x[y]) + + L1(L1_reg, maxent_W, hidden_W, word_e, tag_e) + + L2(L2_reg, maxent_W, hidden_W, wod_e, tag_e) + ) + + train_model = theano.function( + inputs=[words, tags, y], + outputs=guess, + updates=[update(learning_rate, param, cost) for param in params] + ) + + evaluate_model = theano.function( + inputs=[x, y], + outputs=T.neq(y, T.argmax(p_y_given_x[0])), + ) + return train_model, evaluate_model + + +def _init_embedding(vocab_size, n_dim): + embedding = 0.2 * numpy.random.uniform(-1.0, 1.0, (vocab_size+1, n_dim)) + return theano.shared(embedding).astype(theano.config.floatX) + + +def _init_maxent_weights(n_hidden, n_out): + weights = numpy.zeros((n_hidden, 10), dtype=theano.config.floatX) + bias = numpy.zeros((10,), dtype=theano.config.floatX) + return ( + theano.shared(name='W', borrow=True, value=weights), + theano.shared(name='b', borrow=True, value=bias) + ) + + +def _init_hidden_weights(n_in, n_out, activation=T.tanh): + rng = numpy.random.RandomState(1234) + weights = numpy.asarray( + rng.uniform( + low=-numpy.sqrt(6. / (n_in + n_out)), + high=numpy.sqrt(6. / (n_in + n_out)), + size=(n_in, n_out) + ), + dtype=theano.config.floatX + ) + + bias = numpy.zeros((n_out,), dtype=theano.config.floatX) + return ( + theano.shared(value=weights, name='W', borrow=True), + theano.shared(value=bias, name='b', borrow=True) + ) + + +def feed_layer(activation, weights, bias, input): + return activation(T.dot(input, weights) + bias) + + +def L1(L1_reg, w1, w2): + return L1_reg * (abs(w1).sum() + abs(w2).sum()) + + +def L2(L2_reg, w1, w2): + return L2_reg * ((w1 ** 2).sum() + (w2 ** 2).sum()) + + +def update(eta, param, cost): + return (param, param - (eta * T.grad(cost, param))) + + +def main(train_loc, eval_loc, model_dir): + learning_rate = 0.01 + L1_reg = 0.00 + L2_reg = 0.0001 + + print "... reading the data" + gold_train = list(read_json_file(train_loc)) + print '... building the model' + pos_model_dir = path.join(model_dir, 'pos') + if path.exists(pos_model_dir): + shutil.rmtree(pos_model_dir) + os.mkdir(pos_model_dir) + + setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES, pos_model_dir) + + train_model, evaluate_model = build_model(n_hidden, len(POS_TAGS), learning_rate, + L1_reg, L2_reg) + + print '... training' + for epoch in range(1, n_epochs+1): + for raw_text, sents in gold_tuples: + for (ids, words, tags, ner, heads, deps), _ in sents: + tokens = nlp.tokenizer.tokens_from_list(words) + for t in tokens: + guess = train_model([t.orth], [t.tag]) + loss += guess != t.tag + print loss + # compute zero-one loss on validation set + #error = numpy.mean([evaluate_model(x, y) for x, y in dev_examples]) + #print('epoch %i, validation error %f %%' % (epoch, error * 100)) + + +if __name__ == '__main__': + plac.call(main) diff --git a/spacy/_theano.pxd b/spacy/_theano.pxd new file mode 100644 index 000000000..cad0736c2 --- /dev/null +++ b/spacy/_theano.pxd @@ -0,0 +1,13 @@ +from ._ml cimport Model +from thinc.nn cimport InputLayer + + +cdef class TheanoModel(Model): + cdef InputLayer input_layer + cdef object train_func + cdef object predict_func + cdef object debug + + cdef public float eta + cdef public float mu + cdef public float t diff --git a/spacy/_theano.pyx b/spacy/_theano.pyx new file mode 100644 index 000000000..cc6886321 --- /dev/null +++ b/spacy/_theano.pyx @@ -0,0 +1,52 @@ +from thinc.api cimport Example, ExampleC +from thinc.typedefs cimport weight_t + +from ._ml cimport arg_max_if_true +from ._ml cimport arg_max_if_zero + +import numpy +from os import path + + +cdef class TheanoModel(Model): + def __init__(self, n_classes, input_spec, train_func, predict_func, model_loc=None, + eta=0.001, mu=0.9, debug=None): + if model_loc is not None and path.isdir(model_loc): + model_loc = path.join(model_loc, 'model') + + self.eta = eta + self.mu = mu + self.t = 1 + initializer = lambda: 0.2 * numpy.random.uniform(-1.0, 1.0) + self.input_layer = InputLayer(input_spec, initializer) + self.train_func = train_func + self.predict_func = predict_func + self.debug = debug + + self.n_classes = n_classes + self.n_feats = len(self.input_layer) + self.model_loc = model_loc + + def predict(self, Example eg): + self.input_layer.fill(eg.embeddings, eg.atoms, use_avg=True) + theano_scores = self.predict_func(eg.embeddings)[0] + cdef int i + for i in range(self.n_classes): + eg.c.scores[i] = theano_scores[i] + eg.c.guess = arg_max_if_true(eg.c.scores, eg.c.is_valid, self.n_classes) + + def train(self, Example eg): + self.input_layer.fill(eg.embeddings, eg.atoms, use_avg=False) + theano_scores, update, y, loss = self.train_func(eg.embeddings, eg.costs, + self.eta, self.mu) + self.input_layer.update(update, eg.atoms, self.t, self.eta, self.mu) + for i in range(self.n_classes): + eg.c.scores[i] = theano_scores[i] + eg.c.guess = arg_max_if_true(eg.c.scores, eg.c.is_valid, self.n_classes) + eg.c.best = arg_max_if_zero(eg.c.scores, eg.c.costs, self.n_classes) + eg.c.cost = eg.c.costs[eg.c.guess] + eg.c.loss = loss + self.t += 1 + + def end_training(self): + pass diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 29e62cb4e..a83e19ec2 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -398,7 +398,8 @@ cdef class ArcEager(TransitionSystem): n_valid += output[i] assert n_valid >= 1 - cdef int set_costs(self, int* output, StateClass stcls, GoldParse gold) except -1: + cdef int set_costs(self, bint* is_valid, int* costs, + StateClass stcls, GoldParse gold) except -1: cdef int i, move, label cdef label_cost_func_t[N_MOVES] label_cost_funcs cdef move_cost_func_t[N_MOVES] move_cost_funcs @@ -423,30 +424,14 @@ cdef class ArcEager(TransitionSystem): n_gold = 0 for i in range(self.n_moves): if self.c[i].is_valid(stcls, self.c[i].label): + is_valid[i] = True move = self.c[i].move label = self.c[i].label if move_costs[move] == -1: move_costs[move] = move_cost_funcs[move](stcls, &gold.c) - output[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label) - n_gold += output[i] == 0 + costs[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label) + n_gold += costs[i] == 0 else: - output[i] = 9000 + is_valid[i] = False + costs[i] = 9000 assert n_gold >= 1 - - cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *: - cdef bint[N_MOVES] is_valid - is_valid[SHIFT] = Shift.is_valid(stcls, -1) - is_valid[REDUCE] = Reduce.is_valid(stcls, -1) - is_valid[LEFT] = LeftArc.is_valid(stcls, -1) - is_valid[RIGHT] = RightArc.is_valid(stcls, -1) - is_valid[BREAK] = Break.is_valid(stcls, -1) - cdef Transition best - cdef weight_t score = MIN_SCORE - cdef int i - for i in range(self.n_moves): - if scores[i] > score and is_valid[self.c[i].move]: - best = self.c[i] - score = scores[i] - assert best.clas < self.n_moves - assert score > MIN_SCORE, (stcls.stack_depth(), stcls.buffer_length(), stcls.is_final(), stcls._b_i, stcls.length) - return best diff --git a/spacy/syntax/joint.pxd b/spacy/syntax/joint.pxd new file mode 100644 index 000000000..5b7a6e3db --- /dev/null +++ b/spacy/syntax/joint.pxd @@ -0,0 +1,17 @@ +from cymem.cymem cimport Pool + +from thinc.typedefs cimport weight_t + +from .stateclass cimport StateClass + +from .transition_system cimport TransitionSystem, Transition +from ..gold cimport GoldParseC + + +cdef class ArcEager(TransitionSystem): + pass + + +cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil +cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) nogil + diff --git a/spacy/syntax/joint.pyx b/spacy/syntax/joint.pyx new file mode 100644 index 000000000..29e62cb4e --- /dev/null +++ b/spacy/syntax/joint.pyx @@ -0,0 +1,452 @@ +# cython: profile=True +from __future__ import unicode_literals + +import ctypes +import os + +from ..structs cimport TokenC + +from .transition_system cimport do_func_t, get_cost_func_t +from .transition_system cimport move_cost_func_t, label_cost_func_t +from ..gold cimport GoldParse +from ..gold cimport GoldParseC + +from libc.stdint cimport uint32_t +from libc.string cimport memcpy + +from cymem.cymem cimport Pool +from .stateclass cimport StateClass + + +DEF NON_MONOTONIC = True +DEF USE_BREAK = True +DEF USE_ROOT_ARC_SEGMENT = True + +cdef weight_t MIN_SCORE = -90000 + +# Break transition from here +# http://www.aclweb.org/anthology/P13-1074 +cdef enum: + SHIFT + REDUCE + LEFT + RIGHT + + BREAK + + N_MOVES + + +MOVE_NAMES = [None] * N_MOVES +MOVE_NAMES[SHIFT] = 'S' +MOVE_NAMES[REDUCE] = 'D' +MOVE_NAMES[LEFT] = 'L' +MOVE_NAMES[RIGHT] = 'R' +MOVE_NAMES[BREAK] = 'B' + + +# Helper functions for the arc-eager oracle + +cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil: + cdef int cost = 0 + cdef int i, S_i + for i in range(stcls.stack_depth()): + S_i = stcls.S(i) + if gold.heads[target] == S_i: + cost += 1 + if gold.heads[S_i] == target and (NON_MONOTONIC or not stcls.has_head(S_i)): + cost += 1 + cost += Break.is_valid(stcls, -1) and Break.move_cost(stcls, gold) == 0 + return cost + + +cdef int pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil: + cdef int cost = 0 + cdef int i, B_i + for i in range(stcls.buffer_length()): + B_i = stcls.B(i) + cost += gold.heads[B_i] == target + cost += gold.heads[target] == B_i + if gold.heads[B_i] == B_i or gold.heads[B_i] < target: + break + cost += Break.is_valid(stcls, -1) and Break.move_cost(stcls, gold) == 0 + return cost + + +cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) nogil: + if arc_is_gold(gold, head, child): + return 0 + elif stcls.H(child) == gold.heads[child]: + return 1 + # Head in buffer + elif gold.heads[child] >= stcls.B(0) and stcls.B(1) != -1: + return 1 + else: + return 0 + + +cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) nogil: + if gold.labels[child] == -1: + return True + elif USE_ROOT_ARC_SEGMENT and _is_gold_root(gold, head) and _is_gold_root(gold, child): + return True + elif gold.heads[child] == head: + return True + else: + return False + + +cdef bint label_is_gold(const GoldParseC* gold, int head, int child, int label) nogil: + if gold.labels[child] == -1: + return True + elif label == -1: + return True + elif gold.labels[child] == label: + return True + else: + return False + + +cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil: + return gold.labels[word] == -1 or gold.heads[word] == word + + +cdef class Shift: + @staticmethod + cdef bint is_valid(StateClass st, int label) nogil: + return st.buffer_length() >= 2 and not st.shifted[st.B(0)] and not st.B_(0).sent_start + + @staticmethod + cdef int transition(StateClass st, int label) nogil: + st.push() + st.fast_forward() + + @staticmethod + cdef int cost(StateClass st, const GoldParseC* gold, int label) nogil: + return Shift.move_cost(st, gold) + Shift.label_cost(st, gold, label) + + @staticmethod + cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil: + return push_cost(s, gold, s.B(0)) + + @staticmethod + cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil: + return 0 + + +cdef class Reduce: + @staticmethod + cdef bint is_valid(StateClass st, int label) nogil: + return st.stack_depth() >= 2 + + @staticmethod + cdef int transition(StateClass st, int label) nogil: + if st.has_head(st.S(0)): + st.pop() + else: + st.unshift() + st.fast_forward() + + @staticmethod + cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil: + return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label) + + @staticmethod + cdef inline int move_cost(StateClass st, const GoldParseC* gold) nogil: + return pop_cost(st, gold, st.S(0)) + + @staticmethod + cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil: + return 0 + + +cdef class LeftArc: + @staticmethod + cdef bint is_valid(StateClass st, int label) nogil: + return not st.B_(0).sent_start + + @staticmethod + cdef int transition(StateClass st, int label) nogil: + st.add_arc(st.B(0), st.S(0), label) + st.pop() + st.fast_forward() + + @staticmethod + cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil: + return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label) + + @staticmethod + cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil: + cdef int cost = 0 + if arc_is_gold(gold, s.B(0), s.S(0)): + return 0 + else: + # Account for deps we might lose between S0 and stack + if not s.has_head(s.S(0)): + for i in range(1, s.stack_depth()): + cost += gold.heads[s.S(i)] == s.S(0) + cost += gold.heads[s.S(0)] == s.S(i) + return pop_cost(s, gold, s.S(0)) + arc_cost(s, gold, s.B(0), s.S(0)) + + @staticmethod + cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil: + return arc_is_gold(gold, s.B(0), s.S(0)) and not label_is_gold(gold, s.B(0), s.S(0), label) + + +cdef class RightArc: + @staticmethod + cdef bint is_valid(StateClass st, int label) nogil: + return not st.B_(0).sent_start + + @staticmethod + cdef int transition(StateClass st, int label) nogil: + st.add_arc(st.S(0), st.B(0), label) + st.push() + st.fast_forward() + + @staticmethod + cdef inline int cost(StateClass s, const GoldParseC* gold, int label) nogil: + return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label) + + @staticmethod + cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil: + if arc_is_gold(gold, s.S(0), s.B(0)): + return 0 + elif s.shifted[s.B(0)]: + return push_cost(s, gold, s.B(0)) + else: + return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0)) + + @staticmethod + cdef int label_cost(StateClass s, const GoldParseC* gold, int label) nogil: + return arc_is_gold(gold, s.S(0), s.B(0)) and not label_is_gold(gold, s.S(0), s.B(0), label) + + +cdef class Break: + @staticmethod + cdef bint is_valid(StateClass st, int label) nogil: + cdef int i + if not USE_BREAK: + return False + elif st.at_break(): + return False + elif st.B(0) == 0: + return False + elif st.stack_depth() < 1: + return False + elif (st.S(0) + 1) != st.B(0): + # Must break at the token boundary + return False + else: + return True + + @staticmethod + cdef int transition(StateClass st, int label) nogil: + st.set_break(st.B(0)) + st.fast_forward() + + @staticmethod + cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil: + return Break.move_cost(s, gold) + Break.label_cost(s, gold, label) + + @staticmethod + cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil: + cdef int cost = 0 + cdef int i, j, S_i, B_i + for i in range(s.stack_depth()): + S_i = s.S(i) + for j in range(s.buffer_length()): + B_i = s.B(j) + cost += gold.heads[S_i] == B_i + cost += gold.heads[B_i] == S_i + # Check for sentence boundary --- if it's here, we can't have any deps + # between stack and buffer, so rest of action is irrelevant. + s0_root = _get_root(s.S(0), gold) + b0_root = _get_root(s.B(0), gold) + if s0_root != b0_root or s0_root == -1 or b0_root == -1: + return cost + else: + return cost + 1 + + @staticmethod + cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil: + return 0 + +cdef int _get_root(int word, const GoldParseC* gold) nogil: + while gold.heads[word] != word and gold.labels[word] != -1 and word >= 0: + word = gold.heads[word] + if gold.labels[word] == -1: + return -1 + else: + return word + + +cdef class ArcEager(TransitionSystem): + @classmethod + def get_labels(cls, gold_parses): + move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {'ROOT': True}, + LEFT: {'ROOT': True}, BREAK: {'ROOT': True}} + for raw_text, sents in gold_parses: + for (ids, words, tags, heads, labels, iob), ctnts in sents: + for child, head, label in zip(ids, heads, labels): + if label.upper() == 'ROOT': + label = 'ROOT' + if label != 'ROOT': + if head < child: + move_labels[RIGHT][label] = True + elif head > child: + move_labels[LEFT][label] = True + return move_labels + + cdef int preprocess_gold(self, GoldParse gold) except -1: + for i in range(gold.length): + if gold.heads[i] is None: # Missing values + gold.c.heads[i] = i + gold.c.labels[i] = -1 + else: + label = gold.labels[i] + if label.upper() == 'ROOT': + label = 'ROOT' + gold.c.heads[i] = gold.heads[i] + gold.c.labels[i] = self.strings[label] + for end, brackets in gold.brackets.items(): + for start, label_strs in brackets.items(): + gold.c.brackets[start][end] = 1 + for label_str in label_strs: + # Add the encoded label to the set + gold.brackets[end][start].add(self.strings[label_str]) + + cdef Transition lookup_transition(self, object name) except *: + if '-' in name: + move_str, label_str = name.split('-', 1) + label = self.label_ids[label_str] + else: + label = 0 + move = MOVE_NAMES.index(move_str) + for i in range(self.n_moves): + if self.c[i].move == move and self.c[i].label == label: + return self.c[i] + + def move_name(self, int move, int label): + label_str = self.strings[label] + if label_str: + return MOVE_NAMES[move] + '-' + label_str + else: + return MOVE_NAMES[move] + + cdef Transition init_transition(self, int clas, int move, int label) except *: + # TODO: Apparent Cython bug here when we try to use the Transition() + # constructor with the function pointers + cdef Transition t + t.score = 0 + t.clas = clas + t.move = move + t.label = label + if move == SHIFT: + t.is_valid = Shift.is_valid + t.do = Shift.transition + t.get_cost = Shift.cost + elif move == REDUCE: + t.is_valid = Reduce.is_valid + t.do = Reduce.transition + t.get_cost = Reduce.cost + elif move == LEFT: + t.is_valid = LeftArc.is_valid + t.do = LeftArc.transition + t.get_cost = LeftArc.cost + elif move == RIGHT: + t.is_valid = RightArc.is_valid + t.do = RightArc.transition + t.get_cost = RightArc.cost + elif move == BREAK: + t.is_valid = Break.is_valid + t.do = Break.transition + t.get_cost = Break.cost + else: + raise Exception(move) + return t + + cdef int initialize_state(self, StateClass st) except -1: + # Ensure sent_start is set to 0 throughout + for i in range(st.length): + st._sent[i].sent_start = False + st._sent[i].l_edge = i + st._sent[i].r_edge = i + st.fast_forward() + + cdef int finalize_state(self, StateClass st) except -1: + cdef int root_label = self.strings['ROOT'] + for i in range(st.length): + if st._sent[i].head == 0 and st._sent[i].dep == 0: + st._sent[i].dep = root_label + # If we're not using the Break transition, we segment via root-labelled + # arcs between the root words. + elif USE_ROOT_ARC_SEGMENT and st._sent[i].dep == root_label: + st._sent[i].head = 0 + + cdef int set_valid(self, bint* output, StateClass stcls) except -1: + cdef bint[N_MOVES] is_valid + is_valid[SHIFT] = Shift.is_valid(stcls, -1) + is_valid[REDUCE] = Reduce.is_valid(stcls, -1) + is_valid[LEFT] = LeftArc.is_valid(stcls, -1) + is_valid[RIGHT] = RightArc.is_valid(stcls, -1) + is_valid[BREAK] = Break.is_valid(stcls, -1) + cdef int i + n_valid = 0 + for i in range(self.n_moves): + output[i] = is_valid[self.c[i].move] + n_valid += output[i] + assert n_valid >= 1 + + cdef int set_costs(self, int* output, StateClass stcls, GoldParse gold) except -1: + cdef int i, move, label + cdef label_cost_func_t[N_MOVES] label_cost_funcs + cdef move_cost_func_t[N_MOVES] move_cost_funcs + cdef int[N_MOVES] move_costs + for i in range(N_MOVES): + move_costs[i] = -1 + move_cost_funcs[SHIFT] = Shift.move_cost + move_cost_funcs[REDUCE] = Reduce.move_cost + move_cost_funcs[LEFT] = LeftArc.move_cost + move_cost_funcs[RIGHT] = RightArc.move_cost + move_cost_funcs[BREAK] = Break.move_cost + + label_cost_funcs[SHIFT] = Shift.label_cost + label_cost_funcs[REDUCE] = Reduce.label_cost + label_cost_funcs[LEFT] = LeftArc.label_cost + label_cost_funcs[RIGHT] = RightArc.label_cost + label_cost_funcs[BREAK] = Break.label_cost + + cdef int* labels = gold.c.labels + cdef int* heads = gold.c.heads + + n_gold = 0 + for i in range(self.n_moves): + if self.c[i].is_valid(stcls, self.c[i].label): + move = self.c[i].move + label = self.c[i].label + if move_costs[move] == -1: + move_costs[move] = move_cost_funcs[move](stcls, &gold.c) + output[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label) + n_gold += output[i] == 0 + else: + output[i] = 9000 + assert n_gold >= 1 + + cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *: + cdef bint[N_MOVES] is_valid + is_valid[SHIFT] = Shift.is_valid(stcls, -1) + is_valid[REDUCE] = Reduce.is_valid(stcls, -1) + is_valid[LEFT] = LeftArc.is_valid(stcls, -1) + is_valid[RIGHT] = RightArc.is_valid(stcls, -1) + is_valid[BREAK] = Break.is_valid(stcls, -1) + cdef Transition best + cdef weight_t score = MIN_SCORE + cdef int i + for i in range(self.n_moves): + if scores[i] > score and is_valid[self.c[i].move]: + best = self.c[i] + score = scores[i] + assert best.clas < self.n_moves + assert score > MIN_SCORE, (stcls.stack_depth(), stcls.buffer_length(), stcls.is_final(), stcls._b_i, stcls.length) + return best diff --git a/spacy/syntax/ner.pyx b/spacy/syntax/ner.pyx index 4a47a20a8..b145df7ac 100644 --- a/spacy/syntax/ner.pyx +++ b/spacy/syntax/ner.pyx @@ -128,27 +128,6 @@ cdef class BiluoPushDown(TransitionSystem): raise Exception(move) return t - cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *: - cdef int best = -1 - cdef weight_t score = -90000 - cdef const Transition* m - cdef int i - for i in range(self.n_moves): - m = &self.c[i] - if m.is_valid(stcls, m.label) and scores[i] > score: - best = i - score = scores[i] - assert best >= 0 - cdef Transition t = self.c[best] - t.score = score - return t - - cdef int set_valid(self, bint* output, StateClass stcls) except -1: - cdef int i - for i in range(self.n_moves): - m = &self.c[i] - output[i] = m.is_valid(stcls, m.label) - cdef class Missing: @staticmethod diff --git a/spacy/syntax/parser.pxd b/spacy/syntax/parser.pxd index 2c17464e7..54497dd6e 100644 --- a/spacy/syntax/parser.pxd +++ b/spacy/syntax/parser.pxd @@ -12,6 +12,3 @@ cdef class Parser: cdef readonly object cfg cdef readonly Model model cdef readonly TransitionSystem moves - - cdef int _greedy_parse(self, Doc tokens) except -1 - cdef int _beam_parse(self, Doc tokens) except -1 diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index a3c2eb886..592bf0ac3 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -20,17 +20,10 @@ from cymem.cymem cimport Pool, Address from murmurhash.mrmr cimport hash64 from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t - from util import Config -from thinc.features cimport Extractor -from thinc.features cimport Feature -from thinc.features cimport count_feats +from thinc.api cimport Example -from thinc.learner cimport LinearModel - -from thinc.search cimport Beam -from thinc.search cimport MaxViolation from ..structs cimport TokenC @@ -61,6 +54,8 @@ def get_templates(name): return pf.ner elif name == 'debug': return pf.unigrams + elif name.startswith('embed'): + return (pf.words, pf.tags, pf.labels) else: return (pf.unigrams + pf.s0_n0 + pf.s1_n0 + pf.s1_s0 + pf.s0_n1 + pf.n0_n1 + \ pf.tree_shape + pf.trigrams) @@ -83,179 +78,43 @@ cdef class Parser: self.model = Model(self.moves.n_moves, templates, model_dir) def __call__(self, Doc tokens): - if self.model is not None: - if self.cfg.get('beam_width', 0) < 1: - self._greedy_parse(tokens) - else: - self._beam_parse(tokens) - - def train(self, Doc tokens, GoldParse gold): - self.moves.preprocess_gold(gold) - if self.cfg.get('beam_width', 0) < 1: - return self._greedy_train(tokens, gold) - else: - return self._beam_train(tokens, gold) - - cdef int _greedy_parse(self, Doc tokens) except -1: - cdef atom_t[CONTEXT_SIZE] context - cdef int n_feats - cdef Pool mem = Pool() cdef StateClass stcls = StateClass.init(tokens.data, tokens.length) self.moves.initialize_state(stcls) - cdef Transition guess - words = [w.orth_ for w in tokens] + + cdef Example eg = Example(self.model.n_classes, CONTEXT_SIZE, + self.model.n_feats, self.model.n_feats) while not stcls.is_final(): - fill_context(context, stcls) - scores = self.model.score(context) - guess = self.moves.best_valid(scores, stcls) - #print self.moves.move_name(guess.move, guess.label), stcls.print_state(words) - guess.do(stcls, guess.label) - assert stcls._s_i >= 0 + memset(eg.c.scores, 0, eg.c.nr_class * sizeof(weight_t)) + + self.moves.set_valid(eg.c.is_valid, stcls) + fill_context(eg.c.atoms, stcls) + + self.model.predict(eg) + + self.moves.c[eg.c.guess].do(stcls, self.moves.c[eg.c.guess].label) self.moves.finalize_state(stcls) tokens.set_parse(stcls._sent) - cdef int _beam_parse(self, Doc tokens) except -1: - cdef Beam beam = Beam(self.moves.n_moves, self.cfg.beam_width) - words = [w.orth_ for w in tokens] - beam.initialize(_init_state, tokens.length, tokens.data) - beam.check_done(_check_final_state, NULL) - while not beam.is_done: - self._advance_beam(beam, None, False, words) - state = beam.at(0) - self.moves.finalize_state(state) - tokens.set_parse(state._sent) - _cleanup(beam) - - def _greedy_train(self, Doc tokens, GoldParse gold): - cdef Pool mem = Pool() + def train(self, Doc tokens, GoldParse gold): + self.moves.preprocess_gold(gold) cdef StateClass stcls = StateClass.init(tokens.data, tokens.length) self.moves.initialize_state(stcls) - - cdef int cost - cdef const Feature* feats - cdef const weight_t* scores - cdef Transition guess - cdef Transition best - cdef atom_t[CONTEXT_SIZE] context - loss = 0 + cdef Example eg = Example(self.model.n_classes, CONTEXT_SIZE, + self.model.n_feats, self.model.n_feats) + cdef weight_t loss = 0 words = [w.orth_ for w in tokens] - history = [] + cdef Transition G while not stcls.is_final(): - fill_context(context, stcls) - scores = self.model.score(context) - guess = self.moves.best_valid(scores, stcls) - best = self.moves.best_gold(scores, stcls, gold) - cost = guess.get_cost(stcls, &gold.c, guess.label) - self.model.update(context, guess.clas, best.clas, cost) - guess.do(stcls, guess.label) - loss += cost + memset(eg.c.scores, 0, eg.c.nr_class * sizeof(weight_t)) + + self.moves.set_costs(eg.c.is_valid, eg.c.costs, stcls, gold) + + fill_context(eg.c.atoms, stcls) + + self.model.train(eg) + + G = self.moves.c[eg.c.guess] + + self.moves.c[eg.c.guess].do(stcls, self.moves.c[eg.c.guess].label) + loss += eg.c.loss return loss - - def _beam_train(self, Doc tokens, GoldParse gold_parse): - cdef Beam pred = Beam(self.moves.n_moves, self.cfg.beam_width) - pred.initialize(_init_state, tokens.length, tokens.data) - pred.check_done(_check_final_state, NULL) - cdef Beam gold = Beam(self.moves.n_moves, self.cfg.beam_width) - gold.initialize(_init_state, tokens.length, tokens.data) - gold.check_done(_check_final_state, NULL) - - violn = MaxViolation() - words = [w.orth_ for w in tokens] - while not pred.is_done and not gold.is_done: - self._advance_beam(pred, gold_parse, False, words) - self._advance_beam(gold, gold_parse, True, words) - violn.check(pred, gold) - if pred.loss >= 1: - counts = {clas: {} for clas in range(self.model.n_classes)} - self._count_feats(counts, tokens, violn.g_hist, 1) - self._count_feats(counts, tokens, violn.p_hist, -1) - else: - counts = {} - self.model._model.update(counts) - _cleanup(pred) - _cleanup(gold) - return pred.loss - - def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold, words): - cdef atom_t[CONTEXT_SIZE] context - cdef int i, j, cost - cdef bint is_valid - cdef const Transition* move - for i in range(beam.size): - stcls = beam.at(i) - if not stcls.is_final(): - fill_context(context, stcls) - self.model.set_scores(beam.scores[i], context) - self.moves.set_valid(beam.is_valid[i], stcls) - if gold is not None: - for i in range(beam.size): - stcls = beam.at(i) - if not stcls.is_final(): - self.moves.set_costs(beam.costs[i], stcls, gold) - if follow_gold: - for j in range(self.moves.n_moves): - beam.is_valid[i][j] *= beam.costs[i][j] == 0 - beam.advance(_transition_state, _hash_state, self.moves.c) - beam.check_done(_check_final_state, NULL) - - def _count_feats(self, dict counts, Doc tokens, list hist, int inc): - cdef atom_t[CONTEXT_SIZE] context - cdef Pool mem = Pool() - cdef StateClass stcls = StateClass.init(tokens.data, tokens.length) - self.moves.initialize_state(stcls) - - cdef class_t clas - cdef int n_feats - for clas in hist: - fill_context(context, stcls) - feats = self.model._extractor.get_feats(context, &n_feats) - count_feats(counts[clas], feats, n_feats, inc) - self.moves.c[clas].do(stcls, self.moves.c[clas].label) - - -# These are passed as callbacks to thinc.search.Beam - -cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1: - dest = _dest - src = _src - moves = _moves - dest.clone(src) - moves[clas].do(dest, moves[clas].label) - - -cdef void* _init_state(Pool mem, int length, void* tokens) except NULL: - cdef StateClass st = StateClass.init(tokens, length) - st.fast_forward() - Py_INCREF(st) - return st - - -cdef int _check_final_state(void* _state, void* extra_args) except -1: - return (_state).is_final() - - -def _cleanup(Beam beam): - for i in range(beam.width): - Py_XDECREF(beam._states[i].content) - Py_XDECREF(beam._parents[i].content) - -cdef hash_t _hash_state(void* _state, void* _) except 0: - return _state - - #state = _state - #cdef atom_t[10] rep - - #rep[0] = state.stack[0] if state.stack_len >= 1 else 0 - #rep[1] = state.stack[-1] if state.stack_len >= 2 else 0 - #rep[2] = state.stack[-2] if state.stack_len >= 3 else 0 - #rep[3] = state.i - #rep[4] = state.sent[state.stack[0]].l_kids if state.stack_len >= 1 else 0 - #rep[5] = state.sent[state.stack[0]].r_kids if state.stack_len >= 1 else 0 - #rep[6] = state.sent[state.stack[0]].dep if state.stack_len >= 1 else 0 - #rep[7] = state.sent[state.stack[-1]].dep if state.stack_len >= 2 else 0 - #if get_left(state, get_n0(state), 1) != NULL: - # rep[8] = get_left(state, get_n0(state), 1).dep - #else: - # rep[8] = 0 - #rep[9] = state.sent[state.i].l_kids - #return hash64(rep, sizeof(atom_t) * 10, 0) diff --git a/spacy/syntax/stateclass.pyx b/spacy/syntax/stateclass.pyx index 2ce87c79a..2569462a0 100644 --- a/spacy/syntax/stateclass.pyx +++ b/spacy/syntax/stateclass.pyx @@ -52,7 +52,11 @@ cdef class StateClass: cdef const TokenC* target = &self._sent[i] if target.l_kids < idx: return -1 +<<<<<<< HEAD cdef const TokenC* ptr = &self._sent[target.l_edge] +======= + cdef const TokenC* ptr = self._sent +>>>>>>> neuralnet while ptr < target: # If this head is still to the right of us, we can skip to it @@ -78,7 +82,11 @@ cdef class StateClass: cdef const TokenC* target = &self._sent[i] if target.r_kids < idx: return -1 +<<<<<<< HEAD cdef const TokenC* ptr = &self._sent[target.r_edge] +======= + cdef const TokenC* ptr = self._sent + (self.length - 1) +>>>>>>> neuralnet while ptr > target: # If this head is still to the right of us, we can skip to it # No token that's between this token and this head could be our diff --git a/spacy/syntax/transition_system.pxd b/spacy/syntax/transition_system.pxd index d9bd2b3e6..35f0ada30 100644 --- a/spacy/syntax/transition_system.pxd +++ b/spacy/syntax/transition_system.pxd @@ -46,9 +46,5 @@ cdef class TransitionSystem: cdef int set_valid(self, bint* output, StateClass state) except -1 - cdef int set_costs(self, int* output, StateClass state, GoldParse gold) except -1 - - cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except * - - cdef Transition best_gold(self, const weight_t* scores, StateClass state, - GoldParse gold) except * + cdef int set_costs(self, bint* is_valid, int* costs, + StateClass state, GoldParse gold) except -1 diff --git a/spacy/syntax/transition_system.pyx b/spacy/syntax/transition_system.pyx index 927498cba..b13c75ba3 100644 --- a/spacy/syntax/transition_system.pyx +++ b/spacy/syntax/transition_system.pyx @@ -43,30 +43,17 @@ cdef class TransitionSystem: cdef Transition init_transition(self, int clas, int move, int label) except *: raise NotImplementedError - cdef Transition best_valid(self, const weight_t* scores, StateClass s) except *: - raise NotImplementedError - - cdef int set_valid(self, bint* output, StateClass state) except -1: - raise NotImplementedError - - cdef int set_costs(self, int* output, StateClass stcls, GoldParse gold) except -1: + cdef int set_valid(self, bint* is_valid, StateClass stcls) except -1: cdef int i for i in range(self.n_moves): - if self.c[i].is_valid(stcls, self.c[i].label): - output[i] = self.c[i].get_cost(stcls, &gold.c, self.c[i].label) + is_valid[i] = self.c[i].is_valid(stcls, self.c[i].label) + + cdef int set_costs(self, bint* is_valid, int* costs, + StateClass stcls, GoldParse gold) except -1: + cdef int i + self.set_valid(is_valid, stcls) + for i in range(self.n_moves): + if is_valid[i]: + costs[i] = self.c[i].get_cost(stcls, &gold.c, self.c[i].label) else: - output[i] = 9000 - - cdef Transition best_gold(self, const weight_t* scores, StateClass stcls, - GoldParse gold) except *: - cdef Transition best - cdef weight_t score = MIN_SCORE - cdef int i - for i in range(self.n_moves): - if self.c[i].is_valid(stcls, self.c[i].label): - cost = self.c[i].get_cost(stcls, &gold.c, self.c[i].label) - if scores[i] > score and cost == 0: - best = self.c[i] - score = scores[i] - assert score > MIN_SCORE - return best + costs[i] = 9000