This commit is contained in:
ines 2017-11-15 14:24:00 +01:00
commit 97a4f9362b
7 changed files with 155 additions and 115 deletions

View File

@ -9,36 +9,31 @@ from thinc.typedefs cimport hash_t, class_t
from thinc.extra.search cimport MaxViolation from thinc.extra.search cimport MaxViolation
from .transition_system cimport TransitionSystem, Transition from .transition_system cimport TransitionSystem, Transition
from .stateclass cimport StateClass
from ..gold cimport GoldParse from ..gold cimport GoldParse
from .stateclass cimport StateC, StateClass
# These are passed as callbacks to thinc.search.Beam # These are passed as callbacks to thinc.search.Beam
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1: cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
dest = <StateClass>_dest dest = <StateC*>_dest
src = <StateClass>_src src = <StateC*>_src
moves = <const Transition*>_moves moves = <const Transition*>_moves
dest.clone(src) dest.clone(src)
moves[clas].do(dest.c, moves[clas].label) moves[clas].do(dest, moves[clas].label)
dest.c.push_hist(clas) dest.push_hist(clas)
cdef int _check_final_state(void* _state, void* extra_args) except -1: cdef int _check_final_state(void* _state, void* extra_args) except -1:
return (<StateClass>_state).is_final() state = <StateC*>_state
return state.is_final()
def _cleanup(Beam beam):
for i in range(beam.width):
Py_XDECREF(<PyObject*>beam._states[i].content)
Py_XDECREF(<PyObject*>beam._parents[i].content)
cdef hash_t _hash_state(void* _state, void* _) except 0: cdef hash_t _hash_state(void* _state, void* _) except 0:
state = <StateClass>_state state = <StateC*>_state
if state.c.is_final(): if state.is_final():
return 1 return 1
else: else:
return state.c.hash() return state.hash()
cdef class ParserBeam(object): cdef class ParserBeam(object):
@ -55,14 +50,15 @@ cdef class ParserBeam(object):
self.golds = golds self.golds = golds
self.beams = [] self.beams = []
cdef Beam beam cdef Beam beam
cdef StateClass state, st cdef StateClass state
cdef StateC* st
for state in states: for state in states:
beam = Beam(self.moves.n_moves, width, density) beam = Beam(self.moves.n_moves, width, density)
beam.initialize(self.moves.init_beam_state, state.c.length, beam.initialize(self.moves.init_beam_state, state.c.length,
state.c._sent) state.c._sent)
for i in range(beam.width): for i in range(beam.width):
st = <StateClass>beam.at(i) st = <StateC*>beam.at(i)
st.c.offset = state.c.offset st.offset = state.c.offset
self.beams.append(beam) self.beams.append(beam)
self.dones = [False] * len(self.beams) self.dones = [False] * len(self.beams)
@ -85,14 +81,12 @@ cdef class ParserBeam(object):
self._set_scores(beam, scores[i]) self._set_scores(beam, scores[i])
if self.golds is not None: if self.golds is not None:
self._set_costs(beam, self.golds[i], follow_gold=follow_gold) self._set_costs(beam, self.golds[i], follow_gold=follow_gold)
if follow_gold:
beam.advance(_transition_state, NULL, <void*>self.moves.c) beam.advance(_transition_state, NULL, <void*>self.moves.c)
else:
beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
beam.check_done(_check_final_state, NULL) beam.check_done(_check_final_state, NULL)
# This handles the non-monotonic stuff for the parser.
if beam.is_done and self.golds is not None: if beam.is_done and self.golds is not None:
for j in range(beam.size): for j in range(beam.size):
state = <StateClass>beam.at(j) state = StateClass.borrow(<StateC*>beam.at(j))
if state.is_final(): if state.is_final():
try: try:
if self.moves.is_gold_parse(state, self.golds[i]): if self.moves.is_gold_parse(state, self.golds[i]):
@ -107,11 +101,11 @@ cdef class ParserBeam(object):
cdef int nr_state = min(scores.shape[0], beam.size) cdef int nr_state = min(scores.shape[0], beam.size)
cdef int nr_class = scores.shape[1] cdef int nr_class = scores.shape[1]
for i in range(nr_state): for i in range(nr_state):
state = <StateClass>beam.at(i) state = <StateC*>beam.at(i)
if not state.is_final(): if not state.is_final():
for j in range(nr_class): for j in range(nr_class):
beam.scores[i][j] = c_scores[i * nr_class + j] beam.scores[i][j] = c_scores[i * nr_class + j]
self.moves.set_valid(beam.is_valid[i], state.c) self.moves.set_valid(beam.is_valid[i], state)
else: else:
for j in range(beam.nr_class): for j in range(beam.nr_class):
beam.scores[i][j] = 0 beam.scores[i][j] = 0
@ -119,8 +113,8 @@ cdef class ParserBeam(object):
def _set_costs(self, Beam beam, GoldParse gold, int follow_gold=False): def _set_costs(self, Beam beam, GoldParse gold, int follow_gold=False):
for i in range(beam.size): for i in range(beam.size):
state = <StateClass>beam.at(i) state = StateClass.borrow(<StateC*>beam.at(i))
if not state.c.is_final(): if not state.is_final():
self.moves.set_costs(beam.is_valid[i], beam.costs[i], self.moves.set_costs(beam.is_valid[i], beam.costs[i],
state, gold) state, gold)
if follow_gold: if follow_gold:
@ -157,7 +151,7 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
pbeam = ParserBeam(moves, states, golds, pbeam = ParserBeam(moves, states, golds,
width=width, density=density) width=width, density=density)
gbeam = ParserBeam(moves, states, golds, gbeam = ParserBeam(moves, states, golds,
width=width, density=0.0) width=width, density=density)
cdef StateClass state cdef StateClass state
beam_maps = [] beam_maps = []
backprops = [] backprops = []
@ -231,7 +225,7 @@ def get_states(pbeams, gbeams, beam_map, nr_update):
p_indices.append([]) p_indices.append([])
g_indices.append([]) g_indices.append([])
for i in range(pbeam.size): for i in range(pbeam.size):
state = <StateClass>pbeam.at(i) state = StateClass.borrow(<StateC*>pbeam.at(i))
if not state.is_final(): if not state.is_final():
key = tuple([eg_id] + pbeam.histories[i]) key = tuple([eg_id] + pbeam.histories[i])
assert key not in seen, (key, seen) assert key not in seen, (key, seen)
@ -240,7 +234,7 @@ def get_states(pbeams, gbeams, beam_map, nr_update):
states.append(state) states.append(state)
beam_map.update(seen) beam_map.update(seen)
for i in range(gbeam.size): for i in range(gbeam.size):
state = <StateClass>gbeam.at(i) state = StateClass.borrow(<StateC*>gbeam.at(i))
if not state.is_final(): if not state.is_final():
key = tuple([eg_id] + gbeam.histories[i]) key = tuple([eg_id] + gbeam.histories[i])
if key in seen: if key in seen:

