mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Work on parser beam training
This commit is contained in:
parent
4ab0c8c8e9
commit
b353e4d843
1
setup.py
1
setup.py
|
@ -36,6 +36,7 @@ MOD_NAMES = [
|
||||||
'spacy.syntax.transition_system',
|
'spacy.syntax.transition_system',
|
||||||
'spacy.syntax.arc_eager',
|
'spacy.syntax.arc_eager',
|
||||||
'spacy.syntax._parse_features',
|
'spacy.syntax._parse_features',
|
||||||
|
'spacy.syntax._beam_utils',
|
||||||
'spacy.gold',
|
'spacy.gold',
|
||||||
'spacy.tokens.doc',
|
'spacy.tokens.doc',
|
||||||
'spacy.tokens.span',
|
'spacy.tokens.span',
|
||||||
|
|
196
spacy/syntax/_beam_utils.pyx
Normal file
196
spacy/syntax/_beam_utils.pyx
Normal file
|
@ -0,0 +1,196 @@
|
||||||
|
# cython: infer_types=True
|
||||||
|
cimport numpy as np
|
||||||
|
import numpy
|
||||||
|
from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF
|
||||||
|
from thinc.extra.search cimport Beam
|
||||||
|
from thinc.extra.search import MaxViolation
|
||||||
|
from thinc.typedefs cimport hash_t, class_t
|
||||||
|
|
||||||
|
from .transition_system cimport TransitionSystem, Transition
|
||||||
|
from .stateclass cimport StateClass
|
||||||
|
from ..gold cimport GoldParse
|
||||||
|
from ..tokens.doc cimport Doc
|
||||||
|
|
||||||
|
|
||||||
|
# These are passed as callbacks to thinc.search.Beam
|
||||||
|
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
|
||||||
|
dest = <StateClass>_dest
|
||||||
|
src = <StateClass>_src
|
||||||
|
moves = <const Transition*>_moves
|
||||||
|
dest.clone(src)
|
||||||
|
moves[clas].do(dest.c, moves[clas].label)
|
||||||
|
|
||||||
|
|
||||||
|
cdef int _check_final_state(void* _state, void* extra_args) except -1:
|
||||||
|
return (<StateClass>_state).is_final()
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup(Beam beam):
|
||||||
|
for i in range(beam.width):
|
||||||
|
Py_XDECREF(<PyObject*>beam._states[i].content)
|
||||||
|
Py_XDECREF(<PyObject*>beam._parents[i].content)
|
||||||
|
|
||||||
|
|
||||||
|
cdef hash_t _hash_state(void* _state, void* _) except 0:
|
||||||
|
state = <StateClass>_state
|
||||||
|
if state.c.is_final():
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
return state.c.hash()
|
||||||
|
|
||||||
|
|
||||||
|
cdef class ParserBeam(object):
|
||||||
|
cdef public TransitionSystem moves
|
||||||
|
cdef public object docs
|
||||||
|
cdef public object golds
|
||||||
|
cdef public object beams
|
||||||
|
|
||||||
|
def __init__(self, TransitionSystem moves, docs, golds,
|
||||||
|
int width=4, float density=0.001):
|
||||||
|
self.moves = moves
|
||||||
|
self.docs = docs
|
||||||
|
self.golds = golds
|
||||||
|
self.beams = []
|
||||||
|
cdef Doc doc
|
||||||
|
cdef Beam beam
|
||||||
|
for doc in docs:
|
||||||
|
beam = Beam(self.moves.n_moves, width, density)
|
||||||
|
beam.initialize(self.moves.init_beam_state, doc.length, doc.c)
|
||||||
|
self.beams.append(beam)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_done(self):
|
||||||
|
return all(beam.is_done for beam in self.beams)
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
return self.beams[i]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.beams)
|
||||||
|
|
||||||
|
def advance(self, scores, follow_gold=False):
|
||||||
|
cdef Beam beam
|
||||||
|
for i, beam in enumerate(self.beams):
|
||||||
|
self._set_scores(beam, scores[i])
|
||||||
|
if self.golds is not None:
|
||||||
|
self._set_costs(beam, self.golds[i], follow_gold=follow_gold)
|
||||||
|
if follow_gold:
|
||||||
|
assert self.golds is not None
|
||||||
|
beam.advance(_transition_state, NULL, <void*>self.moves.c)
|
||||||
|
else:
|
||||||
|
beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
|
||||||
|
beam.check_done(_check_final_state, NULL)
|
||||||
|
|
||||||
|
def _set_scores(self, Beam beam, scores):
|
||||||
|
for i in range(beam.size):
|
||||||
|
state = <StateClass>beam.at(i)
|
||||||
|
for j in range(beam.nr_class):
|
||||||
|
beam.scores[i][j] = scores[i, j]
|
||||||
|
self.moves.set_valid(beam.is_valid[i], state.c)
|
||||||
|
|
||||||
|
def _set_costs(self, Beam beam, GoldParse gold, int follow_gold=False):
|
||||||
|
for i in range(beam.size):
|
||||||
|
state = <StateClass>beam.at(i)
|
||||||
|
self.moves.set_costs(beam.is_valid[i], beam.costs[i], state, gold)
|
||||||
|
if follow_gold:
|
||||||
|
for j in range(beam.nr_class):
|
||||||
|
beam.is_valid[i][j] *= beam.costs[i][j] <= 0
|
||||||
|
|
||||||
|
|
||||||
|
def get_token_ids(states, int n_tokens):
|
||||||
|
cdef StateClass state
|
||||||
|
cdef np.ndarray ids = numpy.zeros((len(states), n_tokens),
|
||||||
|
dtype='i', order='C')
|
||||||
|
c_ids = <int*>ids.data
|
||||||
|
for i, state in enumerate(states):
|
||||||
|
if not state.is_final():
|
||||||
|
state.c.set_context_tokens(c_ids, n_tokens)
|
||||||
|
c_ids += ids.shape[1]
|
||||||
|
return ids
|
||||||
|
|
||||||
|
|
||||||
|
def update_beam(TransitionSystem moves, int nr_feature,
|
||||||
|
docs, tokvecs, golds,
|
||||||
|
state2vec, vec2scores, drop=0., sgd=None,
|
||||||
|
losses=None, int width=4, float density=0.001):
|
||||||
|
pbeam = ParserBeam(moves, docs, golds,
|
||||||
|
width=width, density=density)
|
||||||
|
gbeam = ParserBeam(moves, docs, golds,
|
||||||
|
width=width, density=density)
|
||||||
|
beam_map = {}
|
||||||
|
backprops = []
|
||||||
|
violns = [MaxViolation() for _ in range(len(docs))]
|
||||||
|
example_ids = list(range(len(docs)))
|
||||||
|
while not pbeam.is_done and not gbeam.is_done:
|
||||||
|
states, p_indices, g_indices = get_states(example_ids, pbeam, gbeam, beam_map)
|
||||||
|
|
||||||
|
token_ids = get_token_ids(states, nr_feature)
|
||||||
|
vectors, bp_vectors = state2vec.begin_update(token_ids, drop=drop)
|
||||||
|
scores, bp_scores = vec2scores.begin_update(vectors, drop=drop)
|
||||||
|
|
||||||
|
backprops.append((token_ids, bp_vectors, bp_scores))
|
||||||
|
|
||||||
|
p_scores = [scores[indices] for indices in p_indices]
|
||||||
|
g_scores = [scores[indices] for indices in g_indices]
|
||||||
|
pbeam.advance(p_scores)
|
||||||
|
gbeam.advance(g_scores, follow_gold=True)
|
||||||
|
|
||||||
|
for i, violn in enumerate(violns):
|
||||||
|
violn.check_crf(pbeam[i], gbeam[i])
|
||||||
|
|
||||||
|
histories = [(v.p_hist + v.g_hist) for v in violns]
|
||||||
|
losses = [(v.p_probs + v.g_probs) for v in violns]
|
||||||
|
states_d_scores = get_gradient(moves.n_moves, beam_map,
|
||||||
|
histories, losses)
|
||||||
|
return states_d_scores, backprops
|
||||||
|
|
||||||
|
|
||||||
|
def get_states(example_ids, pbeams, gbeams, beam_map):
|
||||||
|
states = []
|
||||||
|
seen = {}
|
||||||
|
p_indices = []
|
||||||
|
g_indices = []
|
||||||
|
cdef Beam pbeam, gbeam
|
||||||
|
for eg_id, pbeam, gbeam in zip(example_ids, pbeams, gbeams):
|
||||||
|
p_indices.append([])
|
||||||
|
for j in range(pbeam.size):
|
||||||
|
key = tuple([eg_id] + pbeam.histories[j])
|
||||||
|
seen[key] = len(states)
|
||||||
|
p_indices[-1].append(len(states))
|
||||||
|
states.append(<StateClass>pbeam.at(j))
|
||||||
|
beam_map.update(seen)
|
||||||
|
g_indices.append([])
|
||||||
|
for i in range(gbeam.size):
|
||||||
|
key = tuple([eg_id] + gbeam.histories[i])
|
||||||
|
if key in seen:
|
||||||
|
g_indices[-1].append(seen[key])
|
||||||
|
else:
|
||||||
|
g_indices[-1].append(len(states))
|
||||||
|
beam_map[key] = len(states)
|
||||||
|
states.append(<StateClass>gbeam.at(i))
|
||||||
|
|
||||||
|
p_indices = numpy.asarray(p_indices, dtype='i')
|
||||||
|
g_indices = numpy.asarray(g_indices, dtype='i')
|
||||||
|
return states, p_indices, g_indices
|
||||||
|
|
||||||
|
|
||||||
|
def get_gradient(nr_class, beam_map, histories, losses):
|
||||||
|
"""
|
||||||
|
The global model assigns a loss to each parse. The beam scores
|
||||||
|
are additive, so the same gradient is applied to each action
|
||||||
|
in the history. This gives the gradient of a single *action*
|
||||||
|
for a beam state -- so we have "the gradient of loss for taking
|
||||||
|
action i given history H."
|
||||||
|
"""
|
||||||
|
nr_step = max(len(hist) for hist in histories)
|
||||||
|
nr_beam = len(histories)
|
||||||
|
grads = [numpy.zeros((nr_beam, nr_class), dtype='f') for _ in range(nr_step)]
|
||||||
|
for hist, loss in zip(histories, losses):
|
||||||
|
key = tuple()
|
||||||
|
for j, clas in enumerate(hist):
|
||||||
|
grads[j][i, clas] = loss
|
||||||
|
key = key + clas
|
||||||
|
i = beam_map[key]
|
||||||
|
return grads
|
||||||
|
|
||||||
|
|
|
@ -63,6 +63,7 @@ from ..tokens.doc cimport Doc
|
||||||
from ..strings cimport StringStore
|
from ..strings cimport StringStore
|
||||||
from ..gold cimport GoldParse
|
from ..gold cimport GoldParse
|
||||||
from ..attrs cimport ID, TAG, DEP, ORTH, NORM, PREFIX, SUFFIX, TAG
|
from ..attrs cimport ID, TAG, DEP, ORTH, NORM, PREFIX, SUFFIX, TAG
|
||||||
|
from . import _beam_utils
|
||||||
|
|
||||||
USE_FINE_TUNE = True
|
USE_FINE_TUNE = True
|
||||||
|
|
||||||
|
@ -256,7 +257,7 @@ cdef class Parser:
|
||||||
|
|
||||||
with Model.use_device('cpu'):
|
with Model.use_device('cpu'):
|
||||||
upper = chain(
|
upper = chain(
|
||||||
clone(drop_layer(Residual(Maxout(hidden_width))), (depth-1)),
|
clone(Residual(ReLu(hidden_width)), (depth-1)),
|
||||||
zero_init(Affine(nr_class, drop_factor=0.0))
|
zero_init(Affine(nr_class, drop_factor=0.0))
|
||||||
)
|
)
|
||||||
# TODO: This is an unfortunate hack atm!
|
# TODO: This is an unfortunate hack atm!
|
||||||
|
@ -526,6 +527,30 @@ cdef class Parser:
|
||||||
bp_my_tokvecs(d_tokvecs, sgd=sgd)
|
bp_my_tokvecs(d_tokvecs, sgd=sgd)
|
||||||
return d_tokvecs
|
return d_tokvecs
|
||||||
|
|
||||||
|
def update_beam(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None):
|
||||||
|
docs, tokvecs = docs_tokvecs
|
||||||
|
tokvecs = self.model[0].ops.flatten(tokvecs)
|
||||||
|
|
||||||
|
cuda_stream = get_cuda_stream()
|
||||||
|
state2vec, vec2scores = self.get_batch_model(len(docs), tokvecs, cuda_stream, 0.0)
|
||||||
|
|
||||||
|
states_d_scores, backprops = _beam_utils.update_beam(self.moves, self.nr_feature,
|
||||||
|
docs, tokvecs, golds,
|
||||||
|
state2vec, vec2scores,
|
||||||
|
drop, sgd, losses)
|
||||||
|
backprop_lower = []
|
||||||
|
for i, d_scores in enumerate(states_d_scores):
|
||||||
|
ids, bp_vectors, bp_scores = backprops[i]
|
||||||
|
d_vector = bp_scores(d_scores, sgd=sgd)
|
||||||
|
backprop_lower.append((
|
||||||
|
get_async(cuda_stream, ids),
|
||||||
|
get_async(cuda_stream, d_vector),
|
||||||
|
bp_vectors))
|
||||||
|
d_tokvecs = self.model[0].ops.allocate(tokvecs.shape)
|
||||||
|
self._make_updates(d_tokvecs, backprop_lower, sgd, cuda_stream)
|
||||||
|
lengths = [len(doc) for doc in docs]
|
||||||
|
return self.model[0].ops.unflatten(d_tokvecs, lengths)
|
||||||
|
|
||||||
def _init_gold_batch(self, whole_docs, whole_golds):
|
def _init_gold_batch(self, whole_docs, whole_golds):
|
||||||
"""Make a square batch, of length equal to the shortest doc. A long
|
"""Make a square batch, of length equal to the shortest doc. A long
|
||||||
doc will get multiple states. Let's say we have a doc of length 2*N,
|
doc will get multiple states. Let's say we have a doc of length 2*N,
|
||||||
|
|
98
spacy/tests/parser/test_nn_beam.py
Normal file
98
spacy/tests/parser/test_nn_beam.py
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
import pytest
|
||||||
|
import numpy
|
||||||
|
from thinc.api import layerize
|
||||||
|
|
||||||
|
from ...vocab import Vocab
|
||||||
|
from ...syntax.arc_eager import ArcEager
|
||||||
|
from ...tokens import Doc
|
||||||
|
from ...gold import GoldParse
|
||||||
|
from ...syntax._beam_utils import ParserBeam, update_beam
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def vocab():
|
||||||
|
return Vocab()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def moves(vocab):
|
||||||
|
aeager = ArcEager(vocab.strings, {})
|
||||||
|
aeager.add_action(2, 'nsubj')
|
||||||
|
aeager.add_action(3, 'dobj')
|
||||||
|
aeager.add_action(2, 'aux')
|
||||||
|
return aeager
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def docs(vocab):
|
||||||
|
return [Doc(vocab, words=['Rats', 'bite', 'things'])]
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tokvecs(docs, vector_size):
|
||||||
|
output = []
|
||||||
|
for doc in docs:
|
||||||
|
vec = numpy.random.uniform(-0.1, 0.1, (len(doc), vector_size))
|
||||||
|
output.append(numpy.asarray(vec))
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def golds(docs):
|
||||||
|
return [GoldParse(doc) for doc in docs]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def batch_size(docs):
|
||||||
|
return len(docs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def beam_width():
|
||||||
|
return 4
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def vector_size():
|
||||||
|
return 6
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def beam(moves, docs, golds, beam_width):
|
||||||
|
return ParserBeam(moves, docs, golds, width=beam_width)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def scores(moves, batch_size, beam_width):
|
||||||
|
return [
|
||||||
|
numpy.asarray(
|
||||||
|
numpy.random.uniform(-0.1, 0.1, (batch_size, moves.n_moves)),
|
||||||
|
dtype='f')
|
||||||
|
for _ in range(batch_size)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_beam(beam):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_beam_advance(beam, scores):
|
||||||
|
beam.advance(scores)
|
||||||
|
|
||||||
|
|
||||||
|
def test_beam_advance_too_few_scores(beam, scores):
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
beam.advance(scores[:-1])
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_beam(moves, docs, tokvecs, golds, vector_size):
|
||||||
|
@layerize
|
||||||
|
def state2vec(X, drop=0.):
|
||||||
|
vec = numpy.ones((X.shape[0], vector_size), dtype='f')
|
||||||
|
return vec, None
|
||||||
|
@layerize
|
||||||
|
def vec2scores(X, drop=0.):
|
||||||
|
scores = numpy.ones((X.shape[0], moves.n_moves), dtype='f')
|
||||||
|
return scores, None
|
||||||
|
d_loss, backprops = update_beam(moves, 13, docs, tokvecs, golds,
|
||||||
|
state2vec, vec2scores, drop=0.0, sgd=None,
|
||||||
|
losses={}, width=4, density=0.001)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user