Work on parser beam training

This commit is contained in:
Matthew Honnibal 2017-08-12 14:47:45 -05:00
parent 4ab0c8c8e9
commit b353e4d843
4 changed files with 321 additions and 1 deletions

View File

@ -36,6 +36,7 @@ MOD_NAMES = [
'spacy.syntax.transition_system', 'spacy.syntax.transition_system',
'spacy.syntax.arc_eager', 'spacy.syntax.arc_eager',
'spacy.syntax._parse_features', 'spacy.syntax._parse_features',
'spacy.syntax._beam_utils',
'spacy.gold', 'spacy.gold',
'spacy.tokens.doc', 'spacy.tokens.doc',
'spacy.tokens.span', 'spacy.tokens.span',

View File

@ -0,0 +1,196 @@
# cython: infer_types=True
cimport numpy as np
import numpy
from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF
from thinc.extra.search cimport Beam
from thinc.extra.search import MaxViolation
from thinc.typedefs cimport hash_t, class_t
from .transition_system cimport TransitionSystem, Transition
from .stateclass cimport StateClass
from ..gold cimport GoldParse
from ..tokens.doc cimport Doc
# These are passed as callbacks to thinc.search.Beam
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
dest = <StateClass>_dest
src = <StateClass>_src
moves = <const Transition*>_moves
dest.clone(src)
moves[clas].do(dest.c, moves[clas].label)
cdef int _check_final_state(void* _state, void* extra_args) except -1:
return (<StateClass>_state).is_final()
def _cleanup(Beam beam):
for i in range(beam.width):
Py_XDECREF(<PyObject*>beam._states[i].content)
Py_XDECREF(<PyObject*>beam._parents[i].content)
cdef hash_t _hash_state(void* _state, void* _) except 0:
state = <StateClass>_state
if state.c.is_final():
return 1
else:
return state.c.hash()
cdef class ParserBeam(object):
cdef public TransitionSystem moves
cdef public object docs
cdef public object golds
cdef public object beams
def __init__(self, TransitionSystem moves, docs, golds,
int width=4, float density=0.001):
self.moves = moves
self.docs = docs
self.golds = golds
self.beams = []
cdef Doc doc
cdef Beam beam
for doc in docs:
beam = Beam(self.moves.n_moves, width, density)
beam.initialize(self.moves.init_beam_state, doc.length, doc.c)
self.beams.append(beam)
@property
def is_done(self):
return all(beam.is_done for beam in self.beams)
def __getitem__(self, i):
return self.beams[i]
def __len__(self):
return len(self.beams)
def advance(self, scores, follow_gold=False):
cdef Beam beam
for i, beam in enumerate(self.beams):
self._set_scores(beam, scores[i])
if self.golds is not None:
self._set_costs(beam, self.golds[i], follow_gold=follow_gold)
if follow_gold:
assert self.golds is not None
beam.advance(_transition_state, NULL, <void*>self.moves.c)
else:
beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
beam.check_done(_check_final_state, NULL)
def _set_scores(self, Beam beam, scores):
for i in range(beam.size):
state = <StateClass>beam.at(i)
for j in range(beam.nr_class):
beam.scores[i][j] = scores[i, j]
self.moves.set_valid(beam.is_valid[i], state.c)
def _set_costs(self, Beam beam, GoldParse gold, int follow_gold=False):
for i in range(beam.size):
state = <StateClass>beam.at(i)
self.moves.set_costs(beam.is_valid[i], beam.costs[i], state, gold)
if follow_gold:
for j in range(beam.nr_class):
beam.is_valid[i][j] *= beam.costs[i][j] <= 0
def get_token_ids(states, int n_tokens):
cdef StateClass state
cdef np.ndarray ids = numpy.zeros((len(states), n_tokens),
dtype='i', order='C')
c_ids = <int*>ids.data
for i, state in enumerate(states):
if not state.is_final():
state.c.set_context_tokens(c_ids, n_tokens)
c_ids += ids.shape[1]
return ids
def update_beam(TransitionSystem moves, int nr_feature,
docs, tokvecs, golds,
state2vec, vec2scores, drop=0., sgd=None,
losses=None, int width=4, float density=0.001):
pbeam = ParserBeam(moves, docs, golds,
width=width, density=density)
gbeam = ParserBeam(moves, docs, golds,
width=width, density=density)
beam_map = {}
backprops = []
violns = [MaxViolation() for _ in range(len(docs))]
example_ids = list(range(len(docs)))
while not pbeam.is_done and not gbeam.is_done:
states, p_indices, g_indices = get_states(example_ids, pbeam, gbeam, beam_map)
token_ids = get_token_ids(states, nr_feature)
vectors, bp_vectors = state2vec.begin_update(token_ids, drop=drop)
scores, bp_scores = vec2scores.begin_update(vectors, drop=drop)
backprops.append((token_ids, bp_vectors, bp_scores))
p_scores = [scores[indices] for indices in p_indices]
g_scores = [scores[indices] for indices in g_indices]
pbeam.advance(p_scores)
gbeam.advance(g_scores, follow_gold=True)
for i, violn in enumerate(violns):
violn.check_crf(pbeam[i], gbeam[i])
histories = [(v.p_hist + v.g_hist) for v in violns]
losses = [(v.p_probs + v.g_probs) for v in violns]
states_d_scores = get_gradient(moves.n_moves, beam_map,
histories, losses)
return states_d_scores, backprops
def get_states(example_ids, pbeams, gbeams, beam_map):
states = []
seen = {}
p_indices = []
g_indices = []
cdef Beam pbeam, gbeam
for eg_id, pbeam, gbeam in zip(example_ids, pbeams, gbeams):
p_indices.append([])
for j in range(pbeam.size):
key = tuple([eg_id] + pbeam.histories[j])
seen[key] = len(states)
p_indices[-1].append(len(states))
states.append(<StateClass>pbeam.at(j))
beam_map.update(seen)
g_indices.append([])
for i in range(gbeam.size):
key = tuple([eg_id] + gbeam.histories[i])
if key in seen:
g_indices[-1].append(seen[key])
else:
g_indices[-1].append(len(states))
beam_map[key] = len(states)
states.append(<StateClass>gbeam.at(i))
p_indices = numpy.asarray(p_indices, dtype='i')
g_indices = numpy.asarray(g_indices, dtype='i')
return states, p_indices, g_indices
def get_gradient(nr_class, beam_map, histories, losses):
"""
The global model assigns a loss to each parse. The beam scores
are additive, so the same gradient is applied to each action
in the history. This gives the gradient of a single *action*
for a beam state -- so we have "the gradient of loss for taking
action i given history H."
"""
nr_step = max(len(hist) for hist in histories)
nr_beam = len(histories)
grads = [numpy.zeros((nr_beam, nr_class), dtype='f') for _ in range(nr_step)]
for hist, loss in zip(histories, losses):
key = tuple()
for j, clas in enumerate(hist):
grads[j][i, clas] = loss
key = key + clas
i = beam_map[key]
return grads

View File

@ -63,6 +63,7 @@ from ..tokens.doc cimport Doc
from ..strings cimport StringStore from ..strings cimport StringStore
from ..gold cimport GoldParse from ..gold cimport GoldParse
from ..attrs cimport ID, TAG, DEP, ORTH, NORM, PREFIX, SUFFIX, TAG from ..attrs cimport ID, TAG, DEP, ORTH, NORM, PREFIX, SUFFIX, TAG
from . import _beam_utils
USE_FINE_TUNE = True USE_FINE_TUNE = True
@ -256,7 +257,7 @@ cdef class Parser:
with Model.use_device('cpu'): with Model.use_device('cpu'):
upper = chain( upper = chain(
clone(drop_layer(Residual(Maxout(hidden_width))), (depth-1)), clone(Residual(ReLu(hidden_width)), (depth-1)),
zero_init(Affine(nr_class, drop_factor=0.0)) zero_init(Affine(nr_class, drop_factor=0.0))
) )
# TODO: This is an unfortunate hack atm! # TODO: This is an unfortunate hack atm!
@ -526,6 +527,30 @@ cdef class Parser:
bp_my_tokvecs(d_tokvecs, sgd=sgd) bp_my_tokvecs(d_tokvecs, sgd=sgd)
return d_tokvecs return d_tokvecs
def update_beam(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None):
docs, tokvecs = docs_tokvecs
tokvecs = self.model[0].ops.flatten(tokvecs)
cuda_stream = get_cuda_stream()
state2vec, vec2scores = self.get_batch_model(len(docs), tokvecs, cuda_stream, 0.0)
states_d_scores, backprops = _beam_utils.update_beam(self.moves, self.nr_feature,
docs, tokvecs, golds,
state2vec, vec2scores,
drop, sgd, losses)
backprop_lower = []
for i, d_scores in enumerate(states_d_scores):
ids, bp_vectors, bp_scores = backprops[i]
d_vector = bp_scores(d_scores, sgd=sgd)
backprop_lower.append((
get_async(cuda_stream, ids),
get_async(cuda_stream, d_vector),
bp_vectors))
d_tokvecs = self.model[0].ops.allocate(tokvecs.shape)
self._make_updates(d_tokvecs, backprop_lower, sgd, cuda_stream)
lengths = [len(doc) for doc in docs]
return self.model[0].ops.unflatten(d_tokvecs, lengths)
def _init_gold_batch(self, whole_docs, whole_golds): def _init_gold_batch(self, whole_docs, whole_golds):
"""Make a square batch, of length equal to the shortest doc. A long """Make a square batch, of length equal to the shortest doc. A long
doc will get multiple states. Let's say we have a doc of length 2*N, doc will get multiple states. Let's say we have a doc of length 2*N,

View File

@ -0,0 +1,98 @@
from __future__ import unicode_literals
import pytest
import numpy
from thinc.api import layerize
from ...vocab import Vocab
from ...syntax.arc_eager import ArcEager
from ...tokens import Doc
from ...gold import GoldParse
from ...syntax._beam_utils import ParserBeam, update_beam
@pytest.fixture
def vocab():
return Vocab()
@pytest.fixture
def moves(vocab):
aeager = ArcEager(vocab.strings, {})
aeager.add_action(2, 'nsubj')
aeager.add_action(3, 'dobj')
aeager.add_action(2, 'aux')
return aeager
@pytest.fixture
def docs(vocab):
return [Doc(vocab, words=['Rats', 'bite', 'things'])]
@pytest.fixture
def tokvecs(docs, vector_size):
output = []
for doc in docs:
vec = numpy.random.uniform(-0.1, 0.1, (len(doc), vector_size))
output.append(numpy.asarray(vec))
return output
@pytest.fixture
def golds(docs):
return [GoldParse(doc) for doc in docs]
@pytest.fixture
def batch_size(docs):
return len(docs)
@pytest.fixture
def beam_width():
return 4
@pytest.fixture
def vector_size():
return 6
@pytest.fixture
def beam(moves, docs, golds, beam_width):
return ParserBeam(moves, docs, golds, width=beam_width)
@pytest.fixture
def scores(moves, batch_size, beam_width):
return [
numpy.asarray(
numpy.random.uniform(-0.1, 0.1, (batch_size, moves.n_moves)),
dtype='f')
for _ in range(batch_size)]
def test_create_beam(beam):
pass
def test_beam_advance(beam, scores):
beam.advance(scores)
def test_beam_advance_too_few_scores(beam, scores):
with pytest.raises(IndexError):
beam.advance(scores[:-1])
def test_update_beam(moves, docs, tokvecs, golds, vector_size):
@layerize
def state2vec(X, drop=0.):
vec = numpy.ones((X.shape[0], vector_size), dtype='f')
return vec, None
@layerize
def vec2scores(X, drop=0.):
scores = numpy.ones((X.shape[0], moves.n_moves), dtype='f')
return scores, None
d_loss, backprops = update_beam(moves, 13, docs, tokvecs, golds,
state2vec, vec2scores, drop=0.0, sgd=None,
losses={}, width=4, density=0.001)