View File

@ -292,12 +292,16 @@ cdef int _get_root(int word, const GoldParseC* gold) nogil:
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL: cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
cdef StateClass st = StateClass.init(<const TokenC*>tokens, length) st = new StateC(<const TokenC*>tokens, length)
for i in range(st.c.length): for i in range(st.length):
st.c._sent[i].l_edge = i if st._sent[i].dep == 0:
st.c._sent[i].r_edge = i st._sent[i].l_edge = i
st._sent[i].r_edge = i
st._sent[i].head = 0
st._sent[i].dep = 0
st._sent[i].l_kids = 0
st._sent[i].r_kids = 0
st.fast_forward() st.fast_forward()
Py_INCREF(st)
return <void*>st return <void*>st
@ -533,18 +537,18 @@ cdef class ArcEager(TransitionSystem):
assert n_gold >= 1 assert n_gold >= 1
def get_beam_annot(self, Beam beam): def get_beam_annot(self, Beam beam):
length = (<StateClass>beam.at(0)).c.length length = (<StateC*>beam.at(0)).length
heads = [{} for _ in range(length)] heads = [{} for _ in range(length)]
deps = [{} for _ in range(length)] deps = [{} for _ in range(length)]
probs = beam.probs probs = beam.probs
for i in range(beam.size): for i in range(beam.size):
stcls = <StateClass>beam.at(i) state = <StateC*>beam.at(i)
self.finalize_state(stcls.c) self.finalize_state(state)
if stcls.is_final(): if state.is_final():
prob = probs[i] prob = probs[i]
for j in range(stcls.c.length): for j in range(state.length):
head = j + stcls.c._sent[j].head head = j + state._sent[j].head
dep = stcls.c._sent[j].dep dep = state._sent[j].dep
heads[j].setdefault(head, 0.0) heads[j].setdefault(head, 0.0)
heads[j][head] += prob heads[j][head] += prob
deps[j].setdefault(dep, 0.0) deps[j].setdefault(dep, 0.0)

