Fix memory leak in beam parser

This commit is contained in:
Matthew Honnibal 2017-11-14 02:11:40 +01:00
parent 86ddf692a1
commit 2512ea9eeb
7 changed files with 118 additions and 84 deletions

View File

@ -9,36 +9,31 @@ from thinc.typedefs cimport hash_t, class_t
from thinc.extra.search cimport MaxViolation from thinc.extra.search cimport MaxViolation
from .transition_system cimport TransitionSystem, Transition from .transition_system cimport TransitionSystem, Transition
from .stateclass cimport StateClass
from ..gold cimport GoldParse from ..gold cimport GoldParse
from .stateclass cimport StateC, StateClass
# These are passed as callbacks to thinc.search.Beam # These are passed as callbacks to thinc.search.Beam
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1: cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
dest = <StateClass>_dest dest = <StateC*>_dest
src = <StateClass>_src src = <StateC*>_src
moves = <const Transition*>_moves moves = <const Transition*>_moves
dest.clone(src) dest.clone(src)
moves[clas].do(dest.c, moves[clas].label) moves[clas].do(dest, moves[clas].label)
dest.c.push_hist(clas) dest.push_hist(clas)
cdef int _check_final_state(void* _state, void* extra_args) except -1: cdef int _check_final_state(void* _state, void* extra_args) except -1:
return (<StateClass>_state).is_final() state = <StateC*>_state
return state.is_final()
def _cleanup(Beam beam):
for i in range(beam.width):
Py_XDECREF(<PyObject*>beam._states[i].content)
Py_XDECREF(<PyObject*>beam._parents[i].content)
cdef hash_t _hash_state(void* _state, void* _) except 0: cdef hash_t _hash_state(void* _state, void* _) except 0:
state = <StateClass>_state state = <StateC*>_state
if state.c.is_final(): if state.is_final():
return 1 return 1
else: else:
return state.c.hash() return state.hash()
cdef class ParserBeam(object): cdef class ParserBeam(object):
@ -55,14 +50,15 @@ cdef class ParserBeam(object):
self.golds = golds self.golds = golds
self.beams = [] self.beams = []
cdef Beam beam cdef Beam beam
cdef StateClass state, st cdef StateClass state
cdef StateC* st
for state in states: for state in states:
beam = Beam(self.moves.n_moves, width, density) beam = Beam(self.moves.n_moves, width, density)
beam.initialize(self.moves.init_beam_state, state.c.length, beam.initialize(self.moves.init_beam_state, state.c.length,
state.c._sent) state.c._sent)
for i in range(beam.width): for i in range(beam.width):
st = <StateClass>beam.at(i) st = <StateC*>beam.at(i)
st.c.offset = state.c.offset st.offset = state.c.offset
self.beams.append(beam) self.beams.append(beam)
self.dones = [False] * len(self.beams) self.dones = [False] * len(self.beams)
@ -86,13 +82,14 @@ cdef class ParserBeam(object):
if self.golds is not None: if self.golds is not None:
self._set_costs(beam, self.golds[i], follow_gold=follow_gold) self._set_costs(beam, self.golds[i], follow_gold=follow_gold)
if follow_gold: if follow_gold:
beam.advance(_transition_state, NULL, <void*>self.moves.c) beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
else: else:
beam.advance(_transition_state, _hash_state, <void*>self.moves.c) beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
beam.check_done(_check_final_state, NULL) beam.check_done(_check_final_state, NULL)
# This handles the non-monotonic stuff for the parser.
if beam.is_done and self.golds is not None: if beam.is_done and self.golds is not None:
for j in range(beam.size): for j in range(beam.size):
state = <StateClass>beam.at(j) state = StateClass.borrow(<StateC*>beam.at(j))
if state.is_final(): if state.is_final():
try: try:
if self.moves.is_gold_parse(state, self.golds[i]): if self.moves.is_gold_parse(state, self.golds[i]):
@ -107,11 +104,11 @@ cdef class ParserBeam(object):
cdef int nr_state = min(scores.shape[0], beam.size) cdef int nr_state = min(scores.shape[0], beam.size)
cdef int nr_class = scores.shape[1] cdef int nr_class = scores.shape[1]
for i in range(nr_state): for i in range(nr_state):
state = <StateClass>beam.at(i) state = <StateC*>beam.at(i)
if not state.is_final(): if not state.is_final():
for j in range(nr_class): for j in range(nr_class):
beam.scores[i][j] = c_scores[i * nr_class + j] beam.scores[i][j] = c_scores[i * nr_class + j]
self.moves.set_valid(beam.is_valid[i], state.c) self.moves.set_valid(beam.is_valid[i], state)
else: else:
for j in range(beam.nr_class): for j in range(beam.nr_class):
beam.scores[i][j] = 0 beam.scores[i][j] = 0
@ -119,8 +116,8 @@ cdef class ParserBeam(object):
def _set_costs(self, Beam beam, GoldParse gold, int follow_gold=False): def _set_costs(self, Beam beam, GoldParse gold, int follow_gold=False):
for i in range(beam.size): for i in range(beam.size):
state = <StateClass>beam.at(i) state = StateClass.borrow(<StateC*>beam.at(i))
if not state.c.is_final(): if not state.is_final():
self.moves.set_costs(beam.is_valid[i], beam.costs[i], self.moves.set_costs(beam.is_valid[i], beam.costs[i],
state, gold) state, gold)
if follow_gold: if follow_gold:
@ -157,7 +154,7 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
pbeam = ParserBeam(moves, states, golds, pbeam = ParserBeam(moves, states, golds,
width=width, density=density) width=width, density=density)
gbeam = ParserBeam(moves, states, golds, gbeam = ParserBeam(moves, states, golds,
width=width, density=0.0) width=width, density=density)
cdef StateClass state cdef StateClass state
beam_maps = [] beam_maps = []
backprops = [] backprops = []
@ -231,7 +228,7 @@ def get_states(pbeams, gbeams, beam_map, nr_update):
p_indices.append([]) p_indices.append([])
g_indices.append([]) g_indices.append([])
for i in range(pbeam.size): for i in range(pbeam.size):
state = <StateClass>pbeam.at(i) state = StateClass.borrow(<StateC*>pbeam.at(i))
if not state.is_final(): if not state.is_final():
key = tuple([eg_id] + pbeam.histories[i]) key = tuple([eg_id] + pbeam.histories[i])
assert key not in seen, (key, seen) assert key not in seen, (key, seen)
@ -240,7 +237,7 @@ def get_states(pbeams, gbeams, beam_map, nr_update):
states.append(state) states.append(state)
beam_map.update(seen) beam_map.update(seen)
for i in range(gbeam.size): for i in range(gbeam.size):
state = <StateClass>gbeam.at(i) state = StateClass.borrow(<StateC*>gbeam.at(i))
if not state.is_final(): if not state.is_final():
key = tuple([eg_id] + gbeam.histories[i]) key = tuple([eg_id] + gbeam.histories[i])
if key in seen: if key in seen:

