2017-04-15 14:05:15 +03:00
|
|
|
# cython: infer_types=True
|
Update draft of parser neural network model
Model is good, but code is messy. Currently requires Chainer, which may cause the build to fail on machines without a GPU.
Outline of the model:
We first predict context-sensitive vectors for each word in the input:
(embed_lower | embed_prefix | embed_suffix | embed_shape)
>> Maxout(token_width)
>> convolution ** 4
This convolutional layer is shared between the tagger and the parser. This prevents the parser from needing tag features.
To boost the representation, we make a "super tag" with POS, morphology and dependency label. The tagger predicts this
by adding a softmax layer onto the convolutional layer --- so, we're teaching the convolutional layer to give us a
representation that's one affine transform from this informative lexical information. This is obviously good for the
parser (which backprops to the convolutions too).
The parser model makes a state vector by concatenating the vector representations for its context tokens. Current
results suggest few context tokens works well. Maybe this is a bug.
The current context tokens:
* S0, S1, S2: Top three words on the stack
* B0, B1: First two words of the buffer
* S0L1, S0L2: Leftmost and second leftmost children of S0
* S0R1, S0R2: Rightmost and second rightmost children of S0
* S1L1, S1L2, S1R2, S1R, B0L1, B0L2: Likewise for S1 and B0
This makes the state vector quite long: 13*T, where T is the token vector width (128 is working well). Fortunately,
there's a way to structure the computation to save some expense (and make it more GPU friendly).
The parser typically visits 2*N states for a sentence of length N (although it may visit more, if it back-tracks
with a non-monotonic transition). A naive implementation would require 2*N (B, 13*T) @ (13*T, H) matrix multiplications
for a batch of size B. We can instead perform one (B*N, T) @ (T, 13*H) multiplication, to pre-compute the hidden
weights for each positional feature wrt the words in the batch. (Note that our token vectors come from the CNN
-- so we can't play this trick over the vocabulary. That's how Stanford's NN parser works --- and why its model
is so big.)
This pre-computation strategy allows a nice compromise between GPU-friendliness and implementation simplicity.
The CNN and the wide lower layer are computed on the GPU, and then the precomputed hidden weights are moved
to the CPU, before we start the transition-based parsing process. This makes a lot of things much easier.
We don't have to worry about variable-length batch sizes, and we don't have to implement the dynamic oracle
in CUDA to train.
Currently the parser's loss function is multilabel log loss, as the dynamic oracle allows multiple states to
be 0 cost. This is defined as:
(exp(score) / Z) - (exp(score) / gZ)
Where gZ is the sum of the scores assigned to gold classes. I'm very interested in regressing on the cost directly,
but so far this isn't working well.
Machinery is in place for beam-search, which has been working well for the linear model. Beam search should benefit
greatly from the pre-computation trick.
2017-05-13 00:09:15 +03:00
|
|
|
# cython: profile=True
|
|
|
|
# coding: utf-8
|
2017-05-14 01:18:27 +03:00
|
|
|
from __future__ import unicode_literals, print_function
|
2017-04-15 14:05:15 +03:00
|
|
|
|
|
|
|
from collections import Counter
|
|
|
|
import ujson
|
|
|
|
|
Update draft of parser neural network model
Model is good, but code is messy. Currently requires Chainer, which may cause the build to fail on machines without a GPU.
Outline of the model:
We first predict context-sensitive vectors for each word in the input:
(embed_lower | embed_prefix | embed_suffix | embed_shape)
>> Maxout(token_width)
>> convolution ** 4
This convolutional layer is shared between the tagger and the parser. This prevents the parser from needing tag features.
To boost the representation, we make a "super tag" with POS, morphology and dependency label. The tagger predicts this
by adding a softmax layer onto the convolutional layer --- so, we're teaching the convolutional layer to give us a
representation that's one affine transform from this informative lexical information. This is obviously good for the
parser (which backprops to the convolutions too).
The parser model makes a state vector by concatenating the vector representations for its context tokens. Current
results suggest few context tokens works well. Maybe this is a bug.
The current context tokens:
* S0, S1, S2: Top three words on the stack
* B0, B1: First two words of the buffer
* S0L1, S0L2: Leftmost and second leftmost children of S0
* S0R1, S0R2: Rightmost and second rightmost children of S0
* S1L1, S1L2, S1R2, S1R, B0L1, B0L2: Likewise for S1 and B0
This makes the state vector quite long: 13*T, where T is the token vector width (128 is working well). Fortunately,
there's a way to structure the computation to save some expense (and make it more GPU friendly).
The parser typically visits 2*N states for a sentence of length N (although it may visit more, if it back-tracks
with a non-monotonic transition). A naive implementation would require 2*N (B, 13*T) @ (13*T, H) matrix multiplications
for a batch of size B. We can instead perform one (B*N, T) @ (T, 13*H) multiplication, to pre-compute the hidden
weights for each positional feature wrt the words in the batch. (Note that our token vectors come from the CNN
-- so we can't play this trick over the vocabulary. That's how Stanford's NN parser works --- and why its model
is so big.)
This pre-computation strategy allows a nice compromise between GPU-friendliness and implementation simplicity.
The CNN and the wide lower layer are computed on the GPU, and then the precomputed hidden weights are moved
to the CPU, before we start the transition-based parsing process. This makes a lot of things much easier.
We don't have to worry about variable-length batch sizes, and we don't have to implement the dynamic oracle
in CUDA to train.
Currently the parser's loss function is multilabel log loss, as the dynamic oracle allows multiple states to
be 0 cost. This is defined as:
(exp(score) / Z) - (exp(score) / gZ)
Where gZ is the sum of the scores assigned to gold classes. I'm very interested in regressing on the cost directly,
but so far this isn't working well.
Machinery is in place for beam-search, which has been working well for the linear model. Beam search should benefit
greatly from the pre-computation trick.
2017-05-13 00:09:15 +03:00
|
|
|
from cupy.cuda.stream import Stream
|
|
|
|
import cupy
|
|
|
|
|
|
|
|
from libc.math cimport exp
|
2014-12-16 14:44:43 +03:00
|
|
|
cimport cython
|
2016-02-05 14:20:42 +03:00
|
|
|
cimport cython.parallel
|
Update draft of parser neural network model
Model is good, but code is messy. Currently requires Chainer, which may cause the build to fail on machines without a GPU.
Outline of the model:
We first predict context-sensitive vectors for each word in the input:
(embed_lower | embed_prefix | embed_suffix | embed_shape)
>> Maxout(token_width)
>> convolution ** 4
This convolutional layer is shared between the tagger and the parser. This prevents the parser from needing tag features.
To boost the representation, we make a "super tag" with POS, morphology and dependency label. The tagger predicts this
by adding a softmax layer onto the convolutional layer --- so, we're teaching the convolutional layer to give us a
representation that's one affine transform from this informative lexical information. This is obviously good for the
parser (which backprops to the convolutions too).
The parser model makes a state vector by concatenating the vector representations for its context tokens. Current
results suggest few context tokens works well. Maybe this is a bug.
The current context tokens:
* S0, S1, S2: Top three words on the stack
* B0, B1: First two words of the buffer
* S0L1, S0L2: Leftmost and second leftmost children of S0
* S0R1, S0R2: Rightmost and second rightmost children of S0
* S1L1, S1L2, S1R2, S1R, B0L1, B0L2: Likewise for S1 and B0
This makes the state vector quite long: 13*T, where T is the token vector width (128 is working well). Fortunately,
there's a way to structure the computation to save some expense (and make it more GPU friendly).
The parser typically visits 2*N states for a sentence of length N (although it may visit more, if it back-tracks
with a non-monotonic transition). A naive implementation would require 2*N (B, 13*T) @ (13*T, H) matrix multiplications
for a batch of size B. We can instead perform one (B*N, T) @ (T, 13*H) multiplication, to pre-compute the hidden
weights for each positional feature wrt the words in the batch. (Note that our token vectors come from the CNN
-- so we can't play this trick over the vocabulary. That's how Stanford's NN parser works --- and why its model
is so big.)
This pre-computation strategy allows a nice compromise between GPU-friendliness and implementation simplicity.
The CNN and the wide lower layer are computed on the GPU, and then the precomputed hidden weights are moved
to the CPU, before we start the transition-based parsing process. This makes a lot of things much easier.
We don't have to worry about variable-length batch sizes, and we don't have to implement the dynamic oracle
in CUDA to train.
Currently the parser's loss function is multilabel log loss, as the dynamic oracle allows multiple states to
be 0 cost. This is defined as:
(exp(score) / Z) - (exp(score) / gZ)
Where gZ is the sum of the scores assigned to gold classes. I'm very interested in regressing on the cost directly,
but so far this isn't working well.
Machinery is in place for beam-search, which has been working well for the linear model. Beam search should benefit
greatly from the pre-computation trick.
2017-05-13 00:09:15 +03:00
|
|
|
import cytoolz
|
2015-06-10 05:20:23 +03:00
|
|
|
|
2017-04-27 14:18:39 +03:00
|
|
|
import numpy.random
|
Update draft of parser neural network model
Model is good, but code is messy. Currently requires Chainer, which may cause the build to fail on machines without a GPU.
Outline of the model:
We first predict context-sensitive vectors for each word in the input:
(embed_lower | embed_prefix | embed_suffix | embed_shape)
>> Maxout(token_width)
>> convolution ** 4
This convolutional layer is shared between the tagger and the parser. This prevents the parser from needing tag features.
To boost the representation, we make a "super tag" with POS, morphology and dependency label. The tagger predicts this
by adding a softmax layer onto the convolutional layer --- so, we're teaching the convolutional layer to give us a
representation that's one affine transform from this informative lexical information. This is obviously good for the
parser (which backprops to the convolutions too).
The parser model makes a state vector by concatenating the vector representations for its context tokens. Current
results suggest few context tokens works well. Maybe this is a bug.
The current context tokens:
* S0, S1, S2: Top three words on the stack
* B0, B1: First two words of the buffer
* S0L1, S0L2: Leftmost and second leftmost children of S0
* S0R1, S0R2: Rightmost and second rightmost children of S0
* S1L1, S1L2, S1R2, S1R, B0L1, B0L2: Likewise for S1 and B0
This makes the state vector quite long: 13*T, where T is the token vector width (128 is working well). Fortunately,
there's a way to structure the computation to save some expense (and make it more GPU friendly).
The parser typically visits 2*N states for a sentence of length N (although it may visit more, if it back-tracks
with a non-monotonic transition). A naive implementation would require 2*N (B, 13*T) @ (13*T, H) matrix multiplications
for a batch of size B. We can instead perform one (B*N, T) @ (T, 13*H) multiplication, to pre-compute the hidden
weights for each positional feature wrt the words in the batch. (Note that our token vectors come from the CNN
-- so we can't play this trick over the vocabulary. That's how Stanford's NN parser works --- and why its model
is so big.)
This pre-computation strategy allows a nice compromise between GPU-friendliness and implementation simplicity.
The CNN and the wide lower layer are computed on the GPU, and then the precomputed hidden weights are moved
to the CPU, before we start the transition-based parsing process. This makes a lot of things much easier.
We don't have to worry about variable-length batch sizes, and we don't have to implement the dynamic oracle
in CUDA to train.
Currently the parser's loss function is multilabel log loss, as the dynamic oracle allows multiple states to
be 0 cost. This is defined as:
(exp(score) / Z) - (exp(score) / gZ)
Where gZ is the sum of the scores assigned to gold classes. I'm very interested in regressing on the cost directly,
but so far this isn't working well.
Machinery is in place for beam-search, which has been working well for the linear model. Beam search should benefit
greatly from the pre-computation trick.
2017-05-13 00:09:15 +03:00
|
|
|
cimport numpy as np
|
2017-04-27 14:18:39 +03:00
|
|
|
|
2015-06-10 05:20:23 +03:00
|
|
|
from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF
|
2016-01-16 18:18:44 +03:00
|
|
|
from cpython.exc cimport PyErr_CheckSignals
|
2014-12-19 01:30:50 +03:00
|
|
|
from libc.stdint cimport uint32_t, uint64_t
|
2015-06-02 19:38:41 +03:00
|
|
|
from libc.string cimport memset, memcpy
|
2016-02-01 10:34:55 +03:00
|
|
|
from libc.stdlib cimport malloc, calloc, free
|
2015-06-08 15:49:04 +03:00
|
|
|
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t
|
2016-01-30 16:31:12 +03:00
|
|
|
from thinc.linear.avgtron cimport AveragedPerceptron
|
|
|
|
from thinc.linalg cimport VecVec
|
2017-04-15 14:05:15 +03:00
|
|
|
from thinc.structs cimport SparseArrayC, FeatureC, ExampleC
|
|
|
|
from thinc.extra.eg cimport Example
|
|
|
|
from cymem.cymem cimport Pool, Address
|
|
|
|
from murmurhash.mrmr cimport hash64
|
2016-02-01 05:08:42 +03:00
|
|
|
from preshed.maps cimport MapStruct
|
|
|
|
from preshed.maps cimport map_get
|
2017-03-10 20:21:21 +03:00
|
|
|
|
2017-05-08 01:38:35 +03:00
|
|
|
from thinc.api import layerize, chain
|
Update draft of parser neural network model
Model is good, but code is messy. Currently requires Chainer, which may cause the build to fail on machines without a GPU.
Outline of the model:
We first predict context-sensitive vectors for each word in the input:
(embed_lower | embed_prefix | embed_suffix | embed_shape)
>> Maxout(token_width)
>> convolution ** 4
This convolutional layer is shared between the tagger and the parser. This prevents the parser from needing tag features.
To boost the representation, we make a "super tag" with POS, morphology and dependency label. The tagger predicts this
by adding a softmax layer onto the convolutional layer --- so, we're teaching the convolutional layer to give us a
representation that's one affine transform from this informative lexical information. This is obviously good for the
parser (which backprops to the convolutions too).
The parser model makes a state vector by concatenating the vector representations for its context tokens. Current
results suggest few context tokens works well. Maybe this is a bug.
The current context tokens:
* S0, S1, S2: Top three words on the stack
* B0, B1: First two words of the buffer
* S0L1, S0L2: Leftmost and second leftmost children of S0
* S0R1, S0R2: Rightmost and second rightmost children of S0
* S1L1, S1L2, S1R2, S1R, B0L1, B0L2: Likewise for S1 and B0
This makes the state vector quite long: 13*T, where T is the token vector width (128 is working well). Fortunately,
there's a way to structure the computation to save some expense (and make it more GPU friendly).
The parser typically visits 2*N states for a sentence of length N (although it may visit more, if it back-tracks
with a non-monotonic transition). A naive implementation would require 2*N (B, 13*T) @ (13*T, H) matrix multiplications
for a batch of size B. We can instead perform one (B*N, T) @ (T, 13*H) multiplication, to pre-compute the hidden
weights for each positional feature wrt the words in the batch. (Note that our token vectors come from the CNN
-- so we can't play this trick over the vocabulary. That's how Stanford's NN parser works --- and why its model
is so big.)
This pre-computation strategy allows a nice compromise between GPU-friendliness and implementation simplicity.
The CNN and the wide lower layer are computed on the GPU, and then the precomputed hidden weights are moved
to the CPU, before we start the transition-based parsing process. This makes a lot of things much easier.
We don't have to worry about variable-length batch sizes, and we don't have to implement the dynamic oracle
in CUDA to train.
Currently the parser's loss function is multilabel log loss, as the dynamic oracle allows multiple states to
be 0 cost. This is defined as:
(exp(score) / Z) - (exp(score) / gZ)
Where gZ is the sum of the scores assigned to gold classes. I'm very interested in regressing on the cost directly,
but so far this isn't working well.
Machinery is in place for beam-search, which has been working well for the linear model. Beam search should benefit
greatly from the pre-computation trick.
2017-05-13 00:09:15 +03:00
|
|
|
from thinc.neural import Affine, Model, Maxout
|
|
|
|
from thinc.neural.ops import NumpyOps
|
2017-05-06 17:47:15 +03:00
|
|
|
|
Update draft of parser neural network model
Model is good, but code is messy. Currently requires Chainer, which may cause the build to fail on machines without a GPU.
Outline of the model:
We first predict context-sensitive vectors for each word in the input:
(embed_lower | embed_prefix | embed_suffix | embed_shape)
>> Maxout(token_width)
>> convolution ** 4
This convolutional layer is shared between the tagger and the parser. This prevents the parser from needing tag features.
To boost the representation, we make a "super tag" with POS, morphology and dependency label. The tagger predicts this
by adding a softmax layer onto the convolutional layer --- so, we're teaching the convolutional layer to give us a
representation that's one affine transform from this informative lexical information. This is obviously good for the
parser (which backprops to the convolutions too).
The parser model makes a state vector by concatenating the vector representations for its context tokens. Current
results suggest few context tokens works well. Maybe this is a bug.
The current context tokens:
* S0, S1, S2: Top three words on the stack
* B0, B1: First two words of the buffer
* S0L1, S0L2: Leftmost and second leftmost children of S0
* S0R1, S0R2: Rightmost and second rightmost children of S0
* S1L1, S1L2, S1R2, S1R, B0L1, B0L2: Likewise for S1 and B0
This makes the state vector quite long: 13*T, where T is the token vector width (128 is working well). Fortunately,
there's a way to structure the computation to save some expense (and make it more GPU friendly).
The parser typically visits 2*N states for a sentence of length N (although it may visit more, if it back-tracks
with a non-monotonic transition). A naive implementation would require 2*N (B, 13*T) @ (13*T, H) matrix multiplications
for a batch of size B. We can instead perform one (B*N, T) @ (T, 13*H) multiplication, to pre-compute the hidden
weights for each positional feature wrt the words in the batch. (Note that our token vectors come from the CNN
-- so we can't play this trick over the vocabulary. That's how Stanford's NN parser works --- and why its model
is so big.)
This pre-computation strategy allows a nice compromise between GPU-friendliness and implementation simplicity.
The CNN and the wide lower layer are computed on the GPU, and then the precomputed hidden weights are moved
to the CPU, before we start the transition-based parsing process. This makes a lot of things much easier.
We don't have to worry about variable-length batch sizes, and we don't have to implement the dynamic oracle
in CUDA to train.
Currently the parser's loss function is multilabel log loss, as the dynamic oracle allows multiple states to
be 0 cost. This is defined as:
(exp(score) / Z) - (exp(score) / gZ)
Where gZ is the sum of the scores assigned to gold classes. I'm very interested in regressing on the cost directly,
but so far this isn't working well.
Machinery is in place for beam-search, which has been working well for the linear model. Beam search should benefit
greatly from the pre-computation trick.
2017-05-13 00:09:15 +03:00
|
|
|
from .._ml import zero_init, PrecomputableAffine, PrecomputableMaxouts
|
2014-12-16 14:44:43 +03:00
|
|
|
from . import _parse_features
|
2015-06-09 22:20:14 +03:00
|
|
|
from ._parse_features cimport CONTEXT_SIZE
|
2015-06-10 00:23:28 +03:00
|
|
|
from ._parse_features cimport fill_context
|
|
|
|
from .stateclass cimport StateClass
|
2016-02-01 10:34:55 +03:00
|
|
|
from ._state cimport StateC
|
2017-04-15 14:05:15 +03:00
|
|
|
from .nonproj import PseudoProjectivity
|
|
|
|
from .transition_system import OracleError
|
|
|
|
from .transition_system cimport TransitionSystem, Transition
|
|
|
|
from ..structs cimport TokenC
|
|
|
|
from ..tokens.doc cimport Doc
|
|
|
|
from ..strings cimport StringStore
|
|
|
|
from ..gold cimport GoldParse
|
2017-05-06 15:22:20 +03:00
|
|
|
from ..attrs cimport TAG, DEP
|
|
|
|
|
2017-04-15 14:05:15 +03:00
|
|
|
|
2017-05-08 01:38:35 +03:00
|
|
|
def get_templates(*args, **kwargs):
|
|
|
|
return []
|
2014-12-16 14:44:43 +03:00
|
|
|
|
2017-04-16 19:02:42 +03:00
|
|
|
USE_FTRL = True
|
2015-04-19 11:31:31 +03:00
|
|
|
DEBUG = False
|
2014-12-16 14:44:43 +03:00
|
|
|
def set_debug(val):
|
|
|
|
global DEBUG
|
|
|
|
DEBUG = val
|
|
|
|
|
|
|
|
|
Update draft of parser neural network model
Model is good, but code is messy. Currently requires Chainer, which may cause the build to fail on machines without a GPU.
Outline of the model:
We first predict context-sensitive vectors for each word in the input:
(embed_lower | embed_prefix | embed_suffix | embed_shape)
>> Maxout(token_width)
>> convolution ** 4
This convolutional layer is shared between the tagger and the parser. This prevents the parser from needing tag features.
To boost the representation, we make a "super tag" with POS, morphology and dependency label. The tagger predicts this
by adding a softmax layer onto the convolutional layer --- so, we're teaching the convolutional layer to give us a
representation that's one affine transform from this informative lexical information. This is obviously good for the
parser (which backprops to the convolutions too).
The parser model makes a state vector by concatenating the vector representations for its context tokens. Current
results suggest few context tokens works well. Maybe this is a bug.
The current context tokens:
* S0, S1, S2: Top three words on the stack
* B0, B1: First two words of the buffer
* S0L1, S0L2: Leftmost and second leftmost children of S0
* S0R1, S0R2: Rightmost and second rightmost children of S0
* S1L1, S1L2, S1R2, S1R, B0L1, B0L2: Likewise for S1 and B0
This makes the state vector quite long: 13*T, where T is the token vector width (128 is working well). Fortunately,
there's a way to structure the computation to save some expense (and make it more GPU friendly).
The parser typically visits 2*N states for a sentence of length N (although it may visit more, if it back-tracks
with a non-monotonic transition). A naive implementation would require 2*N (B, 13*T) @ (13*T, H) matrix multiplications
for a batch of size B. We can instead perform one (B*N, T) @ (T, 13*H) multiplication, to pre-compute the hidden
weights for each positional feature wrt the words in the batch. (Note that our token vectors come from the CNN
-- so we can't play this trick over the vocabulary. That's how Stanford's NN parser works --- and why its model
is so big.)
This pre-computation strategy allows a nice compromise between GPU-friendliness and implementation simplicity.
The CNN and the wide lower layer are computed on the GPU, and then the precomputed hidden weights are moved
to the CPU, before we start the transition-based parsing process. This makes a lot of things much easier.
We don't have to worry about variable-length batch sizes, and we don't have to implement the dynamic oracle
in CUDA to train.
Currently the parser's loss function is multilabel log loss, as the dynamic oracle allows multiple states to
be 0 cost. This is defined as:
(exp(score) / Z) - (exp(score) / gZ)
Where gZ is the sum of the scores assigned to gold classes. I'm very interested in regressing on the cost directly,
but so far this isn't working well.
Machinery is in place for beam-search, which has been working well for the linear model. Beam search should benefit
greatly from the pre-computation trick.
2017-05-13 00:09:15 +03:00
|
|
|
def get_greedy_model_for_batch(batch_size, tokvecs, lower_model, cuda_stream=None):
|
|
|
|
'''Allow a model to be "primed" by pre-computing input features in bulk.
|
|
|
|
|
|
|
|
This is used for the parser, where we want to take a batch of documents,
|
|
|
|
and compute vectors for each (token, position) pair. These vectors can then
|
|
|
|
be reused, especially for beam-search.
|
|
|
|
|
|
|
|
Let's say we're using 12 features for each state, e.g. word at start of
|
|
|
|
buffer, three words on stack, their children, etc. In the normal arc-eager
|
|
|
|
system, a document of length N is processed in 2*N states. This means we'll
|
|
|
|
create 2*N*12 feature vectors --- but if we pre-compute, we only need
|
|
|
|
N*12 vector computations. The saving for beam-search is much better:
|
|
|
|
if we have a beam of k, we'll normally make 2*N*12*K computations --
|
|
|
|
so we can save the factor k. This also gives a nice CPU/GPU division:
|
|
|
|
we can do all our hard maths up front, packed into large multiplications,
|
|
|
|
and do the hard-to-program parsing on the CPU.
|
|
|
|
'''
|
|
|
|
gpu_cached, bp_features = lower_model.begin_update(tokvecs, drop=0.)
|
|
|
|
cdef np.ndarray cached
|
|
|
|
if not isinstance(gpu_cached, numpy.ndarray):
|
|
|
|
cached = gpu_cached.get(stream=cuda_stream)
|
|
|
|
else:
|
|
|
|
cached = gpu_cached
|
|
|
|
nF = gpu_cached.shape[1]
|
|
|
|
nP = gpu_cached.shape[3]
|
|
|
|
ops = lower_model.ops
|
|
|
|
features = numpy.zeros((batch_size, cached.shape[2], nP), dtype='f')
|
|
|
|
synchronized = False
|
|
|
|
|
|
|
|
def forward(token_ids, drop=0.):
|
|
|
|
nonlocal synchronized
|
|
|
|
if not synchronized and cuda_stream is not None:
|
|
|
|
cuda_stream.synchronize()
|
|
|
|
synchronized = True
|
|
|
|
# This is tricky, but:
|
|
|
|
# - Input to forward on CPU
|
|
|
|
# - Output from forward on CPU
|
|
|
|
# - Input to backward on GPU!
|
|
|
|
# - Output from backward on GPU
|
|
|
|
nonlocal features
|
|
|
|
features = features[:len(token_ids)]
|
|
|
|
features.fill(0)
|
|
|
|
cdef float[:, :, ::1] feats = features
|
|
|
|
cdef int[:, ::1] ids = token_ids
|
|
|
|
_sum_features(<float*>&feats[0,0,0],
|
|
|
|
<float*>cached.data, &ids[0,0],
|
|
|
|
token_ids.shape[0], nF, cached.shape[2]*nP)
|
|
|
|
|
|
|
|
if nP >= 2:
|
|
|
|
best, which = ops.maxout(features)
|
|
|
|
else:
|
|
|
|
best = features.reshape((features.shape[0], features.shape[1]))
|
|
|
|
which = None
|
|
|
|
|
|
|
|
def backward(d_best, sgd=None):
|
|
|
|
# This will usually be on GPU
|
|
|
|
if isinstance(d_best, numpy.ndarray):
|
|
|
|
d_best = ops.xp.array(d_best)
|
|
|
|
if nP >= 2:
|
|
|
|
d_features = ops.backprop_maxout(d_best, which, nP)
|
|
|
|
else:
|
|
|
|
d_features = d_best.reshape((d_best.shape[0], d_best.shape[1], 1))
|
|
|
|
d_tokens = bp_features((d_features, token_ids), sgd)
|
|
|
|
return d_tokens
|
|
|
|
|
|
|
|
return best, backward
|
|
|
|
|
|
|
|
return forward
|
|
|
|
|
|
|
|
|
|
|
|
cdef void _sum_features(float* output,
|
|
|
|
const float* cached, const int* token_ids, int B, int F, int O) nogil:
|
|
|
|
cdef int idx, b, f, i
|
|
|
|
cdef const float* feature
|
|
|
|
for b in range(B):
|
|
|
|
for f in range(F):
|
|
|
|
if token_ids[f] < 0:
|
|
|
|
continue
|
|
|
|
idx = token_ids[f] * F * O + f*O
|
|
|
|
feature = &cached[idx]
|
|
|
|
for i in range(O):
|
|
|
|
output[i] += feature[i]
|
|
|
|
output += O
|
|
|
|
token_ids += F
|
|
|
|
|
|
|
|
|
|
|
|
def get_batch_loss(TransitionSystem moves, states, golds, float[:, ::1] scores):
|
2017-05-07 23:47:06 +03:00
|
|
|
cdef StateClass state
|
Update draft of parser neural network model
Model is good, but code is messy. Currently requires Chainer, which may cause the build to fail on machines without a GPU.
Outline of the model:
We first predict context-sensitive vectors for each word in the input:
(embed_lower | embed_prefix | embed_suffix | embed_shape)
>> Maxout(token_width)
>> convolution ** 4
This convolutional layer is shared between the tagger and the parser. This prevents the parser from needing tag features.
To boost the representation, we make a "super tag" with POS, morphology and dependency label. The tagger predicts this
by adding a softmax layer onto the convolutional layer --- so, we're teaching the convolutional layer to give us a
representation that's one affine transform from this informative lexical information. This is obviously good for the
parser (which backprops to the convolutions too).
The parser model makes a state vector by concatenating the vector representations for its context tokens. Current
results suggest few context tokens works well. Maybe this is a bug.
The current context tokens:
* S0, S1, S2: Top three words on the stack
* B0, B1: First two words of the buffer
* S0L1, S0L2: Leftmost and second leftmost children of S0
* S0R1, S0R2: Rightmost and second rightmost children of S0
* S1L1, S1L2, S1R2, S1R, B0L1, B0L2: Likewise for S1 and B0
This makes the state vector quite long: 13*T, where T is the token vector width (128 is working well). Fortunately,
there's a way to structure the computation to save some expense (and make it more GPU friendly).
The parser typically visits 2*N states for a sentence of length N (although it may visit more, if it back-tracks
with a non-monotonic transition). A naive implementation would require 2*N (B, 13*T) @ (13*T, H) matrix multiplications
for a batch of size B. We can instead perform one (B*N, T) @ (T, 13*H) multiplication, to pre-compute the hidden
weights for each positional feature wrt the words in the batch. (Note that our token vectors come from the CNN
-- so we can't play this trick over the vocabulary. That's how Stanford's NN parser works --- and why its model
is so big.)
This pre-computation strategy allows a nice compromise between GPU-friendliness and implementation simplicity.
The CNN and the wide lower layer are computed on the GPU, and then the precomputed hidden weights are moved
to the CPU, before we start the transition-based parsing process. This makes a lot of things much easier.
We don't have to worry about variable-length batch sizes, and we don't have to implement the dynamic oracle
in CUDA to train.
Currently the parser's loss function is multilabel log loss, as the dynamic oracle allows multiple states to
be 0 cost. This is defined as:
(exp(score) / Z) - (exp(score) / gZ)
Where gZ is the sum of the scores assigned to gold classes. I'm very interested in regressing on the cost directly,
but so far this isn't working well.
Machinery is in place for beam-search, which has been working well for the linear model. Beam search should benefit
greatly from the pre-computation trick.
2017-05-13 00:09:15 +03:00
|
|
|
cdef GoldParse gold
|
|
|
|
cdef Pool mem = Pool()
|
|
|
|
cdef int i
|
|
|
|
is_valid = <int*>mem.alloc(moves.n_moves, sizeof(int))
|
|
|
|
costs = <float*>mem.alloc(moves.n_moves, sizeof(float))
|
|
|
|
cdef np.ndarray d_scores = numpy.zeros((len(states), moves.n_moves), dtype='f')
|
|
|
|
c_d_scores = <float*>d_scores.data
|
|
|
|
for i, (state, gold) in enumerate(zip(states, golds)):
|
|
|
|
memset(is_valid, 0, moves.n_moves * sizeof(int))
|
|
|
|
memset(costs, 0, moves.n_moves * sizeof(float))
|
|
|
|
moves.set_costs(is_valid, costs, state, gold)
|
|
|
|
cpu_log_loss(c_d_scores, costs, is_valid, &scores[i, 0], d_scores.shape[1])
|
|
|
|
c_d_scores += d_scores.shape[1]
|
|
|
|
return d_scores
|
|
|
|
|
|
|
|
|
|
|
|
cdef void cpu_log_loss(float* d_scores,
|
|
|
|
const float* costs, const int* is_valid, const float* scores,
|
|
|
|
int O) nogil:
|
|
|
|
"""Do multi-label log loss"""
|
|
|
|
cdef double max_, gmax, Z, gZ
|
|
|
|
best = arg_max_if_gold(scores, costs, is_valid, O)
|
|
|
|
guess = arg_max_if_valid(scores, is_valid, O)
|
|
|
|
Z = 1e-10
|
|
|
|
gZ = 1e-10
|
|
|
|
max_ = scores[guess]
|
|
|
|
gmax = scores[best]
|
|
|
|
for i in range(O):
|
|
|
|
if is_valid[i]:
|
|
|
|
Z += exp(scores[i] - max_)
|
2017-05-14 01:17:27 +03:00
|
|
|
if costs[i] <= costs[best]:
|
Update draft of parser neural network model
Model is good, but code is messy. Currently requires Chainer, which may cause the build to fail on machines without a GPU.
Outline of the model:
We first predict context-sensitive vectors for each word in the input:
(embed_lower | embed_prefix | embed_suffix | embed_shape)
>> Maxout(token_width)
>> convolution ** 4
This convolutional layer is shared between the tagger and the parser. This prevents the parser from needing tag features.
To boost the representation, we make a "super tag" with POS, morphology and dependency label. The tagger predicts this
by adding a softmax layer onto the convolutional layer --- so, we're teaching the convolutional layer to give us a
representation that's one affine transform from this informative lexical information. This is obviously good for the
parser (which backprops to the convolutions too).
The parser model makes a state vector by concatenating the vector representations for its context tokens. Current
results suggest few context tokens works well. Maybe this is a bug.
The current context tokens:
* S0, S1, S2: Top three words on the stack
* B0, B1: First two words of the buffer
* S0L1, S0L2: Leftmost and second leftmost children of S0
* S0R1, S0R2: Rightmost and second rightmost children of S0
* S1L1, S1L2, S1R2, S1R, B0L1, B0L2: Likewise for S1 and B0
This makes the state vector quite long: 13*T, where T is the token vector width (128 is working well). Fortunately,
there's a way to structure the computation to save some expense (and make it more GPU friendly).
The parser typically visits 2*N states for a sentence of length N (although it may visit more, if it back-tracks
with a non-monotonic transition). A naive implementation would require 2*N (B, 13*T) @ (13*T, H) matrix multiplications
for a batch of size B. We can instead perform one (B*N, T) @ (T, 13*H) multiplication, to pre-compute the hidden
weights for each positional feature wrt the words in the batch. (Note that our token vectors come from the CNN
-- so we can't play this trick over the vocabulary. That's how Stanford's NN parser works --- and why its model
is so big.)
This pre-computation strategy allows a nice compromise between GPU-friendliness and implementation simplicity.
The CNN and the wide lower layer are computed on the GPU, and then the precomputed hidden weights are moved
to the CPU, before we start the transition-based parsing process. This makes a lot of things much easier.
We don't have to worry about variable-length batch sizes, and we don't have to implement the dynamic oracle
in CUDA to train.
Currently the parser's loss function is multilabel log loss, as the dynamic oracle allows multiple states to
be 0 cost. This is defined as:
(exp(score) / Z) - (exp(score) / gZ)
Where gZ is the sum of the scores assigned to gold classes. I'm very interested in regressing on the cost directly,
but so far this isn't working well.
Machinery is in place for beam-search, which has been working well for the linear model. Beam search should benefit
greatly from the pre-computation trick.
2017-05-13 00:09:15 +03:00
|
|
|
gZ += exp(scores[i] - gmax)
|
|
|
|
for i in range(O):
|
|
|
|
if not is_valid[i]:
|
|
|
|
d_scores[i] = 0.
|
2017-05-14 01:17:27 +03:00
|
|
|
elif costs[i] <= costs[best]:
|
Update draft of parser neural network model
Model is good, but code is messy. Currently requires Chainer, which may cause the build to fail on machines without a GPU.
Outline of the model:
We first predict context-sensitive vectors for each word in the input:
(embed_lower | embed_prefix | embed_suffix | embed_shape)
>> Maxout(token_width)
>> convolution ** 4
This convolutional layer is shared between the tagger and the parser. This prevents the parser from needing tag features.
To boost the representation, we make a "super tag" with POS, morphology and dependency label. The tagger predicts this
by adding a softmax layer onto the convolutional layer --- so, we're teaching the convolutional layer to give us a
representation that's one affine transform from this informative lexical information. This is obviously good for the
parser (which backprops to the convolutions too).
The parser model makes a state vector by concatenating the vector representations for its context tokens. Current
results suggest few context tokens works well. Maybe this is a bug.
The current context tokens:
* S0, S1, S2: Top three words on the stack
* B0, B1: First two words of the buffer
* S0L1, S0L2: Leftmost and second leftmost children of S0
* S0R1, S0R2: Rightmost and second rightmost children of S0
* S1L1, S1L2, S1R2, S1R, B0L1, B0L2: Likewise for S1 and B0
This makes the state vector quite long: 13*T, where T is the token vector width (128 is working well). Fortunately,
there's a way to structure the computation to save some expense (and make it more GPU friendly).
The parser typically visits 2*N states for a sentence of length N (although it may visit more, if it back-tracks
with a non-monotonic transition). A naive implementation would require 2*N (B, 13*T) @ (13*T, H) matrix multiplications
for a batch of size B. We can instead perform one (B*N, T) @ (T, 13*H) multiplication, to pre-compute the hidden
weights for each positional feature wrt the words in the batch. (Note that our token vectors come from the CNN
-- so we can't play this trick over the vocabulary. That's how Stanford's NN parser works --- and why its model
is so big.)
This pre-computation strategy allows a nice compromise between GPU-friendliness and implementation simplicity.
The CNN and the wide lower layer are computed on the GPU, and then the precomputed hidden weights are moved
to the CPU, before we start the transition-based parsing process. This makes a lot of things much easier.
We don't have to worry about variable-length batch sizes, and we don't have to implement the dynamic oracle
in CUDA to train.
Currently the parser's loss function is multilabel log loss, as the dynamic oracle allows multiple states to
be 0 cost. This is defined as:
(exp(score) / Z) - (exp(score) / gZ)
Where gZ is the sum of the scores assigned to gold classes. I'm very interested in regressing on the cost directly,
but so far this isn't working well.
Machinery is in place for beam-search, which has been working well for the linear model. Beam search should benefit
greatly from the pre-computation trick.
2017-05-13 00:09:15 +03:00
|
|
|
d_scores[i] = (exp(scores[i]-max_) / Z) - (exp(scores[i]-gmax)/gZ)
|
|
|
|
else:
|
|
|
|
d_scores[i] = exp(scores[i]-max_) / Z
|
2015-11-06 19:24:30 +03:00
|
|
|
|
|
|
|
|
2017-05-14 01:17:27 +03:00
|
|
|
cdef void cpu_regression_loss(float* d_scores,
|
|
|
|
const float* costs, const int* is_valid, const float* scores,
|
|
|
|
int O) nogil:
|
|
|
|
cdef float eps = 2.
|
|
|
|
best = arg_max_if_gold(scores, costs, is_valid, O)
|
|
|
|
for i in range(O):
|
|
|
|
if not is_valid[i]:
|
|
|
|
d_scores[i] = 0.
|
|
|
|
elif scores[i] < scores[best]:
|
|
|
|
d_scores[i] = 0.
|
|
|
|
else:
|
|
|
|
# I doubt this is correct?
|
|
|
|
# Looking for something like Huber loss
|
|
|
|
diff = scores[i] - -costs[i]
|
|
|
|
if diff > eps:
|
|
|
|
d_scores[i] = eps
|
|
|
|
elif diff < -eps:
|
|
|
|
d_scores[i] = -eps
|
|
|
|
else:
|
|
|
|
d_scores[i] = diff
|
|
|
|
|
|
|
|
|
2017-05-08 00:05:01 +03:00
|
|
|
def init_states(TransitionSystem moves, docs):
|
|
|
|
cdef Doc doc
|
|
|
|
cdef StateClass state
|
2017-05-08 01:38:35 +03:00
|
|
|
offsets = []
|
|
|
|
states = []
|
|
|
|
offset = 0
|
2017-05-08 00:05:01 +03:00
|
|
|
for i, doc in enumerate(docs):
|
|
|
|
state = StateClass.init(doc.c, doc.length)
|
|
|
|
moves.initialize_state(state.c)
|
|
|
|
states.append(state)
|
2017-05-08 01:38:35 +03:00
|
|
|
offsets.append(offset)
|
|
|
|
offset += len(doc)
|
|
|
|
return states, offsets
|
2015-11-06 19:24:30 +03:00
|
|
|
|
|
|
|
|
Update draft of parser neural network model
Model is good, but code is messy. Currently requires Chainer, which may cause the build to fail on machines without a GPU.
Outline of the model:
We first predict context-sensitive vectors for each word in the input:
(embed_lower | embed_prefix | embed_suffix | embed_shape)
>> Maxout(token_width)
>> convolution ** 4
This convolutional layer is shared between the tagger and the parser. This prevents the parser from needing tag features.
To boost the representation, we make a "super tag" with POS, morphology and dependency label. The tagger predicts this
by adding a softmax layer onto the convolutional layer --- so, we're teaching the convolutional layer to give us a
representation that's one affine transform from this informative lexical information. This is obviously good for the
parser (which backprops to the convolutions too).
The parser model makes a state vector by concatenating the vector representations for its context tokens. Current
results suggest few context tokens works well. Maybe this is a bug.
The current context tokens:
* S0, S1, S2: Top three words on the stack
* B0, B1: First two words of the buffer
* S0L1, S0L2: Leftmost and second leftmost children of S0
* S0R1, S0R2: Rightmost and second rightmost children of S0
* S1L1, S1L2, S1R2, S1R, B0L1, B0L2: Likewise for S1 and B0
This makes the state vector quite long: 13*T, where T is the token vector width (128 is working well). Fortunately,
there's a way to structure the computation to save some expense (and make it more GPU friendly).
The parser typically visits 2*N states for a sentence of length N (although it may visit more, if it back-tracks
with a non-monotonic transition). A naive implementation would require 2*N (B, 13*T) @ (13*T, H) matrix multiplications
for a batch of size B. We can instead perform one (B*N, T) @ (T, 13*H) multiplication, to pre-compute the hidden
weights for each positional feature wrt the words in the batch. (Note that our token vectors come from the CNN
-- so we can't play this trick over the vocabulary. That's how Stanford's NN parser works --- and why its model
is so big.)
This pre-computation strategy allows a nice compromise between GPU-friendliness and implementation simplicity.
The CNN and the wide lower layer are computed on the GPU, and then the precomputed hidden weights are moved
to the CPU, before we start the transition-based parsing process. This makes a lot of things much easier.
We don't have to worry about variable-length batch sizes, and we don't have to implement the dynamic oracle
in CUDA to train.
Currently the parser's loss function is multilabel log loss, as the dynamic oracle allows multiple states to
be 0 cost. This is defined as:
(exp(score) / Z) - (exp(score) / gZ)
Where gZ is the sum of the scores assigned to gold classes. I'm very interested in regressing on the cost directly,
but so far this isn't working well.
Machinery is in place for beam-search, which has been working well for the linear model. Beam search should benefit
greatly from the pre-computation trick.
2017-05-13 00:09:15 +03:00
|
|
|
def extract_token_ids(states, offsets=None, nF=1, nB=0, nS=2, nL=0, nR=0):
|
|
|
|
cdef StateClass state
|
|
|
|
cdef int n_tokens = states[0].nr_context_tokens(nF, nB, nS, nL, nR)
|
|
|
|
ids = numpy.zeros((len(states), n_tokens), dtype='i')
|
|
|
|
if offsets is None:
|
|
|
|
offsets = [0] * len(states)
|
|
|
|
for i, (state, offset) in enumerate(zip(states, offsets)):
|
|
|
|
state.set_context_tokens(ids[i], nF, nB, nS, nL, nR)
|
|
|
|
ids[i] += (ids[i] >= 0) * offset
|
|
|
|
return ids
|
|
|
|
|
|
|
|
|
2015-06-02 01:28:02 +03:00
|
|
|
cdef class Parser:
|
2017-04-15 14:05:15 +03:00
|
|
|
"""
|
|
|
|
Base class of the DependencyParser and EntityRecognizer.
|
|
|
|
"""
|
2015-08-26 20:19:01 +03:00
|
|
|
@classmethod
|
2016-11-25 18:00:21 +03:00
|
|
|
def load(cls, path, Vocab vocab, TransitionSystem=None, require=False, **cfg):
|
2017-04-15 14:05:15 +03:00
|
|
|
"""
|
|
|
|
Load the statistical model from the supplied path.
|
2016-11-01 14:25:36 +03:00
|
|
|
|
|
|
|
Arguments:
|
|
|
|
path (Path):
|
|
|
|
The path to load from.
|
|
|
|
vocab (Vocab):
|
|
|
|
The vocabulary. Must be shared by the documents to be processed.
|
|
|
|
require (bool):
|
|
|
|
Whether to raise an error if the files are not found.
|
|
|
|
Returns (Parser):
|
|
|
|
The newly constructed object.
|
|
|
|
"""
|
2016-09-27 15:02:12 +03:00
|
|
|
with (path / 'config.json').open() as file_:
|
2017-04-15 14:05:15 +03:00
|
|
|
cfg = ujson.load(file_)
|
2016-10-16 22:34:57 +03:00
|
|
|
self = cls(vocab, TransitionSystem=TransitionSystem, model=None, **cfg)
|
|
|
|
if (path / 'model').exists():
|
|
|
|
self.model.load(str(path / 'model'))
|
|
|
|
elif require:
|
|
|
|
raise IOError(
|
|
|
|
"Required file %s/model not found when loading" % str(path))
|
|
|
|
return self
|
|
|
|
|
2017-05-04 13:17:36 +03:00
|
|
|
def __init__(self, Vocab vocab, TransitionSystem=None, model=None, **cfg):
|
2017-04-15 14:05:15 +03:00
|
|
|
"""
|
|
|
|
Create a Parser.
|
2016-11-01 14:25:36 +03:00
|
|
|
|
|
|
|
Arguments:
|
|
|
|
vocab (Vocab):
|
|
|
|
The vocabulary object. Must be shared with documents to be processed.
|
2017-05-04 13:17:36 +03:00
|
|
|
model (thinc Model):
|
2016-11-01 14:25:36 +03:00
|
|
|
The statistical model.
|
|
|
|
Returns (Parser):
|
|
|
|
The newly constructed object.
|
|
|
|
"""
|
2016-10-16 22:34:57 +03:00
|
|
|
if TransitionSystem is None:
|
|
|
|
TransitionSystem = self.TransitionSystem
|
2016-10-23 18:45:44 +03:00
|
|
|
self.vocab = vocab
|
2017-04-20 18:02:44 +03:00
|
|
|
cfg['actions'] = TransitionSystem.get_actions(**cfg)
|
|
|
|
self.moves = TransitionSystem(vocab.strings, cfg['actions'])
|
2017-05-04 13:17:36 +03:00
|
|
|
if model is None:
|
2017-05-08 01:38:35 +03:00
|
|
|
self.model, self.feature_maps = self.build_model(**cfg)
|
|
|
|
else:
|
|
|
|
self.model, self.feature_maps = model
|
2016-09-24 16:42:01 +03:00
|
|
|
self.cfg = cfg
|
2017-05-05 20:20:39 +03:00
|
|
|
|
2015-10-12 11:33:11 +03:00
|
|
|
def __reduce__(self):
|
2016-09-24 16:42:01 +03:00
|
|
|
return (Parser, (self.vocab, self.moves, self.model), None, None)
|
2015-10-12 11:33:11 +03:00
|
|
|
|
Update draft of parser neural network model
Model is good, but code is messy. Currently requires Chainer, which may cause the build to fail on machines without a GPU.
Outline of the model:
We first predict context-sensitive vectors for each word in the input:
(embed_lower | embed_prefix | embed_suffix | embed_shape)
>> Maxout(token_width)
>> convolution ** 4
This convolutional layer is shared between the tagger and the parser. This prevents the parser from needing tag features.
To boost the representation, we make a "super tag" with POS, morphology and dependency label. The tagger predicts this
by adding a softmax layer onto the convolutional layer --- so, we're teaching the convolutional layer to give us a
representation that's one affine transform from this informative lexical information. This is obviously good for the
parser (which backprops to the convolutions too).
The parser model makes a state vector by concatenating the vector representations for its context tokens. Current
results suggest few context tokens works well. Maybe this is a bug.
The current context tokens:
* S0, S1, S2: Top three words on the stack
* B0, B1: First two words of the buffer
* S0L1, S0L2: Leftmost and second leftmost children of S0
* S0R1, S0R2: Rightmost and second rightmost children of S0
* S1L1, S1L2, S1R2, S1R, B0L1, B0L2: Likewise for S1 and B0
This makes the state vector quite long: 13*T, where T is the token vector width (128 is working well). Fortunately,
there's a way to structure the computation to save some expense (and make it more GPU friendly).
The parser typically visits 2*N states for a sentence of length N (although it may visit more, if it back-tracks
with a non-monotonic transition). A naive implementation would require 2*N (B, 13*T) @ (13*T, H) matrix multiplications
for a batch of size B. We can instead perform one (B*N, T) @ (T, 13*H) multiplication, to pre-compute the hidden
weights for each positional feature wrt the words in the batch. (Note that our token vectors come from the CNN
-- so we can't play this trick over the vocabulary. That's how Stanford's NN parser works --- and why its model
is so big.)
This pre-computation strategy allows a nice compromise between GPU-friendliness and implementation simplicity.
The CNN and the wide lower layer are computed on the GPU, and then the precomputed hidden weights are moved
to the CPU, before we start the transition-based parsing process. This makes a lot of things much easier.
We don't have to worry about variable-length batch sizes, and we don't have to implement the dynamic oracle
in CUDA to train.
Currently the parser's loss function is multilabel log loss, as the dynamic oracle allows multiple states to
be 0 cost. This is defined as:
(exp(score) / Z) - (exp(score) / gZ)
Where gZ is the sum of the scores assigned to gold classes. I'm very interested in regressing on the cost directly,
but so far this isn't working well.
Machinery is in place for beam-search, which has been working well for the linear model. Beam search should benefit
greatly from the pre-computation trick.
2017-05-13 00:09:15 +03:00
|
|
|
def build_model(self,
|
|
|
|
hidden_width=128, token_vector_width=96, nr_vector=1000,
|
|
|
|
nF=1, nB=1, nS=1, nL=1, nR=1, **cfg):
|
2017-05-06 21:38:12 +03:00
|
|
|
nr_context_tokens = StateClass.nr_context_tokens(nF, nB, nS, nL, nR)
|
Update draft of parser neural network model
Model is good, but code is messy. Currently requires Chainer, which may cause the build to fail on machines without a GPU.
Outline of the model:
We first predict context-sensitive vectors for each word in the input:
(embed_lower | embed_prefix | embed_suffix | embed_shape)
>> Maxout(token_width)
>> convolution ** 4
This convolutional layer is shared between the tagger and the parser. This prevents the parser from needing tag features.
To boost the representation, we make a "super tag" with POS, morphology and dependency label. The tagger predicts this
by adding a softmax layer onto the convolutional layer --- so, we're teaching the convolutional layer to give us a
representation that's one affine transform from this informative lexical information. This is obviously good for the
parser (which backprops to the convolutions too).
The parser model makes a state vector by concatenating the vector representations for its context tokens. Current
results suggest few context tokens works well. Maybe this is a bug.
The current context tokens:
* S0, S1, S2: Top three words on the stack
* B0, B1: First two words of the buffer
* S0L1, S0L2: Leftmost and second leftmost children of S0
* S0R1, S0R2: Rightmost and second rightmost children of S0
* S1L1, S1L2, S1R2, S1R, B0L1, B0L2: Likewise for S1 and B0
This makes the state vector quite long: 13*T, where T is the token vector width (128 is working well). Fortunately,
there's a way to structure the computation to save some expense (and make it more GPU friendly).
The parser typically visits 2*N states for a sentence of length N (although it may visit more, if it back-tracks
with a non-monotonic transition). A naive implementation would require 2*N (B, 13*T) @ (13*T, H) matrix multiplications
for a batch of size B. We can instead perform one (B*N, T) @ (T, 13*H) multiplication, to pre-compute the hidden
weights for each positional feature wrt the words in the batch. (Note that our token vectors come from the CNN
-- so we can't play this trick over the vocabulary. That's how Stanford's NN parser works --- and why its model
is so big.)
This pre-computation strategy allows a nice compromise between GPU-friendliness and implementation simplicity.
The CNN and the wide lower layer are computed on the GPU, and then the precomputed hidden weights are moved
to the CPU, before we start the transition-based parsing process. This makes a lot of things much easier.
We don't have to worry about variable-length batch sizes, and we don't have to implement the dynamic oracle
in CUDA to train.
Currently the parser's loss function is multilabel log loss, as the dynamic oracle allows multiple states to
be 0 cost. This is defined as:
(exp(score) / Z) - (exp(score) / gZ)
Where gZ is the sum of the scores assigned to gold classes. I'm very interested in regressing on the cost directly,
but so far this isn't working well.
Machinery is in place for beam-search, which has been working well for the linear model. Beam search should benefit
greatly from the pre-computation trick.
2017-05-13 00:09:15 +03:00
|
|
|
with Model.use_device('cpu'):
|
|
|
|
upper = chain(
|
|
|
|
Maxout(token_vector_width),
|
|
|
|
zero_init(Affine(self.moves.n_moves, token_vector_width)))
|
|
|
|
assert isinstance(upper.ops, NumpyOps)
|
|
|
|
lower = PrecomputableMaxouts(token_vector_width, nF=nr_context_tokens, nI=token_vector_width,
|
|
|
|
pieces=cfg.get('maxout_pieces', 1))
|
|
|
|
upper.begin_training(upper.ops.allocate((500, token_vector_width)))
|
|
|
|
lower.begin_training(lower.ops.allocate((500, token_vector_width)))
|
2017-05-08 12:36:37 +03:00
|
|
|
return upper, lower
|
2017-05-06 15:22:20 +03:00
|
|
|
|
2015-11-06 19:24:30 +03:00
|
|
|
def __call__(self, Doc tokens):
|
2017-04-15 14:05:15 +03:00
|
|
|
"""
|
2017-05-04 13:17:36 +03:00
|
|
|
Apply the parser or entity recognizer, setting the annotations onto the Doc object.
|
2016-11-01 14:25:36 +03:00
|
|
|
|
|
|
|
Arguments:
|
|
|
|
doc (Doc): The document to be processed.
|
|
|
|
Returns:
|
|
|
|
None
|
|
|
|
"""
|
2017-05-04 13:17:36 +03:00
|
|
|
self.parse_batch([tokens])
|
2017-05-07 15:31:09 +03:00
|
|
|
|
2016-02-03 04:04:55 +03:00
|
|
|
def pipe(self, stream, int batch_size=1000, int n_threads=2):
|
2017-04-15 14:05:15 +03:00
|
|
|
"""
|
|
|
|
Process a stream of documents.
|
2016-11-01 14:25:36 +03:00
|
|
|
|
|
|
|
Arguments:
|
|
|
|
stream: The sequence of documents to process.
|
|
|
|
batch_size (int):
|
|
|
|
The number of documents to accumulate into a working set.
|
|
|
|
n_threads (int):
|
|
|
|
The number of threads with which to work on the buffer in parallel.
|
|
|
|
Yields (Doc): Documents, in order.
|
|
|
|
"""
|
2016-02-03 04:04:55 +03:00
|
|
|
queue = []
|
|
|
|
for doc in stream:
|
2016-02-05 21:37:50 +03:00
|
|
|
queue.append(doc)
|
2016-02-03 04:04:55 +03:00
|
|
|
if len(queue) == batch_size:
|
2017-05-04 13:17:36 +03:00
|
|
|
self.parse_batch(queue)
|
2016-02-03 04:04:55 +03:00
|
|
|
for doc in queue:
|
2016-05-02 15:25:10 +03:00
|
|
|
self.moves.finalize_doc(doc)
|
2016-02-03 04:04:55 +03:00
|
|
|
yield doc
|
|
|
|
queue = []
|
2017-05-04 13:17:36 +03:00
|
|
|
if queue:
|
|
|
|
self.parse_batch(queue)
|
|
|
|
for doc in queue:
|
|
|
|
self.moves.finalize_doc(doc)
|
|
|
|
yield doc
|
|
|
|
|
Update draft of parser neural network model
Model is good, but code is messy. Currently requires Chainer, which may cause the build to fail on machines without a GPU.
Outline of the model:
We first predict context-sensitive vectors for each word in the input:
(embed_lower | embed_prefix | embed_suffix | embed_shape)
>> Maxout(token_width)
>> convolution ** 4
This convolutional layer is shared between the tagger and the parser. This prevents the parser from needing tag features.
To boost the representation, we make a "super tag" with POS, morphology and dependency label. The tagger predicts this
by adding a softmax layer onto the convolutional layer --- so, we're teaching the convolutional layer to give us a
representation that's one affine transform from this informative lexical information. This is obviously good for the
parser (which backprops to the convolutions too).
The parser model makes a state vector by concatenating the vector representations for its context tokens. Current
results suggest few context tokens works well. Maybe this is a bug.
The current context tokens:
* S0, S1, S2: Top three words on the stack
* B0, B1: First two words of the buffer
* S0L1, S0L2: Leftmost and second leftmost children of S0
* S0R1, S0R2: Rightmost and second rightmost children of S0
* S1L1, S1L2, S1R2, S1R, B0L1, B0L2: Likewise for S1 and B0
This makes the state vector quite long: 13*T, where T is the token vector width (128 is working well). Fortunately,
there's a way to structure the computation to save some expense (and make it more GPU friendly).
The parser typically visits 2*N states for a sentence of length N (although it may visit more, if it back-tracks
with a non-monotonic transition). A naive implementation would require 2*N (B, 13*T) @ (13*T, H) matrix multiplications
for a batch of size B. We can instead perform one (B*N, T) @ (T, 13*H) multiplication, to pre-compute the hidden
weights for each positional feature wrt the words in the batch. (Note that our token vectors come from the CNN
-- so we can't play this trick over the vocabulary. That's how Stanford's NN parser works --- and why its model
is so big.)
This pre-computation strategy allows a nice compromise between GPU-friendliness and implementation simplicity.
The CNN and the wide lower layer are computed on the GPU, and then the precomputed hidden weights are moved
to the CPU, before we start the transition-based parsing process. This makes a lot of things much easier.
We don't have to worry about variable-length batch sizes, and we don't have to implement the dynamic oracle
in CUDA to train.
Currently the parser's loss function is multilabel log loss, as the dynamic oracle allows multiple states to
be 0 cost. This is defined as:
(exp(score) / Z) - (exp(score) / gZ)
Where gZ is the sum of the scores assigned to gold classes. I'm very interested in regressing on the cost directly,
but so far this isn't working well.
Machinery is in place for beam-search, which has been working well for the linear model. Beam search should benefit
greatly from the pre-computation trick.
2017-05-13 00:09:15 +03:00
|
|
|
def parse_batch(self, docs_tokvecs):
|
|
|
|
cdef:
|
|
|
|
int nC
|
|
|
|
Doc doc
|
|
|
|
StateClass state
|
|
|
|
np.ndarray py_scores
|
|
|
|
int[500] is_valid # Hacks for now
|
|
|
|
|
|
|
|
cuda_stream = Stream()
|
|
|
|
docs, tokvecs = docs_tokvecs
|
|
|
|
lower_model = get_greedy_model_for_batch(len(docs), tokvecs, self.feature_maps,
|
|
|
|
cuda_stream)
|
|
|
|
upper_model = self.model
|
|
|
|
|
2017-05-08 01:38:35 +03:00
|
|
|
states, offsets = init_states(self.moves, docs)
|
2017-05-06 15:22:20 +03:00
|
|
|
all_states = list(states)
|
Update draft of parser neural network model
Model is good, but code is messy. Currently requires Chainer, which may cause the build to fail on machines without a GPU.
Outline of the model:
We first predict context-sensitive vectors for each word in the input:
(embed_lower | embed_prefix | embed_suffix | embed_shape)
>> Maxout(token_width)
>> convolution ** 4
This convolutional layer is shared between the tagger and the parser. This prevents the parser from needing tag features.
To boost the representation, we make a "super tag" with POS, morphology and dependency label. The tagger predicts this
by adding a softmax layer onto the convolutional layer --- so, we're teaching the convolutional layer to give us a
representation that's one affine transform from this informative lexical information. This is obviously good for the
parser (which backprops to the convolutions too).
The parser model makes a state vector by concatenating the vector representations for its context tokens. Current
results suggest few context tokens works well. Maybe this is a bug.
The current context tokens:
* S0, S1, S2: Top three words on the stack
* B0, B1: First two words of the buffer
* S0L1, S0L2: Leftmost and second leftmost children of S0
* S0R1, S0R2: Rightmost and second rightmost children of S0
* S1L1, S1L2, S1R2, S1R, B0L1, B0L2: Likewise for S1 and B0
This makes the state vector quite long: 13*T, where T is the token vector width (128 is working well). Fortunately,
there's a way to structure the computation to save some expense (and make it more GPU friendly).
The parser typically visits 2*N states for a sentence of length N (although it may visit more, if it back-tracks
with a non-monotonic transition). A naive implementation would require 2*N (B, 13*T) @ (13*T, H) matrix multiplications
for a batch of size B. We can instead perform one (B*N, T) @ (T, 13*H) multiplication, to pre-compute the hidden
weights for each positional feature wrt the words in the batch. (Note that our token vectors come from the CNN
-- so we can't play this trick over the vocabulary. That's how Stanford's NN parser works --- and why its model
is so big.)
This pre-computation strategy allows a nice compromise between GPU-friendliness and implementation simplicity.
The CNN and the wide lower layer are computed on the GPU, and then the precomputed hidden weights are moved
to the CPU, before we start the transition-based parsing process. This makes a lot of things much easier.
We don't have to worry about variable-length batch sizes, and we don't have to implement the dynamic oracle
in CUDA to train.
Currently the parser's loss function is multilabel log loss, as the dynamic oracle allows multiple states to
be 0 cost. This is defined as:
(exp(score) / Z) - (exp(score) / gZ)
Where gZ is the sum of the scores assigned to gold classes. I'm very interested in regressing on the cost directly,
but so far this isn't working well.
Machinery is in place for beam-search, which has been working well for the linear model. Beam search should benefit
greatly from the pre-computation trick.
2017-05-13 00:09:15 +03:00
|
|
|
todo = [st for st in zip(states, offsets) if not st[0].py_is_final()]
|
|
|
|
|
2017-05-06 15:22:20 +03:00
|
|
|
while todo:
|
2017-05-08 01:38:35 +03:00
|
|
|
states, offsets = zip(*todo)
|
Update draft of parser neural network model
Model is good, but code is messy. Currently requires Chainer, which may cause the build to fail on machines without a GPU.
Outline of the model:
We first predict context-sensitive vectors for each word in the input:
(embed_lower | embed_prefix | embed_suffix | embed_shape)
>> Maxout(token_width)
>> convolution ** 4
This convolutional layer is shared between the tagger and the parser. This prevents the parser from needing tag features.
To boost the representation, we make a "super tag" with POS, morphology and dependency label. The tagger predicts this
by adding a softmax layer onto the convolutional layer --- so, we're teaching the convolutional layer to give us a
representation that's one affine transform from this informative lexical information. This is obviously good for the
parser (which backprops to the convolutions too).
The parser model makes a state vector by concatenating the vector representations for its context tokens. Current
results suggest few context tokens works well. Maybe this is a bug.
The current context tokens:
* S0, S1, S2: Top three words on the stack
* B0, B1: First two words of the buffer
* S0L1, S0L2: Leftmost and second leftmost children of S0
* S0R1, S0R2: Rightmost and second rightmost children of S0
* S1L1, S1L2, S1R2, S1R, B0L1, B0L2: Likewise for S1 and B0
This makes the state vector quite long: 13*T, where T is the token vector width (128 is working well). Fortunately,
there's a way to structure the computation to save some expense (and make it more GPU friendly).
The parser typically visits 2*N states for a sentence of length N (although it may visit more, if it back-tracks
with a non-monotonic transition). A naive implementation would require 2*N (B, 13*T) @ (13*T, H) matrix multiplications
for a batch of size B. We can instead perform one (B*N, T) @ (T, 13*H) multiplication, to pre-compute the hidden
weights for each positional feature wrt the words in the batch. (Note that our token vectors come from the CNN
-- so we can't play this trick over the vocabulary. That's how Stanford's NN parser works --- and why its model
is so big.)
This pre-computation strategy allows a nice compromise between GPU-friendliness and implementation simplicity.
The CNN and the wide lower layer are computed on the GPU, and then the precomputed hidden weights are moved
to the CPU, before we start the transition-based parsing process. This makes a lot of things much easier.
We don't have to worry about variable-length batch sizes, and we don't have to implement the dynamic oracle
in CUDA to train.
Currently the parser's loss function is multilabel log loss, as the dynamic oracle allows multiple states to
be 0 cost. This is defined as:
(exp(score) / Z) - (exp(score) / gZ)
Where gZ is the sum of the scores assigned to gold classes. I'm very interested in regressing on the cost directly,
but so far this isn't working well.
Machinery is in place for beam-search, which has been working well for the linear model. Beam search should benefit
greatly from the pre-computation trick.
2017-05-13 00:09:15 +03:00
|
|
|
token_ids = extract_token_ids(states, offsets=offsets)
|
|
|
|
|
|
|
|
py_scores = upper_model(lower_model(token_ids)[0])
|
|
|
|
scores = <float*>py_scores.data
|
|
|
|
nC = py_scores.shape[1]
|
|
|
|
for state, offset in zip(states, offsets):
|
|
|
|
self.moves.set_valid(is_valid, state.c)
|
|
|
|
guess = arg_max_if_valid(scores, is_valid, nC)
|
|
|
|
action = self.moves.c[guess]
|
|
|
|
action.do(state.c, action.label)
|
|
|
|
scores += nC
|
2017-05-08 01:38:35 +03:00
|
|
|
todo = [st for st in todo if not st[0].py_is_final()]
|
Update draft of parser neural network model
Model is good, but code is messy. Currently requires Chainer, which may cause the build to fail on machines without a GPU.
Outline of the model:
We first predict context-sensitive vectors for each word in the input:
(embed_lower | embed_prefix | embed_suffix | embed_shape)
>> Maxout(token_width)
>> convolution ** 4
This convolutional layer is shared between the tagger and the parser. This prevents the parser from needing tag features.
To boost the representation, we make a "super tag" with POS, morphology and dependency label. The tagger predicts this
by adding a softmax layer onto the convolutional layer --- so, we're teaching the convolutional layer to give us a
representation that's one affine transform from this informative lexical information. This is obviously good for the
parser (which backprops to the convolutions too).
The parser model makes a state vector by concatenating the vector representations for its context tokens. Current
results suggest few context tokens works well. Maybe this is a bug.
The current context tokens:
* S0, S1, S2: Top three words on the stack
* B0, B1: First two words of the buffer
* S0L1, S0L2: Leftmost and second leftmost children of S0
* S0R1, S0R2: Rightmost and second rightmost children of S0
* S1L1, S1L2, S1R2, S1R, B0L1, B0L2: Likewise for S1 and B0
This makes the state vector quite long: 13*T, where T is the token vector width (128 is working well). Fortunately,
there's a way to structure the computation to save some expense (and make it more GPU friendly).
The parser typically visits 2*N states for a sentence of length N (although it may visit more, if it back-tracks
with a non-monotonic transition). A naive implementation would require 2*N (B, 13*T) @ (13*T, H) matrix multiplications
for a batch of size B. We can instead perform one (B*N, T) @ (T, 13*H) multiplication, to pre-compute the hidden
weights for each positional feature wrt the words in the batch. (Note that our token vectors come from the CNN
-- so we can't play this trick over the vocabulary. That's how Stanford's NN parser works --- and why its model
is so big.)
This pre-computation strategy allows a nice compromise between GPU-friendliness and implementation simplicity.
The CNN and the wide lower layer are computed on the GPU, and then the precomputed hidden weights are moved
to the CPU, before we start the transition-based parsing process. This makes a lot of things much easier.
We don't have to worry about variable-length batch sizes, and we don't have to implement the dynamic oracle
in CUDA to train.
Currently the parser's loss function is multilabel log loss, as the dynamic oracle allows multiple states to
be 0 cost. This is defined as:
(exp(score) / Z) - (exp(score) / gZ)
Where gZ is the sum of the scores assigned to gold classes. I'm very interested in regressing on the cost directly,
but so far this isn't working well.
Machinery is in place for beam-search, which has been working well for the linear model. Beam search should benefit
greatly from the pre-computation trick.
2017-05-13 00:09:15 +03:00
|
|
|
|
2017-05-06 15:22:20 +03:00
|
|
|
for state, doc in zip(all_states, docs):
|
|
|
|
self.moves.finalize_state(state.c)
|
|
|
|
for i in range(doc.length):
|
|
|
|
doc.c[i] = state.c._sent[i]
|
2017-05-08 01:38:35 +03:00
|
|
|
self.moves.finalize_doc(doc)
|
2017-05-06 17:47:15 +03:00
|
|
|
|
Update draft of parser neural network model
Model is good, but code is messy. Currently requires Chainer, which may cause the build to fail on machines without a GPU.
Outline of the model:
We first predict context-sensitive vectors for each word in the input:
(embed_lower | embed_prefix | embed_suffix | embed_shape)
>> Maxout(token_width)
>> convolution ** 4
This convolutional layer is shared between the tagger and the parser. This prevents the parser from needing tag features.
To boost the representation, we make a "super tag" with POS, morphology and dependency label. The tagger predicts this
by adding a softmax layer onto the convolutional layer --- so, we're teaching the convolutional layer to give us a
representation that's one affine transform from this informative lexical information. This is obviously good for the
parser (which backprops to the convolutions too).
The parser model makes a state vector by concatenating the vector representations for its context tokens. Current
results suggest few context tokens works well. Maybe this is a bug.
The current context tokens:
* S0, S1, S2: Top three words on the stack
* B0, B1: First two words of the buffer
* S0L1, S0L2: Leftmost and second leftmost children of S0
* S0R1, S0R2: Rightmost and second rightmost children of S0
* S1L1, S1L2, S1R2, S1R, B0L1, B0L2: Likewise for S1 and B0
This makes the state vector quite long: 13*T, where T is the token vector width (128 is working well). Fortunately,
there's a way to structure the computation to save some expense (and make it more GPU friendly).
The parser typically visits 2*N states for a sentence of length N (although it may visit more, if it back-tracks
with a non-monotonic transition). A naive implementation would require 2*N (B, 13*T) @ (13*T, H) matrix multiplications
for a batch of size B. We can instead perform one (B*N, T) @ (T, 13*H) multiplication, to pre-compute the hidden
weights for each positional feature wrt the words in the batch. (Note that our token vectors come from the CNN
-- so we can't play this trick over the vocabulary. That's how Stanford's NN parser works --- and why its model
is so big.)
This pre-computation strategy allows a nice compromise between GPU-friendliness and implementation simplicity.
The CNN and the wide lower layer are computed on the GPU, and then the precomputed hidden weights are moved
to the CPU, before we start the transition-based parsing process. This makes a lot of things much easier.
We don't have to worry about variable-length batch sizes, and we don't have to implement the dynamic oracle
in CUDA to train.
Currently the parser's loss function is multilabel log loss, as the dynamic oracle allows multiple states to
be 0 cost. This is defined as:
(exp(score) / Z) - (exp(score) / gZ)
Where gZ is the sum of the scores assigned to gold classes. I'm very interested in regressing on the cost directly,
but so far this isn't working well.
Machinery is in place for beam-search, which has been working well for the linear model. Beam search should benefit
greatly from the pre-computation trick.
2017-05-13 00:09:15 +03:00
|
|
|
def update(self, docs_tokvecs, golds, drop=0., sgd=None):
|
|
|
|
cdef:
|
|
|
|
int nC
|
|
|
|
int[500] is_valid # Hack for now
|
|
|
|
Doc doc
|
|
|
|
StateClass state
|
|
|
|
np.ndarray scores
|
|
|
|
|
|
|
|
docs, tokvecs = docs_tokvecs
|
|
|
|
cuda_stream = Stream()
|
2017-05-06 17:47:15 +03:00
|
|
|
if isinstance(docs, Doc) and isinstance(golds, GoldParse):
|
Update draft of parser neural network model
Model is good, but code is messy. Currently requires Chainer, which may cause the build to fail on machines without a GPU.
Outline of the model:
We first predict context-sensitive vectors for each word in the input:
(embed_lower | embed_prefix | embed_suffix | embed_shape)
>> Maxout(token_width)
>> convolution ** 4
This convolutional layer is shared between the tagger and the parser. This prevents the parser from needing tag features.
To boost the representation, we make a "super tag" with POS, morphology and dependency label. The tagger predicts this
by adding a softmax layer onto the convolutional layer --- so, we're teaching the convolutional layer to give us a
representation that's one affine transform from this informative lexical information. This is obviously good for the
parser (which backprops to the convolutions too).
The parser model makes a state vector by concatenating the vector representations for its context tokens. Current
results suggest few context tokens works well. Maybe this is a bug.
The current context tokens:
* S0, S1, S2: Top three words on the stack
* B0, B1: First two words of the buffer
* S0L1, S0L2: Leftmost and second leftmost children of S0
* S0R1, S0R2: Rightmost and second rightmost children of S0
* S1L1, S1L2, S1R2, S1R, B0L1, B0L2: Likewise for S1 and B0
This makes the state vector quite long: 13*T, where T is the token vector width (128 is working well). Fortunately,
there's a way to structure the computation to save some expense (and make it more GPU friendly).
The parser typically visits 2*N states for a sentence of length N (although it may visit more, if it back-tracks
with a non-monotonic transition). A naive implementation would require 2*N (B, 13*T) @ (13*T, H) matrix multiplications
for a batch of size B. We can instead perform one (B*N, T) @ (T, 13*H) multiplication, to pre-compute the hidden
weights for each positional feature wrt the words in the batch. (Note that our token vectors come from the CNN
-- so we can't play this trick over the vocabulary. That's how Stanford's NN parser works --- and why its model
is so big.)
This pre-computation strategy allows a nice compromise between GPU-friendliness and implementation simplicity.
The CNN and the wide lower layer are computed on the GPU, and then the precomputed hidden weights are moved
to the CPU, before we start the transition-based parsing process. This makes a lot of things much easier.
We don't have to worry about variable-length batch sizes, and we don't have to implement the dynamic oracle
in CUDA to train.
Currently the parser's loss function is multilabel log loss, as the dynamic oracle allows multiple states to
be 0 cost. This is defined as:
(exp(score) / Z) - (exp(score) / gZ)
Where gZ is the sum of the scores assigned to gold classes. I'm very interested in regressing on the cost directly,
but so far this isn't working well.
Machinery is in place for beam-search, which has been working well for the linear model. Beam search should benefit
greatly from the pre-computation trick.
2017-05-13 00:09:15 +03:00
|
|
|
return self.update(([docs], tokvecs), [golds], drop=drop)
|
2017-05-06 17:47:15 +03:00
|
|
|
for gold in golds:
|
|
|
|
self.moves.preprocess_gold(gold)
|
2017-05-07 03:02:43 +03:00
|
|
|
|
2017-05-08 01:38:35 +03:00
|
|
|
states, offsets = init_states(self.moves, docs)
|
2017-05-07 03:02:43 +03:00
|
|
|
|
Update draft of parser neural network model
Model is good, but code is messy. Currently requires Chainer, which may cause the build to fail on machines without a GPU.
Outline of the model:
We first predict context-sensitive vectors for each word in the input:
(embed_lower | embed_prefix | embed_suffix | embed_shape)
>> Maxout(token_width)
>> convolution ** 4
This convolutional layer is shared between the tagger and the parser. This prevents the parser from needing tag features.
To boost the representation, we make a "super tag" with POS, morphology and dependency label. The tagger predicts this
by adding a softmax layer onto the convolutional layer --- so, we're teaching the convolutional layer to give us a
representation that's one affine transform from this informative lexical information. This is obviously good for the
parser (which backprops to the convolutions too).
The parser model makes a state vector by concatenating the vector representations for its context tokens. Current
results suggest few context tokens works well. Maybe this is a bug.
The current context tokens:
* S0, S1, S2: Top three words on the stack
* B0, B1: First two words of the buffer
* S0L1, S0L2: Leftmost and second leftmost children of S0
* S0R1, S0R2: Rightmost and second rightmost children of S0
* S1L1, S1L2, S1R2, S1R, B0L1, B0L2: Likewise for S1 and B0
This makes the state vector quite long: 13*T, where T is the token vector width (128 is working well). Fortunately,
there's a way to structure the computation to save some expense (and make it more GPU friendly).
The parser typically visits 2*N states for a sentence of length N (although it may visit more, if it back-tracks
with a non-monotonic transition). A naive implementation would require 2*N (B, 13*T) @ (13*T, H) matrix multiplications
for a batch of size B. We can instead perform one (B*N, T) @ (T, 13*H) multiplication, to pre-compute the hidden
weights for each positional feature wrt the words in the batch. (Note that our token vectors come from the CNN
-- so we can't play this trick over the vocabulary. That's how Stanford's NN parser works --- and why its model
is so big.)
This pre-computation strategy allows a nice compromise between GPU-friendliness and implementation simplicity.
The CNN and the wide lower layer are computed on the GPU, and then the precomputed hidden weights are moved
to the CPU, before we start the transition-based parsing process. This makes a lot of things much easier.
We don't have to worry about variable-length batch sizes, and we don't have to implement the dynamic oracle
in CUDA to train.
Currently the parser's loss function is multilabel log loss, as the dynamic oracle allows multiple states to
be 0 cost. This is defined as:
(exp(score) / Z) - (exp(score) / gZ)
Where gZ is the sum of the scores assigned to gold classes. I'm very interested in regressing on the cost directly,
but so far this isn't working well.
Machinery is in place for beam-search, which has been working well for the linear model. Beam search should benefit
greatly from the pre-computation trick.
2017-05-13 00:09:15 +03:00
|
|
|
todo = zip(states, offsets, golds)
|
|
|
|
todo = filter(lambda sp: not sp[0].py_is_final(), todo)
|
|
|
|
|
|
|
|
lower_model = get_greedy_model_for_batch(len(todo),
|
|
|
|
tokvecs, self.feature_maps, cuda_stream=cuda_stream)
|
|
|
|
upper_model = self.model
|
|
|
|
d_tokens = self.feature_maps.ops.allocate(tokvecs.shape)
|
|
|
|
backprops = []
|
|
|
|
n_tokens = tokvecs.shape[0]
|
|
|
|
nF = self.feature_maps.nF
|
2017-05-06 15:22:20 +03:00
|
|
|
while todo:
|
Update draft of parser neural network model
Model is good, but code is messy. Currently requires Chainer, which may cause the build to fail on machines without a GPU.
Outline of the model:
We first predict context-sensitive vectors for each word in the input:
(embed_lower | embed_prefix | embed_suffix | embed_shape)
>> Maxout(token_width)
>> convolution ** 4
This convolutional layer is shared between the tagger and the parser. This prevents the parser from needing tag features.
To boost the representation, we make a "super tag" with POS, morphology and dependency label. The tagger predicts this
by adding a softmax layer onto the convolutional layer --- so, we're teaching the convolutional layer to give us a
representation that's one affine transform from this informative lexical information. This is obviously good for the
parser (which backprops to the convolutions too).
The parser model makes a state vector by concatenating the vector representations for its context tokens. Current
results suggest few context tokens works well. Maybe this is a bug.
The current context tokens:
* S0, S1, S2: Top three words on the stack
* B0, B1: First two words of the buffer
* S0L1, S0L2: Leftmost and second leftmost children of S0
* S0R1, S0R2: Rightmost and second rightmost children of S0
* S1L1, S1L2, S1R2, S1R, B0L1, B0L2: Likewise for S1 and B0
This makes the state vector quite long: 13*T, where T is the token vector width (128 is working well). Fortunately,
there's a way to structure the computation to save some expense (and make it more GPU friendly).
The parser typically visits 2*N states for a sentence of length N (although it may visit more, if it back-tracks
with a non-monotonic transition). A naive implementation would require 2*N (B, 13*T) @ (13*T, H) matrix multiplications
for a batch of size B. We can instead perform one (B*N, T) @ (T, 13*H) multiplication, to pre-compute the hidden
weights for each positional feature wrt the words in the batch. (Note that our token vectors come from the CNN
-- so we can't play this trick over the vocabulary. That's how Stanford's NN parser works --- and why its model
is so big.)
This pre-computation strategy allows a nice compromise between GPU-friendliness and implementation simplicity.
The CNN and the wide lower layer are computed on the GPU, and then the precomputed hidden weights are moved
to the CPU, before we start the transition-based parsing process. This makes a lot of things much easier.
We don't have to worry about variable-length batch sizes, and we don't have to implement the dynamic oracle
in CUDA to train.
Currently the parser's loss function is multilabel log loss, as the dynamic oracle allows multiple states to
be 0 cost. This is defined as:
(exp(score) / Z) - (exp(score) / gZ)
Where gZ is the sum of the scores assigned to gold classes. I'm very interested in regressing on the cost directly,
but so far this isn't working well.
Machinery is in place for beam-search, which has been working well for the linear model. Beam search should benefit
greatly from the pre-computation trick.
2017-05-13 00:09:15 +03:00
|
|
|
states, offsets, golds = zip(*todo)
|
|
|
|
|
|
|
|
token_ids = extract_token_ids(states, offsets=offsets)
|
|
|
|
lower, bp_lower = lower_model(token_ids)
|
|
|
|
scores, bp_scores = upper_model.begin_update(lower)
|
|
|
|
|
|
|
|
d_scores = get_batch_loss(self.moves, states, golds, scores)
|
|
|
|
d_lower = bp_scores(d_scores, sgd=sgd)
|
|
|
|
|
|
|
|
gpu_tok_ids = cupy.ndarray(token_ids.shape, dtype='i')
|
|
|
|
gpu_d_lower = cupy.ndarray(d_lower.shape, dtype='f')
|
|
|
|
gpu_tok_ids.set(token_ids, stream=cuda_stream)
|
|
|
|
gpu_d_lower.set(d_lower, stream=cuda_stream)
|
|
|
|
backprops.append((gpu_tok_ids, gpu_d_lower, bp_lower))
|
|
|
|
|
|
|
|
c_scores = <float*>scores.data
|
|
|
|
for state in states:
|
|
|
|
self.moves.set_valid(is_valid, state.c)
|
|
|
|
guess = arg_max_if_valid(c_scores, is_valid, scores.shape[1])
|
|
|
|
action = self.moves.c[guess]
|
|
|
|
action.do(state.c, action.label)
|
|
|
|
c_scores += scores.shape[1]
|
|
|
|
|
2017-05-08 15:54:26 +03:00
|
|
|
todo = filter(lambda sp: not sp[0].py_is_final(), todo)
|
Update draft of parser neural network model
Model is good, but code is messy. Currently requires Chainer, which may cause the build to fail on machines without a GPU.
Outline of the model:
We first predict context-sensitive vectors for each word in the input:
(embed_lower | embed_prefix | embed_suffix | embed_shape)
>> Maxout(token_width)
>> convolution ** 4
This convolutional layer is shared between the tagger and the parser. This prevents the parser from needing tag features.
To boost the representation, we make a "super tag" with POS, morphology and dependency label. The tagger predicts this
by adding a softmax layer onto the convolutional layer --- so, we're teaching the convolutional layer to give us a
representation that's one affine transform from this informative lexical information. This is obviously good for the
parser (which backprops to the convolutions too).
The parser model makes a state vector by concatenating the vector representations for its context tokens. Current
results suggest few context tokens works well. Maybe this is a bug.
The current context tokens:
* S0, S1, S2: Top three words on the stack
* B0, B1: First two words of the buffer
* S0L1, S0L2: Leftmost and second leftmost children of S0
* S0R1, S0R2: Rightmost and second rightmost children of S0
* S1L1, S1L2, S1R2, S1R, B0L1, B0L2: Likewise for S1 and B0
This makes the state vector quite long: 13*T, where T is the token vector width (128 is working well). Fortunately,
there's a way to structure the computation to save some expense (and make it more GPU friendly).
The parser typically visits 2*N states for a sentence of length N (although it may visit more, if it back-tracks
with a non-monotonic transition). A naive implementation would require 2*N (B, 13*T) @ (13*T, H) matrix multiplications
for a batch of size B. We can instead perform one (B*N, T) @ (T, 13*H) multiplication, to pre-compute the hidden
weights for each positional feature wrt the words in the batch. (Note that our token vectors come from the CNN
-- so we can't play this trick over the vocabulary. That's how Stanford's NN parser works --- and why its model
is so big.)
This pre-computation strategy allows a nice compromise between GPU-friendliness and implementation simplicity.
The CNN and the wide lower layer are computed on the GPU, and then the precomputed hidden weights are moved
to the CPU, before we start the transition-based parsing process. This makes a lot of things much easier.
We don't have to worry about variable-length batch sizes, and we don't have to implement the dynamic oracle
in CUDA to train.
Currently the parser's loss function is multilabel log loss, as the dynamic oracle allows multiple states to
be 0 cost. This is defined as:
(exp(score) / Z) - (exp(score) / gZ)
Where gZ is the sum of the scores assigned to gold classes. I'm very interested in regressing on the cost directly,
but so far this isn't working well.
Machinery is in place for beam-search, which has been working well for the linear model. Beam search should benefit
greatly from the pre-computation trick.
2017-05-13 00:09:15 +03:00
|
|
|
# This tells CUDA to block --- so we know our copies are complete.
|
|
|
|
cuda_stream.synchronize()
|
|
|
|
for token_ids, d_lower, bp_lower in backprops:
|
|
|
|
d_state_features = bp_lower(d_lower, sgd=sgd)
|
|
|
|
active_feats = token_ids * (token_ids >= 0)
|
|
|
|
active_feats = active_feats.reshape((token_ids.shape[0], token_ids.shape[1], 1))
|
|
|
|
if hasattr(self.feature_maps.ops.xp, 'scatter_add'):
|
|
|
|
self.feature_maps.ops.xp.scatter_add(d_tokens,
|
|
|
|
token_ids, d_state_features * active_feats)
|
|
|
|
else:
|
|
|
|
self.model.ops.xp.add.at(d_tokens,
|
|
|
|
token_ids, d_state_features * active_feats)
|
|
|
|
return d_tokens
|
2015-08-10 01:08:46 +03:00
|
|
|
|
2017-04-10 12:37:04 +03:00
|
|
|
def step_through(self, Doc doc, GoldParse gold=None):
|
2017-04-15 14:05:15 +03:00
|
|
|
"""
|
|
|
|
Set up a stepwise state, to introspect and control the transition sequence.
|
2016-11-01 14:25:36 +03:00
|
|
|
|
|
|
|
Arguments:
|
|
|
|
doc (Doc): The document to step through.
|
2017-04-10 12:37:04 +03:00
|
|
|
gold (GoldParse): Optional gold parse
|
2016-11-01 14:25:36 +03:00
|
|
|
Returns (StepwiseState):
|
|
|
|
A state object, to step through the annotation process.
|
|
|
|
"""
|
2017-04-10 12:37:04 +03:00
|
|
|
return StepwiseState(self, doc, gold=gold)
|
2015-08-10 01:08:46 +03:00
|
|
|
|
2016-05-03 15:24:35 +03:00
|
|
|
def from_transition_sequence(self, Doc doc, sequence):
|
2016-11-01 14:25:36 +03:00
|
|
|
"""Control the annotations on a document by specifying a transition sequence
|
|
|
|
to follow.
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
doc (Doc): The document to annotate.
|
|
|
|
sequence: A sequence of action names, as unicode strings.
|
|
|
|
Returns: None
|
|
|
|
"""
|
2016-05-03 15:24:35 +03:00
|
|
|
with self.step_through(doc) as stepwise:
|
|
|
|
for transition in sequence:
|
|
|
|
stepwise.transition(transition)
|
|
|
|
|
2016-01-19 21:11:02 +03:00
|
|
|
def add_label(self, label):
|
2016-10-23 18:45:44 +03:00
|
|
|
# Doesn't set label into serializer -- subclasses override it to do that.
|
2016-01-19 21:11:02 +03:00
|
|
|
for action in self.moves.action_types:
|
2017-04-15 17:00:28 +03:00
|
|
|
added = self.moves.add_action(action, label)
|
|
|
|
if added:
|
2017-04-15 00:52:17 +03:00
|
|
|
# Important that the labels be stored as a list! We need the
|
|
|
|
# order, or the model goes out of synch
|
2017-04-15 17:00:28 +03:00
|
|
|
self.cfg.setdefault('extra_labels', []).append(label)
|
2017-03-08 03:38:51 +03:00
|
|
|
|
2016-01-19 21:11:02 +03:00
|
|
|
|
2015-08-10 01:08:46 +03:00
|
|
|
cdef class StepwiseState:
|
|
|
|
cdef readonly StateClass stcls
|
|
|
|
cdef readonly Example eg
|
|
|
|
cdef readonly Doc doc
|
2017-04-10 12:37:04 +03:00
|
|
|
cdef readonly GoldParse gold
|
2015-08-10 01:08:46 +03:00
|
|
|
cdef readonly Parser parser
|
|
|
|
|
2017-04-10 12:37:04 +03:00
|
|
|
def __init__(self, Parser parser, Doc doc, GoldParse gold=None):
|
2015-08-10 01:08:46 +03:00
|
|
|
self.parser = parser
|
|
|
|
self.doc = doc
|
2017-04-15 14:35:01 +03:00
|
|
|
if gold is not None:
|
2017-04-10 12:37:04 +03:00
|
|
|
self.gold = gold
|
2017-04-15 17:00:28 +03:00
|
|
|
self.parser.moves.preprocess_gold(self.gold)
|
2017-04-10 12:37:04 +03:00
|
|
|
else:
|
|
|
|
self.gold = GoldParse(doc)
|
2015-11-03 16:15:14 +03:00
|
|
|
self.stcls = StateClass.init(doc.c, doc.length)
|
2016-02-01 10:34:55 +03:00
|
|
|
self.parser.moves.initialize_state(self.stcls.c)
|
2016-01-30 16:31:12 +03:00
|
|
|
self.eg = Example(
|
|
|
|
nr_class=self.parser.moves.n_moves,
|
|
|
|
nr_atom=CONTEXT_SIZE,
|
|
|
|
nr_feat=self.parser.model.nr_feat)
|
2015-08-10 01:08:46 +03:00
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
return self
|
|
|
|
|
|
|
|
def __exit__(self, type, value, traceback):
|
|
|
|
self.finish()
|
|
|
|
|
|
|
|
@property
|
|
|
|
def is_final(self):
|
|
|
|
return self.stcls.is_final()
|
|
|
|
|
|
|
|
@property
|
|
|
|
def stack(self):
|
|
|
|
return self.stcls.stack
|
|
|
|
|
|
|
|
@property
|
|
|
|
def queue(self):
|
|
|
|
return self.stcls.queue
|
|
|
|
|
|
|
|
@property
|
|
|
|
def heads(self):
|
2016-04-13 16:28:28 +03:00
|
|
|
return [self.stcls.H(i) for i in range(self.stcls.c.length)]
|
2015-08-10 01:08:46 +03:00
|
|
|
|
|
|
|
@property
|
|
|
|
def deps(self):
|
2016-02-01 04:22:21 +03:00
|
|
|
return [self.doc.vocab.strings[self.stcls.c._sent[i].dep]
|
2016-04-13 16:28:28 +03:00
|
|
|
for i in range(self.stcls.c.length)]
|
2015-08-10 01:08:46 +03:00
|
|
|
|
2017-04-10 12:37:04 +03:00
|
|
|
@property
|
|
|
|
def costs(self):
|
2017-04-15 14:05:15 +03:00
|
|
|
"""
|
|
|
|
Find the action-costs for the current state.
|
|
|
|
"""
|
2017-04-15 14:35:01 +03:00
|
|
|
if not self.gold:
|
|
|
|
raise ValueError("Can't set costs: No GoldParse provided")
|
2017-04-10 12:37:04 +03:00
|
|
|
self.parser.moves.set_costs(self.eg.c.is_valid, self.eg.c.costs,
|
|
|
|
self.stcls, self.gold)
|
|
|
|
costs = {}
|
|
|
|
for i in range(self.parser.moves.n_moves):
|
|
|
|
if not self.eg.c.is_valid[i]:
|
|
|
|
continue
|
|
|
|
transition = self.parser.moves.c[i]
|
|
|
|
name = self.parser.moves.move_name(transition.move, transition.label)
|
|
|
|
costs[name] = self.eg.c.costs[i]
|
|
|
|
return costs
|
|
|
|
|
2015-08-10 01:08:46 +03:00
|
|
|
def predict(self):
|
2016-01-30 16:31:12 +03:00
|
|
|
self.eg.reset()
|
2017-05-05 20:20:39 +03:00
|
|
|
#self.eg.c.nr_feat = self.parser.model.set_featuresC(self.eg.c.atoms, self.eg.c.features,
|
|
|
|
# self.stcls.c)
|
2016-02-01 05:00:15 +03:00
|
|
|
self.parser.moves.set_valid(self.eg.c.is_valid, self.stcls.c)
|
2017-05-05 20:20:39 +03:00
|
|
|
#self.parser.model.set_scoresC(self.eg.c.scores,
|
|
|
|
# self.eg.c.features, self.eg.c.nr_feat)
|
2015-11-06 19:24:30 +03:00
|
|
|
|
2016-01-30 16:31:12 +03:00
|
|
|
cdef Transition action = self.parser.moves.c[self.eg.guess]
|
2015-08-10 01:08:46 +03:00
|
|
|
return self.parser.moves.move_name(action.move, action.label)
|
|
|
|
|
2016-10-16 18:04:16 +03:00
|
|
|
def transition(self, action_name=None):
|
|
|
|
if action_name is None:
|
|
|
|
action_name = self.predict()
|
2015-08-10 06:05:31 +03:00
|
|
|
moves = {'S': 0, 'D': 1, 'L': 2, 'R': 3}
|
2015-08-10 01:08:46 +03:00
|
|
|
if action_name == '_':
|
|
|
|
action_name = self.predict()
|
2015-08-10 06:58:43 +03:00
|
|
|
action = self.parser.moves.lookup_transition(action_name)
|
|
|
|
elif action_name == 'L' or action_name == 'R':
|
2015-08-10 06:05:31 +03:00
|
|
|
self.predict()
|
|
|
|
move = moves[action_name]
|
|
|
|
clas = _arg_max_clas(self.eg.c.scores, move, self.parser.moves.c,
|
|
|
|
self.eg.c.nr_class)
|
|
|
|
action = self.parser.moves.c[clas]
|
|
|
|
else:
|
|
|
|
action = self.parser.moves.lookup_transition(action_name)
|
2016-02-01 04:58:14 +03:00
|
|
|
action.do(self.stcls.c, action.label)
|
2015-08-10 01:08:46 +03:00
|
|
|
|
|
|
|
def finish(self):
|
|
|
|
if self.stcls.is_final():
|
2016-02-01 10:34:55 +03:00
|
|
|
self.parser.moves.finalize_state(self.stcls.c)
|
2016-02-01 04:22:21 +03:00
|
|
|
self.doc.set_parse(self.stcls.c._sent)
|
2016-05-02 15:25:10 +03:00
|
|
|
self.parser.moves.finalize_doc(self.doc)
|
2015-08-10 06:05:31 +03:00
|
|
|
|
|
|
|
|
2016-09-27 20:19:53 +03:00
|
|
|
class ParserStateError(ValueError):
|
2016-10-12 15:35:55 +03:00
|
|
|
def __init__(self, doc):
|
2016-10-12 15:44:31 +03:00
|
|
|
ValueError.__init__(self,
|
|
|
|
"Error analysing doc -- no valid actions available. This should "
|
|
|
|
"never happen, so please report the error on the issue tracker. "
|
|
|
|
"Here's the thread to do so --- reopen it if it's closed:\n"
|
|
|
|
"https://github.com/spacy-io/spaCy/issues/429\n"
|
|
|
|
"Please include the text that the parser failed on, which is:\n"
|
|
|
|
"%s" % repr(doc.text))
|
Update draft of parser neural network model
Model is good, but code is messy. Currently requires Chainer, which may cause the build to fail on machines without a GPU.
Outline of the model:
We first predict context-sensitive vectors for each word in the input:
(embed_lower | embed_prefix | embed_suffix | embed_shape)
>> Maxout(token_width)
>> convolution ** 4
This convolutional layer is shared between the tagger and the parser. This prevents the parser from needing tag features.
To boost the representation, we make a "super tag" with POS, morphology and dependency label. The tagger predicts this
by adding a softmax layer onto the convolutional layer --- so, we're teaching the convolutional layer to give us a
representation that's one affine transform from this informative lexical information. This is obviously good for the
parser (which backprops to the convolutions too).
The parser model makes a state vector by concatenating the vector representations for its context tokens. Current
results suggest few context tokens works well. Maybe this is a bug.
The current context tokens:
* S0, S1, S2: Top three words on the stack
* B0, B1: First two words of the buffer
* S0L1, S0L2: Leftmost and second leftmost children of S0
* S0R1, S0R2: Rightmost and second rightmost children of S0
* S1L1, S1L2, S1R2, S1R, B0L1, B0L2: Likewise for S1 and B0
This makes the state vector quite long: 13*T, where T is the token vector width (128 is working well). Fortunately,
there's a way to structure the computation to save some expense (and make it more GPU friendly).
The parser typically visits 2*N states for a sentence of length N (although it may visit more, if it back-tracks
with a non-monotonic transition). A naive implementation would require 2*N (B, 13*T) @ (13*T, H) matrix multiplications
for a batch of size B. We can instead perform one (B*N, T) @ (T, 13*H) multiplication, to pre-compute the hidden
weights for each positional feature wrt the words in the batch. (Note that our token vectors come from the CNN
-- so we can't play this trick over the vocabulary. That's how Stanford's NN parser works --- and why its model
is so big.)
This pre-computation strategy allows a nice compromise between GPU-friendliness and implementation simplicity.
The CNN and the wide lower layer are computed on the GPU, and then the precomputed hidden weights are moved
to the CPU, before we start the transition-based parsing process. This makes a lot of things much easier.
We don't have to worry about variable-length batch sizes, and we don't have to implement the dynamic oracle
in CUDA to train.
Currently the parser's loss function is multilabel log loss, as the dynamic oracle allows multiple states to
be 0 cost. This is defined as:
(exp(score) / Z) - (exp(score) / gZ)
Where gZ is the sum of the scores assigned to gold classes. I'm very interested in regressing on the cost directly,
but so far this isn't working well.
Machinery is in place for beam-search, which has been working well for the linear model. Beam search should benefit
greatly from the pre-computation trick.
2017-05-13 00:09:15 +03:00
|
|
|
|
|
|
|
|
|
|
|
cdef int arg_max_if_gold(const weight_t* scores, const weight_t* costs, const int* is_valid, int n) nogil:
|
|
|
|
# Find minimum cost
|
|
|
|
cdef float cost = 1
|
|
|
|
for i in range(n):
|
|
|
|
if is_valid[i] and costs[i] < cost:
|
|
|
|
cost = costs[i]
|
|
|
|
# Now find best-scoring with that cost
|
|
|
|
cdef int best = -1
|
|
|
|
for i in range(n):
|
|
|
|
if costs[i] <= cost and is_valid[i]:
|
|
|
|
if best == -1 or scores[i] > scores[best]:
|
|
|
|
best = i
|
|
|
|
return best
|
|
|
|
|
|
|
|
|
|
|
|
cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) nogil:
|
|
|
|
cdef int best = -1
|
|
|
|
for i in range(n):
|
|
|
|
if is_valid[i] >= 1:
|
|
|
|
if best == -1 or scores[i] > scores[best]:
|
|
|
|
best = i
|
|
|
|
return best
|
|
|
|
|
|
|
|
|
|
|
|
cdef int _arg_max_clas(const weight_t* scores, int move, const Transition* actions,
|
|
|
|
int nr_class) except -1:
|
|
|
|
cdef weight_t score = 0
|
|
|
|
cdef int mode = -1
|
|
|
|
cdef int i
|
|
|
|
for i in range(nr_class):
|
|
|
|
if actions[i].move == move and (mode == -1 or scores[i] >= score):
|
|
|
|
mode = i
|
|
|
|
score = scores[i]
|
|
|
|
return mode
|