View File

@ -123,14 +123,14 @@ cdef class BiluoPushDown(TransitionSystem):
entities = {} entities = {}
probs = beam.probs probs = beam.probs
for i in range(beam.size): for i in range(beam.size):
stcls = <StateClass>beam.at(i) state = <StateC*>beam.at(i)
if stcls.is_final(): if state.is_final():
self.finalize_state(stcls.c) self.finalize_state(state)
prob = probs[i] prob = probs[i]
for j in range(stcls.c._e_i): for j in range(state._e_i):
start = stcls.c._ents[j].start start = state._ents[j].start
end = stcls.c._ents[j].end end = state._ents[j].end
label = stcls.c._ents[j].label label = state._ents[j].label
entities.setdefault((start, end, label), 0.0) entities.setdefault((start, end, label), 0.0)
entities[(start, end, label)] += prob entities[(start, end, label)] += prob
return entities return entities
@ -139,15 +139,15 @@ cdef class BiluoPushDown(TransitionSystem):
parses = [] parses = []
probs = beam.probs probs = beam.probs
for i in range(beam.size): for i in range(beam.size):
stcls = <StateClass>beam.at(i) state = <StateC*>beam.at(i)
if stcls.is_final(): if state.is_final():
self.finalize_state(stcls.c) self.finalize_state(state)
prob = probs[i] prob = probs[i]
parse = [] parse = []
for j in range(stcls.c._e_i): for j in range(state._e_i):
start = stcls.c._ents[j].start start = state._ents[j].start
end = stcls.c._ents[j].end end = state._ents[j].end
label = stcls.c._ents[j].label label = state._ents[j].label
parse.append((start, end, self.strings[label])) parse.append((start, end, self.strings[label]))
parses.append((prob, parse)) parses.append((prob, parse))
return parses return parses

View File