View File

@ -292,12 +292,16 @@ cdef int _get_root(int word, const GoldParseC* gold) nogil:
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL: cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
cdef StateClass st = StateClass.init(<const TokenC*>tokens, length) st = new StateC(<const TokenC*>tokens, length)
for i in range(st.c.length): for i in range(st.length):
st.c._sent[i].l_edge = i if st._sent[i].dep == 0:
st.c._sent[i].r_edge = i st._sent[i].l_edge = i
st._sent[i].r_edge = i
st._sent[i].head = 0
st._sent[i].dep = 0
st._sent[i].l_kids = 0
st._sent[i].r_kids = 0
st.fast_forward() st.fast_forward()
Py_INCREF(st)
return <void*>st return <void*>st
@ -533,18 +537,18 @@ cdef class ArcEager(TransitionSystem):
assert n_gold >= 1 assert n_gold >= 1
def get_beam_annot(self, Beam beam): def get_beam_annot(self, Beam beam):
length = (<StateClass>beam.at(0)).c.length length = (<StateC*>beam.at(0)).length
heads = [{} for _ in range(length)] heads = [{} for _ in range(length)]
deps = [{} for _ in range(length)] deps = [{} for _ in range(length)]
probs = beam.probs probs = beam.probs
for i in range(beam.size): for i in range(beam.size):
stcls = <StateClass>beam.at(i) state = <StateC*>beam.at(i)
self.finalize_state(stcls.c) self.finalize_state(state)
if stcls.is_final(): if state.is_final():
prob = probs[i] prob = probs[i]
for j in range(stcls.c.length): for j in range(state.length):
head = j + stcls.c._sent[j].head head = j + state._sent[j].head
dep = stcls.c._sent[j].dep dep = state._sent[j].dep
heads[j].setdefault(head, 0.0) heads[j].setdefault(head, 0.0)
heads[j][head] += prob heads[j][head] += prob
deps[j].setdefault(dep, 0.0) deps[j].setdefault(dep, 0.0)

View File

