Restore patches from nn-beam-parser to spacy/syntax

This commit is contained in:
Matthew Honnibal 2017-08-18 22:38:59 +02:00
parent fe90dfc390
commit c307a0ffb8
6 changed files with 116 additions and 66 deletions

View File

@ -6,6 +6,7 @@ from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF
from thinc.extra.search cimport Beam from thinc.extra.search cimport Beam
from thinc.extra.search import MaxViolation from thinc.extra.search import MaxViolation
from thinc.typedefs cimport hash_t, class_t from thinc.typedefs cimport hash_t, class_t
from thinc.extra.search cimport MaxViolation
from .transition_system cimport TransitionSystem, Transition from .transition_system cimport TransitionSystem, Transition
from .stateclass cimport StateClass from .stateclass cimport StateClass
@ -45,9 +46,10 @@ cdef class ParserBeam(object):
cdef public object states cdef public object states
cdef public object golds cdef public object golds
cdef public object beams cdef public object beams
cdef public object dones
def __init__(self, TransitionSystem moves, states, golds, def __init__(self, TransitionSystem moves, states, golds,
int width=4, float density=0.001): int width, float density):
self.moves = moves self.moves = moves
self.states = states self.states = states
self.golds = golds self.golds = golds
@ -61,6 +63,7 @@ cdef class ParserBeam(object):
st = <StateClass>beam.at(i) st = <StateClass>beam.at(i)
st.c.offset = state.c.offset st.c.offset = state.c.offset
self.beams.append(beam) self.beams.append(beam)
self.dones = [False] * len(self.beams)
def __dealloc__(self): def __dealloc__(self):
if self.beams is not None: if self.beams is not None:
@ -70,7 +73,7 @@ cdef class ParserBeam(object):
@property @property
def is_done(self): def is_done(self):
return all(b.is_done for b in self.beams) return all(b.is_done or self.dones[i] for i, b in enumerate(self.beams))
def __getitem__(self, i): def __getitem__(self, i):
return self.beams[i] return self.beams[i]
@ -81,32 +84,42 @@ cdef class ParserBeam(object):
def advance(self, scores, follow_gold=False): def advance(self, scores, follow_gold=False):
cdef Beam beam cdef Beam beam
for i, beam in enumerate(self.beams): for i, beam in enumerate(self.beams):
if beam.is_done or not scores[i].size: if beam.is_done or not scores[i].size or self.dones[i]:
continue continue
self._set_scores(beam, scores[i]) self._set_scores(beam, scores[i])
if self.golds is not None: if self.golds is not None:
self._set_costs(beam, self.golds[i], follow_gold=follow_gold) self._set_costs(beam, self.golds[i], follow_gold=follow_gold)
if follow_gold: if follow_gold:
assert self.golds is not None
beam.advance(_transition_state, NULL, <void*>self.moves.c) beam.advance(_transition_state, NULL, <void*>self.moves.c)
else: else:
beam.advance(_transition_state, _hash_state, <void*>self.moves.c) beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
beam.check_done(_check_final_state, NULL) beam.check_done(_check_final_state, NULL)
if beam.is_done: if beam.is_done and self.golds is not None:
for j in range(beam.size): for j in range(beam.size):
if is_gold(<StateClass>beam.at(j), self.golds[i], self.moves.strings): state = <StateClass>beam.at(j)
beam._states[j].loss = 0.0 if state.is_final():
elif beam._states[j].loss == 0.0: try:
beam._states[j].loss = 1.0 if self.moves.is_gold_parse(state, self.golds[i]):
beam._states[j].loss = 0.0
elif beam._states[j].loss == 0.0:
beam._states[j].loss = 1.0
except NotImplementedError:
break
def _set_scores(self, Beam beam, float[:, ::1] scores): def _set_scores(self, Beam beam, float[:, ::1] scores):
cdef float* c_scores = &scores[0, 0] cdef float* c_scores = &scores[0, 0]
for i in range(beam.size): cdef int nr_state = min(scores.shape[0], beam.size)
cdef int nr_class = scores.shape[1]
for i in range(nr_state):
state = <StateClass>beam.at(i) state = <StateClass>beam.at(i)
if not state.is_final(): if not state.is_final():
for j in range(beam.nr_class): for j in range(nr_class):
beam.scores[i][j] = c_scores[i * beam.nr_class + j] beam.scores[i][j] = c_scores[i * nr_class + j]
self.moves.set_valid(beam.is_valid[i], state.c) self.moves.set_valid(beam.is_valid[i], state.c)
else:
for j in range(beam.nr_class):
beam.scores[i][j] = 0
beam.costs[i][j] = 0
def _set_costs(self, Beam beam, GoldParse gold, int follow_gold=False): def _set_costs(self, Beam beam, GoldParse gold, int follow_gold=False):
for i in range(beam.size): for i in range(beam.size):
@ -119,21 +132,6 @@ cdef class ParserBeam(object):
beam.is_valid[i][j] = 0 beam.is_valid[i][j] = 0
def is_gold(StateClass state, GoldParse gold, strings):
predicted = set()
truth = set()
for i in range(gold.length):
if gold.cand_to_gold[i] is None:
continue
if state.safe_get(i).dep:
predicted.add((i, state.H(i), strings[state.safe_get(i).dep]))
else:
predicted.add((i, state.H(i), 'ROOT'))
id_, word, tag, head, dep, ner = gold.orig_annot[gold.cand_to_gold[i]]
truth.add((id_, head, dep))
return truth == predicted
def get_token_ids(states, int n_tokens): def get_token_ids(states, int n_tokens):
cdef StateClass state cdef StateClass state
cdef np.ndarray ids = numpy.zeros((len(states), n_tokens), cdef np.ndarray ids = numpy.zeros((len(states), n_tokens),
@ -150,9 +148,11 @@ def get_token_ids(states, int n_tokens):
nr_update = 0 nr_update = 0
def update_beam(TransitionSystem moves, int nr_feature, int max_steps, def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
states, tokvecs, golds, states, tokvecs, golds,
state2vec, vec2scores, drop=0., sgd=None, state2vec, vec2scores,
losses=None, int width=4, float density=0.001): int width, float density,
sgd=None, losses=None, drop=0.):
global nr_update global nr_update
cdef MaxViolation violn
nr_update += 1 nr_update += 1
pbeam = ParserBeam(moves, states, golds, pbeam = ParserBeam(moves, states, golds,
width=width, density=density) width=width, density=density)
@ -163,6 +163,8 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
backprops = [] backprops = []
violns = [MaxViolation() for _ in range(len(states))] violns = [MaxViolation() for _ in range(len(states))]
for t in range(max_steps): for t in range(max_steps):
if pbeam.is_done and gbeam.is_done:
break
# The beam maps let us find the right row in the flattened scores # The beam maps let us find the right row in the flattened scores
# arrays for each state. States are identified by (example id, history). # arrays for each state. States are identified by (example id, history).
# We keep a different beam map for each step (since we'll have a flat # We keep a different beam map for each step (since we'll have a flat
@ -194,14 +196,17 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
# Track the "maximum violation", to use in the update. # Track the "maximum violation", to use in the update.
for i, violn in enumerate(violns): for i, violn in enumerate(violns):
violn.check_crf(pbeam[i], gbeam[i]) violn.check_crf(pbeam[i], gbeam[i])
histories = []
# Only make updates if we have non-gold states losses = []
histories = [((v.p_hist + v.g_hist) if v.p_hist else []) for v in violns] for violn in violns:
losses = [((v.p_probs + v.g_probs) if v.p_probs else []) for v in violns] if violn.p_hist:
states_d_scores = get_gradient(moves.n_moves, beam_maps, histories.append(violn.p_hist + violn.g_hist)
histories, losses) losses.append(violn.p_probs + violn.g_probs)
assert len(states_d_scores) == len(backprops), (len(states_d_scores), len(backprops)) else:
return states_d_scores, backprops histories.append([])
losses.append([])
states_d_scores = get_gradient(moves.n_moves, beam_maps, histories, losses)
return states_d_scores, backprops[:len(states_d_scores)]
def get_states(pbeams, gbeams, beam_map, nr_update): def get_states(pbeams, gbeams, beam_map, nr_update):
@ -214,12 +219,11 @@ def get_states(pbeams, gbeams, beam_map, nr_update):
for eg_id, (pbeam, gbeam) in enumerate(zip(pbeams, gbeams)): for eg_id, (pbeam, gbeam) in enumerate(zip(pbeams, gbeams)):
p_indices.append([]) p_indices.append([])
g_indices.append([]) g_indices.append([])
if pbeam.loss > 0 and pbeam.min_score > gbeam.score:
continue
for i in range(pbeam.size): for i in range(pbeam.size):
state = <StateClass>pbeam.at(i) state = <StateClass>pbeam.at(i)
if not state.is_final(): if not state.is_final():
key = tuple([eg_id] + pbeam.histories[i]) key = tuple([eg_id] + pbeam.histories[i])
assert key not in seen, (key, seen)
seen[key] = len(states) seen[key] = len(states)
p_indices[-1].append(len(states)) p_indices[-1].append(len(states))
states.append(state) states.append(state)
@ -255,18 +259,27 @@ def get_gradient(nr_class, beam_maps, histories, losses):
""" """
nr_step = len(beam_maps) nr_step = len(beam_maps)
grads = [] grads = []
for beam_map in beam_maps: nr_step = 0
if beam_map: for eg_id, hists in enumerate(histories):
grads.append(numpy.zeros((max(beam_map.values())+1, nr_class), dtype='f')) for loss, hist in zip(losses[eg_id], hists):
if loss != 0.0 and not numpy.isnan(loss):
nr_step = max(nr_step, len(hist))
for i in range(nr_step):
grads.append(numpy.zeros((max(beam_maps[i].values())+1, nr_class), dtype='f'))
assert len(histories) == len(losses) assert len(histories) == len(losses)
for eg_id, hists in enumerate(histories): for eg_id, hists in enumerate(histories):
for loss, hist in zip(losses[eg_id], hists): for loss, hist in zip(losses[eg_id], hists):
if loss == 0.0 or numpy.isnan(loss):
continue
key = tuple([eg_id]) key = tuple([eg_id])
# Adjust loss for length
avg_loss = loss / len(hist)
loss += avg_loss * (nr_step - len(hist))
for j, clas in enumerate(hist): for j, clas in enumerate(hist):
i = beam_maps[j][key] i = beam_maps[j][key]
# In step j, at state i action clas # In step j, at state i action clas
# resulted in loss # resulted in loss
grads[j][i, clas] += loss / len(histories) grads[j][i, clas] += loss
key = key + tuple([clas]) key = key + tuple([clas])
return grads return grads

View File

@ -74,7 +74,16 @@ cdef cppclass StateC:
free(this.shifted - PADDING) free(this.shifted - PADDING)
void set_context_tokens(int* ids, int n) nogil: void set_context_tokens(int* ids, int n) nogil:
if n == 13: if n == 8:
ids[0] = this.B(0)
ids[1] = this.B(1)
ids[2] = this.S(0)
ids[3] = this.S(1)
ids[4] = this.H(this.S(0))
ids[5] = this.L(this.B(0), 1)
ids[6] = this.L(this.S(0), 2)
ids[7] = this.R(this.S(0), 1)
elif n == 13:
ids[0] = this.B(0) ids[0] = this.B(0)
ids[1] = this.B(1) ids[1] = this.B(1)
ids[2] = this.S(0) ids[2] = this.S(0)

View File

@ -351,6 +351,20 @@ cdef class ArcEager(TransitionSystem):
def __get__(self): def __get__(self):
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK) return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
def is_gold_parse(self, StateClass state, GoldParse gold):
predicted = set()
truth = set()
for i in range(gold.length):
if gold.cand_to_gold[i] is None:
continue
if state.safe_get(i).dep:
predicted.add((i, state.H(i), self.strings[state.safe_get(i).dep]))
else:
predicted.add((i, state.H(i), 'ROOT'))
id_, word, tag, head, dep, ner = gold.orig_annot[gold.cand_to_gold[i]]
truth.add((id_, head, dep))
return truth == predicted
def has_gold(self, GoldParse gold, start=0, end=None): def has_gold(self, GoldParse gold, start=0, end=None):
end = end or len(gold.heads) end = end or len(gold.heads)
if all([tag is None for tag in gold.heads[start:end]]): if all([tag is None for tag in gold.heads[start:end]]):