@ -17,7 +17,7 @@ from cpython.ref cimport PyObject, Py_XDECREF
from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno
from libc.math cimport exp from libc.math cimport exp
from libcpp.vector cimport vector from libcpp.vector cimport vector
from libc.string cimport memset from libc.string cimport memset, memcpy
from libc.stdlib cimport calloc, free from libc.stdlib cimport calloc, free
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from thinc.typedefs cimport weight_t, class_t, hash_t from thinc.typedefs cimport weight_t, class_t, hash_t
@ -224,6 +224,16 @@ cdef void cpu_regression_loss(float* d_scores,
d_scores[i] = diff d_scores[i] = diff
def _collect_states(beams):
cdef StateClass state
cdef Beam beam
states = []
for beam in beams:
state = StateClass.borrow(<StateC*>beam.at(0))
states.append(state)
return states
cdef class Parser: cdef class Parser:
""" """
Base class of the DependencyParser and EntityRecognizer. Base class of the DependencyParser and EntityRecognizer.
@ -336,7 +346,7 @@ cdef class Parser:
beam_density=beam_density) beam_density=beam_density)
beam = beams[0] beam = beams[0]
output = self.moves.get_beam_annot(beam) output = self.moves.get_beam_annot(beam)
state = <StateClass>beam.at(0) state = StateClass.borrow(<StateC*>beam.at(0))
self.set_annotations([doc], [state], tensors=tokvecs) self.set_annotations([doc], [state], tensors=tokvecs)
_cleanup(beam) _cleanup(beam)
return output return output
@ -356,10 +366,10 @@ cdef class Parser:
if beam_density is None: if beam_density is None:
beam_density = self.cfg.get('beam_density', 0.0) beam_density = self.cfg.get('beam_density', 0.0)
cdef Doc doc cdef Doc doc
cdef Beam beam
for batch in cytoolz.partition_all(batch_size, docs): for batch in cytoolz.partition_all(batch_size, docs):
batch = list(batch) batch_in_order = list(batch)
by_length = sorted(list(batch), key=lambda doc: len(doc)) by_length = sorted(batch_in_order, key=lambda doc: len(doc))
batch_beams = []
for subbatch in cytoolz.partition_all(8, by_length): for subbatch in cytoolz.partition_all(8, by_length):
subbatch = list(subbatch) subbatch = list(subbatch)
if beam_width == 1: if beam_width == 1:
@ -369,21 +379,20 @@ cdef class Parser:
beams, tokvecs = self.beam_parse(subbatch, beams, tokvecs = self.beam_parse(subbatch,
beam_width=beam_width, beam_width=beam_width,
beam_density=beam_density) beam_density=beam_density)
parse_states = [] parse_states = _collect_states(beams)
for beam in beams: self.set_annotations(subbatch, parse_states, tensors=None)
parse_states.append(<StateClass>beam.at(0))
self.set_annotations(subbatch, parse_states, tensors=tokvecs)
yield from batch
for beam in beams: for beam in beams:
_cleanup(beam) _cleanup(beam)
for doc in batch_in_order:
yield doc
def parse_batch(self, docs): def parse_batch(self, docs):
cdef: cdef:
precompute_hiddens state2vec precompute_hiddens state2vec
StateClass stcls
Pool mem Pool mem
const float* feat_weights const float* feat_weights
StateC* st StateC* st
StateClass stcls
vector[StateC*] states vector[StateC*] states
int guess, nr_class, nr_feat, nr_piece, nr_dim, nr_state, nr_step int guess, nr_class, nr_feat, nr_piece, nr_dim, nr_state, nr_step
int j int j
@ -476,50 +485,59 @@ cdef class Parser:
cdef np.ndarray scores cdef np.ndarray scores
cdef Doc doc cdef Doc doc
cdef int nr_class = self.moves.n_moves cdef int nr_class = self.moves.n_moves
cdef StateClass stcls, output
cuda_stream = util.get_cuda_stream() cuda_stream = util.get_cuda_stream()
(tokvecs, bp_tokvecs), state2vec, vec2scores = self.get_batch_model( (tokvecs, bp_tokvecs), state2vec, vec2scores = self.get_batch_model(
docs, cuda_stream, 0.0) docs, cuda_stream, 0.0)
beams = []
cdef int offset = 0 cdef int offset = 0
cdef int j = 0 cdef int j = 0
cdef int k cdef int k
beams = []
for doc in docs: for doc in docs:
beam = Beam(nr_class, beam_width, min_density=beam_density) beam = Beam(nr_class, beam_width, min_density=beam_density)
beam.initialize(self.moves.init_beam_state, doc.length, doc.c) beam.initialize(self.moves.init_beam_state, doc.length, doc.c)
for i in range(beam.width): for i in range(beam.width):
stcls = <StateClass>beam.at(i) state = <StateC*>beam.at(i)
stcls.c.offset = offset state.offset = offset
offset += len(doc) offset += len(doc)
beam.check_done(_check_final_state, NULL) beam.check_done(_check_final_state, NULL)
while not beam.is_done: beams.append(beam)
states = [] cdef np.ndarray token_ids
token_ids = numpy.zeros((len(docs) * beam_width, self.nr_feature),
dtype='i', order='C')
todo = [beam for beam in beams if not beam.is_done]
cdef int* c_ids
cdef int nr_feature = self.nr_feature
cdef int n_states
while todo:
todo = [beam for beam in beams if not beam.is_done]
token_ids.fill(-1)
c_ids = <int*>token_ids.data
n_states = 0
for beam in todo:
for i in range(beam.size): for i in range(beam.size):
stcls = <StateClass>beam.at(i) state = <StateC*>beam.at(i)
# This way we avoid having to score finalized states # This way we avoid having to score finalized states
# We do have to take care to keep indexes aligned, though # We do have to take care to keep indexes aligned, though
if not stcls.is_final(): if not state.is_final():
states.append(stcls) state.set_context_tokens(c_ids, nr_feature)
token_ids = self.get_token_ids(states) c_ids += nr_feature
vectors = state2vec(token_ids) n_states += 1
if self.cfg.get('hist_size', 0): if n_states == 0:
hists = numpy.asarray([st.history[:self.cfg['hist_size']] break
for st in states], dtype='i') vectors = state2vec(token_ids[:n_states])
scores = vec2scores((vectors, hists))
else:
scores = vec2scores(vectors) scores = vec2scores(vectors)
j = 0
c_scores = <float*>scores.data c_scores = <float*>scores.data
for beam in todo:
for i in range(beam.size): for i in range(beam.size):
stcls = <StateClass>beam.at(i) state = <StateC*>beam.at(i)
if not stcls.is_final(): if not state.is_final():
self.moves.set_valid(beam.is_valid[i], stcls.c) self.moves.set_valid(beam.is_valid[i], state)
for k in range(nr_class): memcpy(beam.scores[i], c_scores, nr_class * sizeof(float))
beam.scores[i][k] = c_scores[j * scores.shape[1] + k] c_scores += nr_class
j += 1 beam.advance(_transition_state, NULL, <void*>self.moves.c)
beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
beam.check_done(_check_final_state, NULL) beam.check_done(_check_final_state, NULL)
beams.append(beam)
tokvecs = self.model[0].ops.unflatten(tokvecs, tokvecs = self.model[0].ops.unflatten(tokvecs,
[len(doc) for doc in docs]) [len(doc) for doc in docs])
return beams, tokvecs return beams, tokvecs
@ -527,7 +545,7 @@ cdef class Parser:
def update(self, docs, golds, drop=0., sgd=None, losses=None): def update(self, docs, golds, drop=0., sgd=None, losses=None):
if not any(self.moves.has_gold(gold) for gold in golds): if not any(self.moves.has_gold(gold) for gold in golds):
return None return None
if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() >= 0.5: if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() >= 0.0:
return self.update_beam(docs, golds, return self.update_beam(docs, golds,
self.cfg['beam_width'], self.cfg['beam_density'], self.cfg['beam_width'], self.cfg['beam_density'],
drop=drop, sgd=sgd, losses=losses) drop=drop, sgd=sgd, losses=losses)
@ -965,27 +983,40 @@ cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) no
# These are passed as callbacks to thinc.search.Beam # These are passed as callbacks to thinc.search.Beam
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1: cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
dest = <StateClass>_dest dest = <StateC*>_dest
src = <StateClass>_src src = <StateC*>_src
moves = <const Transition*>_moves moves = <const Transition*>_moves
dest.clone(src) dest.clone(src)
moves[clas].do(dest.c, moves[clas].label) moves[clas].do(dest, moves[clas].label)
dest.c.push_hist(clas) dest.push_hist(clas)
cdef int _check_final_state(void* _state, void* extra_args) except -1: cdef int _check_final_state(void* _state, void* extra_args) except -1:
return (<StateClass>_state).is_final() state = <StateC*>_state
return state.is_final()
def _cleanup(Beam beam): def _cleanup(Beam beam):
cdef StateC* state
# Once parsing has finished, states in beam may not be unique. Is this
# correct?
seen = set()
for i in range(beam.width): for i in range(beam.width):
Py_XDECREF(<PyObject*>beam._states[i].content) addr = <size_t>beam._parents[i].content
Py_XDECREF(<PyObject*>beam._parents[i].content) if addr not in seen:
state = <StateC*>addr
del state
cdef hash_t _hash_state(void* _state, void* _) except 0: seen.add(addr)
state = <StateClass>_state
if state.c.is_final():
return 1
else: else:
return state.c.hash() print(i, addr)
print(seen)
raise Exception
addr = <size_t>beam._states[i].content
if addr not in seen:
state = <StateC*>addr
del state
seen.add(addr)
else:
print(i, addr)
print(seen)
raise Exception