@ -123,14 +123,14 @@ cdef class BiluoPushDown(TransitionSystem):
entities = {} entities = {}
probs = beam.probs probs = beam.probs
for i in range(beam.size): for i in range(beam.size):
stcls = <StateClass>beam.at(i) state = <StateC*>beam.at(i)
if stcls.is_final(): if state.is_final():
self.finalize_state(stcls.c) self.finalize_state(state)
prob = probs[i] prob = probs[i]
for j in range(stcls.c._e_i): for j in range(state._e_i):
start = stcls.c._ents[j].start start = state._ents[j].start
end = stcls.c._ents[j].end end = state._ents[j].end
label = stcls.c._ents[j].label label = state._ents[j].label
entities.setdefault((start, end, label), 0.0) entities.setdefault((start, end, label), 0.0)
entities[(start, end, label)] += prob entities[(start, end, label)] += prob
return entities return entities
@ -139,15 +139,15 @@ cdef class BiluoPushDown(TransitionSystem):
parses = [] parses = []
probs = beam.probs probs = beam.probs
for i in range(beam.size): for i in range(beam.size):
stcls = <StateClass>beam.at(i) state = <StateC*>beam.at(i)
if stcls.is_final(): if state.is_final():
self.finalize_state(stcls.c) self.finalize_state(state)
prob = probs[i] prob = probs[i]
parse = [] parse = []
for j in range(stcls.c._e_i): for j in range(state._e_i):
start = stcls.c._ents[j].start start = state._ents[j].start
end = stcls.c._ents[j].end end = state._ents[j].end
label = stcls.c._ents[j].label label = state._ents[j].label
parse.append((start, end, self.strings[label])) parse.append((start, end, self.strings[label]))
parses.append((prob, parse)) parses.append((prob, parse))
return parses return parses

View File

@ -224,6 +224,16 @@ cdef void cpu_regression_loss(float* d_scores,
d_scores[i] = diff d_scores[i] = diff
def _collect_states(beams):
cdef StateClass state
cdef Beam beam
states = []
for beam in beams:
state = StateClass.borrow(<StateC*>beam.at(0))
states.append(state)
return states
cdef class Parser: cdef class Parser:
""" """
Base class of the DependencyParser and EntityRecognizer. Base class of the DependencyParser and EntityRecognizer.
@ -336,7 +346,7 @@ cdef class Parser:
beam_density=beam_density) beam_density=beam_density)
beam = beams[0] beam = beams[0]
output = self.moves.get_beam_annot(beam) output = self.moves.get_beam_annot(beam)
state = <StateClass>beam.at(0) state = StateClass.borrow(<StateC*>beam.at(0))
self.set_annotations([doc], [state], tensors=tokvecs) self.set_annotations([doc], [state], tensors=tokvecs)
_cleanup(beam) _cleanup(beam)
return output return output
@ -356,10 +366,10 @@ cdef class Parser:
if beam_density is None: if beam_density is None:
beam_density = self.cfg.get('beam_density', 0.0) beam_density = self.cfg.get('beam_density', 0.0)
cdef Doc doc cdef Doc doc
cdef Beam beam
for batch in cytoolz.partition_all(batch_size, docs): for batch in cytoolz.partition_all(batch_size, docs):
batch = list(batch) batch_in_order = list(batch)
by_length = sorted(list(batch), key=lambda doc: len(doc)) by_length = sorted(batch_in_order, key=lambda doc: len(doc))
batch_beams = []
for subbatch in cytoolz.partition_all(8, by_length): for subbatch in cytoolz.partition_all(8, by_length):
subbatch = list(subbatch) subbatch = list(subbatch)
if beam_width == 1: if beam_width == 1:
@ -369,21 +379,20 @@ cdef class Parser:
beams, tokvecs = self.beam_parse(subbatch, beams, tokvecs = self.beam_parse(subbatch,
beam_width=beam_width, beam_width=beam_width,
beam_density=beam_density) beam_density=beam_density)
parse_states = [] parse_states = _collect_states(beams)
for beam in beams: self.set_annotations(subbatch, parse_states, tensors=None)
parse_states.append(<StateClass>beam.at(0))
self.set_annotations(subbatch, parse_states, tensors=tokvecs)
yield from batch
for beam in beams: for beam in beams:
_cleanup(beam) _cleanup(beam)
for doc in batch_in_order:
yield doc
def parse_batch(self, docs): def parse_batch(self, docs):
cdef: cdef:
precompute_hiddens state2vec precompute_hiddens state2vec
StateClass stcls
Pool mem Pool mem
const float* feat_weights const float* feat_weights
StateC* st StateC* st
StateClass stcls
vector[StateC*] states vector[StateC*] states
int guess, nr_class, nr_feat, nr_piece, nr_dim, nr_state, nr_step int guess, nr_class, nr_feat, nr_piece, nr_dim, nr_state, nr_step
int j int j
@ -488,14 +497,14 @@ cdef class Parser:
beam = Beam(nr_class, beam_width, min_density=beam_density) beam = Beam(nr_class, beam_width, min_density=beam_density)
beam.initialize(self.moves.init_beam_state, doc.length, doc.c) beam.initialize(self.moves.init_beam_state, doc.length, doc.c)
for i in range(beam.width): for i in range(beam.width):
stcls = <StateClass>beam.at(i) state = <StateC*>beam.at(i)
stcls.c.offset = offset state.offset = offset
offset += len(doc) offset += len(doc)
beam.check_done(_check_final_state, NULL) beam.check_done(_check_final_state, NULL)
while not beam.is_done: while not beam.is_done:
states = [] states = []
for i in range(beam.size): for i in range(beam.size):
stcls = <StateClass>beam.at(i) stcls = StateClass.borrow(<StateC*>beam.at(i))
# This way we avoid having to score finalized states # This way we avoid having to score finalized states
# We do have to take care to keep indexes aligned, though # We do have to take care to keep indexes aligned, though
if not stcls.is_final(): if not stcls.is_final():
@ -511,9 +520,9 @@ cdef class Parser:
j = 0 j = 0
c_scores = <float*>scores.data c_scores = <float*>scores.data
for i in range(beam.size): for i in range(beam.size):
stcls = <StateClass>beam.at(i) state = <StateC*>beam.at(i)
if not stcls.is_final(): if not state.is_final():
self.moves.set_valid(beam.is_valid[i], stcls.c) self.moves.set_valid(beam.is_valid[i], state)
for k in range(nr_class): for k in range(nr_class):
beam.scores[i][k] = c_scores[j * scores.shape[1] + k] beam.scores[i][k] = c_scores[j * scores.shape[1] + k]
j += 1 j += 1
@ -965,27 +974,40 @@ cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) no
# These are passed as callbacks to thinc.search.Beam # These are passed as callbacks to thinc.search.Beam
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1: cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
dest = <StateClass>_dest dest = <StateC*>_dest
src = <StateClass>_src src = <StateC*>_src
moves = <const Transition*>_moves moves = <const Transition*>_moves
dest.clone(src) dest.clone(src)
moves[clas].do(dest.c, moves[clas].label) moves[clas].do(dest, moves[clas].label)
dest.c.push_hist(clas) dest.push_hist(clas)
cdef int _check_final_state(void* _state, void* extra_args) except -1: cdef int _check_final_state(void* _state, void* extra_args) except -1:
return (<StateClass>_state).is_final() state = <StateC*>_state
return state.is_final()
def _cleanup(Beam beam): def _cleanup(Beam beam):
cdef StateC* state
# Once parsing has finished, states in beam may not be unique. Is this
# correct?
seen = set()
for i in range(beam.width): for i in range(beam.width):
Py_XDECREF(<PyObject*>beam._states[i].content) addr = <size_t>beam._parents[i].content
Py_XDECREF(<PyObject*>beam._parents[i].content) if addr not in seen:
state = <StateC*>addr
del state
seen.add(addr)
addr = <size_t>beam._states[i].content
if addr not in seen:
state = <StateC*>addr
del state
seen.add(addr)
cdef hash_t _hash_state(void* _state, void* _) except 0: cdef hash_t _hash_state(void* _state, void* _) except 0:
state = <StateClass>_state state = <StateC*>_state
if state.c.is_final(): if state.is_final():
return 1 return 1
else: else:
return state.c.hash() return state.hash()