View File

@ -34,7 +34,6 @@ from ._parse_features cimport CONTEXT_SIZE
from ._parse_features cimport fill_context from ._parse_features cimport fill_context
from .stateclass cimport StateClass from .stateclass cimport StateClass
from .parser cimport Parser from .parser cimport Parser
from ._beam_utils import is_gold
DEBUG = False DEBUG = False
@ -108,7 +107,7 @@ cdef class BeamParser(Parser):
# The non-monotonic oracle makes it difficult to ensure final costs are # The non-monotonic oracle makes it difficult to ensure final costs are
# correct. Therefore do final correction # correct. Therefore do final correction
for i in range(pred.size): for i in range(pred.size):
if is_gold(<StateClass>pred.at(i), gold_parse, self.moves.strings): if self.moves.is_gold_parse(<StateClass>pred.at(i), gold_parse):
pred._states[i].loss = 0.0 pred._states[i].loss = 0.0
elif pred._states[i].loss == 0.0: elif pred._states[i].loss == 0.0:
pred._states[i].loss = 1.0 pred._states[i].loss = 1.0
@ -214,7 +213,7 @@ def _check_train_integrity(Beam pred, Beam gold, GoldParse gold_parse, Transitio
if not pred._states[i].is_done or pred._states[i].loss == 0: if not pred._states[i].is_done or pred._states[i].loss == 0:
continue continue
state = <StateClass>pred.at(i) state = <StateClass>pred.at(i)
if is_gold(state, gold_parse, moves.strings) == True: if moves.is_gold_parse(state, gold_parse) == True:
for dep in gold_parse.orig_annot: for dep in gold_parse.orig_annot:
print(dep[1], dep[3], dep[4]) print(dep[1], dep[3], dep[4])
print("Cost", pred._states[i].loss) print("Cost", pred._states[i].loss)
@ -228,7 +227,7 @@ def _check_train_integrity(Beam pred, Beam gold, GoldParse gold_parse, Transitio
if not gold._states[i].is_done: if not gold._states[i].is_done:
continue continue
state = <StateClass>gold.at(i) state = <StateClass>gold.at(i)
if is_gold(state, gold_parse, moves.strings) == False: if moves.is_gold(state, gold_parse) == False:
print("Truth") print("Truth")
for dep in gold_parse.orig_annot: for dep in gold_parse.orig_annot:
print(dep[1], dep[3], dep[4]) print(dep[1], dep[3], dep[4])

View File

@ -38,6 +38,7 @@ from preshed.maps cimport map_get
from thinc.api import layerize, chain, noop, clone from thinc.api import layerize, chain, noop, clone
from thinc.neural import Model, Affine, ReLu, Maxout from thinc.neural import Model, Affine, ReLu, Maxout
from thinc.neural._classes.batchnorm import BatchNorm as BN
from thinc.neural._classes.selu import SELU from thinc.neural._classes.selu import SELU
from thinc.neural._classes.layernorm import LayerNorm from thinc.neural._classes.layernorm import LayerNorm
from thinc.neural.ops import NumpyOps, CupyOps from thinc.neural.ops import NumpyOps, CupyOps
@ -66,7 +67,6 @@ from ..attrs cimport ID, TAG, DEP, ORTH, NORM, PREFIX, SUFFIX, TAG
from . import _beam_utils from . import _beam_utils
USE_FINE_TUNE = True USE_FINE_TUNE = True
BEAM_PARSE = True
def get_templates(*args, **kwargs): def get_templates(*args, **kwargs):
return [] return []
@ -258,7 +258,7 @@ cdef class Parser:
with Model.use_device('cpu'): with Model.use_device('cpu'):
upper = chain( upper = chain(
clone(Residual(ReLu(hidden_width)), (depth-1)), clone(Maxout(hidden_width), (depth-1)),
zero_init(Affine(nr_class, drop_factor=0.0)) zero_init(Affine(nr_class, drop_factor=0.0))
) )
# TODO: This is an unfortunate hack atm! # TODO: This is an unfortunate hack atm!
@ -298,6 +298,10 @@ cdef class Parser:
self.moves = self.TransitionSystem(self.vocab.strings, {}) self.moves = self.TransitionSystem(self.vocab.strings, {})
else: else:
self.moves = moves self.moves = moves
if 'beam_width' not in cfg:
cfg['beam_width'] = util.env_opt('beam_width', 1)
if 'beam_density' not in cfg:
cfg['beam_density'] = util.env_opt('beam_density', 0.0)
self.cfg = cfg self.cfg = cfg
if 'actions' in self.cfg: if 'actions' in self.cfg:
for action, labels in self.cfg.get('actions', {}).items(): for action, labels in self.cfg.get('actions', {}).items():
@ -320,7 +324,7 @@ cdef class Parser:
if beam_width is None: if beam_width is None:
beam_width = self.cfg.get('beam_width', 1) beam_width = self.cfg.get('beam_width', 1)
if beam_density is None: if beam_density is None:
beam_density = self.cfg.get('beam_density', 0.001) beam_density = self.cfg.get('beam_density', 0.0)
cdef Beam beam cdef Beam beam
if beam_width == 1: if beam_width == 1:
states = self.parse_batch([doc], [doc.tensor]) states = self.parse_batch([doc], [doc.tensor])
@ -336,7 +340,7 @@ cdef class Parser:
return output return output
def pipe(self, docs, int batch_size=1000, int n_threads=2, def pipe(self, docs, int batch_size=1000, int n_threads=2,
beam_width=1, beam_density=0.001): beam_width=None, beam_density=None):
""" """
Process a stream of documents. Process a stream of documents.
@ -348,8 +352,10 @@ cdef class Parser:
The number of threads with which to work on the buffer in parallel. The number of threads with which to work on the buffer in parallel.
Yields (Doc): Documents, in order. Yields (Doc): Documents, in order.
""" """
if BEAM_PARSE: if beam_width is None:
beam_width = 8 beam_width = self.cfg.get('beam_width', 1)
if beam_density is None:
beam_density = self.cfg.get('beam_density', 0.0)
cdef Doc doc cdef Doc doc
cdef Beam beam cdef Beam beam
for docs in cytoolz.partition_all(batch_size, docs): for docs in cytoolz.partition_all(batch_size, docs):
@ -411,7 +417,7 @@ cdef class Parser:
st = next_step[i] st = next_step[i]
st.set_context_tokens(&c_token_ids[i*nr_feat], nr_feat) st.set_context_tokens(&c_token_ids[i*nr_feat], nr_feat)
self.moves.set_valid(&c_is_valid[i*nr_class], st) self.moves.set_valid(&c_is_valid[i*nr_class], st)
vectors = state2vec(token_ids[:next_step.size()]) vectors = state2vec(token_ids[:next_step.size()])
scores = vec2scores(vectors) scores = vec2scores(vectors)
c_scores = <float*>scores.data c_scores = <float*>scores.data
for i in range(next_step.size()): for i in range(next_step.size()):
@ -427,7 +433,7 @@ cdef class Parser:
next_step.push_back(st) next_step.push_back(st)
return states return states
def beam_parse(self, docs, tokvecses, int beam_width=8, float beam_density=0.001): def beam_parse(self, docs, tokvecses, int beam_width=3, float beam_density=0.001):
cdef Beam beam cdef Beam beam
cdef np.ndarray scores cdef np.ndarray scores
cdef Doc doc cdef Doc doc
@ -477,9 +483,10 @@ cdef class Parser:
return beams return beams
def update(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None): def update(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None):
if BEAM_PARSE: if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() >= 0.5:
return self.update_beam(docs_tokvecs, golds, drop=drop, sgd=sgd, return self.update_beam(docs_tokvecs, golds,
losses=losses) self.cfg['beam_width'], self.cfg['beam_density'],
drop=drop, sgd=sgd, losses=losses)
if losses is not None and self.name not in losses: if losses is not None and self.name not in losses:
losses[self.name] = 0. losses[self.name] = 0.
docs, tokvec_lists = docs_tokvecs docs, tokvec_lists = docs_tokvecs
@ -545,7 +552,12 @@ cdef class Parser:
bp_my_tokvecs(d_tokvecs, sgd=sgd) bp_my_tokvecs(d_tokvecs, sgd=sgd)
return d_tokvecs return d_tokvecs
def update_beam(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None): def update_beam(self, docs_tokvecs, golds, width=None, density=None,
drop=0., sgd=None, losses=None):
if width is None:
width = self.cfg.get('beam_width', 2)
if density is None:
density = self.cfg.get('beam_density', 0.0)
if losses is not None and self.name not in losses: if losses is not None and self.name not in losses:
losses[self.name] = 0. losses[self.name] = 0.
docs, tokvecs = docs_tokvecs docs, tokvecs = docs_tokvecs
@ -567,8 +579,8 @@ cdef class Parser:
states_d_scores, backprops = _beam_utils.update_beam(self.moves, self.nr_feature, 500, states_d_scores, backprops = _beam_utils.update_beam(self.moves, self.nr_feature, 500,
states, tokvecs, golds, states, tokvecs, golds,
state2vec, vec2scores, state2vec, vec2scores,
drop, sgd, losses, width, density,
width=8) sgd=sgd, drop=drop, losses=losses)
backprop_lower = [] backprop_lower = []
for i, d_scores in enumerate(states_d_scores): for i, d_scores in enumerate(states_d_scores):
if losses is not None: if losses is not None:
@ -634,9 +646,9 @@ cdef class Parser:
for ids, d_vector, bp_vector in backprops: for ids, d_vector, bp_vector in backprops:
d_state_features = bp_vector(d_vector, sgd=sgd) d_state_features = bp_vector(d_vector, sgd=sgd)
mask = ids >= 0 mask = ids >= 0
indices = xp.nonzero(mask) d_state_features *= mask.reshape(ids.shape + (1,))
self.model[0].ops.scatter_add(d_tokvecs, ids[indices], self.model[0].ops.scatter_add(d_tokvecs, ids * mask,
d_state_features[indices]) d_state_features)
@property @property
def move_names(self): def move_names(self):
@ -652,7 +664,7 @@ cdef class Parser:
lower, stream, drop=dropout) lower, stream, drop=dropout)
return state2vec, upper return state2vec, upper
nr_feature = 13 nr_feature = 8
def get_token_ids(self, states): def get_token_ids(self, states):
cdef StateClass state cdef StateClass state

View File

@ -99,6 +99,9 @@ cdef class TransitionSystem:
def preprocess_gold(self, GoldParse gold): def preprocess_gold(self, GoldParse gold):
raise NotImplementedError raise NotImplementedError
def is_gold_parse(self, StateClass state, GoldParse gold):
raise NotImplementedError
cdef Transition lookup_transition(self, object name) except *: cdef Transition lookup_transition(self, object name) except *:
raise NotImplementedError raise NotImplementedError