View File

@ -13,6 +13,7 @@ from ._state cimport StateC
cdef class StateClass: cdef class StateClass:
cdef Pool mem cdef Pool mem
cdef StateC* c cdef StateC* c
cdef int _borrowed
@staticmethod @staticmethod
cdef inline StateClass init(const TokenC* sent, int length): cdef inline StateClass init(const TokenC* sent, int length):
@ -20,6 +21,15 @@ cdef class StateClass:
self.c = new StateC(sent, length) self.c = new StateC(sent, length)
return self return self
@staticmethod
cdef inline StateClass borrow(StateC* ptr):
cdef StateClass self = StateClass()
del self.c
self.c = ptr
self._borrowed = 1
return self
@staticmethod @staticmethod
cdef inline StateClass init_offset(const TokenC* sent, int length, int cdef inline StateClass init_offset(const TokenC* sent, int length, int
offset): offset):

View File

@ -11,11 +11,13 @@ cdef class StateClass:
def __init__(self, Doc doc=None, int offset=0): def __init__(self, Doc doc=None, int offset=0):
cdef Pool mem = Pool() cdef Pool mem = Pool()
self.mem = mem self.mem = mem
self._borrowed = 0
if doc is not None: if doc is not None:
self.c = new StateC(doc.c, doc.length) self.c = new StateC(doc.c, doc.length)
self.c.offset = offset self.c.offset = offset
def __dealloc__(self): def __dealloc__(self):
if self._borrowed != 1:
del self.c del self.c
@property @property

View File

@ -23,8 +23,7 @@ class OracleError(Exception):
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL: cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
cdef StateClass st = StateClass.init(<const TokenC*>tokens, length) cdef StateC* st = new StateC(<const TokenC*>tokens, length)
Py_INCREF(st)
return <void*>st return <void*>st