View File

@ -13,6 +13,7 @@ from ._state cimport StateC
cdef class StateClass: cdef class StateClass:
cdef Pool mem cdef Pool mem
cdef StateC* c cdef StateC* c
cdef int _borrowed
@staticmethod @staticmethod
cdef inline StateClass init(const TokenC* sent, int length): cdef inline StateClass init(const TokenC* sent, int length):
@ -20,6 +21,15 @@ cdef class StateClass:
self.c = new StateC(sent, length) self.c = new StateC(sent, length)
return self return self
@staticmethod
cdef inline StateClass borrow(StateC* ptr):
cdef StateClass self = StateClass()
del self.c
self.c = ptr
self._borrowed = 1
return self
@staticmethod @staticmethod
cdef inline StateClass init_offset(const TokenC* sent, int length, int cdef inline StateClass init_offset(const TokenC* sent, int length, int
offset): offset):

View File

@ -11,11 +11,13 @@ cdef class StateClass:
def __init__(self, Doc doc=None, int offset=0): def __init__(self, Doc doc=None, int offset=0):
cdef Pool mem = Pool() cdef Pool mem = Pool()
self.mem = mem self.mem = mem
self._borrowed = 0
if doc is not None: if doc is not None:
self.c = new StateC(doc.c, doc.length) self.c = new StateC(doc.c, doc.length)
self.c.offset = offset self.c.offset = offset
def __dealloc__(self): def __dealloc__(self):
if self._borrowed != 1:
del self.c del self.c
@property @property

View File

@ -23,8 +23,7 @@ class OracleError(Exception):
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL: cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
cdef StateClass st = StateClass.init(<const TokenC*>tokens, length) cdef StateC* st = new StateC(<const TokenC*>tokens, length)
Py_INCREF(st)
return <void*>st return <void*>st