mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
Merge branch 'neuralnet' into refactor
Mostly refactors parser, to use new thinc3.2 Example class. Aim is to remove use of shared memory, so that we can parallelize over documents easily. Conflicts: setup.py spacy/syntax/parser.pxd spacy/syntax/parser.pyx spacy/syntax/stateclass.pyx
This commit is contained in:
commit
38ca0c33f5
261
bin/parser/nn_train.py
Executable file
261
bin/parser/nn_train.py
Executable file
|
@ -0,0 +1,261 @@
|
|||
#!/usr/bin/env python
|
||||
from __future__ import division
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import os
|
||||
from os import path
|
||||
import shutil
|
||||
import codecs
|
||||
import random
|
||||
|
||||
import plac
|
||||
import cProfile
|
||||
import pstats
|
||||
import re
|
||||
|
||||
import spacy.util
|
||||
from spacy.en import English
|
||||
from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir
|
||||
|
||||
from spacy.syntax.util import Config
|
||||
from spacy.gold import read_json_file
|
||||
from spacy.gold import GoldParse
|
||||
|
||||
from spacy.scorer import Scorer
|
||||
|
||||
from spacy.syntax.parser import Parser, get_templates
|
||||
from spacy._theano import TheanoModel
|
||||
|
||||
import theano
|
||||
import theano.tensor as T
|
||||
|
||||
from theano.printing import Print
|
||||
|
||||
import numpy
|
||||
from collections import OrderedDict, defaultdict
|
||||
|
||||
|
||||
theano.config.profile = False
|
||||
theano.config.floatX = 'float32'
|
||||
floatX = theano.config.floatX
|
||||
|
||||
|
||||
def L1(L1_reg, *weights):
|
||||
return L1_reg * sum(abs(w).sum() for w in weights)
|
||||
|
||||
|
||||
def L2(L2_reg, *weights):
|
||||
return L2_reg * sum((w ** 2).sum() for w in weights)
|
||||
|
||||
|
||||
def rms_prop(loss, params, eta=1.0, rho=0.9, eps=1e-6):
|
||||
updates = OrderedDict()
|
||||
for param in params:
|
||||
value = param.get_value(borrow=True)
|
||||
accu = theano.shared(np.zeros(value.shape, dtype=value.dtype),
|
||||
broadcastable=param.broadcastable)
|
||||
|
||||
grad = T.grad(loss, param)
|
||||
accu_new = rho * accu + (1 - rho) * grad ** 2
|
||||
updates[accu] = accu_new
|
||||
updates[param] = param - (eta * grad / T.sqrt(accu_new + eps))
|
||||
return updates
|
||||
|
||||
|
||||
def relu(x):
|
||||
return x * (x > 0)
|
||||
|
||||
|
||||
def feed_layer(activation, weights, bias, input_):
|
||||
return activation(T.dot(input_, weights) + bias)
|
||||
|
||||
|
||||
def init_weights(n_in, n_out):
|
||||
rng = numpy.random.RandomState(1235)
|
||||
|
||||
weights = numpy.asarray(
|
||||
rng.standard_normal(size=(n_in, n_out)) * numpy.sqrt(2.0 / n_in),
|
||||
dtype=theano.config.floatX
|
||||
)
|
||||
bias = numpy.zeros((n_out,), dtype=theano.config.floatX)
|
||||
return [wrapper(weights, name='W'), wrapper(bias, name='b')]
|
||||
|
||||
|
||||
def compile_model(n_classes, n_hidden, n_in, optimizer):
|
||||
x = T.vector('x')
|
||||
costs = T.ivector('costs')
|
||||
loss = T.scalar('loss')
|
||||
|
||||
maxent_W, maxent_b = init_weights(n_hidden, n_classes)
|
||||
hidden_W, hidden_b = init_weights(n_in, n_hidden)
|
||||
|
||||
# Feed the inputs forward through the network
|
||||
p_y_given_x = feed_layer(
|
||||
T.nnet.softmax,
|
||||
maxent_W,
|
||||
maxent_b,
|
||||
feed_layer(
|
||||
relu,
|
||||
hidden_W,
|
||||
hidden_b,
|
||||
x))
|
||||
|
||||
loss = -T.log(T.sum(p_y_given_x[0] * T.eq(costs, 0)) + 1e-8)
|
||||
|
||||
train_model = theano.function(
|
||||
name='train_model',
|
||||
inputs=[x, costs],
|
||||
outputs=[p_y_given_x[0], T.grad(loss, x), loss],
|
||||
updates=optimizer(loss, [maxent_W, maxent_b, hidden_W, hidden_b]),
|
||||
on_unused_input='warn'
|
||||
)
|
||||
|
||||
evaluate_model = theano.function(
|
||||
name='evaluate_model',
|
||||
inputs=[x],
|
||||
outputs=[
|
||||
feed_layer(
|
||||
T.nnet.softmax,
|
||||
maxent_W,
|
||||
maxent_b,
|
||||
feed_layer(
|
||||
relu,
|
||||
hidden_W,
|
||||
hidden_b,
|
||||
x
|
||||
)
|
||||
)[0]
|
||||
]
|
||||
)
|
||||
return train_model, evaluate_model
|
||||
|
||||
|
||||
def score_model(scorer, nlp, annot_tuples, verbose=False):
|
||||
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
||||
nlp.tagger(tokens)
|
||||
nlp.parser(tokens)
|
||||
gold = GoldParse(tokens, annot_tuples)
|
||||
scorer.score(tokens, gold, verbose=verbose)
|
||||
|
||||
|
||||
def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
||||
eta=0.01, mu=0.9, nv_hidden=100, nv_word=10, nv_tag=10, nv_label=10,
|
||||
seed=0, n_sents=0, verbose=False):
|
||||
|
||||
dep_model_dir = path.join(model_dir, 'deps')
|
||||
pos_model_dir = path.join(model_dir, 'pos')
|
||||
if path.exists(dep_model_dir):
|
||||
shutil.rmtree(dep_model_dir)
|
||||
if path.exists(pos_model_dir):
|
||||
shutil.rmtree(pos_model_dir)
|
||||
os.mkdir(dep_model_dir)
|
||||
os.mkdir(pos_model_dir)
|
||||
setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES, pos_model_dir)
|
||||
|
||||
Config.write(dep_model_dir, 'config',
|
||||
seed=seed,
|
||||
templates=tuple(),
|
||||
labels=Language.ParserTransitionSystem.get_labels(gold_tuples),
|
||||
vector_lengths=(nv_word, nv_tag, nv_label),
|
||||
hidden_nodes=nv_hidden,
|
||||
eta=eta,
|
||||
mu=mu
|
||||
)
|
||||
|
||||
# Bake-in hyper-parameters
|
||||
optimizer = lambda loss, params: rms_prop(loss, params, eta=eta, rho=rho, eps=eps)
|
||||
nlp = Language(data_dir=model_dir)
|
||||
n_classes = nlp.parser.model.n_classes
|
||||
train, predict = compile_model(n_classes, nv_hidden, n_in, optimizer)
|
||||
nlp.parser.model = TheanoModel(n_classes, input_spec, train,
|
||||
predict, model_loc)
|
||||
|
||||
if n_sents > 0:
|
||||
gold_tuples = gold_tuples[:n_sents]
|
||||
print "Itn.\tP.Loss\tUAS\tTag %\tToken %"
|
||||
log_loc = path.join(model_dir, 'job.log')
|
||||
for itn in range(n_iter):
|
||||
scorer = Scorer()
|
||||
loss = 0
|
||||
for _, sents in gold_tuples:
|
||||
for annot_tuples, ctnt in sents:
|
||||
if len(annot_tuples[1]) == 1:
|
||||
continue
|
||||
score_model(scorer, nlp, annot_tuples)
|
||||
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
||||
nlp.tagger(tokens)
|
||||
gold = GoldParse(tokens, annot_tuples, make_projective=True)
|
||||
assert gold.is_projective
|
||||
loss += nlp.parser.train(tokens, gold)
|
||||
nlp.tagger.train(tokens, gold.tags)
|
||||
random.shuffle(gold_tuples)
|
||||
logline = '%d:\t%d\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas,
|
||||
scorer.tags_acc,
|
||||
scorer.token_acc)
|
||||
print logline
|
||||
with open(log_loc, 'aw') as file_:
|
||||
file_.write(logline + '\n')
|
||||
nlp.parser.model.end_training()
|
||||
nlp.tagger.model.end_training()
|
||||
nlp.vocab.strings.dump(path.join(model_dir, 'vocab', 'strings.txt'))
|
||||
return nlp
|
||||
|
||||
|
||||
def evaluate(nlp, gold_tuples, gold_preproc=True):
|
||||
scorer = Scorer()
|
||||
for raw_text, sents in gold_tuples:
|
||||
for annot_tuples, brackets in sents:
|
||||
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
||||
nlp.tagger(tokens)
|
||||
nlp.parser(tokens)
|
||||
gold = GoldParse(tokens, annot_tuples)
|
||||
scorer.score(tokens, gold)
|
||||
return scorer
|
||||
|
||||
|
||||
@plac.annotations(
|
||||
train_loc=("Location of training file or directory"),
|
||||
dev_loc=("Location of development file or directory"),
|
||||
model_dir=("Location of output model directory",),
|
||||
eval_only=("Skip training, and only evaluate", "flag", "e", bool),
|
||||
n_sents=("Number of training sentences", "option", "n", int),
|
||||
n_iter=("Number of training iterations", "option", "i", int),
|
||||
verbose=("Verbose error reporting", "flag", "v", bool),
|
||||
|
||||
nv_word=("Word vector length", "option", "W", int),
|
||||
nv_tag=("Tag vector length", "option", "T", int),
|
||||
nv_label=("Label vector length", "option", "L", int),
|
||||
nv_hidden=("Hidden nodes length", "option", "H", int),
|
||||
eta=("Learning rate", "option", "E", float),
|
||||
mu=("Momentum", "option", "M", float),
|
||||
)
|
||||
def main(train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, verbose=False,
|
||||
nv_word=10, nv_tag=10, nv_label=10, nv_hidden=10,
|
||||
eta=0.1, mu=0.9, eval_only=False):
|
||||
|
||||
|
||||
|
||||
|
||||
gold_train = list(read_json_file(train_loc, lambda doc: 'wsj' in doc['id']))
|
||||
|
||||
nlp = train(English, gold_train, model_dir,
|
||||
feat_set='embed',
|
||||
eta=eta, mu=mu,
|
||||
nv_word=nv_word, nv_tag=nv_tag, nv_label=nv_label, nv_hidden=nv_hidden,
|
||||
n_sents=n_sents, n_iter=n_iter,
|
||||
verbose=verbose)
|
||||
|
||||
scorer = evaluate(nlp, list(read_json_file(dev_loc)))
|
||||
|
||||
print 'TOK', 100-scorer.token_acc
|
||||
print 'POS', scorer.tags_acc
|
||||
print 'UAS', scorer.uas
|
||||
print 'LAS', scorer.las
|
||||
|
||||
print 'NER P', scorer.ents_p
|
||||
print 'NER R', scorer.ents_r
|
||||
print 'NER F', scorer.ents_f
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
plac.call(main)
|
3
setup.py
3
setup.py
|
@ -152,7 +152,8 @@ MOD_NAMES = ['spacy.parts_of_speech', 'spacy.strings',
|
|||
'spacy.lexeme', 'spacy.vocab',
|
||||
'spacy.morphology',
|
||||
'spacy.syntax.stateclass',
|
||||
'spacy._ml', 'spacy.tokenizer', 'spacy.en.attrs',
|
||||
'spacy._ml', 'spacy._theano',
|
||||
'spacy.tokenizer', 'spacy.en.attrs',
|
||||
'spacy.en.pos', 'spacy.syntax.parser',
|
||||
'spacy.syntax.transition_system',
|
||||
'spacy.syntax.arc_eager',
|
||||
|
|
490
spacy/_bu_nn.pyx
Normal file
490
spacy/_bu_nn.pyx
Normal file
|
@ -0,0 +1,490 @@
|
|||
"""Feed-forward neural network, using Thenao."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy
|
||||
|
||||
import theano
|
||||
import theano.tensor as T
|
||||
import gzip
|
||||
import cPickle
|
||||
|
||||
|
||||
def load_data(dataset):
|
||||
''' Loads the dataset
|
||||
|
||||
:type dataset: string
|
||||
:param dataset: the path to the dataset (here MNIST)
|
||||
'''
|
||||
|
||||
#############
|
||||
# LOAD DATA #
|
||||
#############
|
||||
|
||||
# Download the MNIST dataset if it is not present
|
||||
data_dir, data_file = os.path.split(dataset)
|
||||
if data_dir == "" and not os.path.isfile(dataset):
|
||||
# Check if dataset is in the data directory.
|
||||
new_path = os.path.join(
|
||||
os.path.split(__file__)[0],
|
||||
"..",
|
||||
"data",
|
||||
dataset
|
||||
)
|
||||
if os.path.isfile(new_path) or data_file == 'mnist.pkl.gz':
|
||||
dataset = new_path
|
||||
|
||||
if (not os.path.isfile(dataset)) and data_file == 'mnist.pkl.gz':
|
||||
import urllib
|
||||
origin = (
|
||||
'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz'
|
||||
)
|
||||
print 'Downloading data from %s' % origin
|
||||
urllib.urlretrieve(origin, dataset)
|
||||
|
||||
print '... loading data'
|
||||
|
||||
# Load the dataset
|
||||
f = gzip.open(dataset, 'rb')
|
||||
train_set, valid_set, test_set = cPickle.load(f)
|
||||
f.close()
|
||||
#train_set, valid_set, test_set format: tuple(input, target)
|
||||
#input is an numpy.ndarray of 2 dimensions (a matrix),
|
||||
#each row corresponding to an example. target is a
|
||||
#numpy.ndarray of 1 dimension (vector)) that have the same length as
|
||||
#the number of rows in the input. It should give the target
|
||||
#target to the example with the same index in the input.
|
||||
|
||||
def shared_dataset(data_xy, borrow=True):
|
||||
""" Function that loads the dataset into shared variables
|
||||
|
||||
The reason we store our dataset in shared variables is to allow
|
||||
Theano to copy it into the GPU memory (when code is run on GPU).
|
||||
Since copying data into the GPU is slow, copying a minibatch everytime
|
||||
is needed (the default behaviour if the data is not in a shared
|
||||
variable) would lead to a large decrease in performance.
|
||||
"""
|
||||
data_x, data_y = data_xy
|
||||
shared_x = theano.shared(numpy.asarray(data_x, dtype=theano.config.floatX),
|
||||
borrow=borrow)
|
||||
shared_y = theano.shared(numpy.asarray(data_y, dtype=theano.config.floatX),
|
||||
borrow=borrow)
|
||||
# When storing data on the GPU it has to be stored as floats
|
||||
# therefore we will store the labels as ``floatX`` as well
|
||||
# (``shared_y`` does exactly that). But during our computations
|
||||
# we need them as ints (we use labels as index, and if they are
|
||||
# floats it doesn't make sense) therefore instead of returning
|
||||
# ``shared_y`` we will have to cast it to int. This little hack
|
||||
# lets ous get around this issue
|
||||
return shared_x, T.cast(shared_y, 'int32')
|
||||
|
||||
test_set_x, test_set_y = shared_dataset(test_set)
|
||||
valid_set_x, valid_set_y = shared_dataset(valid_set)
|
||||
train_set_x, train_set_y = shared_dataset(train_set)
|
||||
|
||||
rval = [(train_set_x, train_set_y), (valid_set_x, valid_set_y),
|
||||
(test_set_x, test_set_y)]
|
||||
return rval
|
||||
|
||||
|
||||
class LogisticRegression(object):
|
||||
"""Multi-class Logistic Regression Class
|
||||
|
||||
The logistic regression is fully described by a weight matrix :math:`W`
|
||||
and bias vector :math:`b`. Classification is done by projecting data
|
||||
points onto a set of hyperplanes, the distance to which is used to
|
||||
determine a class membership probability.
|
||||
"""
|
||||
|
||||
def __init__(self, input, n_in, n_out):
|
||||
""" Initialize the parameters of the logistic regression
|
||||
|
||||
:type input: theano.tensor.TensorType
|
||||
:param input: symbolic variable that describes the input of the
|
||||
architecture (one minibatch)
|
||||
|
||||
:type n_in: int
|
||||
:param n_in: number of input units, the dimension of the space in
|
||||
which the datapoints lie
|
||||
|
||||
:type n_out: int
|
||||
:param n_out: number of output units, the dimension of the space in
|
||||
which the labels lie
|
||||
|
||||
"""
|
||||
# start-snippet-1
|
||||
# initialize with 0 the weights W as a matrix of shape (n_in, n_out)
|
||||
self.W = theano.shared(
|
||||
value=numpy.zeros((n_in, n_out),
|
||||
dtype=theano.config.floatX
|
||||
),
|
||||
name='W',
|
||||
borrow=True
|
||||
)
|
||||
# initialize the baises b as a vector of n_out 0s
|
||||
self.b = theano.shared(
|
||||
value=numpy.zeros(
|
||||
(n_out,),
|
||||
dtype=theano.config.floatX
|
||||
),
|
||||
name='b',
|
||||
borrow=True
|
||||
)
|
||||
|
||||
# symbolic expression for computing the matrix of class-membership
|
||||
# probabilities
|
||||
# Where:
|
||||
# W is a matrix where column-k represent the separation hyper plain for
|
||||
# class-k
|
||||
# x is a matrix where row-j represents input training sample-j
|
||||
# b is a vector where element-k represent the free parameter of hyper
|
||||
# plain-k
|
||||
self.p_y_given_x = T.nnet.softmax(T.dot(input, self.W) + self.b)
|
||||
|
||||
# symbolic description of how to compute prediction as class whose
|
||||
# probability is maximal
|
||||
self.y_pred = T.argmax(self.p_y_given_x, axis=1)
|
||||
# end-snippet-1
|
||||
|
||||
# parameters of the model
|
||||
self.params = [self.W, self.b]
|
||||
|
||||
def neg_ll(self, y):
|
||||
"""Return the mean of the negative log-likelihood of the prediction
|
||||
of this model under a given target distribution.
|
||||
|
||||
.. math::
|
||||
|
||||
\frac{1}{|\mathcal{D}|} \mathcal{L} (\theta=\{W,b\}, \mathcal{D}) =
|
||||
\frac{1}{|\mathcal{D}|} \sum_{i=0}^{|\mathcal{D}|}
|
||||
\log(P(Y=y^{(i)}|x^{(i)}, W,b)) \\
|
||||
\ell (\theta=\{W,b\}, \mathcal{D})
|
||||
|
||||
:type y: theano.tensor.TensorType
|
||||
:param y: corresponds to a vector that gives for each example the
|
||||
correct label
|
||||
|
||||
Note: we use the mean instead of the sum so that
|
||||
the learning rate is less dependent on the batch size
|
||||
"""
|
||||
# start-snippet-2
|
||||
# y.shape[0] is (symbolically) the number of rows in y, i.e.,
|
||||
# number of examples (call it n) in the minibatch
|
||||
# T.arange(y.shape[0]) is a symbolic vector which will contain
|
||||
# [0,1,2,... n-1] T.log(self.p_y_given_x) is a matrix of
|
||||
# Log-Probabilities (call it LP) with one row per example and
|
||||
# one column per class LP[T.arange(y.shape[0]),y] is a vector
|
||||
# v containing [LP[0,y[0]], LP[1,y[1]], LP[2,y[2]], ...,
|
||||
# LP[n-1,y[n-1]]] and T.mean(LP[T.arange(y.shape[0]),y]) is
|
||||
# the mean (across minibatch examples) of the elements in v,
|
||||
# i.e., the mean log-likelihood across the minibatch.
|
||||
return -T.mean(T.log(self.p_y_given_x)[T.arange(y.shape[0]), y])
|
||||
# end-snippet-2
|
||||
|
||||
def errors(self, y):
|
||||
"""Return a float representing the number of errors in the minibatch
|
||||
over the total number of examples of the minibatch ; zero one
|
||||
loss over the size of the minibatch
|
||||
|
||||
:type y: theano.tensor.TensorType
|
||||
:param y: corresponds to a vector that gives for each example the
|
||||
correct label
|
||||
"""
|
||||
|
||||
# check if y has same dimension of y_pred
|
||||
if y.ndim != self.y_pred.ndim:
|
||||
raise TypeError(
|
||||
'y should have the same shape as self.y_pred',
|
||||
('y', y.type, 'y_pred', self.y_pred.type)
|
||||
)
|
||||
# check if y is of the correct datatype
|
||||
if y.dtype.startswith('int'):
|
||||
# the T.neq operator returns a vector of 0s and 1s, where 1
|
||||
# represents a mistake in prediction
|
||||
return T.mean(T.neq(self.y_pred, y))
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
# start-snippet-1
|
||||
class HiddenLayer(object):
|
||||
def __init__(self, rng, input, n_in, n_out, W=None, b=None,
|
||||
activation=T.tanh):
|
||||
"""
|
||||
Typical hidden layer of a MLP: units are fully-connected and have
|
||||
sigmoidal activation function. Weight matrix W is of shape (n_in,n_out)
|
||||
and the bias vector b is of shape (n_out,).
|
||||
|
||||
NOTE : The nonlinearity used here is tanh
|
||||
|
||||
Hidden unit activation is given by: tanh(dot(input,W) + b)
|
||||
|
||||
:type rng: numpy.random.RandomState
|
||||
:param rng: a random number generator used to initialize weights
|
||||
|
||||
:type input: theano.tensor.dmatrix
|
||||
:param input: a symbolic tensor of shape (n_examples, n_in)
|
||||
|
||||
:type n_in: int
|
||||
:param n_in: dimensionality of input
|
||||
|
||||
:type n_out: int
|
||||
:param n_out: number of hidden units
|
||||
|
||||
:type activation: theano.Op or function
|
||||
:param activation: Non linearity to be applied in the hidden
|
||||
layer
|
||||
"""
|
||||
self.input = input
|
||||
# end-snippet-1
|
||||
|
||||
# `W` is initialized with `W_values` which is uniformely sampled
|
||||
# from sqrt(-6./(n_in+n_hidden)) and sqrt(6./(n_in+n_hidden))
|
||||
# for tanh activation function
|
||||
# the output of uniform if converted using asarray to dtype
|
||||
# theano.config.floatX so that the code is runable on GPU
|
||||
# Note : optimal initialization of weights is dependent on the
|
||||
# activation function used (among other things).
|
||||
# For example, results presented in [Xavier10] suggest that you
|
||||
# should use 4 times larger initial weights for sigmoid
|
||||
# compared to tanh
|
||||
# We have no info for other function, so we use the same as
|
||||
# tanh.
|
||||
if W is None:
|
||||
W_values = numpy.asarray(
|
||||
rng.uniform(
|
||||
low=-numpy.sqrt(6. / (n_in + n_out)),
|
||||
high=numpy.sqrt(6. / (n_in + n_out)),
|
||||
size=(n_in, n_out)
|
||||
),
|
||||
dtype=theano.config.floatX
|
||||
)
|
||||
if activation == theano.tensor.nnet.sigmoid:
|
||||
W_values *= 4
|
||||
|
||||
W = theano.shared(value=W_values, name='W', borrow=True)
|
||||
|
||||
if b is None:
|
||||
b_values = numpy.zeros((n_out,), dtype=theano.config.floatX)
|
||||
b = theano.shared(value=b_values, name='b', borrow=True)
|
||||
|
||||
self.W = W
|
||||
self.b = b
|
||||
|
||||
lin_output = T.dot(input, self.W) + self.b
|
||||
self.output = (
|
||||
lin_output if activation is None
|
||||
else activation(lin_output)
|
||||
)
|
||||
# parameters of the model
|
||||
self.params = [self.W, self.b]
|
||||
|
||||
|
||||
# start-snippet-2
|
||||
class MLP(object):
|
||||
"""Multi-Layer Perceptron Class
|
||||
|
||||
A multilayer perceptron is a feedforward artificial neural network model
|
||||
that has one layer or more of hidden units and nonlinear activations.
|
||||
Intermediate layers usually have as activation function tanh or the
|
||||
sigmoid function (defined here by a ``HiddenLayer`` class) while the
|
||||
top layer is a softmax layer (defined here by a ``LogisticRegression``
|
||||
class).
|
||||
"""
|
||||
|
||||
def __init__(self, rng, input, n_in, n_hidden, n_out):
|
||||
"""Initialize the parameters for the multilayer perceptron
|
||||
|
||||
:type rng: numpy.random.RandomState
|
||||
:param rng: a random number generator used to initialize weights
|
||||
|
||||
:type input: theano.tensor.TensorType
|
||||
:param input: symbolic variable that describes the input of the
|
||||
architecture (one minibatch)
|
||||
|
||||
:type n_in: int
|
||||
:param n_in: number of input units, the dimension of the space in
|
||||
which the datapoints lie
|
||||
|
||||
:type n_hidden: int
|
||||
:param n_hidden: number of hidden units
|
||||
|
||||
:type n_out: int
|
||||
:param n_out: number of output units, the dimension of the space in
|
||||
which the labels lie
|
||||
|
||||
"""
|
||||
|
||||
# Since we are dealing with a one hidden layer MLP, this will translate
|
||||
# into a HiddenLayer with a tanh activation function connected to the
|
||||
# LogisticRegression layer; the activation function can be replaced by
|
||||
# sigmoid or any other nonlinear function
|
||||
self.hidden = HiddenLayer(
|
||||
rng=rng,
|
||||
input=input,
|
||||
n_in=n_in,
|
||||
n_out=n_hidden,
|
||||
activation=T.tanh
|
||||
)
|
||||
|
||||
# The logistic regression layer gets as input the hidden units
|
||||
# of the hidden layer
|
||||
self.maxent = LogisticRegression(
|
||||
input=self.hidden.output,
|
||||
n_in=n_hidden,
|
||||
n_out=n_out
|
||||
)
|
||||
# L1 norm ; one regularization option is to enforce L1 norm to
|
||||
# be small
|
||||
self.L1 = abs(self.hidden.W).sum() + abs(self.maxent.W).sum()
|
||||
|
||||
# square of L2 norm ; one regularization option is to enforce
|
||||
# square of L2 norm to be small
|
||||
self.L2_sqr = (self.hidden.W ** 2).sum() + (self.maxent.W ** 2).sum()
|
||||
|
||||
# negative log likelihood of the MLP is given by the negative
|
||||
# log likelihood of the output of the model, computed in the
|
||||
# logistic regression layer
|
||||
self.neg_ll = self.maxent.neg_ll
|
||||
# same holds for the function computing the number of errors
|
||||
self.errors = self.maxent.errors
|
||||
|
||||
# the parameters of the model are the parameters of the two layer it is
|
||||
# made out of
|
||||
self.params = self.hidden.params + self.maxent.params
|
||||
|
||||
|
||||
|
||||
|
||||
def test_mlp(learning_rate=0.01, L1_reg=0.00, L2_reg=0.0001, n_epochs=1000,
|
||||
dataset='mnist.pkl.gz', batch_size=1, n_hidden=500):
|
||||
"""
|
||||
Demonstrate stochastic gradient descent optimization for a multilayer
|
||||
perceptron
|
||||
|
||||
This is demonstrated on MNIST.
|
||||
|
||||
:type learning_rate: float
|
||||
:param learning_rate: learning rate used (factor for the stochastic
|
||||
gradient
|
||||
|
||||
:type L1_reg: float
|
||||
:param L1_reg: L1-norm's weight when added to the cost (see
|
||||
regularization)
|
||||
|
||||
:type L2_reg: float
|
||||
:param L2_reg: L2-norm's weight when added to the cost (see
|
||||
regularization)
|
||||
|
||||
:type n_epochs: int
|
||||
:param n_epochs: maximal number of epochs to run the optimizer
|
||||
|
||||
:type dataset: string
|
||||
:param dataset: the path of the MNIST dataset file from
|
||||
http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz
|
||||
"""
|
||||
datasets = load_data(dataset)
|
||||
|
||||
train_set_x, train_set_y = datasets[0]
|
||||
valid_set_x, valid_set_y = datasets[1]
|
||||
test_set_x, test_set_y = datasets[2]
|
||||
|
||||
######################
|
||||
# BUILD ACTUAL MODEL #
|
||||
######################
|
||||
print '... building the model'
|
||||
|
||||
# allocate symbolic variables for the data
|
||||
index = T.lscalar() # index to a [mini]batch
|
||||
x = T.matrix('x') # the data is presented as rasterized images
|
||||
y = T.ivector('y') # the labels are presented as 1D vector of
|
||||
# [int] labels
|
||||
|
||||
rng = numpy.random.RandomState(1234)
|
||||
|
||||
# construct the MLP class
|
||||
mlp = MLP(
|
||||
rng=rng,
|
||||
input=x,
|
||||
n_in=28 * 28,
|
||||
n_hidden=n_hidden,
|
||||
n_out=10
|
||||
)
|
||||
|
||||
# the cost we minimize during training is the negative log likelihood of
|
||||
# the model plus the regularization terms (L1 and L2); cost is expressed
|
||||
# here symbolically
|
||||
|
||||
# compiling a Theano function that computes the mistakes that are made
|
||||
# by the model on a minibatch
|
||||
test_model = theano.function(
|
||||
inputs=[index],
|
||||
outputs=mlp.maxent.errors(y),
|
||||
givens={
|
||||
x: test_set_x[index:index+1],
|
||||
y: test_set_y[index:index+1]
|
||||
}
|
||||
)
|
||||
|
||||
validate_model = theano.function(
|
||||
inputs=[index],
|
||||
outputs=mlp.maxent.errors(y),
|
||||
givens={
|
||||
x: valid_set_x[index:index+1],
|
||||
y: valid_set_y[index:index+1]
|
||||
}
|
||||
)
|
||||
|
||||
# compute the gradient of cost with respect to theta (sotred in params)
|
||||
# the resulting gradients will be stored in a list gparams
|
||||
cost = mlp.neg_ll(y) + L1_reg * mlp.L1 + L2_reg * mlp.L2_sqr
|
||||
gparams = [T.grad(cost, param) for param in mlp.params]
|
||||
|
||||
# specify how to update the parameters of the model as a list of
|
||||
# (variable, update expression) pairs
|
||||
|
||||
updates = [(mlp.params[i], mlp.params[i] - (learning_rate * gparams[i]))
|
||||
for i in xrange(len(gparams))]
|
||||
|
||||
# compiling a Theano function `train_model` that returns the cost, but
|
||||
# in the same time updates the parameter of the model based on the rules
|
||||
# defined in `updates`
|
||||
train_model = theano.function(
|
||||
inputs=[index],
|
||||
outputs=cost,
|
||||
updates=updates,
|
||||
givens={
|
||||
x: train_set_x[index:index+1],
|
||||
y: train_set_y[index:index+1]
|
||||
}
|
||||
)
|
||||
# end-snippet-5
|
||||
|
||||
###############
|
||||
# TRAIN MODEL #
|
||||
###############
|
||||
print '... training'
|
||||
|
||||
start_time = time.clock()
|
||||
|
||||
n_examples = train_set_x.get_value(borrow=True).shape[0]
|
||||
n_dev_examples = valid_set_x.get_value(borrow=True).shape[0]
|
||||
n_test_examples = test_set_x.get_value(borrow=True).shape[0]
|
||||
|
||||
for epoch in range(1, n_epochs+1):
|
||||
for idx in xrange(n_examples):
|
||||
train_model(idx)
|
||||
# compute zero-one loss on validation set
|
||||
error = numpy.mean(map(validate_model, xrange(n_dev_examples)))
|
||||
print('epoch %i, validation error %f %%' % (epoch, error * 100))
|
||||
|
||||
end_time = time.clock()
|
||||
print >> sys.stderr, ('The code for file ' +
|
||||
os.path.split(__file__)[1] +
|
||||
' ran for %.2fm' % ((end_time - start_time) / 60.))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_mlp()
|
|
@ -5,6 +5,7 @@ from cymem.cymem cimport Pool
|
|||
from thinc.learner cimport LinearModel
|
||||
from thinc.features cimport Extractor, Feature
|
||||
from thinc.typedefs cimport atom_t, feat_t, weight_t, class_t
|
||||
from thinc.api cimport ExampleC
|
||||
|
||||
from preshed.maps cimport PreshMapArray
|
||||
|
||||
|
@ -13,9 +14,14 @@ from .typedefs cimport hash_t, id_t
|
|||
|
||||
cdef int arg_max(const weight_t* scores, const int n_classes) nogil
|
||||
|
||||
cdef int arg_max_if_true(const weight_t* scores, const int* is_valid, int n_classes) nogil
|
||||
|
||||
cdef int arg_max_if_zero(const weight_t* scores, const int* costs, int n_classes) nogil
|
||||
|
||||
|
||||
cdef class Model:
|
||||
cdef int n_classes
|
||||
cdef readonly int n_classes
|
||||
cdef readonly int n_feats
|
||||
|
||||
cdef const weight_t* score(self, atom_t* context) except NULL
|
||||
cdef int set_scores(self, weight_t* scores, atom_t* context) except -1
|
||||
|
|
|
@ -10,6 +10,7 @@ import cython
|
|||
import numpy.random
|
||||
|
||||
from thinc.features cimport Feature, count_feats
|
||||
from thinc.api cimport Example
|
||||
|
||||
|
||||
cdef int arg_max(const weight_t* scores, const int n_classes) nogil:
|
||||
|
@ -23,17 +24,52 @@ cdef int arg_max(const weight_t* scores, const int n_classes) nogil:
|
|||
return best
|
||||
|
||||
|
||||
cdef int arg_max_if_true(const weight_t* scores, const int* is_valid,
|
||||
const int n_classes) nogil:
|
||||
cdef int i
|
||||
cdef int best = 0
|
||||
cdef weight_t mode = -900000
|
||||
for i in range(n_classes):
|
||||
if is_valid[i] and scores[i] > mode:
|
||||
mode = scores[i]
|
||||
best = i
|
||||
return best
|
||||
|
||||
|
||||
cdef int arg_max_if_zero(const weight_t* scores, const int* costs,
|
||||
const int n_classes) nogil:
|
||||
cdef int i
|
||||
cdef int best = 0
|
||||
cdef weight_t mode = -900000
|
||||
for i in range(n_classes):
|
||||
if costs[i] == 0 and scores[i] > mode:
|
||||
mode = scores[i]
|
||||
best = i
|
||||
return best
|
||||
|
||||
|
||||
cdef class Model:
|
||||
def __init__(self, n_classes, templates, model_loc=None):
|
||||
if model_loc is not None and path.isdir(model_loc):
|
||||
model_loc = path.join(model_loc, 'model')
|
||||
self.n_classes = n_classes
|
||||
self._extractor = Extractor(templates)
|
||||
self.n_feats = self._extractor.n_templ
|
||||
self._model = LinearModel(n_classes, self._extractor.n_templ)
|
||||
self.model_loc = model_loc
|
||||
if self.model_loc and path.exists(self.model_loc):
|
||||
self._model.load(self.model_loc, freq_thresh=0)
|
||||
|
||||
def predict(self, Example eg):
|
||||
self.set_scores(eg.c.scores, eg.c.atoms)
|
||||
eg.c.guess = arg_max_if_true(eg.c.scores, eg.c.is_valid, self.n_classes)
|
||||
|
||||
def train(self, Example eg):
|
||||
self.predict(eg)
|
||||
eg.c.best = arg_max_if_zero(eg.c.scores, eg.c.costs, self.n_classes)
|
||||
eg.c.cost = eg.c.costs[eg.c.guess]
|
||||
self.update(eg.c.atoms, eg.c.guess, eg.c.best, eg.c.cost)
|
||||
|
||||
cdef const weight_t* score(self, atom_t* context) except NULL:
|
||||
cdef int n_feats
|
||||
feats = self._extractor.get_feats(context, &n_feats)
|
||||
|
|
3
spacy/_nn.py
Normal file
3
spacy/_nn.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
"""Feed-forward neural network, using Thenao."""
|
||||
|
||||
|
146
spacy/_nn.pyx
Normal file
146
spacy/_nn.pyx
Normal file
|
@ -0,0 +1,146 @@
|
|||
"""Feed-forward neural network, using Thenao."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy
|
||||
|
||||
import theano
|
||||
import theano.tensor as T
|
||||
import plac
|
||||
|
||||
from spacy.gold import read_json_file
|
||||
from spacy.gold import GoldParse
|
||||
from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir
|
||||
|
||||
|
||||
def build_model(n_classes, n_vocab, n_hidden, n_word_embed, n_tag_embed):
|
||||
# allocate symbolic variables for the data
|
||||
words = T.vector('words')
|
||||
tags = T.vector('tags')
|
||||
|
||||
word_e = _init_embedding(n_words, n_word_embed)
|
||||
tag_e = _init_embedding(n_tags, n_tag_embed)
|
||||
label_e = _init_embedding(n_labels, n_label_embed)
|
||||
maxent_W, maxent_b = _init_maxent_weights(n_hidden, n_classes)
|
||||
hidden_W, hidden_b = _init_hidden_weights(28*28, n_hidden, T.tanh)
|
||||
params = [hidden_W, hidden_b, maxent_W, maxent_b, word_e, tag_e, label_e]
|
||||
|
||||
x = T.concatenate([
|
||||
T.flatten(word_e[word_indices], outdim=1),
|
||||
T.flatten(tag_e[tag_indices], outdim=1)])
|
||||
|
||||
p_y_given_x = feed_layer(
|
||||
T.nnet.softmax,
|
||||
maxent_W,
|
||||
maxent_b,
|
||||
feed_layer(
|
||||
T.tanh,
|
||||
hidden_W,
|
||||
hidden_b,
|
||||
x))[0]
|
||||
|
||||
guess = T.argmax(p_y_given_x)
|
||||
|
||||
cost = (
|
||||
-T.log(p_y_given_x[y])
|
||||
+ L1(L1_reg, maxent_W, hidden_W, word_e, tag_e)
|
||||
+ L2(L2_reg, maxent_W, hidden_W, wod_e, tag_e)
|
||||
)
|
||||
|
||||
train_model = theano.function(
|
||||
inputs=[words, tags, y],
|
||||
outputs=guess,
|
||||
updates=[update(learning_rate, param, cost) for param in params]
|
||||
)
|
||||
|
||||
evaluate_model = theano.function(
|
||||
inputs=[x, y],
|
||||
outputs=T.neq(y, T.argmax(p_y_given_x[0])),
|
||||
)
|
||||
return train_model, evaluate_model
|
||||
|
||||
|
||||
def _init_embedding(vocab_size, n_dim):
|
||||
embedding = 0.2 * numpy.random.uniform(-1.0, 1.0, (vocab_size+1, n_dim))
|
||||
return theano.shared(embedding).astype(theano.config.floatX)
|
||||
|
||||
|
||||
def _init_maxent_weights(n_hidden, n_out):
|
||||
weights = numpy.zeros((n_hidden, 10), dtype=theano.config.floatX)
|
||||
bias = numpy.zeros((10,), dtype=theano.config.floatX)
|
||||
return (
|
||||
theano.shared(name='W', borrow=True, value=weights),
|
||||
theano.shared(name='b', borrow=True, value=bias)
|
||||
)
|
||||
|
||||
|
||||
def _init_hidden_weights(n_in, n_out, activation=T.tanh):
|
||||
rng = numpy.random.RandomState(1234)
|
||||
weights = numpy.asarray(
|
||||
rng.uniform(
|
||||
low=-numpy.sqrt(6. / (n_in + n_out)),
|
||||
high=numpy.sqrt(6. / (n_in + n_out)),
|
||||
size=(n_in, n_out)
|
||||
),
|
||||
dtype=theano.config.floatX
|
||||
)
|
||||
|
||||
bias = numpy.zeros((n_out,), dtype=theano.config.floatX)
|
||||
return (
|
||||
theano.shared(value=weights, name='W', borrow=True),
|
||||
theano.shared(value=bias, name='b', borrow=True)
|
||||
)
|
||||
|
||||
|
||||
def feed_layer(activation, weights, bias, input):
|
||||
return activation(T.dot(input, weights) + bias)
|
||||
|
||||
|
||||
def L1(L1_reg, w1, w2):
|
||||
return L1_reg * (abs(w1).sum() + abs(w2).sum())
|
||||
|
||||
|
||||
def L2(L2_reg, w1, w2):
|
||||
return L2_reg * ((w1 ** 2).sum() + (w2 ** 2).sum())
|
||||
|
||||
|
||||
def update(eta, param, cost):
|
||||
return (param, param - (eta * T.grad(cost, param)))
|
||||
|
||||
|
||||
def main(train_loc, eval_loc, model_dir):
|
||||
learning_rate = 0.01
|
||||
L1_reg = 0.00
|
||||
L2_reg = 0.0001
|
||||
|
||||
print "... reading the data"
|
||||
gold_train = list(read_json_file(train_loc))
|
||||
print '... building the model'
|
||||
pos_model_dir = path.join(model_dir, 'pos')
|
||||
if path.exists(pos_model_dir):
|
||||
shutil.rmtree(pos_model_dir)
|
||||
os.mkdir(pos_model_dir)
|
||||
|
||||
setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES, pos_model_dir)
|
||||
|
||||
train_model, evaluate_model = build_model(n_hidden, len(POS_TAGS), learning_rate,
|
||||
L1_reg, L2_reg)
|
||||
|
||||
print '... training'
|
||||
for epoch in range(1, n_epochs+1):
|
||||
for raw_text, sents in gold_tuples:
|
||||
for (ids, words, tags, ner, heads, deps), _ in sents:
|
||||
tokens = nlp.tokenizer.tokens_from_list(words)
|
||||
for t in tokens:
|
||||
guess = train_model([t.orth], [t.tag])
|
||||
loss += guess != t.tag
|
||||
print loss
|
||||
# compute zero-one loss on validation set
|
||||
#error = numpy.mean([evaluate_model(x, y) for x, y in dev_examples])
|
||||
#print('epoch %i, validation error %f %%' % (epoch, error * 100))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
plac.call(main)
|
13
spacy/_theano.pxd
Normal file
13
spacy/_theano.pxd
Normal file
|
@ -0,0 +1,13 @@
|
|||
from ._ml cimport Model
|
||||
from thinc.nn cimport InputLayer
|
||||
|
||||
|
||||
cdef class TheanoModel(Model):
|
||||
cdef InputLayer input_layer
|
||||
cdef object train_func
|
||||
cdef object predict_func
|
||||
cdef object debug
|
||||
|
||||
cdef public float eta
|
||||
cdef public float mu
|
||||
cdef public float t
|
52
spacy/_theano.pyx
Normal file
52
spacy/_theano.pyx
Normal file
|
@ -0,0 +1,52 @@
|
|||
from thinc.api cimport Example, ExampleC
|
||||
from thinc.typedefs cimport weight_t
|
||||
|
||||
from ._ml cimport arg_max_if_true
|
||||
from ._ml cimport arg_max_if_zero
|
||||
|
||||
import numpy
|
||||
from os import path
|
||||
|
||||
|
||||
cdef class TheanoModel(Model):
|
||||
def __init__(self, n_classes, input_spec, train_func, predict_func, model_loc=None,
|
||||
eta=0.001, mu=0.9, debug=None):
|
||||
if model_loc is not None and path.isdir(model_loc):
|
||||
model_loc = path.join(model_loc, 'model')
|
||||
|
||||
self.eta = eta
|
||||
self.mu = mu
|
||||
self.t = 1
|
||||
initializer = lambda: 0.2 * numpy.random.uniform(-1.0, 1.0)
|
||||
self.input_layer = InputLayer(input_spec, initializer)
|
||||
self.train_func = train_func
|
||||
self.predict_func = predict_func
|
||||
self.debug = debug
|
||||
|
||||
self.n_classes = n_classes
|
||||
self.n_feats = len(self.input_layer)
|
||||
self.model_loc = model_loc
|
||||
|
||||
def predict(self, Example eg):
|
||||
self.input_layer.fill(eg.embeddings, eg.atoms, use_avg=True)
|
||||
theano_scores = self.predict_func(eg.embeddings)[0]
|
||||
cdef int i
|
||||
for i in range(self.n_classes):
|
||||
eg.c.scores[i] = theano_scores[i]
|
||||
eg.c.guess = arg_max_if_true(eg.c.scores, eg.c.is_valid, self.n_classes)
|
||||
|
||||
def train(self, Example eg):
|
||||
self.input_layer.fill(eg.embeddings, eg.atoms, use_avg=False)
|
||||
theano_scores, update, y, loss = self.train_func(eg.embeddings, eg.costs,
|
||||
self.eta, self.mu)
|
||||
self.input_layer.update(update, eg.atoms, self.t, self.eta, self.mu)
|
||||
for i in range(self.n_classes):
|
||||
eg.c.scores[i] = theano_scores[i]
|
||||
eg.c.guess = arg_max_if_true(eg.c.scores, eg.c.is_valid, self.n_classes)
|
||||
eg.c.best = arg_max_if_zero(eg.c.scores, eg.c.costs, self.n_classes)
|
||||
eg.c.cost = eg.c.costs[eg.c.guess]
|
||||
eg.c.loss = loss
|
||||
self.t += 1
|
||||
|
||||
def end_training(self):
|
||||
pass
|
|
@ -398,7 +398,8 @@ cdef class ArcEager(TransitionSystem):
|
|||
n_valid += output[i]
|
||||
assert n_valid >= 1
|
||||
|
||||
cdef int set_costs(self, int* output, StateClass stcls, GoldParse gold) except -1:
|
||||
cdef int set_costs(self, bint* is_valid, int* costs,
|
||||
StateClass stcls, GoldParse gold) except -1:
|
||||
cdef int i, move, label
|
||||
cdef label_cost_func_t[N_MOVES] label_cost_funcs
|
||||
cdef move_cost_func_t[N_MOVES] move_cost_funcs
|
||||
|
@ -423,30 +424,14 @@ cdef class ArcEager(TransitionSystem):
|
|||
n_gold = 0
|
||||
for i in range(self.n_moves):
|
||||
if self.c[i].is_valid(stcls, self.c[i].label):
|
||||
is_valid[i] = True
|
||||
move = self.c[i].move
|
||||
label = self.c[i].label
|
||||
if move_costs[move] == -1:
|
||||
move_costs[move] = move_cost_funcs[move](stcls, &gold.c)
|
||||
output[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label)
|
||||
n_gold += output[i] == 0
|
||||
costs[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label)
|
||||
n_gold += costs[i] == 0
|
||||
else:
|
||||
output[i] = 9000
|
||||
is_valid[i] = False
|
||||
costs[i] = 9000
|
||||
assert n_gold >= 1
|
||||
|
||||
cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *:
|
||||
cdef bint[N_MOVES] is_valid
|
||||
is_valid[SHIFT] = Shift.is_valid(stcls, -1)
|
||||
is_valid[REDUCE] = Reduce.is_valid(stcls, -1)
|
||||
is_valid[LEFT] = LeftArc.is_valid(stcls, -1)
|
||||
is_valid[RIGHT] = RightArc.is_valid(stcls, -1)
|
||||
is_valid[BREAK] = Break.is_valid(stcls, -1)
|
||||
cdef Transition best
|
||||
cdef weight_t score = MIN_SCORE
|
||||
cdef int i
|
||||
for i in range(self.n_moves):
|
||||
if scores[i] > score and is_valid[self.c[i].move]:
|
||||
best = self.c[i]
|
||||
score = scores[i]
|
||||
assert best.clas < self.n_moves
|
||||
assert score > MIN_SCORE, (stcls.stack_depth(), stcls.buffer_length(), stcls.is_final(), stcls._b_i, stcls.length)
|
||||
return best
|
||||
|
|
17
spacy/syntax/joint.pxd
Normal file
17
spacy/syntax/joint.pxd
Normal file
|
@ -0,0 +1,17 @@
|
|||
from cymem.cymem cimport Pool
|
||||
|
||||
from thinc.typedefs cimport weight_t
|
||||
|
||||
from .stateclass cimport StateClass
|
||||
|
||||
from .transition_system cimport TransitionSystem, Transition
|
||||
from ..gold cimport GoldParseC
|
||||
|
||||
|
||||
cdef class ArcEager(TransitionSystem):
|
||||
pass
|
||||
|
||||
|
||||
cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil
|
||||
cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) nogil
|
||||
|
452
spacy/syntax/joint.pyx
Normal file
452
spacy/syntax/joint.pyx
Normal file
|
@ -0,0 +1,452 @@
|
|||
# cython: profile=True
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import ctypes
|
||||
import os
|
||||
|
||||
from ..structs cimport TokenC
|
||||
|
||||
from .transition_system cimport do_func_t, get_cost_func_t
|
||||
from .transition_system cimport move_cost_func_t, label_cost_func_t
|
||||
from ..gold cimport GoldParse
|
||||
from ..gold cimport GoldParseC
|
||||
|
||||
from libc.stdint cimport uint32_t
|
||||
from libc.string cimport memcpy
|
||||
|
||||
from cymem.cymem cimport Pool
|
||||
from .stateclass cimport StateClass
|
||||
|
||||
|
||||
DEF NON_MONOTONIC = True
|
||||
DEF USE_BREAK = True
|
||||
DEF USE_ROOT_ARC_SEGMENT = True
|
||||
|
||||
cdef weight_t MIN_SCORE = -90000
|
||||
|
||||
# Break transition from here
|
||||
# http://www.aclweb.org/anthology/P13-1074
|
||||
cdef enum:
|
||||
SHIFT
|
||||
REDUCE
|
||||
LEFT
|
||||
RIGHT
|
||||
|
||||
BREAK
|
||||
|
||||
N_MOVES
|
||||
|
||||
|
||||
MOVE_NAMES = [None] * N_MOVES
|
||||
MOVE_NAMES[SHIFT] = 'S'
|
||||
MOVE_NAMES[REDUCE] = 'D'
|
||||
MOVE_NAMES[LEFT] = 'L'
|
||||
MOVE_NAMES[RIGHT] = 'R'
|
||||
MOVE_NAMES[BREAK] = 'B'
|
||||
|
||||
|
||||
# Helper functions for the arc-eager oracle
|
||||
|
||||
cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
|
||||
cdef int cost = 0
|
||||
cdef int i, S_i
|
||||
for i in range(stcls.stack_depth()):
|
||||
S_i = stcls.S(i)
|
||||
if gold.heads[target] == S_i:
|
||||
cost += 1
|
||||
if gold.heads[S_i] == target and (NON_MONOTONIC or not stcls.has_head(S_i)):
|
||||
cost += 1
|
||||
cost += Break.is_valid(stcls, -1) and Break.move_cost(stcls, gold) == 0
|
||||
return cost
|
||||
|
||||
|
||||
cdef int pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
|
||||
cdef int cost = 0
|
||||
cdef int i, B_i
|
||||
for i in range(stcls.buffer_length()):
|
||||
B_i = stcls.B(i)
|
||||
cost += gold.heads[B_i] == target
|
||||
cost += gold.heads[target] == B_i
|
||||
if gold.heads[B_i] == B_i or gold.heads[B_i] < target:
|
||||
break
|
||||
cost += Break.is_valid(stcls, -1) and Break.move_cost(stcls, gold) == 0
|
||||
return cost
|
||||
|
||||
|
||||
cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) nogil:
|
||||
if arc_is_gold(gold, head, child):
|
||||
return 0
|
||||
elif stcls.H(child) == gold.heads[child]:
|
||||
return 1
|
||||
# Head in buffer
|
||||
elif gold.heads[child] >= stcls.B(0) and stcls.B(1) != -1:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) nogil:
|
||||
if gold.labels[child] == -1:
|
||||
return True
|
||||
elif USE_ROOT_ARC_SEGMENT and _is_gold_root(gold, head) and _is_gold_root(gold, child):
|
||||
return True
|
||||
elif gold.heads[child] == head:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
cdef bint label_is_gold(const GoldParseC* gold, int head, int child, int label) nogil:
|
||||
if gold.labels[child] == -1:
|
||||
return True
|
||||
elif label == -1:
|
||||
return True
|
||||
elif gold.labels[child] == label:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil:
|
||||
return gold.labels[word] == -1 or gold.heads[word] == word
|
||||
|
||||
|
||||
cdef class Shift:
|
||||
@staticmethod
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
return st.buffer_length() >= 2 and not st.shifted[st.B(0)] and not st.B_(0).sent_start
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
st.push()
|
||||
st.fast_forward()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(StateClass st, const GoldParseC* gold, int label) nogil:
|
||||
return Shift.move_cost(st, gold) + Shift.label_cost(st, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
|
||||
return push_cost(s, gold, s.B(0))
|
||||
|
||||
@staticmethod
|
||||
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return 0
|
||||
|
||||
|
||||
cdef class Reduce:
|
||||
@staticmethod
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
return st.stack_depth() >= 2
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
if st.has_head(st.S(0)):
|
||||
st.pop()
|
||||
else:
|
||||
st.unshift()
|
||||
st.fast_forward()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef inline int move_cost(StateClass st, const GoldParseC* gold) nogil:
|
||||
return pop_cost(st, gold, st.S(0))
|
||||
|
||||
@staticmethod
|
||||
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return 0
|
||||
|
||||
|
||||
cdef class LeftArc:
|
||||
@staticmethod
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
return not st.B_(0).sent_start
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
st.add_arc(st.B(0), st.S(0), label)
|
||||
st.pop()
|
||||
st.fast_forward()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
|
||||
cdef int cost = 0
|
||||
if arc_is_gold(gold, s.B(0), s.S(0)):
|
||||
return 0
|
||||
else:
|
||||
# Account for deps we might lose between S0 and stack
|
||||
if not s.has_head(s.S(0)):
|
||||
for i in range(1, s.stack_depth()):
|
||||
cost += gold.heads[s.S(i)] == s.S(0)
|
||||
cost += gold.heads[s.S(0)] == s.S(i)
|
||||
return pop_cost(s, gold, s.S(0)) + arc_cost(s, gold, s.B(0), s.S(0))
|
||||
|
||||
@staticmethod
|
||||
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return arc_is_gold(gold, s.B(0), s.S(0)) and not label_is_gold(gold, s.B(0), s.S(0), label)
|
||||
|
||||
|
||||
cdef class RightArc:
|
||||
@staticmethod
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
return not st.B_(0).sent_start
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
st.add_arc(st.S(0), st.B(0), label)
|
||||
st.push()
|
||||
st.fast_forward()
|
||||
|
||||
@staticmethod
|
||||
cdef inline int cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
|
||||
if arc_is_gold(gold, s.S(0), s.B(0)):
|
||||
return 0
|
||||
elif s.shifted[s.B(0)]:
|
||||
return push_cost(s, gold, s.B(0))
|
||||
else:
|
||||
return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0))
|
||||
|
||||
@staticmethod
|
||||
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return arc_is_gold(gold, s.S(0), s.B(0)) and not label_is_gold(gold, s.S(0), s.B(0), label)
|
||||
|
||||
|
||||
cdef class Break:
|
||||
@staticmethod
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
cdef int i
|
||||
if not USE_BREAK:
|
||||
return False
|
||||
elif st.at_break():
|
||||
return False
|
||||
elif st.B(0) == 0:
|
||||
return False
|
||||
elif st.stack_depth() < 1:
|
||||
return False
|
||||
elif (st.S(0) + 1) != st.B(0):
|
||||
# Must break at the token boundary
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
st.set_break(st.B(0))
|
||||
st.fast_forward()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return Break.move_cost(s, gold) + Break.label_cost(s, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
|
||||
cdef int cost = 0
|
||||
cdef int i, j, S_i, B_i
|
||||
for i in range(s.stack_depth()):
|
||||
S_i = s.S(i)
|
||||
for j in range(s.buffer_length()):
|
||||
B_i = s.B(j)
|
||||
cost += gold.heads[S_i] == B_i
|
||||
cost += gold.heads[B_i] == S_i
|
||||
# Check for sentence boundary --- if it's here, we can't have any deps
|
||||
# between stack and buffer, so rest of action is irrelevant.
|
||||
s0_root = _get_root(s.S(0), gold)
|
||||
b0_root = _get_root(s.B(0), gold)
|
||||
if s0_root != b0_root or s0_root == -1 or b0_root == -1:
|
||||
return cost
|
||||
else:
|
||||
return cost + 1
|
||||
|
||||
@staticmethod
|
||||
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return 0
|
||||
|
||||
cdef int _get_root(int word, const GoldParseC* gold) nogil:
|
||||
while gold.heads[word] != word and gold.labels[word] != -1 and word >= 0:
|
||||
word = gold.heads[word]
|
||||
if gold.labels[word] == -1:
|
||||
return -1
|
||||
else:
|
||||
return word
|
||||
|
||||
|
||||
cdef class ArcEager(TransitionSystem):
|
||||
@classmethod
|
||||
def get_labels(cls, gold_parses):
|
||||
move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {'ROOT': True},
|
||||
LEFT: {'ROOT': True}, BREAK: {'ROOT': True}}
|
||||
for raw_text, sents in gold_parses:
|
||||
for (ids, words, tags, heads, labels, iob), ctnts in sents:
|
||||
for child, head, label in zip(ids, heads, labels):
|
||||
if label.upper() == 'ROOT':
|
||||
label = 'ROOT'
|
||||
if label != 'ROOT':
|
||||
if head < child:
|
||||
move_labels[RIGHT][label] = True
|
||||
elif head > child:
|
||||
move_labels[LEFT][label] = True
|
||||
return move_labels
|
||||
|
||||
cdef int preprocess_gold(self, GoldParse gold) except -1:
|
||||
for i in range(gold.length):
|
||||
if gold.heads[i] is None: # Missing values
|
||||
gold.c.heads[i] = i
|
||||
gold.c.labels[i] = -1
|
||||
else:
|
||||
label = gold.labels[i]
|
||||
if label.upper() == 'ROOT':
|
||||
label = 'ROOT'
|
||||
gold.c.heads[i] = gold.heads[i]
|
||||
gold.c.labels[i] = self.strings[label]
|
||||
for end, brackets in gold.brackets.items():
|
||||
for start, label_strs in brackets.items():
|
||||
gold.c.brackets[start][end] = 1
|
||||
for label_str in label_strs:
|
||||
# Add the encoded label to the set
|
||||
gold.brackets[end][start].add(self.strings[label_str])
|
||||
|
||||
cdef Transition lookup_transition(self, object name) except *:
|
||||
if '-' in name:
|
||||
move_str, label_str = name.split('-', 1)
|
||||
label = self.label_ids[label_str]
|
||||
else:
|
||||
label = 0
|
||||
move = MOVE_NAMES.index(move_str)
|
||||
for i in range(self.n_moves):
|
||||
if self.c[i].move == move and self.c[i].label == label:
|
||||
return self.c[i]
|
||||
|
||||
def move_name(self, int move, int label):
|
||||
label_str = self.strings[label]
|
||||
if label_str:
|
||||
return MOVE_NAMES[move] + '-' + label_str
|
||||
else:
|
||||
return MOVE_NAMES[move]
|
||||
|
||||
cdef Transition init_transition(self, int clas, int move, int label) except *:
|
||||
# TODO: Apparent Cython bug here when we try to use the Transition()
|
||||
# constructor with the function pointers
|
||||
cdef Transition t
|
||||
t.score = 0
|
||||
t.clas = clas
|
||||
t.move = move
|
||||
t.label = label
|
||||
if move == SHIFT:
|
||||
t.is_valid = Shift.is_valid
|
||||
t.do = Shift.transition
|
||||
t.get_cost = Shift.cost
|
||||
elif move == REDUCE:
|
||||
t.is_valid = Reduce.is_valid
|
||||
t.do = Reduce.transition
|
||||
t.get_cost = Reduce.cost
|
||||
elif move == LEFT:
|
||||
t.is_valid = LeftArc.is_valid
|
||||
t.do = LeftArc.transition
|
||||
t.get_cost = LeftArc.cost
|
||||
elif move == RIGHT:
|
||||
t.is_valid = RightArc.is_valid
|
||||
t.do = RightArc.transition
|
||||
t.get_cost = RightArc.cost
|
||||
elif move == BREAK:
|
||||
t.is_valid = Break.is_valid
|
||||
t.do = Break.transition
|
||||
t.get_cost = Break.cost
|
||||
else:
|
||||
raise Exception(move)
|
||||
return t
|
||||
|
||||
cdef int initialize_state(self, StateClass st) except -1:
|
||||
# Ensure sent_start is set to 0 throughout
|
||||
for i in range(st.length):
|
||||
st._sent[i].sent_start = False
|
||||
st._sent[i].l_edge = i
|
||||
st._sent[i].r_edge = i
|
||||
st.fast_forward()
|
||||
|
||||
cdef int finalize_state(self, StateClass st) except -1:
|
||||
cdef int root_label = self.strings['ROOT']
|
||||
for i in range(st.length):
|
||||
if st._sent[i].head == 0 and st._sent[i].dep == 0:
|
||||
st._sent[i].dep = root_label
|
||||
# If we're not using the Break transition, we segment via root-labelled
|
||||
# arcs between the root words.
|
||||
elif USE_ROOT_ARC_SEGMENT and st._sent[i].dep == root_label:
|
||||
st._sent[i].head = 0
|
||||
|
||||
cdef int set_valid(self, bint* output, StateClass stcls) except -1:
|
||||
cdef bint[N_MOVES] is_valid
|
||||
is_valid[SHIFT] = Shift.is_valid(stcls, -1)
|
||||
is_valid[REDUCE] = Reduce.is_valid(stcls, -1)
|
||||
is_valid[LEFT] = LeftArc.is_valid(stcls, -1)
|
||||
is_valid[RIGHT] = RightArc.is_valid(stcls, -1)
|
||||
is_valid[BREAK] = Break.is_valid(stcls, -1)
|
||||
cdef int i
|
||||
n_valid = 0
|
||||
for i in range(self.n_moves):
|
||||
output[i] = is_valid[self.c[i].move]
|
||||
n_valid += output[i]
|
||||
assert n_valid >= 1
|
||||
|
||||
cdef int set_costs(self, int* output, StateClass stcls, GoldParse gold) except -1:
|
||||
cdef int i, move, label
|
||||
cdef label_cost_func_t[N_MOVES] label_cost_funcs
|
||||
cdef move_cost_func_t[N_MOVES] move_cost_funcs
|
||||
cdef int[N_MOVES] move_costs
|
||||
for i in range(N_MOVES):
|
||||
move_costs[i] = -1
|
||||
move_cost_funcs[SHIFT] = Shift.move_cost
|
||||
move_cost_funcs[REDUCE] = Reduce.move_cost
|
||||
move_cost_funcs[LEFT] = LeftArc.move_cost
|
||||
move_cost_funcs[RIGHT] = RightArc.move_cost
|
||||
move_cost_funcs[BREAK] = Break.move_cost
|
||||
|
||||
label_cost_funcs[SHIFT] = Shift.label_cost
|
||||
label_cost_funcs[REDUCE] = Reduce.label_cost
|
||||
label_cost_funcs[LEFT] = LeftArc.label_cost
|
||||
label_cost_funcs[RIGHT] = RightArc.label_cost
|
||||
label_cost_funcs[BREAK] = Break.label_cost
|
||||
|
||||
cdef int* labels = gold.c.labels
|
||||
cdef int* heads = gold.c.heads
|
||||
|
||||
n_gold = 0
|
||||
for i in range(self.n_moves):
|
||||
if self.c[i].is_valid(stcls, self.c[i].label):
|
||||
move = self.c[i].move
|
||||
label = self.c[i].label
|
||||
if move_costs[move] == -1:
|
||||
move_costs[move] = move_cost_funcs[move](stcls, &gold.c)
|
||||
output[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label)
|
||||
n_gold += output[i] == 0
|
||||
else:
|
||||
output[i] = 9000
|
||||
assert n_gold >= 1
|
||||
|
||||
cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *:
|
||||
cdef bint[N_MOVES] is_valid
|
||||
is_valid[SHIFT] = Shift.is_valid(stcls, -1)
|
||||
is_valid[REDUCE] = Reduce.is_valid(stcls, -1)
|
||||
is_valid[LEFT] = LeftArc.is_valid(stcls, -1)
|
||||
is_valid[RIGHT] = RightArc.is_valid(stcls, -1)
|
||||
is_valid[BREAK] = Break.is_valid(stcls, -1)
|
||||
cdef Transition best
|
||||
cdef weight_t score = MIN_SCORE
|
||||
cdef int i
|
||||
for i in range(self.n_moves):
|
||||
if scores[i] > score and is_valid[self.c[i].move]:
|
||||
best = self.c[i]
|
||||
score = scores[i]
|
||||
assert best.clas < self.n_moves
|
||||
assert score > MIN_SCORE, (stcls.stack_depth(), stcls.buffer_length(), stcls.is_final(), stcls._b_i, stcls.length)
|
||||
return best
|
|
@ -128,27 +128,6 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
raise Exception(move)
|
||||
return t
|
||||
|
||||
cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *:
|
||||
cdef int best = -1
|
||||
cdef weight_t score = -90000
|
||||
cdef const Transition* m
|
||||
cdef int i
|
||||
for i in range(self.n_moves):
|
||||
m = &self.c[i]
|
||||
if m.is_valid(stcls, m.label) and scores[i] > score:
|
||||
best = i
|
||||
score = scores[i]
|
||||
assert best >= 0
|
||||
cdef Transition t = self.c[best]
|
||||
t.score = score
|
||||
return t
|
||||
|
||||
cdef int set_valid(self, bint* output, StateClass stcls) except -1:
|
||||
cdef int i
|
||||
for i in range(self.n_moves):
|
||||
m = &self.c[i]
|
||||
output[i] = m.is_valid(stcls, m.label)
|
||||
|
||||
|
||||
cdef class Missing:
|
||||
@staticmethod
|
||||
|
|
|
@ -12,6 +12,3 @@ cdef class Parser:
|
|||
cdef readonly object cfg
|
||||
cdef readonly Model model
|
||||
cdef readonly TransitionSystem moves
|
||||
|
||||
cdef int _greedy_parse(self, Doc tokens) except -1
|
||||
cdef int _beam_parse(self, Doc tokens) except -1
|
||||
|
|
|
@ -20,17 +20,10 @@ from cymem.cymem cimport Pool, Address
|
|||
from murmurhash.mrmr cimport hash64
|
||||
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t
|
||||
|
||||
|
||||
from util import Config
|
||||
|
||||
from thinc.features cimport Extractor
|
||||
from thinc.features cimport Feature
|
||||
from thinc.features cimport count_feats
|
||||
from thinc.api cimport Example
|
||||
|
||||
from thinc.learner cimport LinearModel
|
||||
|
||||
from thinc.search cimport Beam
|
||||
from thinc.search cimport MaxViolation
|
||||
|
||||
from ..structs cimport TokenC
|
||||
|
||||
|
@ -61,6 +54,8 @@ def get_templates(name):
|
|||
return pf.ner
|
||||
elif name == 'debug':
|
||||
return pf.unigrams
|
||||
elif name.startswith('embed'):
|
||||
return (pf.words, pf.tags, pf.labels)
|
||||
else:
|
||||
return (pf.unigrams + pf.s0_n0 + pf.s1_n0 + pf.s1_s0 + pf.s0_n1 + pf.n0_n1 + \
|
||||
pf.tree_shape + pf.trigrams)
|
||||
|
@ -83,179 +78,43 @@ cdef class Parser:
|
|||
self.model = Model(self.moves.n_moves, templates, model_dir)
|
||||
|
||||
def __call__(self, Doc tokens):
|
||||
if self.model is not None:
|
||||
if self.cfg.get('beam_width', 0) < 1:
|
||||
self._greedy_parse(tokens)
|
||||
else:
|
||||
self._beam_parse(tokens)
|
||||
|
||||
def train(self, Doc tokens, GoldParse gold):
|
||||
self.moves.preprocess_gold(gold)
|
||||
if self.cfg.get('beam_width', 0) < 1:
|
||||
return self._greedy_train(tokens, gold)
|
||||
else:
|
||||
return self._beam_train(tokens, gold)
|
||||
|
||||
cdef int _greedy_parse(self, Doc tokens) except -1:
|
||||
cdef atom_t[CONTEXT_SIZE] context
|
||||
cdef int n_feats
|
||||
cdef Pool mem = Pool()
|
||||
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
|
||||
self.moves.initialize_state(stcls)
|
||||
cdef Transition guess
|
||||
words = [w.orth_ for w in tokens]
|
||||
|
||||
cdef Example eg = Example(self.model.n_classes, CONTEXT_SIZE,
|
||||
self.model.n_feats, self.model.n_feats)
|
||||
while not stcls.is_final():
|
||||
fill_context(context, stcls)
|
||||
scores = self.model.score(context)
|
||||
guess = self.moves.best_valid(scores, stcls)
|
||||
#print self.moves.move_name(guess.move, guess.label), stcls.print_state(words)
|
||||
guess.do(stcls, guess.label)
|
||||
assert stcls._s_i >= 0
|
||||
memset(eg.c.scores, 0, eg.c.nr_class * sizeof(weight_t))
|
||||
|
||||
self.moves.set_valid(<bint*>eg.c.is_valid, stcls)
|
||||
fill_context(eg.c.atoms, stcls)
|
||||
|
||||
self.model.predict(eg)
|
||||
|
||||
self.moves.c[eg.c.guess].do(stcls, self.moves.c[eg.c.guess].label)
|
||||
self.moves.finalize_state(stcls)
|
||||
tokens.set_parse(stcls._sent)
|
||||
|
||||
cdef int _beam_parse(self, Doc tokens) except -1:
|
||||
cdef Beam beam = Beam(self.moves.n_moves, self.cfg.beam_width)
|
||||
words = [w.orth_ for w in tokens]
|
||||
beam.initialize(_init_state, tokens.length, tokens.data)
|
||||
beam.check_done(_check_final_state, NULL)
|
||||
while not beam.is_done:
|
||||
self._advance_beam(beam, None, False, words)
|
||||
state = <StateClass>beam.at(0)
|
||||
self.moves.finalize_state(state)
|
||||
tokens.set_parse(state._sent)
|
||||
_cleanup(beam)
|
||||
|
||||
def _greedy_train(self, Doc tokens, GoldParse gold):
|
||||
cdef Pool mem = Pool()
|
||||
def train(self, Doc tokens, GoldParse gold):
|
||||
self.moves.preprocess_gold(gold)
|
||||
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
|
||||
self.moves.initialize_state(stcls)
|
||||
|
||||
cdef int cost
|
||||
cdef const Feature* feats
|
||||
cdef const weight_t* scores
|
||||
cdef Transition guess
|
||||
cdef Transition best
|
||||
cdef atom_t[CONTEXT_SIZE] context
|
||||
loss = 0
|
||||
cdef Example eg = Example(self.model.n_classes, CONTEXT_SIZE,
|
||||
self.model.n_feats, self.model.n_feats)
|
||||
cdef weight_t loss = 0
|
||||
words = [w.orth_ for w in tokens]
|
||||
history = []
|
||||
cdef Transition G
|
||||
while not stcls.is_final():
|
||||
fill_context(context, stcls)
|
||||
scores = self.model.score(context)
|
||||
guess = self.moves.best_valid(scores, stcls)
|
||||
best = self.moves.best_gold(scores, stcls, gold)
|
||||
cost = guess.get_cost(stcls, &gold.c, guess.label)
|
||||
self.model.update(context, guess.clas, best.clas, cost)
|
||||
guess.do(stcls, guess.label)
|
||||
loss += cost
|
||||
memset(eg.c.scores, 0, eg.c.nr_class * sizeof(weight_t))
|
||||
|
||||
self.moves.set_costs(<bint*>eg.c.is_valid, eg.c.costs, stcls, gold)
|
||||
|
||||
fill_context(eg.c.atoms, stcls)
|
||||
|
||||
self.model.train(eg)
|
||||
|
||||
G = self.moves.c[eg.c.guess]
|
||||
|
||||
self.moves.c[eg.c.guess].do(stcls, self.moves.c[eg.c.guess].label)
|
||||
loss += eg.c.loss
|
||||
return loss
|
||||
|
||||
def _beam_train(self, Doc tokens, GoldParse gold_parse):
|
||||
cdef Beam pred = Beam(self.moves.n_moves, self.cfg.beam_width)
|
||||
pred.initialize(_init_state, tokens.length, tokens.data)
|
||||
pred.check_done(_check_final_state, NULL)
|
||||
cdef Beam gold = Beam(self.moves.n_moves, self.cfg.beam_width)
|
||||
gold.initialize(_init_state, tokens.length, tokens.data)
|
||||
gold.check_done(_check_final_state, NULL)
|
||||
|
||||
violn = MaxViolation()
|
||||
words = [w.orth_ for w in tokens]
|
||||
while not pred.is_done and not gold.is_done:
|
||||
self._advance_beam(pred, gold_parse, False, words)
|
||||
self._advance_beam(gold, gold_parse, True, words)
|
||||
violn.check(pred, gold)
|
||||
if pred.loss >= 1:
|
||||
counts = {clas: {} for clas in range(self.model.n_classes)}
|
||||
self._count_feats(counts, tokens, violn.g_hist, 1)
|
||||
self._count_feats(counts, tokens, violn.p_hist, -1)
|
||||
else:
|
||||
counts = {}
|
||||
self.model._model.update(counts)
|
||||
_cleanup(pred)
|
||||
_cleanup(gold)
|
||||
return pred.loss
|
||||
|
||||
def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold, words):
|
||||
cdef atom_t[CONTEXT_SIZE] context
|
||||
cdef int i, j, cost
|
||||
cdef bint is_valid
|
||||
cdef const Transition* move
|
||||
for i in range(beam.size):
|
||||
stcls = <StateClass>beam.at(i)
|
||||
if not stcls.is_final():
|
||||
fill_context(context, stcls)
|
||||
self.model.set_scores(beam.scores[i], context)
|
||||
self.moves.set_valid(beam.is_valid[i], stcls)
|
||||
if gold is not None:
|
||||
for i in range(beam.size):
|
||||
stcls = <StateClass>beam.at(i)
|
||||
if not stcls.is_final():
|
||||
self.moves.set_costs(beam.costs[i], stcls, gold)
|
||||
if follow_gold:
|
||||
for j in range(self.moves.n_moves):
|
||||
beam.is_valid[i][j] *= beam.costs[i][j] == 0
|
||||
beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
|
||||
beam.check_done(_check_final_state, NULL)
|
||||
|
||||
def _count_feats(self, dict counts, Doc tokens, list hist, int inc):
|
||||
cdef atom_t[CONTEXT_SIZE] context
|
||||
cdef Pool mem = Pool()
|
||||
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
|
||||
self.moves.initialize_state(stcls)
|
||||
|
||||
cdef class_t clas
|
||||
cdef int n_feats
|
||||
for clas in hist:
|
||||
fill_context(context, stcls)
|
||||
feats = self.model._extractor.get_feats(context, &n_feats)
|
||||
count_feats(counts[clas], feats, n_feats, inc)
|
||||
self.moves.c[clas].do(stcls, self.moves.c[clas].label)
|
||||
|
||||
|
||||
# These are passed as callbacks to thinc.search.Beam
|
||||
|
||||
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
|
||||
dest = <StateClass>_dest
|
||||
src = <StateClass>_src
|
||||
moves = <const Transition*>_moves
|
||||
dest.clone(src)
|
||||
moves[clas].do(dest, moves[clas].label)
|
||||
|
||||
|
||||
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
|
||||
cdef StateClass st = StateClass.init(<const TokenC*>tokens, length)
|
||||
st.fast_forward()
|
||||
Py_INCREF(st)
|
||||
return <void*>st
|
||||
|
||||
|
||||
cdef int _check_final_state(void* _state, void* extra_args) except -1:
|
||||
return (<StateClass>_state).is_final()
|
||||
|
||||
|
||||
def _cleanup(Beam beam):
|
||||
for i in range(beam.width):
|
||||
Py_XDECREF(<PyObject*>beam._states[i].content)
|
||||
Py_XDECREF(<PyObject*>beam._parents[i].content)
|
||||
|
||||
cdef hash_t _hash_state(void* _state, void* _) except 0:
|
||||
return <hash_t>_state
|
||||
|
||||
#state = <const State*>_state
|
||||
#cdef atom_t[10] rep
|
||||
|
||||
#rep[0] = state.stack[0] if state.stack_len >= 1 else 0
|
||||
#rep[1] = state.stack[-1] if state.stack_len >= 2 else 0
|
||||
#rep[2] = state.stack[-2] if state.stack_len >= 3 else 0
|
||||
#rep[3] = state.i
|
||||
#rep[4] = state.sent[state.stack[0]].l_kids if state.stack_len >= 1 else 0
|
||||
#rep[5] = state.sent[state.stack[0]].r_kids if state.stack_len >= 1 else 0
|
||||
#rep[6] = state.sent[state.stack[0]].dep if state.stack_len >= 1 else 0
|
||||
#rep[7] = state.sent[state.stack[-1]].dep if state.stack_len >= 2 else 0
|
||||
#if get_left(state, get_n0(state), 1) != NULL:
|
||||
# rep[8] = get_left(state, get_n0(state), 1).dep
|
||||
#else:
|
||||
# rep[8] = 0
|
||||
#rep[9] = state.sent[state.i].l_kids
|
||||
#return hash64(rep, sizeof(atom_t) * 10, 0)
|
||||
|
|
|
@ -52,7 +52,11 @@ cdef class StateClass:
|
|||
cdef const TokenC* target = &self._sent[i]
|
||||
if target.l_kids < idx:
|
||||
return -1
|
||||
<<<<<<< HEAD
|
||||
cdef const TokenC* ptr = &self._sent[target.l_edge]
|
||||
=======
|
||||
cdef const TokenC* ptr = self._sent
|
||||
>>>>>>> neuralnet
|
||||
|
||||
while ptr < target:
|
||||
# If this head is still to the right of us, we can skip to it
|
||||
|
@ -78,7 +82,11 @@ cdef class StateClass:
|
|||
cdef const TokenC* target = &self._sent[i]
|
||||
if target.r_kids < idx:
|
||||
return -1
|
||||
<<<<<<< HEAD
|
||||
cdef const TokenC* ptr = &self._sent[target.r_edge]
|
||||
=======
|
||||
cdef const TokenC* ptr = self._sent + (self.length - 1)
|
||||
>>>>>>> neuralnet
|
||||
while ptr > target:
|
||||
# If this head is still to the right of us, we can skip to it
|
||||
# No token that's between this token and this head could be our
|
||||
|
|
|
@ -46,9 +46,5 @@ cdef class TransitionSystem:
|
|||
|
||||
cdef int set_valid(self, bint* output, StateClass state) except -1
|
||||
|
||||
cdef int set_costs(self, int* output, StateClass state, GoldParse gold) except -1
|
||||
|
||||
cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *
|
||||
|
||||
cdef Transition best_gold(self, const weight_t* scores, StateClass state,
|
||||
GoldParse gold) except *
|
||||
cdef int set_costs(self, bint* is_valid, int* costs,
|
||||
StateClass state, GoldParse gold) except -1
|
||||
|
|
|
@ -43,30 +43,17 @@ cdef class TransitionSystem:
|
|||
cdef Transition init_transition(self, int clas, int move, int label) except *:
|
||||
raise NotImplementedError
|
||||
|
||||
cdef Transition best_valid(self, const weight_t* scores, StateClass s) except *:
|
||||
raise NotImplementedError
|
||||
|
||||
cdef int set_valid(self, bint* output, StateClass state) except -1:
|
||||
raise NotImplementedError
|
||||
|
||||
cdef int set_costs(self, int* output, StateClass stcls, GoldParse gold) except -1:
|
||||
cdef int set_valid(self, bint* is_valid, StateClass stcls) except -1:
|
||||
cdef int i
|
||||
for i in range(self.n_moves):
|
||||
if self.c[i].is_valid(stcls, self.c[i].label):
|
||||
output[i] = self.c[i].get_cost(stcls, &gold.c, self.c[i].label)
|
||||
is_valid[i] = self.c[i].is_valid(stcls, self.c[i].label)
|
||||
|
||||
cdef int set_costs(self, bint* is_valid, int* costs,
|
||||
StateClass stcls, GoldParse gold) except -1:
|
||||
cdef int i
|
||||
self.set_valid(is_valid, stcls)
|
||||
for i in range(self.n_moves):
|
||||
if is_valid[i]:
|
||||
costs[i] = self.c[i].get_cost(stcls, &gold.c, self.c[i].label)
|
||||
else:
|
||||
output[i] = 9000
|
||||
|
||||
cdef Transition best_gold(self, const weight_t* scores, StateClass stcls,
|
||||
GoldParse gold) except *:
|
||||
cdef Transition best
|
||||
cdef weight_t score = MIN_SCORE
|
||||
cdef int i
|
||||
for i in range(self.n_moves):
|
||||
if self.c[i].is_valid(stcls, self.c[i].label):
|
||||
cost = self.c[i].get_cost(stcls, &gold.c, self.c[i].label)
|
||||
if scores[i] > score and cost == 0:
|
||||
best = self.c[i]
|
||||
score = scores[i]
|
||||
assert score > MIN_SCORE
|
||||
return best
|
||||
costs[i] = 9000
|
||||
|
|
Loading…
Reference in New Issue
Block a user